1#![allow(
2 clippy::cast_possible_truncation,
3 clippy::cast_precision_loss,
4 clippy::too_many_arguments
5)]
6use std::any::Any;
7
8use super::llava_llm::{LLaVALLM, Llama, Mistral};
9use crate::amoe::AnyMoeBaseModelMixin;
10use crate::amoe::MlpLayer;
11use crate::device_map::DeviceMapper;
12use crate::layers;
13use crate::ops::NonZeroOp;
14use crate::paged_attention::{AttentionImplementation, ModelConfigMetadata};
15use crate::pipeline::text_models_inputs_processor::FlashParams;
16use crate::pipeline::text_models_inputs_processor::PagedAttentionInputMetadata;
17use crate::pipeline::IsqModel;
18use crate::pipeline::NormalLoadingMetadata;
19use crate::pipeline::VisionModel;
20use crate::utils::unvarbuilder::UnVarBuilder;
21use crate::vision_models::clip::{ClipConfig, ClipVisionTransformer};
22use crate::vision_models::llava::config::Config;
23use crate::AnyMoeConfig;
24use crate::AnyMoeExpertType;
25use candle_core::{bail, DType, Device, IndexOp, Result, Tensor};
26use candle_nn::{Activation, Linear};
27use mistralrs_quant::ShardedVarBuilder;
28
29pub(crate) struct LLaVAVisionSpecificArgs; pub struct MMProjector {
32 linear_1: Linear,
33 activation: Activation,
34 linear_2: Linear,
35}
36
37impl MMProjector {
38 pub fn new(vb: &ShardedVarBuilder, config: &Config, device: &Device) -> Result<Self> {
39 let linear_1 = layers::linear(
40 config.vision_config.hidden_size,
41 config.text_config.hidden_size,
42 vb.pp("multi_modal_projector.linear_1")
43 .set_device(device.clone()),
44 )?;
45 let activation = match config.projector_hidden_act.as_str() {
46 "gelu" => Activation::Gelu,
47 _ => {
48 bail!(
49 "Unsupporg projector hidden act: {}",
50 config.projector_hidden_act
51 );
52 }
53 };
54 let linear_2 = layers::linear(
55 config.text_config.hidden_size,
56 config.text_config.hidden_size,
57 vb.pp("multi_modal_projector.linear_2")
58 .set_device(device.clone()),
59 )?;
60 Ok(Self {
61 linear_1,
62 activation,
63 linear_2,
64 })
65 }
66
67 pub fn forward(&self, x: &Tensor) -> Result<Tensor> {
68 x.apply(&self.linear_1)?
69 .apply(&self.activation)?
70 .apply(&self.linear_2)
71 }
72}
73
74pub struct ClipVisionTower {
75 model: ClipVisionTransformer,
76 select_layer: isize,
77 select_feature_method: String,
78 config: ClipConfig,
79}
80
81impl ClipVisionTower {
82 pub fn new(
83 vb: ShardedVarBuilder,
84 select_layer: isize,
85 select_feature_method: &str,
86 config: &ClipConfig,
87 ) -> Result<Self> {
88 let model = ClipVisionTransformer::new(vb, config)?;
89 Ok(Self {
90 model,
91 select_layer,
92 select_feature_method: select_feature_method.to_string(),
93 config: config.clone(),
94 })
95 }
96
97 pub fn forward(&self, x: &Tensor) -> Result<Tensor> {
98 let result = self.model.forward_get_hidden_states(x)?;
99 let index = result.len() as isize + self.select_layer;
100 let result = result[index as usize].clone();
101 if self.select_feature_method == "cls_patch" || self.select_feature_method == "full" {
102 Ok(result)
103 } else {
104 result.i((.., 1..))
105 }
106 }
107
108 pub fn num_patches_per_side(&self) -> usize {
109 self.config.image_size / self.config.patch_size
110 }
111}
112
113pub struct Model {
114 clip_vision_tower: ClipVisionTower,
115 mm_projector: MMProjector,
116 llm: Box<dyn LLaVALLM>,
117 config: Config,
118 device: Device,
119 dtype: DType,
120}
121
122impl Model {
123 pub fn new(
124 config: &Config,
125 vb: ShardedVarBuilder,
126 is_gptx: bool,
127 normal_loading_metadata: NormalLoadingMetadata,
128 attention_mechanism: AttentionImplementation,
129 ) -> Result<Self> {
130 let device = normal_loading_metadata.real_device.clone();
131 let dtype = vb.dtype();
132 let clip_config = config.to_clip_config();
133 let mm_projector = MMProjector::new(&vb, config, &device)?;
134 let clip_vision_tower = ClipVisionTower::new(
135 vb.pp("vision_tower.vision_model")
136 .set_device(device.clone()),
137 config.vision_feature_layer,
138 &config.vision_feature_select_strategy,
139 &clip_config,
140 )?;
141
142 let llm: Box<dyn LLaVALLM> = match config.text_config.model_type.as_str() {
143 "llama" => {
144 let llama_config = config.to_llama_config();
145 let llama = Llama::new(
146 &llama_config,
147 vb.pp("language_model"),
148 is_gptx,
149 normal_loading_metadata,
150 attention_mechanism,
151 )?;
152 Box::new(llama)
153 }
154 "mistral" => {
155 let mistral_config = config.to_mistral_config();
156 let mistral = Mistral::new(
157 &mistral_config,
158 vb.pp("language_model"),
159 is_gptx,
160 normal_loading_metadata,
161 attention_mechanism,
162 )?;
163 Box::new(mistral)
164 }
165 _ => {
166 bail!("Unsupported model type: {}", config.text_config.model_type);
167 }
168 };
169 Ok(Self {
170 clip_vision_tower,
171 mm_projector,
172 llm,
173 config: config.clone(),
174 device,
175 dtype,
176 })
177 }
178
179 pub fn encode_images(&self, x: &Tensor) -> Result<Tensor> {
180 let mut image_features = self.clip_vision_tower.forward(x)?;
181 image_features = self.mm_projector.forward(&image_features)?;
182 Ok(image_features)
183 }
184
185 pub fn prepare_inputs_labels_for_multimodal(
186 &self,
187 input_ids: &Tensor, images: &Tensor, num_image_tokens: usize,
190 ) -> Result<Tensor> {
191 let image_indexes = input_ids
192 .squeeze(0)?
193 .lt(0i64)?
194 .nonzero()?
195 .squeeze(1)?
196 .to_vec1::<u32>()?;
197 let mut result = input_ids.clamp(0i64, i64::MAX)?.to_dtype(DType::U32)?;
198 result = self.llm.embed(&result)?; let image_features = self.encode_images(&images.to_dtype(self.dtype)?)?; let num_of_images = image_features.shape().dims()[0];
201 let mut image_features_vec = Vec::new();
202 for i in 0..num_of_images {
203 image_features_vec.push(image_features.get(i)?.unsqueeze(0)?);
204 }
205 for (i, image_index) in image_indexes.iter().enumerate() {
206 result = result.slice_assign(
207 &[
208 &(0usize..1usize),
209 &(*image_index as usize..*image_index as usize + num_image_tokens),
210 &(..),
211 ],
212 &image_features_vec[i],
213 )?;
214 }
215 let (_, seq_len) = input_ids.shape().dims2()?;
217 if seq_len > self.config.text_config.max_length {
218 result = result.i((.., ..self.config.text_config.max_length, ..))?
219 }
220 Ok(result)
221 }
222
223 pub fn forward_inputs(
224 &self,
225 input_ids: &Tensor,
226 pixel_values: Option<Tensor>,
227 num_image_tokens: Option<usize>,
228 seqlen_offsets: &[usize],
229 context_lens: Vec<(usize, usize)>,
230 position_ids: Vec<usize>,
231 metadata: Option<(Vec<(Tensor, Tensor)>, &PagedAttentionInputMetadata)>,
232 flash_params: &FlashParams,
233 ) -> Result<Tensor> {
234 if let Some(ref pixel_values) = pixel_values {
235 let input_embeds = self.prepare_inputs_labels_for_multimodal(
237 input_ids,
238 pixel_values,
239 num_image_tokens.unwrap(),
240 )?;
241 self.llm.forward_input_embed(
242 input_ids,
243 input_embeds,
244 seqlen_offsets,
245 context_lens,
246 metadata,
247 flash_params,
248 )
249 } else {
250 self.llm.forward(
251 input_ids,
252 seqlen_offsets,
253 context_lens,
254 position_ids,
255 metadata,
256 flash_params,
257 )
258 }
259 }
260}
261
262impl IsqModel for Model {
263 fn get_layers(
264 &mut self,
265 ) -> (
266 Vec<(
267 &mut std::sync::Arc<dyn mistralrs_quant::QuantMethod>,
268 Option<usize>,
269 )>,
270 &dyn DeviceMapper,
271 ) {
272 self.llm.get_layers()
273 }
274
275 fn residual_tensors(&self) -> Vec<(String, Tensor)> {
276 let uvb = UnVarBuilder::new();
277
278 uvb.pp("multi_modal_projector.linear_1")
280 .add(&self.mm_projector.linear_1);
281 uvb.pp("multi_modal_projector.linear_2")
282 .add(&self.mm_projector.linear_2);
283
284 {
286 let uvb_vt = uvb.pp("vision_tower.vision_model");
287 uvb_vt.extend(self.clip_vision_tower.model.residual_tensors());
288 }
289
290 uvb.to_safetensors()
291 }
292}
293
294impl VisionModel for Model {
295 fn forward(
296 &self,
297 input_ids: &Tensor,
298 pixel_values: Option<Tensor>,
299 seqlen_offsets: &[usize],
300 context_lens: Vec<(usize, usize)>,
301 position_ids: Vec<usize>,
302 _model_specific_args: Box<dyn std::any::Any>, metadata: Option<(Vec<(Tensor, Tensor)>, &PagedAttentionInputMetadata)>,
304 flash_params: &FlashParams,
305 ) -> candle_core::Result<Tensor> {
306 self.forward_inputs(
307 input_ids,
308 pixel_values,
309 Some(
310 self.clip_vision_tower.num_patches_per_side()
311 * self.clip_vision_tower.num_patches_per_side(),
312 ),
313 seqlen_offsets,
314 context_lens,
315 position_ids,
316 metadata,
317 flash_params,
318 )
319 }
320
321 fn device(&self) -> &Device {
322 &self.device
323 }
324
325 fn cache(&self) -> &crate::pipeline::EitherCache {
326 self.llm.cache()
327 }
328 fn cache_mut(&mut self) -> &mut crate::pipeline::EitherCache {
329 self.llm.cache_mut()
330 }
331
332 fn max_seq_len(&self) -> usize {
333 self.config.text_config.max_length
334 }
335
336 fn has_conv2d(&self) -> bool {
337 true
338 }
339
340 fn config(&self) -> &ModelConfigMetadata {
341 self.llm.config()
342 }
343 fn default_model_specific_args(&self, _input_ids: &Tensor) -> Box<dyn Any> {
344 Box::new(())
345 }
346}
347
348impl AnyMoeBaseModelMixin for Model {
349 fn get_mlps(&self) -> Vec<&dyn MlpLayer> {
350 self.llm.get_mlps()
351 }
352 fn get_mlps_mut(&mut self) -> Vec<&mut Box<dyn MlpLayer>> {
353 self.llm.get_mlps_mut()
354 }
355 fn create_anymoe_layers(
356 &mut self,
357 additional_vbs: Vec<ShardedVarBuilder>,
358 config: AnyMoeConfig,
359 (prefix, mlp): (String, String),
360 layers: Vec<usize>,
361 expert_type: AnyMoeExpertType,
362 gate_vb: Option<ShardedVarBuilder>,
363 ) -> Result<()> {
364 self.llm.create_anymoe_layers(
365 additional_vbs,
366 config,
367 (prefix, mlp),
368 layers,
369 expert_type,
370 gate_vb,
371 )
372 }
373 fn amoe_supported(&self) -> bool {
374 true
375 }
376}