mistralrs_core/vision_models/llama4/
mod.rs

1#![allow(clippy::cast_possible_truncation, clippy::cast_precision_loss)]
2
3mod text;
4
5use std::sync::Arc;
6
7use candle_core::{DType, Device, Result, Tensor, D};
8use candle_nn::{Linear, Module};
9use mistralrs_quant::{NonZeroOp, QuantMethod, ShardedVarBuilder};
10use text::TextModel;
11use vision::Llama4VisionModel;
12
13use crate::{
14    amoe::AnyMoeBaseModelMixin,
15    device_map::DeviceMapper,
16    layers::linear_no_bias,
17    paged_attention::{AttentionImplementation, ModelConfigMetadata},
18    pipeline::{
19        text_models_inputs_processor::{FlashParams, PagedAttentionInputMetadata},
20        EitherCache, IsqModel, NormalLoadingMetadata, NormalModel, VisionModel,
21    },
22    utils::unvarbuilder::UnVarBuilder,
23};
24
25mod config;
26mod inputs_processor;
27mod vision;
28
29pub(crate) use config::{Llama4Config, TextConfig};
30pub(crate) use inputs_processor::{Llama4ImageProcessor, Llama4Processor, IMAGE_TOKEN};
31
32struct Llama4MultiModalProjector {
33    linear_1: Linear,
34}
35
36impl Llama4MultiModalProjector {
37    fn new(cfg: &Llama4Config, vb: ShardedVarBuilder) -> Result<Self> {
38        Ok(Self {
39            linear_1: linear_no_bias(
40                cfg.vision_config.vision_output_dim,
41                cfg.text_config.hidden_size,
42                vb.pp("linear_1"),
43            )?,
44        })
45    }
46
47    fn forward(&self, xs: &Tensor) -> Result<Tensor> {
48        self.linear_1.forward(xs)
49    }
50}
51
52pub struct Llama4Model {
53    language_model: TextModel,
54    vision_model: Llama4VisionModel,
55    multi_modal_projector: Llama4MultiModalProjector,
56    image_token_index: usize,
57}
58
59impl Llama4Model {
60    pub fn new(
61        cfg: &Llama4Config,
62        vb: ShardedVarBuilder,
63        is_gptx: bool,
64        normal_loading_metadata: NormalLoadingMetadata,
65        attention_mechanism: AttentionImplementation,
66    ) -> Result<Self> {
67        let vision_model = Llama4VisionModel::new(
68            &cfg.vision_config,
69            vb.pp("vision_model"),
70            &normal_loading_metadata.real_device,
71            &normal_loading_metadata.mapper.get_comm_for(0)?,
72            &normal_loading_metadata.multi_progress,
73        )?;
74        let multi_modal_projector = Llama4MultiModalProjector::new(
75            cfg,
76            vb.pp("multi_modal_projector")
77                .set_device(normal_loading_metadata.real_device.clone()),
78        )?;
79        let language_model = TextModel::new(
80            &cfg.text_config,
81            vb.pp("language_model"),
82            is_gptx,
83            normal_loading_metadata,
84            attention_mechanism,
85        )?;
86
87        Ok(Self {
88            language_model,
89            vision_model,
90            multi_modal_projector,
91            image_token_index: cfg.image_token_index,
92        })
93    }
94
95    fn forward(
96        &self,
97        input_ids: &Tensor,
98        pixel_values: Option<Tensor>,
99        seqlen_offsets: &[usize],
100        context_lens: Vec<(usize, usize)>,
101        metadata: Option<(Vec<(Tensor, Tensor)>, &PagedAttentionInputMetadata)>,
102        flash_params: &FlashParams,
103    ) -> Result<Tensor> {
104        let mut input_embeds = self.language_model.get_input_embeddings(input_ids)?;
105
106        if let Some(pixel_values) = pixel_values {
107            let special_image_mask = input_ids
108                .eq(self.image_token_index as f64)?
109                .unsqueeze(D::Minus1)?
110                .broadcast_as(input_embeds.shape().clone())?
111                .to_dtype(DType::U32)?;
112
113            let mask_flat = special_image_mask.flatten_all()?;
114            // Nonzero before vision model to allow async processing all the way through logits.
115            let indices = mask_flat.nonzero()?.squeeze(1)?;
116
117            let image_features = self.vision_model.forward(&pixel_values)?;
118
119            let vision_flat = image_features.reshape(((), image_features.dim(D::Minus1)?))?;
120            let projected_vision_flat = self.multi_modal_projector.forward(&vision_flat)?;
121
122            let mut x_flat = input_embeds.flatten_all()?;
123            let src_flat = projected_vision_flat.flatten_all()?;
124
125            let current_vals = x_flat.gather(&indices, 0)?;
126            let diff = (src_flat - current_vals)?;
127            x_flat = x_flat.scatter_add(&indices, &diff, 0)?;
128
129            input_embeds = x_flat.reshape(input_embeds.shape())?;
130        }
131
132        self.language_model.forward_embeds(
133            input_ids,
134            input_embeds,
135            seqlen_offsets,
136            context_lens,
137            metadata,
138            flash_params,
139        )
140    }
141}
142
143impl IsqModel for Llama4Model {
144    fn get_layers(
145        &mut self,
146    ) -> (
147        Vec<(&mut Arc<dyn QuantMethod>, Option<usize>)>,
148        &dyn DeviceMapper,
149    ) {
150        let (mut layers, device_map) = self.language_model.get_layers();
151        layers.extend(
152            self.vision_model
153                .get_isq_layers()
154                .into_iter()
155                .map(|x| (x, None)),
156        );
157        (layers, device_map)
158    }
159
160    fn residual_tensors(&self) -> Vec<(String, Tensor)> {
161        let uvb = UnVarBuilder::new();
162
163        uvb.pp("multi_modal_projector")
164            .pp("linear_1")
165            .add(&self.multi_modal_projector.linear_1);
166        uvb.pp("language_model")
167            .extend(self.language_model.residual_tensors());
168        uvb.pp("vision_model")
169            .extend(self.vision_model.residual_tensors());
170
171        uvb.to_safetensors()
172    }
173}
174
175pub struct Llama4ModelSpecificArgs;
176
177impl NormalModel for Llama4Model {
178    fn forward(
179        &self,
180        input_ids: &Tensor,
181        seqlen_offsets: &[usize],
182        context_lens: Vec<(usize, usize)>,
183        _position_ids: Vec<usize>,
184        metadata: Option<(Vec<(Tensor, Tensor)>, &PagedAttentionInputMetadata)>,
185        flash_params: &FlashParams,
186    ) -> candle_core::Result<Tensor> {
187        self.forward(
188            input_ids,
189            None,
190            seqlen_offsets,
191            context_lens,
192            metadata,
193            flash_params,
194        )
195    }
196    fn xlora_forward(
197        &self,
198        _input_ids: &Tensor,
199        _input_ids_full: &Tensor,
200        _seqlen_offsets: &[usize],
201        _seqlen_offsets_full: &[usize],
202        _no_kv_cache: bool,
203        _non_granular_state: &Option<crate::xlora_models::NonGranularState>,
204        _context_lens: Vec<(usize, usize)>,
205        _position_ids: Vec<usize>,
206        _flash_params: &FlashParams,
207        _flash_params_full: &FlashParams,
208    ) -> Result<Tensor> {
209        unimplemented!()
210    }
211    fn cache(&self) -> &EitherCache {
212        self.language_model.cache()
213    }
214    fn cache_mut(&mut self) -> &mut EitherCache {
215        self.language_model.cache_mut()
216    }
217    fn config(&self) -> &ModelConfigMetadata {
218        self.language_model.config()
219    }
220    fn is_xlora(&self) -> bool {
221        false
222    }
223    fn device(&self) -> &Device {
224        self.language_model.device()
225    }
226    fn max_seq_len(&self) -> usize {
227        self.language_model.max_seq_len()
228    }
229}
230
231impl VisionModel for Llama4Model {
232    fn forward(
233        &self,
234        input_ids: &Tensor,
235        pixel_values: Option<Tensor>,
236        seqlen_offsets: &[usize],
237        context_lens: Vec<(usize, usize)>,
238        _position_ids: Vec<usize>,
239        model_specific_args: Box<dyn std::any::Any>,
240        metadata: Option<(Vec<(Tensor, Tensor)>, &PagedAttentionInputMetadata)>,
241        flash_params: &FlashParams,
242    ) -> candle_core::Result<Tensor> {
243        let Llama4ModelSpecificArgs = *model_specific_args
244            .downcast()
245            .expect("Cannot downcast into `Llama4ModelSpecificArgs`");
246        self.forward(
247            input_ids,
248            pixel_values,
249            seqlen_offsets,
250            context_lens,
251            metadata,
252            flash_params,
253        )
254    }
255    fn cache(&self) -> &EitherCache {
256        self.language_model.cache()
257    }
258    fn cache_mut(&mut self) -> &mut EitherCache {
259        self.language_model.cache_mut()
260    }
261    fn config(&self) -> &ModelConfigMetadata {
262        self.language_model.config()
263    }
264    fn device(&self) -> &Device {
265        self.language_model.device()
266    }
267    fn max_seq_len(&self) -> usize {
268        self.language_model.max_seq_len()
269    }
270    fn default_model_specific_args(&self, _input_ids: &Tensor) -> Box<dyn std::any::Any> {
271        Box::new(Llama4ModelSpecificArgs)
272    }
273}
274
275impl AnyMoeBaseModelMixin for Llama4Model {}