1#![allow(
2 clippy::cast_possible_truncation,
3 clippy::cast_precision_loss,
4 clippy::too_many_arguments
5)]
6use std::any::Any;
7
8use candle_core::{bail, DType, Device, IndexOp, Result, Tensor};
9use candle_nn::{Activation, Linear};
10use mistralrs_quant::ShardedVarBuilder;
11
12use crate::amoe::{AnyMoeBaseModelMixin, MlpLayer};
13use crate::device_map::DeviceMapper;
14use crate::ops::NonZeroOp;
15use crate::paged_attention::{AttentionImplementation, ModelConfigMetadata};
16use crate::pipeline::text_models_inputs_processor::{FlashParams, PagedAttentionInputMetadata};
17use crate::pipeline::IsqModel;
18use crate::pipeline::NormalLoadingMetadata;
19use crate::pipeline::VisionModel;
20
21use crate::utils::unvarbuilder::UnVarBuilder;
22use crate::vision_models::clip::{ClipConfig, ClipVisionTransformer};
23use crate::vision_models::llava::config::Config;
24use crate::vision_models::llava::utils::get_anyres_image_grid_shape;
25use crate::{layers, AnyMoeConfig, AnyMoeExpertType};
26
27use super::llava_llm::{LLaVALLM, Llama, Mistral};
28
29#[derive(Default)]
30pub(crate) struct LLaVANextVisionSpecificArgs {
31 pub image_sizes: Option<Vec<(usize, usize)>>, pub num_image_tokens: Option<Vec<usize>>, pub num_image_samples: Option<Vec<usize>>, }
35
36pub struct MMProjector {
37 linear_1: Linear,
38 activation: Activation,
39 linear_2: Linear,
40}
41
42impl MMProjector {
43 pub fn new(vb: &ShardedVarBuilder, config: &Config, device: &Device) -> Result<Self> {
44 let linear_1 = layers::linear(
45 config.vision_config.hidden_size,
46 config.text_config.hidden_size,
47 vb.pp("multi_modal_projector.linear_1")
48 .set_device(device.clone()),
49 )?;
50 let activation = match config.projector_hidden_act.as_str() {
51 "gelu" => Activation::Gelu,
52 _ => {
53 bail!(
54 "Unsupporg projector hidden act: {}",
55 config.projector_hidden_act
56 );
57 }
58 };
59 let linear_2 = layers::linear(
60 config.text_config.hidden_size,
61 config.text_config.hidden_size,
62 vb.pp("multi_modal_projector.linear_2")
63 .set_device(device.clone()),
64 )?;
65 Ok(Self {
66 linear_1,
67 activation,
68 linear_2,
69 })
70 }
71
72 pub fn forward(&self, x: &Tensor) -> Result<Tensor> {
73 x.apply(&self.linear_1)?
74 .apply(&self.activation)?
75 .apply(&self.linear_2)
76 }
77}
78
79pub struct ClipVisionTower {
80 model: ClipVisionTransformer,
81 select_layer: isize,
82 select_feature_method: String,
83 config: ClipConfig,
84}
85
86impl ClipVisionTower {
87 pub fn new(
88 vb: ShardedVarBuilder,
89 select_layer: isize,
90 select_feature_method: &str,
91 config: &ClipConfig,
92 ) -> Result<Self> {
93 let model = ClipVisionTransformer::new(vb, config)?;
94 Ok(Self {
95 model,
96 select_layer,
97 select_feature_method: select_feature_method.to_string(),
98 config: config.clone(),
99 })
100 }
101
102 pub fn forward(&self, x: &Tensor) -> Result<Tensor> {
103 let result = self.model.forward_get_hidden_states(x)?;
104 let index = result.len() as isize + self.select_layer;
105 let result = result[index as usize].clone();
106 if self.select_feature_method == "cls_patch" || self.select_feature_method == "full" {
107 Ok(result)
108 } else {
109 result.i((.., 1..))
110 }
111 }
112
113 pub fn num_patches_per_side(&self) -> usize {
114 self.config.image_size / self.config.patch_size
115 }
116}
117
118pub struct Model {
119 clip_vision_tower: ClipVisionTower,
120 image_newline: Tensor,
121 mm_projector: MMProjector,
122 llm: Box<dyn LLaVALLM>,
123 config: Config,
124 device: Device,
125 dtype: DType,
126}
127
128impl Model {
129 pub fn new(
130 config: &Config,
131 vb: ShardedVarBuilder,
132 is_gptx: bool,
133 normal_loading_metadata: NormalLoadingMetadata,
134 attention_mechanism: AttentionImplementation,
135 ) -> Result<Self> {
136 let device = normal_loading_metadata.real_device.clone();
137 let dtype = vb.dtype();
138 let clip_config = config.to_clip_config();
139 let mm_projector = MMProjector::new(&vb, config, &device)?;
140 let clip_vision_tower = ClipVisionTower::new(
141 vb.pp("vision_tower.vision_model")
142 .set_device(device.clone()),
143 config.vision_feature_layer,
144 &config.vision_feature_select_strategy,
145 &clip_config,
146 )?;
147 let image_newline = vb
148 .get(&[config.text_config.hidden_size], "image_newline")?
149 .to_device(&device)?;
150
151 let llm: Box<dyn LLaVALLM> = match config.text_config.model_type.as_str() {
152 "llama" => {
153 let llama_config = config.to_llama_config();
154 let llama = Llama::new(
155 &llama_config,
156 vb.pp("language_model"),
157 is_gptx,
158 normal_loading_metadata,
159 attention_mechanism,
160 )?;
161 Box::new(llama)
162 }
163 "mistral" => {
164 let mistral_config = config.to_mistral_config();
165 let mistral = Mistral::new(
166 &mistral_config,
167 vb.pp("language_model"),
168 is_gptx,
169 normal_loading_metadata,
170 attention_mechanism,
171 )?;
172 Box::new(mistral)
173 }
174 _ => {
175 bail!("Unsupported model type: {}", config.text_config.model_type);
176 }
177 };
178 Ok(Self {
179 clip_vision_tower,
180 image_newline,
181 mm_projector,
182 llm,
183 config: config.clone(),
184 device,
185 dtype,
186 })
187 }
188
189 pub fn encode_images(&self, x: &Tensor) -> Result<Tensor> {
190 let mut image_features = self.clip_vision_tower.forward(x)?;
191 image_features = self.mm_projector.forward(&image_features)?;
192 Ok(image_features)
193 }
194
195 fn unpad_image(&self, tensor: &Tensor, original_size: (u32, u32)) -> Result<Tensor> {
196 assert_eq!(tensor.dims().len(), 3);
197 let (original_width, original_height) = original_size;
198 let tensor_dims = tensor.dims();
199 let current_height = tensor_dims[1];
200 let current_width = tensor_dims[2];
201 let original_aspect_ratio = (original_width as f32) / (original_height as f32);
202 let current_aspect_ratio = (current_width as f32) / (current_height as f32);
203 if original_aspect_ratio > current_aspect_ratio {
204 let scale_factor = (current_width as f32) / (original_width as f32);
205 let new_height = (original_height as f32 * scale_factor).floor() as usize;
206 let padding = (current_height - new_height) / 2;
207 tensor.i((.., padding..current_height - padding, ..))
208 } else {
209 let scale_factor = (current_height as f32) / (original_height as f32);
210 let new_width = (original_width as f32 * scale_factor).floor() as usize;
211 let padding = (current_width - new_width) / 2;
212 tensor.i((.., .., padding..current_width - padding))
213 }
214 }
215
216 pub fn prepare_inputs_labels_for_multimodal(
217 &self,
218 input_ids: &Tensor, images: &Tensor, num_image_tokens: Vec<usize>,
221 num_image_samples: Vec<usize>,
222 image_sizes: &[(u32, u32)],
223 ) -> Result<Tensor> {
224 let image_indexes = input_ids
225 .squeeze(0)?
226 .lt(0i64)?
227 .nonzero()?
228 .squeeze(1)?
229 .to_vec1::<u32>()?;
230 let mut result = input_ids.clamp(0i64, i64::MAX)?.to_dtype(DType::U32)?;
231 result = self.llm.embed(&result)?; let image_features = self.encode_images(&images.to_dtype(self.dtype)?)?; let mut image_features_vec = Vec::new();
234 let mut index = 0;
235 for num_image_sample in num_image_samples {
236 image_features_vec.push(image_features.i(index..index + num_image_sample)?);
237 index += num_image_sample;
238 }
239 let image_features_vec = image_features_vec
240 .iter()
241 .enumerate()
242 .map(|(image_idx, image_feature)| {
243 let base_image_feature = image_feature.get(0).unwrap();
244 let patch_image_feature = image_feature.i(1..).unwrap();
245 let height = self.clip_vision_tower.num_patches_per_side();
246 let width = height;
247 assert_eq!(height * width, base_image_feature.dims()[0]);
248 let image_size = image_sizes[image_idx];
249 let image_grid_pinpoints = self.config.image_grid_pinpoints.clone().unwrap();
250 let (num_patch_width, num_patch_height) = get_anyres_image_grid_shape(
251 image_size,
252 &image_grid_pinpoints,
253 self.clip_vision_tower.config.image_size as u32,
254 );
255 let mut new_image_feature = patch_image_feature.reshape((
256 num_patch_height as usize,
257 num_patch_width as usize,
258 height,
259 width,
260 (),
261 ))?;
262 new_image_feature = new_image_feature
263 .permute((4, 0, 2, 1, 3))?
264 .flatten(1, 2)?
265 .flatten(2, 3)?;
266 new_image_feature = self.unpad_image(&new_image_feature, image_size)?;
267 let new_image_feature_dims = new_image_feature.dims();
268 let image_new_line = self
269 .image_newline
270 .reshape((self.config.text_config.hidden_size, 1, 1))?
271 .broadcast_as((new_image_feature_dims[0], new_image_feature_dims[1], 1))?;
272 new_image_feature = Tensor::cat(&[new_image_feature, image_new_line], 2)?
273 .flatten(1, 2)?
274 .transpose(0, 1)?;
275 new_image_feature =
276 Tensor::cat(&[base_image_feature, new_image_feature], 0)?.unsqueeze(0)?;
277 Ok(new_image_feature)
278 })
279 .collect::<Result<Vec<Tensor>>>()?;
280 for (i, image_index) in image_indexes.iter().enumerate() {
281 result = result.slice_assign(
282 &[
283 &(0usize..1usize),
284 &(*image_index as usize..*image_index as usize + num_image_tokens[i]),
285 &(..),
286 ],
287 &image_features_vec[i],
288 )?;
289 }
290 let (_, seq_len) = input_ids.shape().dims2()?;
292 if seq_len > self.config.text_config.max_length {
293 result = result.i((.., ..self.config.text_config.max_length, ..))?
294 }
295 Ok(result)
296 }
297
298 pub fn forward_inputs(
299 &self,
300 input_ids: &Tensor,
301 pixel_values: Option<Tensor>,
302 image_sizes: Option<Vec<(u32, u32)>>,
303 num_image_tokens: Option<Vec<usize>>,
304 num_image_samples: Option<Vec<usize>>,
305 seqlen_offsets: &[usize],
306 context_lens: Vec<(usize, usize)>,
307 position_ids: Vec<usize>,
308 metadata: Option<(Vec<(Tensor, Tensor)>, &PagedAttentionInputMetadata)>,
309 flash_params: &FlashParams,
310 ) -> Result<Tensor> {
311 if let Some(ref pixel_values) = pixel_values {
312 let input_embeds = self.prepare_inputs_labels_for_multimodal(
314 input_ids,
315 pixel_values,
316 num_image_tokens.unwrap(),
317 num_image_samples.unwrap(),
318 &image_sizes.unwrap(),
319 )?;
320 self.llm.forward_input_embed(
321 input_ids,
322 input_embeds,
323 seqlen_offsets,
324 context_lens,
325 metadata,
326 flash_params,
327 )
328 } else {
329 self.llm.forward(
330 input_ids,
331 seqlen_offsets,
332 context_lens,
333 position_ids,
334 metadata,
335 flash_params,
336 )
337 }
338 }
339}
340
341impl IsqModel for Model {
342 fn get_layers(
343 &mut self,
344 ) -> (
345 Vec<(
346 &mut std::sync::Arc<dyn mistralrs_quant::QuantMethod>,
347 Option<usize>,
348 )>,
349 &dyn DeviceMapper,
350 ) {
351 self.llm.get_layers()
352 }
353
354 fn residual_tensors(&self) -> Vec<(String, Tensor)> {
355 let uvb = UnVarBuilder::new();
356
357 uvb.pp("multi_modal_projector.linear_1")
359 .add(&self.mm_projector.linear_1);
360 uvb.pp("multi_modal_projector.linear_2")
361 .add(&self.mm_projector.linear_2);
362
363 {
365 let uvb_vt = uvb.pp("vision_tower.vision_model");
366 uvb_vt.extend(self.clip_vision_tower.model.residual_tensors());
367 }
368
369 uvb.add_tensor("image_newline", self.image_newline.clone());
370
371 uvb.to_safetensors()
372 }
373}
374
375impl VisionModel for Model {
376 fn forward(
377 &self,
378 input_ids: &Tensor,
379 pixel_values: Option<Tensor>,
380 seqlen_offsets: &[usize],
381 context_lens: Vec<(usize, usize)>,
382 position_ids: Vec<usize>,
383 model_specific_args: Box<dyn std::any::Any>, metadata: Option<(Vec<(Tensor, Tensor)>, &PagedAttentionInputMetadata)>,
385 flash_params: &FlashParams,
386 ) -> candle_core::Result<Tensor> {
387 let LLaVANextVisionSpecificArgs {
388 image_sizes,
389 num_image_tokens,
390 num_image_samples,
391 } = *model_specific_args
392 .downcast()
393 .expect("Cannot downcast into `LLaVANextVisionSpecificArgs`");
394 let image_sizes = image_sizes.map(|image_sizes| {
395 image_sizes
396 .iter()
397 .map(|(w, h)| (*w as u32, *h as u32))
398 .collect::<Vec<_>>()
399 });
400 self.forward_inputs(
401 input_ids,
402 pixel_values,
403 image_sizes,
404 num_image_tokens,
405 num_image_samples,
406 seqlen_offsets,
407 context_lens,
408 position_ids,
409 metadata,
410 flash_params,
411 )
412 }
413
414 fn device(&self) -> &Device {
415 &self.device
416 }
417
418 fn cache(&self) -> &crate::pipeline::EitherCache {
419 self.llm.cache()
420 }
421 fn cache_mut(&mut self) -> &mut crate::pipeline::EitherCache {
422 self.llm.cache_mut()
423 }
424
425 fn max_seq_len(&self) -> usize {
426 self.config.text_config.max_length
427 }
428
429 fn has_conv2d(&self) -> bool {
430 true
431 }
432
433 fn config(&self) -> &ModelConfigMetadata {
434 self.llm.config()
435 }
436 fn default_model_specific_args(&self, _input_ids: &Tensor) -> Box<dyn Any> {
437 Box::new(LLaVANextVisionSpecificArgs::default())
438 }
439}
440
441impl AnyMoeBaseModelMixin for Model {
442 fn get_mlps(&self) -> Vec<&dyn MlpLayer> {
443 self.llm.get_mlps()
444 }
445 fn get_mlps_mut(&mut self) -> Vec<&mut Box<dyn MlpLayer>> {
446 self.llm.get_mlps_mut()
447 }
448 fn create_anymoe_layers(
449 &mut self,
450 additional_vbs: Vec<ShardedVarBuilder>,
451 config: AnyMoeConfig,
452 (prefix, mlp): (String, String),
453 layers: Vec<usize>,
454 expert_type: AnyMoeExpertType,
455 gate_vb: Option<ShardedVarBuilder>,
456 ) -> Result<()> {
457 self.llm.create_anymoe_layers(
458 additional_vbs,
459 config,
460 (prefix, mlp),
461 layers,
462 expert_type,
463 gate_vb,
464 )
465 }
466 fn amoe_supported(&self) -> bool {
467 true
468 }
469}