mistralrs_core/vision_models/llava/
llava_next.rs

1#![allow(
2    clippy::cast_possible_truncation,
3    clippy::cast_precision_loss,
4    clippy::too_many_arguments
5)]
6use std::any::Any;
7
8use candle_core::{bail, DType, Device, IndexOp, Result, Tensor};
9use candle_nn::{Activation, Linear};
10use mistralrs_quant::{NonZeroOp, ShardedVarBuilder};
11
12use crate::amoe::{AnyMoeBaseModelMixin, MlpLayer};
13use crate::device_map::DeviceMapper;
14use crate::paged_attention::{AttentionImplementation, ModelConfigMetadata};
15use crate::pipeline::text_models_inputs_processor::{FlashParams, PagedAttentionInputMetadata};
16use crate::pipeline::IsqModel;
17use crate::pipeline::NormalLoadingMetadata;
18use crate::pipeline::VisionModel;
19
20use crate::utils::unvarbuilder::UnVarBuilder;
21use crate::vision_models::clip::{ClipConfig, ClipVisionTransformer};
22use crate::vision_models::llava::config::Config;
23use crate::vision_models::llava::utils::get_anyres_image_grid_shape;
24use crate::{layers, AnyMoeConfig, AnyMoeExpertType};
25
26use super::llava_llm::{LLaVALLM, Llama, Mistral};
27
28#[derive(Default)]
29pub(crate) struct LLaVANextVisionSpecificArgs {
30    pub image_sizes: Option<Vec<(usize, usize)>>, // width, height
31    pub num_image_tokens: Option<Vec<usize>>,     // number of image tokens for each image
32    pub num_image_samples: Option<Vec<usize>>,    // number of image samples for each image
33}
34
35pub struct MMProjector {
36    linear_1: Linear,
37    activation: Activation,
38    linear_2: Linear,
39}
40
41impl MMProjector {
42    pub fn new(vb: &ShardedVarBuilder, config: &Config, device: &Device) -> Result<Self> {
43        let linear_1 = layers::linear(
44            config.vision_config.hidden_size,
45            config.text_config.hidden_size,
46            vb.pp("multi_modal_projector.linear_1")
47                .set_device(device.clone()),
48        )?;
49        let activation = match config.projector_hidden_act.as_str() {
50            "gelu" => Activation::Gelu,
51            _ => {
52                bail!(
53                    "Unsupporg projector hidden act: {}",
54                    config.projector_hidden_act
55                );
56            }
57        };
58        let linear_2 = layers::linear(
59            config.text_config.hidden_size,
60            config.text_config.hidden_size,
61            vb.pp("multi_modal_projector.linear_2")
62                .set_device(device.clone()),
63        )?;
64        Ok(Self {
65            linear_1,
66            activation,
67            linear_2,
68        })
69    }
70
71    pub fn forward(&self, x: &Tensor) -> Result<Tensor> {
72        x.apply(&self.linear_1)?
73            .apply(&self.activation)?
74            .apply(&self.linear_2)
75    }
76}
77
78pub struct ClipVisionTower {
79    model: ClipVisionTransformer,
80    select_layer: isize,
81    select_feature_method: String,
82    config: ClipConfig,
83}
84
85impl ClipVisionTower {
86    pub fn new(
87        vb: ShardedVarBuilder,
88        select_layer: isize,
89        select_feature_method: &str,
90        config: &ClipConfig,
91    ) -> Result<Self> {
92        let model = ClipVisionTransformer::new(vb, config)?;
93        Ok(Self {
94            model,
95            select_layer,
96            select_feature_method: select_feature_method.to_string(),
97            config: config.clone(),
98        })
99    }
100
101    pub fn forward(&self, x: &Tensor) -> Result<Tensor> {
102        let result = self.model.forward_get_hidden_states(x)?;
103        let index = result.len() as isize + self.select_layer;
104        let result = result[index as usize].clone();
105        if self.select_feature_method == "cls_patch" || self.select_feature_method == "full" {
106            Ok(result)
107        } else {
108            result.i((.., 1..))
109        }
110    }
111
112    pub fn num_patches_per_side(&self) -> usize {
113        self.config.image_size / self.config.patch_size
114    }
115}
116
117pub struct Model {
118    clip_vision_tower: ClipVisionTower,
119    image_newline: Tensor,
120    mm_projector: MMProjector,
121    llm: Box<dyn LLaVALLM>,
122    config: Config,
123    device: Device,
124    dtype: DType,
125}
126
127impl Model {
128    pub fn new(
129        config: &Config,
130        vb: ShardedVarBuilder,
131        is_gptx: bool,
132        normal_loading_metadata: NormalLoadingMetadata,
133        attention_mechanism: AttentionImplementation,
134    ) -> Result<Self> {
135        let device = normal_loading_metadata.real_device.clone();
136        let dtype = vb.dtype();
137        let clip_config = config.to_clip_config();
138        let mm_projector = MMProjector::new(&vb, config, &device)?;
139        let clip_vision_tower = ClipVisionTower::new(
140            vb.pp("vision_tower.vision_model")
141                .set_device(device.clone()),
142            config.vision_feature_layer,
143            &config.vision_feature_select_strategy,
144            &clip_config,
145        )?;
146        let image_newline = vb
147            .get(&[config.text_config.hidden_size], "image_newline")?
148            .to_device(&device)?;
149
150        let llm: Box<dyn LLaVALLM> = match config.text_config.model_type.as_str() {
151            "llama" => {
152                let llama_config = config.to_llama_config();
153                let llama = Llama::new(
154                    &llama_config,
155                    vb.pp("language_model"),
156                    is_gptx,
157                    normal_loading_metadata,
158                    attention_mechanism,
159                )?;
160                Box::new(llama)
161            }
162            "mistral" => {
163                let mistral_config = config.to_mistral_config();
164                let mistral = Mistral::new(
165                    &mistral_config,
166                    vb.pp("language_model"),
167                    is_gptx,
168                    normal_loading_metadata,
169                    attention_mechanism,
170                )?;
171                Box::new(mistral)
172            }
173            _ => {
174                bail!("Unsupported model type: {}", config.text_config.model_type);
175            }
176        };
177        Ok(Self {
178            clip_vision_tower,
179            image_newline,
180            mm_projector,
181            llm,
182            config: config.clone(),
183            device,
184            dtype,
185        })
186    }
187
188    pub fn encode_images(&self, x: &Tensor) -> Result<Tensor> {
189        let mut image_features = self.clip_vision_tower.forward(x)?;
190        image_features = self.mm_projector.forward(&image_features)?;
191        Ok(image_features)
192    }
193
194    fn unpad_image(&self, tensor: &Tensor, original_size: (u32, u32)) -> Result<Tensor> {
195        assert_eq!(tensor.dims().len(), 3);
196        let (original_width, original_height) = original_size;
197        let tensor_dims = tensor.dims();
198        let current_height = tensor_dims[1];
199        let current_width = tensor_dims[2];
200        let original_aspect_ratio = (original_width as f32) / (original_height as f32);
201        let current_aspect_ratio = (current_width as f32) / (current_height as f32);
202        if original_aspect_ratio > current_aspect_ratio {
203            let scale_factor = (current_width as f32) / (original_width as f32);
204            let new_height = (original_height as f32 * scale_factor).floor() as usize;
205            let padding = (current_height - new_height) / 2;
206            tensor.i((.., padding..current_height - padding, ..))
207        } else {
208            let scale_factor = (current_height as f32) / (original_height as f32);
209            let new_width = (original_width as f32 * scale_factor).floor() as usize;
210            let padding = (current_width - new_width) / 2;
211            tensor.i((.., .., padding..current_width - padding))
212        }
213    }
214
215    pub fn prepare_inputs_labels_for_multimodal(
216        &self,
217        input_ids: &Tensor, //[1,seq_len]
218        images: &Tensor,    //[sum of samples of all images,channel,width,height]
219        num_image_tokens: Vec<usize>,
220        num_image_samples: Vec<usize>,
221        image_sizes: &[(u32, u32)],
222    ) -> Result<Tensor> {
223        let image_indexes = input_ids
224            .squeeze(0)?
225            .lt(0i64)?
226            .nonzero()?
227            .squeeze(1)?
228            .to_vec1::<u32>()?;
229        let mut result = input_ids.clamp(0i64, i64::MAX)?.to_dtype(DType::U32)?;
230        result = self.llm.embed(&result)?; //[seq_len,hidden_size]
231        let image_features = self.encode_images(&images.to_dtype(self.dtype)?)?; //[sum of samples of all images,patch_size*patch_size,hidden_size]
232        let mut image_features_vec = Vec::new();
233        let mut index = 0;
234        for num_image_sample in num_image_samples {
235            image_features_vec.push(image_features.i(index..index + num_image_sample)?);
236            index += num_image_sample;
237        }
238        let image_features_vec = image_features_vec
239            .iter()
240            .enumerate()
241            .map(|(image_idx, image_feature)| {
242                let base_image_feature = image_feature.get(0).unwrap();
243                let patch_image_feature = image_feature.i(1..).unwrap();
244                let height = self.clip_vision_tower.num_patches_per_side();
245                let width = height;
246                assert_eq!(height * width, base_image_feature.dims()[0]);
247                let image_size = image_sizes[image_idx];
248                let image_grid_pinpoints = self.config.image_grid_pinpoints.clone().unwrap();
249                let (num_patch_width, num_patch_height) = get_anyres_image_grid_shape(
250                    image_size,
251                    &image_grid_pinpoints,
252                    self.clip_vision_tower.config.image_size as u32,
253                );
254                let mut new_image_feature = patch_image_feature.reshape((
255                    num_patch_height as usize,
256                    num_patch_width as usize,
257                    height,
258                    width,
259                    (),
260                ))?;
261                new_image_feature = new_image_feature
262                    .permute((4, 0, 2, 1, 3))?
263                    .flatten(1, 2)?
264                    .flatten(2, 3)?;
265                new_image_feature = self.unpad_image(&new_image_feature, image_size)?;
266                let new_image_feature_dims = new_image_feature.dims();
267                let image_new_line = self
268                    .image_newline
269                    .reshape((self.config.text_config.hidden_size, 1, 1))?
270                    .broadcast_as((new_image_feature_dims[0], new_image_feature_dims[1], 1))?;
271                new_image_feature = Tensor::cat(&[new_image_feature, image_new_line], 2)?
272                    .flatten(1, 2)?
273                    .transpose(0, 1)?;
274                new_image_feature =
275                    Tensor::cat(&[base_image_feature, new_image_feature], 0)?.unsqueeze(0)?;
276                Ok(new_image_feature)
277            })
278            .collect::<Result<Vec<Tensor>>>()?;
279        for (i, image_index) in image_indexes.iter().enumerate() {
280            result = result.slice_assign(
281                &[
282                    &(0usize..1usize),
283                    &(*image_index as usize..*image_index as usize + num_image_tokens[i]),
284                    &(..),
285                ],
286                &image_features_vec[i],
287            )?;
288        }
289        //truncate
290        let (_, seq_len) = input_ids.shape().dims2()?;
291        if seq_len > self.config.text_config.max_length {
292            result = result.i((.., ..self.config.text_config.max_length, ..))?
293        }
294        Ok(result)
295    }
296
297    pub fn forward_inputs(
298        &self,
299        input_ids: &Tensor,
300        pixel_values: Option<Tensor>,
301        image_sizes: Option<Vec<(u32, u32)>>,
302        num_image_tokens: Option<Vec<usize>>,
303        num_image_samples: Option<Vec<usize>>,
304        seqlen_offsets: &[usize],
305        context_lens: Vec<(usize, usize)>,
306        position_ids: Vec<usize>,
307        metadata: Option<(Vec<(Tensor, Tensor)>, &PagedAttentionInputMetadata)>,
308        flash_params: &FlashParams,
309    ) -> Result<Tensor> {
310        if let Some(ref pixel_values) = pixel_values {
311            // we assume(as it should be) only prompt request contains image
312            let input_embeds = self.prepare_inputs_labels_for_multimodal(
313                input_ids,
314                pixel_values,
315                num_image_tokens.unwrap(),
316                num_image_samples.unwrap(),
317                &image_sizes.unwrap(),
318            )?;
319            self.llm.forward_input_embed(
320                input_ids,
321                input_embeds,
322                seqlen_offsets,
323                context_lens,
324                metadata,
325                flash_params,
326            )
327        } else {
328            self.llm.forward(
329                input_ids,
330                seqlen_offsets,
331                context_lens,
332                position_ids,
333                metadata,
334                flash_params,
335            )
336        }
337    }
338}
339
340impl IsqModel for Model {
341    fn get_layers(
342        &mut self,
343    ) -> (
344        Vec<(
345            &mut std::sync::Arc<dyn mistralrs_quant::QuantMethod>,
346            Option<usize>,
347        )>,
348        &dyn DeviceMapper,
349    ) {
350        self.llm.get_layers()
351    }
352
353    fn residual_tensors(&self) -> Vec<(String, Tensor)> {
354        let uvb = UnVarBuilder::new();
355
356        // MM projectors
357        uvb.pp("multi_modal_projector.linear_1")
358            .add(&self.mm_projector.linear_1);
359        uvb.pp("multi_modal_projector.linear_2")
360            .add(&self.mm_projector.linear_2);
361
362        // Vision tower
363        {
364            let uvb_vt = uvb.pp("vision_tower.vision_model");
365            uvb_vt.extend(self.clip_vision_tower.model.residual_tensors());
366        }
367
368        uvb.add_tensor("image_newline", self.image_newline.clone());
369
370        uvb.to_safetensors()
371    }
372}
373
374impl VisionModel for Model {
375    fn forward(
376        &self,
377        input_ids: &Tensor,
378        pixel_values: Option<Tensor>,
379        seqlen_offsets: &[usize],
380        context_lens: Vec<(usize, usize)>,
381        position_ids: Vec<usize>,
382        model_specific_args: Box<dyn std::any::Any>, // pixel attention mask, or image sizes, or anything else
383        metadata: Option<(Vec<(Tensor, Tensor)>, &PagedAttentionInputMetadata)>,
384        flash_params: &FlashParams,
385    ) -> candle_core::Result<Tensor> {
386        let LLaVANextVisionSpecificArgs {
387            image_sizes,
388            num_image_tokens,
389            num_image_samples,
390        } = *model_specific_args
391            .downcast()
392            .expect("Cannot downcast into `LLaVANextVisionSpecificArgs`");
393        let image_sizes = image_sizes.map(|image_sizes| {
394            image_sizes
395                .iter()
396                .map(|(w, h)| (*w as u32, *h as u32))
397                .collect::<Vec<_>>()
398        });
399        self.forward_inputs(
400            input_ids,
401            pixel_values,
402            image_sizes,
403            num_image_tokens,
404            num_image_samples,
405            seqlen_offsets,
406            context_lens,
407            position_ids,
408            metadata,
409            flash_params,
410        )
411    }
412
413    fn device(&self) -> &Device {
414        &self.device
415    }
416
417    fn cache(&self) -> &crate::pipeline::EitherCache {
418        self.llm.cache()
419    }
420    fn cache_mut(&mut self) -> &mut crate::pipeline::EitherCache {
421        self.llm.cache_mut()
422    }
423
424    fn max_seq_len(&self) -> usize {
425        self.config.text_config.max_length
426    }
427
428    fn has_conv2d(&self) -> bool {
429        true
430    }
431
432    fn config(&self) -> &ModelConfigMetadata {
433        self.llm.config()
434    }
435    fn default_model_specific_args(&self, _input_ids: &Tensor) -> Box<dyn Any> {
436        Box::new(LLaVANextVisionSpecificArgs::default())
437    }
438}
439
440impl AnyMoeBaseModelMixin for Model {
441    fn get_mlps(&self) -> Vec<&dyn MlpLayer> {
442        self.llm.get_mlps()
443    }
444    fn get_mlps_mut(&mut self) -> Vec<&mut Box<dyn MlpLayer>> {
445        self.llm.get_mlps_mut()
446    }
447    fn create_anymoe_layers(
448        &mut self,
449        additional_vbs: Vec<ShardedVarBuilder>,
450        config: AnyMoeConfig,
451        (prefix, mlp): (String, String),
452        layers: Vec<usize>,
453        expert_type: AnyMoeExpertType,
454        gate_vb: Option<ShardedVarBuilder>,
455    ) -> Result<()> {
456        self.llm.create_anymoe_layers(
457            additional_vbs,
458            config,
459            (prefix, mlp),
460            layers,
461            expert_type,
462            gate_vb,
463        )
464    }
465    fn amoe_supported(&self) -> bool {
466        true
467    }
468}