mistralrs_core/vision_models/llava/
llava_next.rs

1#![allow(
2    clippy::cast_possible_truncation,
3    clippy::cast_precision_loss,
4    clippy::too_many_arguments
5)]
6use std::any::Any;
7
8use candle_core::{bail, DType, Device, IndexOp, Result, Tensor};
9use candle_nn::{Activation, Linear};
10use mistralrs_quant::ShardedVarBuilder;
11
12use crate::amoe::{AnyMoeBaseModelMixin, MlpLayer};
13use crate::device_map::DeviceMapper;
14use crate::ops::NonZeroOp;
15use crate::paged_attention::{AttentionImplementation, ModelConfigMetadata};
16use crate::pipeline::text_models_inputs_processor::{FlashParams, PagedAttentionInputMetadata};
17use crate::pipeline::IsqModel;
18use crate::pipeline::NormalLoadingMetadata;
19use crate::pipeline::VisionModel;
20
21use crate::utils::unvarbuilder::UnVarBuilder;
22use crate::vision_models::clip::{ClipConfig, ClipVisionTransformer};
23use crate::vision_models::llava::config::Config;
24use crate::vision_models::llava::utils::get_anyres_image_grid_shape;
25use crate::{layers, AnyMoeConfig, AnyMoeExpertType};
26
27use super::llava_llm::{LLaVALLM, Llama, Mistral};
28
29#[derive(Default)]
30pub(crate) struct LLaVANextVisionSpecificArgs {
31    pub image_sizes: Option<Vec<(usize, usize)>>, // width, height
32    pub num_image_tokens: Option<Vec<usize>>,     // number of image tokens for each image
33    pub num_image_samples: Option<Vec<usize>>,    // number of image samples for each image
34}
35
36pub struct MMProjector {
37    linear_1: Linear,
38    activation: Activation,
39    linear_2: Linear,
40}
41
42impl MMProjector {
43    pub fn new(vb: &ShardedVarBuilder, config: &Config, device: &Device) -> Result<Self> {
44        let linear_1 = layers::linear(
45            config.vision_config.hidden_size,
46            config.text_config.hidden_size,
47            vb.pp("multi_modal_projector.linear_1")
48                .set_device(device.clone()),
49        )?;
50        let activation = match config.projector_hidden_act.as_str() {
51            "gelu" => Activation::Gelu,
52            _ => {
53                bail!(
54                    "Unsupporg projector hidden act: {}",
55                    config.projector_hidden_act
56                );
57            }
58        };
59        let linear_2 = layers::linear(
60            config.text_config.hidden_size,
61            config.text_config.hidden_size,
62            vb.pp("multi_modal_projector.linear_2")
63                .set_device(device.clone()),
64        )?;
65        Ok(Self {
66            linear_1,
67            activation,
68            linear_2,
69        })
70    }
71
72    pub fn forward(&self, x: &Tensor) -> Result<Tensor> {
73        x.apply(&self.linear_1)?
74            .apply(&self.activation)?
75            .apply(&self.linear_2)
76    }
77}
78
79pub struct ClipVisionTower {
80    model: ClipVisionTransformer,
81    select_layer: isize,
82    select_feature_method: String,
83    config: ClipConfig,
84}
85
86impl ClipVisionTower {
87    pub fn new(
88        vb: ShardedVarBuilder,
89        select_layer: isize,
90        select_feature_method: &str,
91        config: &ClipConfig,
92    ) -> Result<Self> {
93        let model = ClipVisionTransformer::new(vb, config)?;
94        Ok(Self {
95            model,
96            select_layer,
97            select_feature_method: select_feature_method.to_string(),
98            config: config.clone(),
99        })
100    }
101
102    pub fn forward(&self, x: &Tensor) -> Result<Tensor> {
103        let result = self.model.forward_get_hidden_states(x)?;
104        let index = result.len() as isize + self.select_layer;
105        let result = result[index as usize].clone();
106        if self.select_feature_method == "cls_patch" || self.select_feature_method == "full" {
107            Ok(result)
108        } else {
109            result.i((.., 1..))
110        }
111    }
112
113    pub fn num_patches_per_side(&self) -> usize {
114        self.config.image_size / self.config.patch_size
115    }
116}
117
118pub struct Model {
119    clip_vision_tower: ClipVisionTower,
120    image_newline: Tensor,
121    mm_projector: MMProjector,
122    llm: Box<dyn LLaVALLM>,
123    config: Config,
124    device: Device,
125    dtype: DType,
126}
127
128impl Model {
129    pub fn new(
130        config: &Config,
131        vb: ShardedVarBuilder,
132        is_gptx: bool,
133        normal_loading_metadata: NormalLoadingMetadata,
134        attention_mechanism: AttentionImplementation,
135    ) -> Result<Self> {
136        let device = normal_loading_metadata.real_device.clone();
137        let dtype = vb.dtype();
138        let clip_config = config.to_clip_config();
139        let mm_projector = MMProjector::new(&vb, config, &device)?;
140        let clip_vision_tower = ClipVisionTower::new(
141            vb.pp("vision_tower.vision_model")
142                .set_device(device.clone()),
143            config.vision_feature_layer,
144            &config.vision_feature_select_strategy,
145            &clip_config,
146        )?;
147        let image_newline = vb
148            .get(&[config.text_config.hidden_size], "image_newline")?
149            .to_device(&device)?;
150
151        let llm: Box<dyn LLaVALLM> = match config.text_config.model_type.as_str() {
152            "llama" => {
153                let llama_config = config.to_llama_config();
154                let llama = Llama::new(
155                    &llama_config,
156                    vb.pp("language_model"),
157                    is_gptx,
158                    normal_loading_metadata,
159                    attention_mechanism,
160                )?;
161                Box::new(llama)
162            }
163            "mistral" => {
164                let mistral_config = config.to_mistral_config();
165                let mistral = Mistral::new(
166                    &mistral_config,
167                    vb.pp("language_model"),
168                    is_gptx,
169                    normal_loading_metadata,
170                    attention_mechanism,
171                )?;
172                Box::new(mistral)
173            }
174            _ => {
175                bail!("Unsupported model type: {}", config.text_config.model_type);
176            }
177        };
178        Ok(Self {
179            clip_vision_tower,
180            image_newline,
181            mm_projector,
182            llm,
183            config: config.clone(),
184            device,
185            dtype,
186        })
187    }
188
189    pub fn encode_images(&self, x: &Tensor) -> Result<Tensor> {
190        let mut image_features = self.clip_vision_tower.forward(x)?;
191        image_features = self.mm_projector.forward(&image_features)?;
192        Ok(image_features)
193    }
194
195    fn unpad_image(&self, tensor: &Tensor, original_size: (u32, u32)) -> Result<Tensor> {
196        assert_eq!(tensor.dims().len(), 3);
197        let (original_width, original_height) = original_size;
198        let tensor_dims = tensor.dims();
199        let current_height = tensor_dims[1];
200        let current_width = tensor_dims[2];
201        let original_aspect_ratio = (original_width as f32) / (original_height as f32);
202        let current_aspect_ratio = (current_width as f32) / (current_height as f32);
203        if original_aspect_ratio > current_aspect_ratio {
204            let scale_factor = (current_width as f32) / (original_width as f32);
205            let new_height = (original_height as f32 * scale_factor).floor() as usize;
206            let padding = (current_height - new_height) / 2;
207            tensor.i((.., padding..current_height - padding, ..))
208        } else {
209            let scale_factor = (current_height as f32) / (original_height as f32);
210            let new_width = (original_width as f32 * scale_factor).floor() as usize;
211            let padding = (current_width - new_width) / 2;
212            tensor.i((.., .., padding..current_width - padding))
213        }
214    }
215
216    pub fn prepare_inputs_labels_for_multimodal(
217        &self,
218        input_ids: &Tensor, //[1,seq_len]
219        images: &Tensor,    //[sum of samples of all images,channel,width,height]
220        num_image_tokens: Vec<usize>,
221        num_image_samples: Vec<usize>,
222        image_sizes: &[(u32, u32)],
223    ) -> Result<Tensor> {
224        let image_indexes = input_ids
225            .squeeze(0)?
226            .lt(0i64)?
227            .nonzero()?
228            .squeeze(1)?
229            .to_vec1::<u32>()?;
230        let mut result = input_ids.clamp(0i64, i64::MAX)?.to_dtype(DType::U32)?;
231        result = self.llm.embed(&result)?; //[seq_len,hidden_size]
232        let image_features = self.encode_images(&images.to_dtype(self.dtype)?)?; //[sum of samples of all images,patch_size*patch_size,hidden_size]
233        let mut image_features_vec = Vec::new();
234        let mut index = 0;
235        for num_image_sample in num_image_samples {
236            image_features_vec.push(image_features.i(index..index + num_image_sample)?);
237            index += num_image_sample;
238        }
239        let image_features_vec = image_features_vec
240            .iter()
241            .enumerate()
242            .map(|(image_idx, image_feature)| {
243                let base_image_feature = image_feature.get(0).unwrap();
244                let patch_image_feature = image_feature.i(1..).unwrap();
245                let height = self.clip_vision_tower.num_patches_per_side();
246                let width = height;
247                assert_eq!(height * width, base_image_feature.dims()[0]);
248                let image_size = image_sizes[image_idx];
249                let image_grid_pinpoints = self.config.image_grid_pinpoints.clone().unwrap();
250                let (num_patch_width, num_patch_height) = get_anyres_image_grid_shape(
251                    image_size,
252                    &image_grid_pinpoints,
253                    self.clip_vision_tower.config.image_size as u32,
254                );
255                let mut new_image_feature = patch_image_feature.reshape((
256                    num_patch_height as usize,
257                    num_patch_width as usize,
258                    height,
259                    width,
260                    (),
261                ))?;
262                new_image_feature = new_image_feature
263                    .permute((4, 0, 2, 1, 3))?
264                    .flatten(1, 2)?
265                    .flatten(2, 3)?;
266                new_image_feature = self.unpad_image(&new_image_feature, image_size)?;
267                let new_image_feature_dims = new_image_feature.dims();
268                let image_new_line = self
269                    .image_newline
270                    .reshape((self.config.text_config.hidden_size, 1, 1))?
271                    .broadcast_as((new_image_feature_dims[0], new_image_feature_dims[1], 1))?;
272                new_image_feature = Tensor::cat(&[new_image_feature, image_new_line], 2)?
273                    .flatten(1, 2)?
274                    .transpose(0, 1)?;
275                new_image_feature =
276                    Tensor::cat(&[base_image_feature, new_image_feature], 0)?.unsqueeze(0)?;
277                Ok(new_image_feature)
278            })
279            .collect::<Result<Vec<Tensor>>>()?;
280        for (i, image_index) in image_indexes.iter().enumerate() {
281            result = result.slice_assign(
282                &[
283                    &(0usize..1usize),
284                    &(*image_index as usize..*image_index as usize + num_image_tokens[i]),
285                    &(..),
286                ],
287                &image_features_vec[i],
288            )?;
289        }
290        //truncate
291        let (_, seq_len) = input_ids.shape().dims2()?;
292        if seq_len > self.config.text_config.max_length {
293            result = result.i((.., ..self.config.text_config.max_length, ..))?
294        }
295        Ok(result)
296    }
297
298    pub fn forward_inputs(
299        &self,
300        input_ids: &Tensor,
301        pixel_values: Option<Tensor>,
302        image_sizes: Option<Vec<(u32, u32)>>,
303        num_image_tokens: Option<Vec<usize>>,
304        num_image_samples: Option<Vec<usize>>,
305        seqlen_offsets: &[usize],
306        context_lens: Vec<(usize, usize)>,
307        position_ids: Vec<usize>,
308        metadata: Option<(Vec<(Tensor, Tensor)>, &PagedAttentionInputMetadata)>,
309        flash_params: &FlashParams,
310    ) -> Result<Tensor> {
311        if let Some(ref pixel_values) = pixel_values {
312            // we assume(as it should be) only prompt request contains image
313            let input_embeds = self.prepare_inputs_labels_for_multimodal(
314                input_ids,
315                pixel_values,
316                num_image_tokens.unwrap(),
317                num_image_samples.unwrap(),
318                &image_sizes.unwrap(),
319            )?;
320            self.llm.forward_input_embed(
321                input_ids,
322                input_embeds,
323                seqlen_offsets,
324                context_lens,
325                metadata,
326                flash_params,
327            )
328        } else {
329            self.llm.forward(
330                input_ids,
331                seqlen_offsets,
332                context_lens,
333                position_ids,
334                metadata,
335                flash_params,
336            )
337        }
338    }
339}
340
341impl IsqModel for Model {
342    fn get_layers(
343        &mut self,
344    ) -> (
345        Vec<(
346            &mut std::sync::Arc<dyn mistralrs_quant::QuantMethod>,
347            Option<usize>,
348        )>,
349        &dyn DeviceMapper,
350    ) {
351        self.llm.get_layers()
352    }
353
354    fn residual_tensors(&self) -> Vec<(String, Tensor)> {
355        let uvb = UnVarBuilder::new();
356
357        // MM projectors
358        uvb.pp("multi_modal_projector.linear_1")
359            .add(&self.mm_projector.linear_1);
360        uvb.pp("multi_modal_projector.linear_2")
361            .add(&self.mm_projector.linear_2);
362
363        // Vision tower
364        {
365            let uvb_vt = uvb.pp("vision_tower.vision_model");
366            uvb_vt.extend(self.clip_vision_tower.model.residual_tensors());
367        }
368
369        uvb.add_tensor("image_newline", self.image_newline.clone());
370
371        uvb.to_safetensors()
372    }
373}
374
375impl VisionModel for Model {
376    fn forward(
377        &self,
378        input_ids: &Tensor,
379        pixel_values: Option<Tensor>,
380        seqlen_offsets: &[usize],
381        context_lens: Vec<(usize, usize)>,
382        position_ids: Vec<usize>,
383        model_specific_args: Box<dyn std::any::Any>, // pixel attention mask, or image sizes, or anything else
384        metadata: Option<(Vec<(Tensor, Tensor)>, &PagedAttentionInputMetadata)>,
385        flash_params: &FlashParams,
386    ) -> candle_core::Result<Tensor> {
387        let LLaVANextVisionSpecificArgs {
388            image_sizes,
389            num_image_tokens,
390            num_image_samples,
391        } = *model_specific_args
392            .downcast()
393            .expect("Cannot downcast into `LLaVANextVisionSpecificArgs`");
394        let image_sizes = image_sizes.map(|image_sizes| {
395            image_sizes
396                .iter()
397                .map(|(w, h)| (*w as u32, *h as u32))
398                .collect::<Vec<_>>()
399        });
400        self.forward_inputs(
401            input_ids,
402            pixel_values,
403            image_sizes,
404            num_image_tokens,
405            num_image_samples,
406            seqlen_offsets,
407            context_lens,
408            position_ids,
409            metadata,
410            flash_params,
411        )
412    }
413
414    fn device(&self) -> &Device {
415        &self.device
416    }
417
418    fn cache(&self) -> &crate::pipeline::EitherCache {
419        self.llm.cache()
420    }
421    fn cache_mut(&mut self) -> &mut crate::pipeline::EitherCache {
422        self.llm.cache_mut()
423    }
424
425    fn max_seq_len(&self) -> usize {
426        self.config.text_config.max_length
427    }
428
429    fn has_conv2d(&self) -> bool {
430        true
431    }
432
433    fn config(&self) -> &ModelConfigMetadata {
434        self.llm.config()
435    }
436    fn default_model_specific_args(&self, _input_ids: &Tensor) -> Box<dyn Any> {
437        Box::new(LLaVANextVisionSpecificArgs::default())
438    }
439}
440
441impl AnyMoeBaseModelMixin for Model {
442    fn get_mlps(&self) -> Vec<&dyn MlpLayer> {
443        self.llm.get_mlps()
444    }
445    fn get_mlps_mut(&mut self) -> Vec<&mut Box<dyn MlpLayer>> {
446        self.llm.get_mlps_mut()
447    }
448    fn create_anymoe_layers(
449        &mut self,
450        additional_vbs: Vec<ShardedVarBuilder>,
451        config: AnyMoeConfig,
452        (prefix, mlp): (String, String),
453        layers: Vec<usize>,
454        expert_type: AnyMoeExpertType,
455        gate_vb: Option<ShardedVarBuilder>,
456    ) -> Result<()> {
457        self.llm.create_anymoe_layers(
458            additional_vbs,
459            config,
460            (prefix, mlp),
461            layers,
462            expert_type,
463            gate_vb,
464        )
465    }
466    fn amoe_supported(&self) -> bool {
467        true
468    }
469}