mistralrs_core/vision_models/llava/
llava15.rs

1#![allow(
2    clippy::cast_possible_truncation,
3    clippy::cast_precision_loss,
4    clippy::too_many_arguments
5)]
6use std::any::Any;
7
8use super::llava_llm::{LLaVALLM, Llama, Mistral};
9use crate::amoe::AnyMoeBaseModelMixin;
10use crate::amoe::MlpLayer;
11use crate::device_map::DeviceMapper;
12use crate::layers;
13use crate::ops::NonZeroOp;
14use crate::paged_attention::{AttentionImplementation, ModelConfigMetadata};
15use crate::pipeline::text_models_inputs_processor::FlashParams;
16use crate::pipeline::text_models_inputs_processor::PagedAttentionInputMetadata;
17use crate::pipeline::IsqModel;
18use crate::pipeline::NormalLoadingMetadata;
19use crate::pipeline::VisionModel;
20use crate::utils::unvarbuilder::UnVarBuilder;
21use crate::vision_models::clip::{ClipConfig, ClipVisionTransformer};
22use crate::vision_models::llava::config::Config;
23use crate::AnyMoeConfig;
24use crate::AnyMoeExpertType;
25use candle_core::{bail, DType, Device, IndexOp, Result, Tensor};
26use candle_nn::{Activation, Linear};
27use mistralrs_quant::ShardedVarBuilder;
28
29pub(crate) struct LLaVAVisionSpecificArgs; // only a dumb struct to satisfy the trait
30
31pub struct MMProjector {
32    linear_1: Linear,
33    activation: Activation,
34    linear_2: Linear,
35}
36
37impl MMProjector {
38    pub fn new(vb: &ShardedVarBuilder, config: &Config, device: &Device) -> Result<Self> {
39        let linear_1 = layers::linear(
40            config.vision_config.hidden_size,
41            config.text_config.hidden_size,
42            vb.pp("multi_modal_projector.linear_1")
43                .set_device(device.clone()),
44        )?;
45        let activation = match config.projector_hidden_act.as_str() {
46            "gelu" => Activation::Gelu,
47            _ => {
48                bail!(
49                    "Unsupporg projector hidden act: {}",
50                    config.projector_hidden_act
51                );
52            }
53        };
54        let linear_2 = layers::linear(
55            config.text_config.hidden_size,
56            config.text_config.hidden_size,
57            vb.pp("multi_modal_projector.linear_2")
58                .set_device(device.clone()),
59        )?;
60        Ok(Self {
61            linear_1,
62            activation,
63            linear_2,
64        })
65    }
66
67    pub fn forward(&self, x: &Tensor) -> Result<Tensor> {
68        x.apply(&self.linear_1)?
69            .apply(&self.activation)?
70            .apply(&self.linear_2)
71    }
72}
73
74pub struct ClipVisionTower {
75    model: ClipVisionTransformer,
76    select_layer: isize,
77    select_feature_method: String,
78    config: ClipConfig,
79}
80
81impl ClipVisionTower {
82    pub fn new(
83        vb: ShardedVarBuilder,
84        select_layer: isize,
85        select_feature_method: &str,
86        config: &ClipConfig,
87    ) -> Result<Self> {
88        let model = ClipVisionTransformer::new(vb, config)?;
89        Ok(Self {
90            model,
91            select_layer,
92            select_feature_method: select_feature_method.to_string(),
93            config: config.clone(),
94        })
95    }
96
97    pub fn forward(&self, x: &Tensor) -> Result<Tensor> {
98        let result = self.model.forward_get_hidden_states(x)?;
99        let index = result.len() as isize + self.select_layer;
100        let result = result[index as usize].clone();
101        if self.select_feature_method == "cls_patch" || self.select_feature_method == "full" {
102            Ok(result)
103        } else {
104            result.i((.., 1..))
105        }
106    }
107
108    pub fn num_patches_per_side(&self) -> usize {
109        self.config.image_size / self.config.patch_size
110    }
111}
112
113pub struct Model {
114    clip_vision_tower: ClipVisionTower,
115    mm_projector: MMProjector,
116    llm: Box<dyn LLaVALLM>,
117    config: Config,
118    device: Device,
119    dtype: DType,
120}
121
122impl Model {
123    pub fn new(
124        config: &Config,
125        vb: ShardedVarBuilder,
126        is_gptx: bool,
127        normal_loading_metadata: NormalLoadingMetadata,
128        attention_mechanism: AttentionImplementation,
129    ) -> Result<Self> {
130        let device = normal_loading_metadata.real_device.clone();
131        let dtype = vb.dtype();
132        let clip_config = config.to_clip_config();
133        let mm_projector = MMProjector::new(&vb, config, &device)?;
134        let clip_vision_tower = ClipVisionTower::new(
135            vb.pp("vision_tower.vision_model")
136                .set_device(device.clone()),
137            config.vision_feature_layer,
138            &config.vision_feature_select_strategy,
139            &clip_config,
140        )?;
141
142        let llm: Box<dyn LLaVALLM> = match config.text_config.model_type.as_str() {
143            "llama" => {
144                let llama_config = config.to_llama_config();
145                let llama = Llama::new(
146                    &llama_config,
147                    vb.pp("language_model"),
148                    is_gptx,
149                    normal_loading_metadata,
150                    attention_mechanism,
151                )?;
152                Box::new(llama)
153            }
154            "mistral" => {
155                let mistral_config = config.to_mistral_config();
156                let mistral = Mistral::new(
157                    &mistral_config,
158                    vb.pp("language_model"),
159                    is_gptx,
160                    normal_loading_metadata,
161                    attention_mechanism,
162                )?;
163                Box::new(mistral)
164            }
165            _ => {
166                bail!("Unsupported model type: {}", config.text_config.model_type);
167            }
168        };
169        Ok(Self {
170            clip_vision_tower,
171            mm_projector,
172            llm,
173            config: config.clone(),
174            device,
175            dtype,
176        })
177    }
178
179    pub fn encode_images(&self, x: &Tensor) -> Result<Tensor> {
180        let mut image_features = self.clip_vision_tower.forward(x)?;
181        image_features = self.mm_projector.forward(&image_features)?;
182        Ok(image_features)
183    }
184
185    pub fn prepare_inputs_labels_for_multimodal(
186        &self,
187        input_ids: &Tensor, //[1,seq_len]
188        images: &Tensor,    //[sum of samples of all images,channel,width,height]
189        num_image_tokens: usize,
190    ) -> Result<Tensor> {
191        let image_indexes = input_ids
192            .squeeze(0)?
193            .lt(0i64)?
194            .nonzero()?
195            .squeeze(1)?
196            .to_vec1::<u32>()?;
197        let mut result = input_ids.clamp(0i64, i64::MAX)?.to_dtype(DType::U32)?;
198        result = self.llm.embed(&result)?; //[seq_len,hidden_size]
199        let image_features = self.encode_images(&images.to_dtype(self.dtype)?)?; //[num of images,patch_size*patch_size,hidden_size]
200        let num_of_images = image_features.shape().dims()[0];
201        let mut image_features_vec = Vec::new();
202        for i in 0..num_of_images {
203            image_features_vec.push(image_features.get(i)?.unsqueeze(0)?);
204        }
205        for (i, image_index) in image_indexes.iter().enumerate() {
206            result = result.slice_assign(
207                &[
208                    &(0usize..1usize),
209                    &(*image_index as usize..*image_index as usize + num_image_tokens),
210                    &(..),
211                ],
212                &image_features_vec[i],
213            )?;
214        }
215        //truncate
216        let (_, seq_len) = input_ids.shape().dims2()?;
217        if seq_len > self.config.text_config.max_length {
218            result = result.i((.., ..self.config.text_config.max_length, ..))?
219        }
220        Ok(result)
221    }
222
223    pub fn forward_inputs(
224        &self,
225        input_ids: &Tensor,
226        pixel_values: Option<Tensor>,
227        num_image_tokens: Option<usize>,
228        seqlen_offsets: &[usize],
229        context_lens: Vec<(usize, usize)>,
230        position_ids: Vec<usize>,
231        metadata: Option<(Vec<(Tensor, Tensor)>, &PagedAttentionInputMetadata)>,
232        flash_params: &FlashParams,
233    ) -> Result<Tensor> {
234        if let Some(ref pixel_values) = pixel_values {
235            // we assume(as it should be) only prompt request contains image
236            let input_embeds = self.prepare_inputs_labels_for_multimodal(
237                input_ids,
238                pixel_values,
239                num_image_tokens.unwrap(),
240            )?;
241            self.llm.forward_input_embed(
242                input_ids,
243                input_embeds,
244                seqlen_offsets,
245                context_lens,
246                metadata,
247                flash_params,
248            )
249        } else {
250            self.llm.forward(
251                input_ids,
252                seqlen_offsets,
253                context_lens,
254                position_ids,
255                metadata,
256                flash_params,
257            )
258        }
259    }
260}
261
262impl IsqModel for Model {
263    fn get_layers(
264        &mut self,
265    ) -> (
266        Vec<(
267            &mut std::sync::Arc<dyn mistralrs_quant::QuantMethod>,
268            Option<usize>,
269        )>,
270        &dyn DeviceMapper,
271    ) {
272        self.llm.get_layers()
273    }
274
275    fn residual_tensors(&self) -> Vec<(String, Tensor)> {
276        let uvb = UnVarBuilder::new();
277
278        // MM projectors
279        uvb.pp("multi_modal_projector.linear_1")
280            .add(&self.mm_projector.linear_1);
281        uvb.pp("multi_modal_projector.linear_2")
282            .add(&self.mm_projector.linear_2);
283
284        // Vision tower
285        {
286            let uvb_vt = uvb.pp("vision_tower.vision_model");
287            uvb_vt.extend(self.clip_vision_tower.model.residual_tensors());
288        }
289
290        uvb.to_safetensors()
291    }
292}
293
294impl VisionModel for Model {
295    fn forward(
296        &self,
297        input_ids: &Tensor,
298        pixel_values: Option<Tensor>,
299        seqlen_offsets: &[usize],
300        context_lens: Vec<(usize, usize)>,
301        position_ids: Vec<usize>,
302        _model_specific_args: Box<dyn std::any::Any>, // pixel attention mask, or image sizes, or anything else
303        metadata: Option<(Vec<(Tensor, Tensor)>, &PagedAttentionInputMetadata)>,
304        flash_params: &FlashParams,
305    ) -> candle_core::Result<Tensor> {
306        self.forward_inputs(
307            input_ids,
308            pixel_values,
309            Some(
310                self.clip_vision_tower.num_patches_per_side()
311                    * self.clip_vision_tower.num_patches_per_side(),
312            ),
313            seqlen_offsets,
314            context_lens,
315            position_ids,
316            metadata,
317            flash_params,
318        )
319    }
320
321    fn device(&self) -> &Device {
322        &self.device
323    }
324
325    fn cache(&self) -> &crate::pipeline::EitherCache {
326        self.llm.cache()
327    }
328    fn cache_mut(&mut self) -> &mut crate::pipeline::EitherCache {
329        self.llm.cache_mut()
330    }
331
332    fn max_seq_len(&self) -> usize {
333        self.config.text_config.max_length
334    }
335
336    fn has_conv2d(&self) -> bool {
337        true
338    }
339
340    fn config(&self) -> &ModelConfigMetadata {
341        self.llm.config()
342    }
343    fn default_model_specific_args(&self, _input_ids: &Tensor) -> Box<dyn Any> {
344        Box::new(())
345    }
346}
347
348impl AnyMoeBaseModelMixin for Model {
349    fn get_mlps(&self) -> Vec<&dyn MlpLayer> {
350        self.llm.get_mlps()
351    }
352    fn get_mlps_mut(&mut self) -> Vec<&mut Box<dyn MlpLayer>> {
353        self.llm.get_mlps_mut()
354    }
355    fn create_anymoe_layers(
356        &mut self,
357        additional_vbs: Vec<ShardedVarBuilder>,
358        config: AnyMoeConfig,
359        (prefix, mlp): (String, String),
360        layers: Vec<usize>,
361        expert_type: AnyMoeExpertType,
362        gate_vb: Option<ShardedVarBuilder>,
363    ) -> Result<()> {
364        self.llm.create_anymoe_layers(
365            additional_vbs,
366            config,
367            (prefix, mlp),
368            layers,
369            expert_type,
370            gate_vb,
371        )
372    }
373    fn amoe_supported(&self) -> bool {
374        true
375    }
376}