1#![allow(
2 clippy::cast_possible_truncation,
3 clippy::cast_precision_loss,
4 clippy::too_many_arguments
5)]
6use std::any::Any;
7
8use candle_core::{bail, DType, Device, IndexOp, Result, Tensor};
9use candle_nn::{Activation, Linear};
10use mistralrs_quant::{NonZeroOp, ShardedVarBuilder};
11
12use crate::amoe::{AnyMoeBaseModelMixin, MlpLayer};
13use crate::device_map::DeviceMapper;
14use crate::paged_attention::{AttentionImplementation, ModelConfigMetadata};
15use crate::pipeline::text_models_inputs_processor::{FlashParams, PagedAttentionInputMetadata};
16use crate::pipeline::IsqModel;
17use crate::pipeline::NormalLoadingMetadata;
18use crate::pipeline::VisionModel;
19
20use crate::utils::unvarbuilder::UnVarBuilder;
21use crate::vision_models::clip::{ClipConfig, ClipVisionTransformer};
22use crate::vision_models::llava::config::Config;
23use crate::vision_models::llava::utils::get_anyres_image_grid_shape;
24use crate::{layers, AnyMoeConfig, AnyMoeExpertType};
25
26use super::llava_llm::{LLaVALLM, Llama, Mistral};
27
28#[derive(Default)]
29pub(crate) struct LLaVANextVisionSpecificArgs {
30 pub image_sizes: Option<Vec<(usize, usize)>>, pub num_image_tokens: Option<Vec<usize>>, pub num_image_samples: Option<Vec<usize>>, }
34
35pub struct MMProjector {
36 linear_1: Linear,
37 activation: Activation,
38 linear_2: Linear,
39}
40
41impl MMProjector {
42 pub fn new(vb: &ShardedVarBuilder, config: &Config, device: &Device) -> Result<Self> {
43 let linear_1 = layers::linear(
44 config.vision_config.hidden_size,
45 config.text_config.hidden_size,
46 vb.pp("multi_modal_projector.linear_1")
47 .set_device(device.clone()),
48 )?;
49 let activation = match config.projector_hidden_act.as_str() {
50 "gelu" => Activation::Gelu,
51 _ => {
52 bail!(
53 "Unsupporg projector hidden act: {}",
54 config.projector_hidden_act
55 );
56 }
57 };
58 let linear_2 = layers::linear(
59 config.text_config.hidden_size,
60 config.text_config.hidden_size,
61 vb.pp("multi_modal_projector.linear_2")
62 .set_device(device.clone()),
63 )?;
64 Ok(Self {
65 linear_1,
66 activation,
67 linear_2,
68 })
69 }
70
71 pub fn forward(&self, x: &Tensor) -> Result<Tensor> {
72 x.apply(&self.linear_1)?
73 .apply(&self.activation)?
74 .apply(&self.linear_2)
75 }
76}
77
78pub struct ClipVisionTower {
79 model: ClipVisionTransformer,
80 select_layer: isize,
81 select_feature_method: String,
82 config: ClipConfig,
83}
84
85impl ClipVisionTower {
86 pub fn new(
87 vb: ShardedVarBuilder,
88 select_layer: isize,
89 select_feature_method: &str,
90 config: &ClipConfig,
91 ) -> Result<Self> {
92 let model = ClipVisionTransformer::new(vb, config)?;
93 Ok(Self {
94 model,
95 select_layer,
96 select_feature_method: select_feature_method.to_string(),
97 config: config.clone(),
98 })
99 }
100
101 pub fn forward(&self, x: &Tensor) -> Result<Tensor> {
102 let result = self.model.forward_get_hidden_states(x)?;
103 let index = result.len() as isize + self.select_layer;
104 let result = result[index as usize].clone();
105 if self.select_feature_method == "cls_patch" || self.select_feature_method == "full" {
106 Ok(result)
107 } else {
108 result.i((.., 1..))
109 }
110 }
111
112 pub fn num_patches_per_side(&self) -> usize {
113 self.config.image_size / self.config.patch_size
114 }
115}
116
117pub struct Model {
118 clip_vision_tower: ClipVisionTower,
119 image_newline: Tensor,
120 mm_projector: MMProjector,
121 llm: Box<dyn LLaVALLM>,
122 config: Config,
123 device: Device,
124 dtype: DType,
125}
126
127impl Model {
128 pub fn new(
129 config: &Config,
130 vb: ShardedVarBuilder,
131 is_gptx: bool,
132 normal_loading_metadata: NormalLoadingMetadata,
133 attention_mechanism: AttentionImplementation,
134 ) -> Result<Self> {
135 let device = normal_loading_metadata.real_device.clone();
136 let dtype = vb.dtype();
137 let clip_config = config.to_clip_config();
138 let mm_projector = MMProjector::new(&vb, config, &device)?;
139 let clip_vision_tower = ClipVisionTower::new(
140 vb.pp("vision_tower.vision_model")
141 .set_device(device.clone()),
142 config.vision_feature_layer,
143 &config.vision_feature_select_strategy,
144 &clip_config,
145 )?;
146 let image_newline = vb
147 .get(&[config.text_config.hidden_size], "image_newline")?
148 .to_device(&device)?;
149
150 let llm: Box<dyn LLaVALLM> = match config.text_config.model_type.as_str() {
151 "llama" => {
152 let llama_config = config.to_llama_config();
153 let llama = Llama::new(
154 &llama_config,
155 vb.pp("language_model"),
156 is_gptx,
157 normal_loading_metadata,
158 attention_mechanism,
159 )?;
160 Box::new(llama)
161 }
162 "mistral" => {
163 let mistral_config = config.to_mistral_config();
164 let mistral = Mistral::new(
165 &mistral_config,
166 vb.pp("language_model"),
167 is_gptx,
168 normal_loading_metadata,
169 attention_mechanism,
170 )?;
171 Box::new(mistral)
172 }
173 _ => {
174 bail!("Unsupported model type: {}", config.text_config.model_type);
175 }
176 };
177 Ok(Self {
178 clip_vision_tower,
179 image_newline,
180 mm_projector,
181 llm,
182 config: config.clone(),
183 device,
184 dtype,
185 })
186 }
187
188 pub fn encode_images(&self, x: &Tensor) -> Result<Tensor> {
189 let mut image_features = self.clip_vision_tower.forward(x)?;
190 image_features = self.mm_projector.forward(&image_features)?;
191 Ok(image_features)
192 }
193
194 fn unpad_image(&self, tensor: &Tensor, original_size: (u32, u32)) -> Result<Tensor> {
195 assert_eq!(tensor.dims().len(), 3);
196 let (original_width, original_height) = original_size;
197 let tensor_dims = tensor.dims();
198 let current_height = tensor_dims[1];
199 let current_width = tensor_dims[2];
200 let original_aspect_ratio = (original_width as f32) / (original_height as f32);
201 let current_aspect_ratio = (current_width as f32) / (current_height as f32);
202 if original_aspect_ratio > current_aspect_ratio {
203 let scale_factor = (current_width as f32) / (original_width as f32);
204 let new_height = (original_height as f32 * scale_factor).floor() as usize;
205 let padding = (current_height - new_height) / 2;
206 tensor.i((.., padding..current_height - padding, ..))
207 } else {
208 let scale_factor = (current_height as f32) / (original_height as f32);
209 let new_width = (original_width as f32 * scale_factor).floor() as usize;
210 let padding = (current_width - new_width) / 2;
211 tensor.i((.., .., padding..current_width - padding))
212 }
213 }
214
215 pub fn prepare_inputs_labels_for_multimodal(
216 &self,
217 input_ids: &Tensor, images: &Tensor, num_image_tokens: Vec<usize>,
220 num_image_samples: Vec<usize>,
221 image_sizes: &[(u32, u32)],
222 ) -> Result<Tensor> {
223 let image_indexes = input_ids
224 .squeeze(0)?
225 .lt(0i64)?
226 .nonzero()?
227 .squeeze(1)?
228 .to_vec1::<u32>()?;
229 let mut result = input_ids.clamp(0i64, i64::MAX)?.to_dtype(DType::U32)?;
230 result = self.llm.embed(&result)?; let image_features = self.encode_images(&images.to_dtype(self.dtype)?)?; let mut image_features_vec = Vec::new();
233 let mut index = 0;
234 for num_image_sample in num_image_samples {
235 image_features_vec.push(image_features.i(index..index + num_image_sample)?);
236 index += num_image_sample;
237 }
238 let image_features_vec = image_features_vec
239 .iter()
240 .enumerate()
241 .map(|(image_idx, image_feature)| {
242 let base_image_feature = image_feature.get(0).unwrap();
243 let patch_image_feature = image_feature.i(1..).unwrap();
244 let height = self.clip_vision_tower.num_patches_per_side();
245 let width = height;
246 assert_eq!(height * width, base_image_feature.dims()[0]);
247 let image_size = image_sizes[image_idx];
248 let image_grid_pinpoints = self.config.image_grid_pinpoints.clone().unwrap();
249 let (num_patch_width, num_patch_height) = get_anyres_image_grid_shape(
250 image_size,
251 &image_grid_pinpoints,
252 self.clip_vision_tower.config.image_size as u32,
253 );
254 let mut new_image_feature = patch_image_feature.reshape((
255 num_patch_height as usize,
256 num_patch_width as usize,
257 height,
258 width,
259 (),
260 ))?;
261 new_image_feature = new_image_feature
262 .permute((4, 0, 2, 1, 3))?
263 .flatten(1, 2)?
264 .flatten(2, 3)?;
265 new_image_feature = self.unpad_image(&new_image_feature, image_size)?;
266 let new_image_feature_dims = new_image_feature.dims();
267 let image_new_line = self
268 .image_newline
269 .reshape((self.config.text_config.hidden_size, 1, 1))?
270 .broadcast_as((new_image_feature_dims[0], new_image_feature_dims[1], 1))?;
271 new_image_feature = Tensor::cat(&[new_image_feature, image_new_line], 2)?
272 .flatten(1, 2)?
273 .transpose(0, 1)?;
274 new_image_feature =
275 Tensor::cat(&[base_image_feature, new_image_feature], 0)?.unsqueeze(0)?;
276 Ok(new_image_feature)
277 })
278 .collect::<Result<Vec<Tensor>>>()?;
279 for (i, image_index) in image_indexes.iter().enumerate() {
280 result = result.slice_assign(
281 &[
282 &(0usize..1usize),
283 &(*image_index as usize..*image_index as usize + num_image_tokens[i]),
284 &(..),
285 ],
286 &image_features_vec[i],
287 )?;
288 }
289 let (_, seq_len) = input_ids.shape().dims2()?;
291 if seq_len > self.config.text_config.max_length {
292 result = result.i((.., ..self.config.text_config.max_length, ..))?
293 }
294 Ok(result)
295 }
296
297 pub fn forward_inputs(
298 &self,
299 input_ids: &Tensor,
300 pixel_values: Option<Tensor>,
301 image_sizes: Option<Vec<(u32, u32)>>,
302 num_image_tokens: Option<Vec<usize>>,
303 num_image_samples: Option<Vec<usize>>,
304 seqlen_offsets: &[usize],
305 context_lens: Vec<(usize, usize)>,
306 position_ids: Vec<usize>,
307 metadata: Option<(Vec<(Tensor, Tensor)>, &PagedAttentionInputMetadata)>,
308 flash_params: &FlashParams,
309 ) -> Result<Tensor> {
310 if let Some(ref pixel_values) = pixel_values {
311 let input_embeds = self.prepare_inputs_labels_for_multimodal(
313 input_ids,
314 pixel_values,
315 num_image_tokens.unwrap(),
316 num_image_samples.unwrap(),
317 &image_sizes.unwrap(),
318 )?;
319 self.llm.forward_input_embed(
320 input_ids,
321 input_embeds,
322 seqlen_offsets,
323 context_lens,
324 metadata,
325 flash_params,
326 )
327 } else {
328 self.llm.forward(
329 input_ids,
330 seqlen_offsets,
331 context_lens,
332 position_ids,
333 metadata,
334 flash_params,
335 )
336 }
337 }
338}
339
340impl IsqModel for Model {
341 fn get_layers(
342 &mut self,
343 ) -> (
344 Vec<(
345 &mut std::sync::Arc<dyn mistralrs_quant::QuantMethod>,
346 Option<usize>,
347 )>,
348 &dyn DeviceMapper,
349 ) {
350 self.llm.get_layers()
351 }
352
353 fn residual_tensors(&self) -> Vec<(String, Tensor)> {
354 let uvb = UnVarBuilder::new();
355
356 uvb.pp("multi_modal_projector.linear_1")
358 .add(&self.mm_projector.linear_1);
359 uvb.pp("multi_modal_projector.linear_2")
360 .add(&self.mm_projector.linear_2);
361
362 {
364 let uvb_vt = uvb.pp("vision_tower.vision_model");
365 uvb_vt.extend(self.clip_vision_tower.model.residual_tensors());
366 }
367
368 uvb.add_tensor("image_newline", self.image_newline.clone());
369
370 uvb.to_safetensors()
371 }
372}
373
374impl VisionModel for Model {
375 fn forward(
376 &self,
377 input_ids: &Tensor,
378 pixel_values: Option<Tensor>,
379 seqlen_offsets: &[usize],
380 context_lens: Vec<(usize, usize)>,
381 position_ids: Vec<usize>,
382 model_specific_args: Box<dyn std::any::Any>, metadata: Option<(Vec<(Tensor, Tensor)>, &PagedAttentionInputMetadata)>,
384 flash_params: &FlashParams,
385 ) -> candle_core::Result<Tensor> {
386 let LLaVANextVisionSpecificArgs {
387 image_sizes,
388 num_image_tokens,
389 num_image_samples,
390 } = *model_specific_args
391 .downcast()
392 .expect("Cannot downcast into `LLaVANextVisionSpecificArgs`");
393 let image_sizes = image_sizes.map(|image_sizes| {
394 image_sizes
395 .iter()
396 .map(|(w, h)| (*w as u32, *h as u32))
397 .collect::<Vec<_>>()
398 });
399 self.forward_inputs(
400 input_ids,
401 pixel_values,
402 image_sizes,
403 num_image_tokens,
404 num_image_samples,
405 seqlen_offsets,
406 context_lens,
407 position_ids,
408 metadata,
409 flash_params,
410 )
411 }
412
413 fn device(&self) -> &Device {
414 &self.device
415 }
416
417 fn cache(&self) -> &crate::pipeline::EitherCache {
418 self.llm.cache()
419 }
420 fn cache_mut(&mut self) -> &mut crate::pipeline::EitherCache {
421 self.llm.cache_mut()
422 }
423
424 fn max_seq_len(&self) -> usize {
425 self.config.text_config.max_length
426 }
427
428 fn has_conv2d(&self) -> bool {
429 true
430 }
431
432 fn config(&self) -> &ModelConfigMetadata {
433 self.llm.config()
434 }
435 fn default_model_specific_args(&self, _input_ids: &Tensor) -> Box<dyn Any> {
436 Box::new(LLaVANextVisionSpecificArgs::default())
437 }
438}
439
440impl AnyMoeBaseModelMixin for Model {
441 fn get_mlps(&self) -> Vec<&dyn MlpLayer> {
442 self.llm.get_mlps()
443 }
444 fn get_mlps_mut(&mut self) -> Vec<&mut Box<dyn MlpLayer>> {
445 self.llm.get_mlps_mut()
446 }
447 fn create_anymoe_layers(
448 &mut self,
449 additional_vbs: Vec<ShardedVarBuilder>,
450 config: AnyMoeConfig,
451 (prefix, mlp): (String, String),
452 layers: Vec<usize>,
453 expert_type: AnyMoeExpertType,
454 gate_vb: Option<ShardedVarBuilder>,
455 ) -> Result<()> {
456 self.llm.create_anymoe_layers(
457 additional_vbs,
458 config,
459 (prefix, mlp),
460 layers,
461 expert_type,
462 gate_vb,
463 )
464 }
465 fn amoe_supported(&self) -> bool {
466 true
467 }
468}