mistralrs_core/vision_models/llama4/
mod.rs1#![allow(clippy::cast_possible_truncation, clippy::cast_precision_loss)]
2
3mod text;
4
5use std::sync::Arc;
6
7use candle_core::{DType, Device, Result, Tensor, D};
8use candle_nn::{Linear, Module};
9use mistralrs_quant::{NonZeroOp, QuantMethod, ShardedVarBuilder};
10use text::TextModel;
11use vision::Llama4VisionModel;
12
13use crate::{
14 amoe::AnyMoeBaseModelMixin,
15 device_map::DeviceMapper,
16 layers::linear_no_bias,
17 paged_attention::{AttentionImplementation, ModelConfigMetadata},
18 pipeline::{
19 text_models_inputs_processor::{FlashParams, PagedAttentionInputMetadata},
20 EitherCache, IsqModel, NormalLoadingMetadata, NormalModel, VisionModel,
21 },
22 utils::unvarbuilder::UnVarBuilder,
23};
24
25mod config;
26mod inputs_processor;
27mod vision;
28
29pub(crate) use config::{Llama4Config, TextConfig};
30pub(crate) use inputs_processor::{Llama4ImageProcessor, Llama4Processor, IMAGE_TOKEN};
31
32struct Llama4MultiModalProjector {
33 linear_1: Linear,
34}
35
36impl Llama4MultiModalProjector {
37 fn new(cfg: &Llama4Config, vb: ShardedVarBuilder) -> Result<Self> {
38 Ok(Self {
39 linear_1: linear_no_bias(
40 cfg.vision_config.vision_output_dim,
41 cfg.text_config.hidden_size,
42 vb.pp("linear_1"),
43 )?,
44 })
45 }
46
47 fn forward(&self, xs: &Tensor) -> Result<Tensor> {
48 self.linear_1.forward(xs)
49 }
50}
51
52pub struct Llama4Model {
53 language_model: TextModel,
54 vision_model: Llama4VisionModel,
55 multi_modal_projector: Llama4MultiModalProjector,
56 image_token_index: usize,
57}
58
59impl Llama4Model {
60 pub fn new(
61 cfg: &Llama4Config,
62 vb: ShardedVarBuilder,
63 is_gptx: bool,
64 normal_loading_metadata: NormalLoadingMetadata,
65 attention_mechanism: AttentionImplementation,
66 ) -> Result<Self> {
67 let vision_model = Llama4VisionModel::new(
68 &cfg.vision_config,
69 vb.pp("vision_model"),
70 &normal_loading_metadata.real_device,
71 &normal_loading_metadata.mapper.get_comm_for(0)?,
72 &normal_loading_metadata.multi_progress,
73 )?;
74 let multi_modal_projector = Llama4MultiModalProjector::new(
75 cfg,
76 vb.pp("multi_modal_projector")
77 .set_device(normal_loading_metadata.real_device.clone()),
78 )?;
79 let language_model = TextModel::new(
80 &cfg.text_config,
81 vb.pp("language_model"),
82 is_gptx,
83 normal_loading_metadata,
84 attention_mechanism,
85 )?;
86
87 Ok(Self {
88 language_model,
89 vision_model,
90 multi_modal_projector,
91 image_token_index: cfg.image_token_index,
92 })
93 }
94
95 fn forward(
96 &self,
97 input_ids: &Tensor,
98 pixel_values: Option<Tensor>,
99 seqlen_offsets: &[usize],
100 context_lens: Vec<(usize, usize)>,
101 metadata: Option<(Vec<(Tensor, Tensor)>, &PagedAttentionInputMetadata)>,
102 flash_params: &FlashParams,
103 ) -> Result<Tensor> {
104 let mut input_embeds = self.language_model.get_input_embeddings(input_ids)?;
105
106 if let Some(pixel_values) = pixel_values {
107 let special_image_mask = input_ids
108 .eq(self.image_token_index as f64)?
109 .unsqueeze(D::Minus1)?
110 .broadcast_as(input_embeds.shape().clone())?
111 .to_dtype(DType::U32)?;
112
113 let mask_flat = special_image_mask.flatten_all()?;
114 let indices = mask_flat.nonzero()?.squeeze(1)?;
116
117 let image_features = self.vision_model.forward(&pixel_values)?;
118
119 let vision_flat = image_features.reshape(((), image_features.dim(D::Minus1)?))?;
120 let projected_vision_flat = self.multi_modal_projector.forward(&vision_flat)?;
121
122 let mut x_flat = input_embeds.flatten_all()?;
123 let src_flat = projected_vision_flat.flatten_all()?;
124
125 let current_vals = x_flat.gather(&indices, 0)?;
126 let diff = (src_flat - current_vals)?;
127 x_flat = x_flat.scatter_add(&indices, &diff, 0)?;
128
129 input_embeds = x_flat.reshape(input_embeds.shape())?;
130 }
131
132 self.language_model.forward_embeds(
133 input_ids,
134 input_embeds,
135 seqlen_offsets,
136 context_lens,
137 metadata,
138 flash_params,
139 )
140 }
141}
142
143impl IsqModel for Llama4Model {
144 fn get_layers(
145 &mut self,
146 ) -> (
147 Vec<(&mut Arc<dyn QuantMethod>, Option<usize>)>,
148 &dyn DeviceMapper,
149 ) {
150 let (mut layers, device_map) = self.language_model.get_layers();
151 layers.extend(
152 self.vision_model
153 .get_isq_layers()
154 .into_iter()
155 .map(|x| (x, None)),
156 );
157 (layers, device_map)
158 }
159
160 fn residual_tensors(&self) -> Vec<(String, Tensor)> {
161 let uvb = UnVarBuilder::new();
162
163 uvb.pp("multi_modal_projector")
164 .pp("linear_1")
165 .add(&self.multi_modal_projector.linear_1);
166 uvb.pp("language_model")
167 .extend(self.language_model.residual_tensors());
168 uvb.pp("vision_model")
169 .extend(self.vision_model.residual_tensors());
170
171 uvb.to_safetensors()
172 }
173}
174
175pub struct Llama4ModelSpecificArgs;
176
177impl NormalModel for Llama4Model {
178 fn forward(
179 &self,
180 input_ids: &Tensor,
181 seqlen_offsets: &[usize],
182 context_lens: Vec<(usize, usize)>,
183 _position_ids: Vec<usize>,
184 metadata: Option<(Vec<(Tensor, Tensor)>, &PagedAttentionInputMetadata)>,
185 flash_params: &FlashParams,
186 ) -> candle_core::Result<Tensor> {
187 self.forward(
188 input_ids,
189 None,
190 seqlen_offsets,
191 context_lens,
192 metadata,
193 flash_params,
194 )
195 }
196 fn xlora_forward(
197 &self,
198 _input_ids: &Tensor,
199 _input_ids_full: &Tensor,
200 _seqlen_offsets: &[usize],
201 _seqlen_offsets_full: &[usize],
202 _no_kv_cache: bool,
203 _non_granular_state: &Option<crate::xlora_models::NonGranularState>,
204 _context_lens: Vec<(usize, usize)>,
205 _position_ids: Vec<usize>,
206 _flash_params: &FlashParams,
207 _flash_params_full: &FlashParams,
208 ) -> Result<Tensor> {
209 unimplemented!()
210 }
211 fn cache(&self) -> &EitherCache {
212 self.language_model.cache()
213 }
214 fn cache_mut(&mut self) -> &mut EitherCache {
215 self.language_model.cache_mut()
216 }
217 fn config(&self) -> &ModelConfigMetadata {
218 self.language_model.config()
219 }
220 fn is_xlora(&self) -> bool {
221 false
222 }
223 fn device(&self) -> &Device {
224 self.language_model.device()
225 }
226 fn max_seq_len(&self) -> usize {
227 self.language_model.max_seq_len()
228 }
229}
230
231impl VisionModel for Llama4Model {
232 fn forward(
233 &self,
234 input_ids: &Tensor,
235 pixel_values: Option<Tensor>,
236 seqlen_offsets: &[usize],
237 context_lens: Vec<(usize, usize)>,
238 _position_ids: Vec<usize>,
239 model_specific_args: Box<dyn std::any::Any>,
240 metadata: Option<(Vec<(Tensor, Tensor)>, &PagedAttentionInputMetadata)>,
241 flash_params: &FlashParams,
242 ) -> candle_core::Result<Tensor> {
243 let Llama4ModelSpecificArgs = *model_specific_args
244 .downcast()
245 .expect("Cannot downcast into `Llama4ModelSpecificArgs`");
246 self.forward(
247 input_ids,
248 pixel_values,
249 seqlen_offsets,
250 context_lens,
251 metadata,
252 flash_params,
253 )
254 }
255 fn cache(&self) -> &EitherCache {
256 self.language_model.cache()
257 }
258 fn cache_mut(&mut self) -> &mut EitherCache {
259 self.language_model.cache_mut()
260 }
261 fn config(&self) -> &ModelConfigMetadata {
262 self.language_model.config()
263 }
264 fn device(&self) -> &Device {
265 self.language_model.device()
266 }
267 fn max_seq_len(&self) -> usize {
268 self.language_model.max_seq_len()
269 }
270 fn default_model_specific_args(&self, _input_ids: &Tensor) -> Box<dyn std::any::Any> {
271 Box::new(Llama4ModelSpecificArgs)
272 }
273}
274
275impl AnyMoeBaseModelMixin for Llama4Model {}