mistralrs_core/vision_models/llama4/
mod.rs

1#![allow(clippy::cast_possible_truncation, clippy::cast_precision_loss)]
2
3mod text;
4
5use std::sync::Arc;
6
7use candle_core::{DType, Device, Result, Tensor, D};
8use candle_nn::{Linear, Module};
9use mistralrs_quant::{QuantMethod, ShardedVarBuilder};
10use text::TextModel;
11use vision::Llama4VisionModel;
12
13use crate::{
14    amoe::AnyMoeBaseModelMixin,
15    device_map::DeviceMapper,
16    layers::linear_no_bias,
17    ops::NonZeroOp,
18    paged_attention::{AttentionImplementation, ModelConfigMetadata},
19    pipeline::{
20        text_models_inputs_processor::{FlashParams, PagedAttentionInputMetadata},
21        EitherCache, IsqModel, NormalLoadingMetadata, NormalModel, VisionModel,
22    },
23    utils::unvarbuilder::UnVarBuilder,
24};
25
26mod config;
27mod inputs_processor;
28mod vision;
29
30pub(crate) use config::{Llama4Config, TextConfig};
31pub(crate) use inputs_processor::{Llama4ImageProcessor, Llama4Processor, IMAGE_TOKEN};
32
33struct Llama4MultiModalProjector {
34    linear_1: Linear,
35}
36
37impl Llama4MultiModalProjector {
38    fn new(cfg: &Llama4Config, vb: ShardedVarBuilder) -> Result<Self> {
39        Ok(Self {
40            linear_1: linear_no_bias(
41                cfg.vision_config.vision_output_dim,
42                cfg.text_config.hidden_size,
43                vb.pp("linear_1"),
44            )?,
45        })
46    }
47
48    fn forward(&self, xs: &Tensor) -> Result<Tensor> {
49        self.linear_1.forward(xs)
50    }
51}
52
53pub struct Llama4Model {
54    language_model: TextModel,
55    vision_model: Llama4VisionModel,
56    multi_modal_projector: Llama4MultiModalProjector,
57    image_token_index: usize,
58}
59
60impl Llama4Model {
61    pub fn new(
62        cfg: &Llama4Config,
63        vb: ShardedVarBuilder,
64        is_gptx: bool,
65        normal_loading_metadata: NormalLoadingMetadata,
66        attention_mechanism: AttentionImplementation,
67    ) -> Result<Self> {
68        let vision_model = Llama4VisionModel::new(
69            &cfg.vision_config,
70            vb.pp("vision_model"),
71            &normal_loading_metadata.real_device,
72            &normal_loading_metadata.mapper.get_comm_for(0)?,
73            &normal_loading_metadata.multi_progress,
74        )?;
75        let multi_modal_projector = Llama4MultiModalProjector::new(
76            cfg,
77            vb.pp("multi_modal_projector")
78                .set_device(normal_loading_metadata.real_device.clone()),
79        )?;
80        let language_model = TextModel::new(
81            &cfg.text_config,
82            vb.pp("language_model"),
83            is_gptx,
84            normal_loading_metadata,
85            attention_mechanism,
86        )?;
87
88        Ok(Self {
89            language_model,
90            vision_model,
91            multi_modal_projector,
92            image_token_index: cfg.image_token_index,
93        })
94    }
95
96    fn forward(
97        &self,
98        input_ids: &Tensor,
99        pixel_values: Option<Tensor>,
100        seqlen_offsets: &[usize],
101        context_lens: Vec<(usize, usize)>,
102        metadata: Option<(Vec<(Tensor, Tensor)>, &PagedAttentionInputMetadata)>,
103        flash_params: &FlashParams,
104    ) -> Result<Tensor> {
105        let mut input_embeds = self.language_model.get_input_embeddings(input_ids)?;
106
107        if let Some(pixel_values) = pixel_values {
108            let image_features = self.vision_model.forward(&pixel_values)?;
109
110            let vision_flat = image_features.reshape(((), image_features.dim(D::Minus1)?))?;
111            let projected_vision_flat = self.multi_modal_projector.forward(&vision_flat)?;
112
113            let special_image_mask = input_ids
114                .eq(self.image_token_index as f64)?
115                .unsqueeze(D::Minus1)?
116                .broadcast_as(input_embeds.shape().clone())?
117                .to_dtype(DType::U32)?;
118
119            let mask_flat = special_image_mask.flatten_all()?;
120            let mut x_flat = input_embeds.flatten_all()?;
121            let src_flat = projected_vision_flat.flatten_all()?;
122
123            let indices = mask_flat.nonzero()?.squeeze(1)?;
124            let current_vals = x_flat.gather(&indices, 0)?;
125            let diff = (src_flat - current_vals)?;
126            x_flat = x_flat.scatter_add(&indices, &diff, 0)?;
127
128            input_embeds = x_flat.reshape(input_embeds.shape())?;
129        }
130
131        self.language_model.forward_embeds(
132            input_ids,
133            input_embeds,
134            seqlen_offsets,
135            context_lens,
136            metadata,
137            flash_params,
138        )
139    }
140}
141
142impl IsqModel for Llama4Model {
143    fn get_layers(
144        &mut self,
145    ) -> (
146        Vec<(&mut Arc<dyn QuantMethod>, Option<usize>)>,
147        &dyn DeviceMapper,
148    ) {
149        let (mut layers, device_map) = self.language_model.get_layers();
150        layers.extend(
151            self.vision_model
152                .get_isq_layers()
153                .into_iter()
154                .map(|x| (x, None)),
155        );
156        (layers, device_map)
157    }
158
159    fn residual_tensors(&self) -> Vec<(String, Tensor)> {
160        let uvb = UnVarBuilder::new();
161
162        uvb.pp("multi_modal_projector")
163            .pp("linear_1")
164            .add(&self.multi_modal_projector.linear_1);
165        uvb.pp("language_model")
166            .extend(self.language_model.residual_tensors());
167        uvb.pp("vision_model")
168            .extend(self.vision_model.residual_tensors());
169
170        uvb.to_safetensors()
171    }
172}
173
174pub struct Llama4ModelSpecificArgs;
175
176impl NormalModel for Llama4Model {
177    fn forward(
178        &self,
179        input_ids: &Tensor,
180        seqlen_offsets: &[usize],
181        context_lens: Vec<(usize, usize)>,
182        _position_ids: Vec<usize>,
183        metadata: Option<(Vec<(Tensor, Tensor)>, &PagedAttentionInputMetadata)>,
184        flash_params: &FlashParams,
185    ) -> candle_core::Result<Tensor> {
186        self.forward(
187            input_ids,
188            None,
189            seqlen_offsets,
190            context_lens,
191            metadata,
192            flash_params,
193        )
194    }
195    fn xlora_forward(
196        &self,
197        _input_ids: &Tensor,
198        _input_ids_full: &Tensor,
199        _seqlen_offsets: &[usize],
200        _seqlen_offsets_full: &[usize],
201        _no_kv_cache: bool,
202        _non_granular_state: &Option<crate::xlora_models::NonGranularState>,
203        _context_lens: Vec<(usize, usize)>,
204        _position_ids: Vec<usize>,
205        _flash_params: &FlashParams,
206        _flash_params_full: &FlashParams,
207    ) -> Result<Tensor> {
208        unimplemented!()
209    }
210    fn cache(&self) -> &EitherCache {
211        self.language_model.cache()
212    }
213    fn cache_mut(&mut self) -> &mut EitherCache {
214        self.language_model.cache_mut()
215    }
216    fn config(&self) -> &ModelConfigMetadata {
217        self.language_model.config()
218    }
219    fn is_xlora(&self) -> bool {
220        false
221    }
222    fn device(&self) -> &Device {
223        self.language_model.device()
224    }
225    fn max_seq_len(&self) -> usize {
226        self.language_model.max_seq_len()
227    }
228}
229
230impl VisionModel for Llama4Model {
231    fn forward(
232        &self,
233        input_ids: &Tensor,
234        pixel_values: Option<Tensor>,
235        seqlen_offsets: &[usize],
236        context_lens: Vec<(usize, usize)>,
237        _position_ids: Vec<usize>,
238        model_specific_args: Box<dyn std::any::Any>,
239        metadata: Option<(Vec<(Tensor, Tensor)>, &PagedAttentionInputMetadata)>,
240        flash_params: &FlashParams,
241    ) -> candle_core::Result<Tensor> {
242        let Llama4ModelSpecificArgs = *model_specific_args
243            .downcast()
244            .expect("Cannot downcast into `Llama4ModelSpecificArgs`");
245        self.forward(
246            input_ids,
247            pixel_values,
248            seqlen_offsets,
249            context_lens,
250            metadata,
251            flash_params,
252        )
253    }
254    fn cache(&self) -> &EitherCache {
255        self.language_model.cache()
256    }
257    fn cache_mut(&mut self) -> &mut EitherCache {
258        self.language_model.cache_mut()
259    }
260    fn config(&self) -> &ModelConfigMetadata {
261        self.language_model.config()
262    }
263    fn has_conv2d(&self) -> bool {
264        false
265    }
266    fn device(&self) -> &Device {
267        self.language_model.device()
268    }
269    fn max_seq_len(&self) -> usize {
270        self.language_model.max_seq_len()
271    }
272    fn default_model_specific_args(&self, _input_ids: &Tensor) -> Box<dyn std::any::Any> {
273        Box::new(Llama4ModelSpecificArgs)
274    }
275}
276
277impl AnyMoeBaseModelMixin for Llama4Model {}