1use std::{
2 fmt::Debug,
3 sync::{Arc, LazyLock},
4};
5
6use candle_core::{shape::ShapeWithOneHole, DType, Device, IndexOp, Result, Shape, Tensor, D};
7use candle_nn::Module;
8use mistralrs_quant::{NonZeroOp, QuantMethod, ShardedVarBuilder};
9
10use crate::{
11 layers::{AvgPool2d, ReflectionPad2d},
12 utils::unvarbuilder::UnVarBuilder,
13 vision_models::{
14 phi4::config::Phi4MMImgProcessorConfig,
15 siglip::{SiglipVisionConfig, SiglipVisionTransformer},
16 },
17};
18
19use super::{config::Phi4MMImageEmbedConfig, Phi4MMConfig};
20
21pub(super) const IMAGE_SPECIAL_TOKEN_ID: f64 = 200010.;
22
23trait ModuleWithMetadata: Module + Debug + Send + Sync {
24 fn device(&self) -> Device;
25 fn dtype(&self) -> DType;
26}
27
28#[derive(Debug)]
29struct QuantMethodWrapper(Arc<dyn QuantMethod>);
30
31impl Module for QuantMethodWrapper {
32 fn forward(&self, xs: &Tensor) -> Result<Tensor> {
33 self.0.forward(xs)
34 }
35}
36
37impl ModuleWithMetadata for QuantMethodWrapper {
38 fn device(&self) -> Device {
39 self.0.unquant_weight_bias().unwrap().0.device().clone()
40 }
41 fn dtype(&self) -> DType {
42 self.0.unquant_weight_bias().unwrap().0.dtype()
43 }
44}
45
46impl ModuleWithMetadata for candle_nn::Activation {
47 fn device(&self) -> Device {
48 unreachable!()
49 }
50 fn dtype(&self) -> DType {
51 unreachable!()
52 }
53}
54
55#[derive(Debug)]
56struct BigShapeWithOneHole((usize, usize, usize, usize, usize, ()));
57
58fn hole_size(el_count: usize, prod_d: usize, s: &dyn std::fmt::Debug) -> Result<usize> {
59 if prod_d == 0 {
60 candle_core::bail!("cannot reshape tensor of {el_count} elements to {s:?}")
61 }
62 if el_count % prod_d != 0 {
63 candle_core::bail!("cannot reshape tensor with {el_count} elements to {s:?}")
64 }
65 Ok(el_count / prod_d)
66}
67
68impl ShapeWithOneHole for BigShapeWithOneHole {
69 fn into_shape(self, el_count: usize) -> Result<Shape> {
70 let (d1, d2, d3, d4, d5, ()) = self.0;
71 let d = hole_size(el_count, d1 * d2 * d3 * d4 * d5, &self)?;
72 Ok((d1, d2, d3, d4, d5, d).into())
73 }
74}
75
76#[derive(Debug)]
77struct EmbeddingLayers(Vec<Box<dyn ModuleWithMetadata>>);
78
79impl Module for EmbeddingLayers {
80 fn forward(&self, xs: &Tensor) -> Result<Tensor> {
81 let mut xs = xs.clone();
82 for layer in &self.0 {
83 xs = layer.forward(&xs)?;
84 }
85 Ok(xs)
86 }
87}
88
89pub(crate) static PHI4_MM_VISION_CFG: LazyLock<SiglipVisionConfig> =
90 LazyLock::new(|| SiglipVisionConfig {
91 hidden_size: 1152,
92 image_size: 448,
93 intermediate_size: 4304,
94 num_attention_heads: 16,
95 num_hidden_layers: 27,
96 patch_size: 14,
97 ..Default::default()
98 });
99
100pub struct ImageEmbedding {
101 wte: candle_nn::Embedding,
102 image_dim_out: usize,
103 num_img_tokens: usize,
104 glb_gn: Option<Tensor>,
105 sub_gn: Option<Tensor>,
106 layers: EmbeddingLayers,
107 type_feature: String,
108 layer_idx: isize,
109 image_processor: SiglipVisionTransformer,
110 hd_transform_order: String,
111 use_hd_transform: bool,
112 tensors: Vec<(String, Tensor)>,
113 img_processor_padding: Option<ReflectionPad2d>,
114 crop_size: usize,
115 image_token_compression: Option<AvgPool2d>,
116 base_feat_height_reduction: usize,
117 base_feat_height_target: Option<usize>,
118}
119
120impl ImageEmbedding {
121 pub fn new(
122 cfg: &Phi4MMConfig,
123 img_embd_config: &Phi4MMImageEmbedConfig,
124 wte: candle_nn::Embedding,
125 vb: ShardedVarBuilder,
126 ) -> Result<Self> {
127 let hidden_size = img_embd_config.n_embd.unwrap_or(cfg.hidden_size);
128
129 let siglip_vision_config = &PHI4_MM_VISION_CFG;
130 let image_processor =
131 SiglipVisionTransformer::new(siglip_vision_config, vb.pp("img_processor"))?;
132
133 let pe_weight = image_processor.embeddings.position_embedding.embeddings();
134 let (l, d) = pe_weight.dims2()?;
135 let mut m = (l as f64).sqrt() as usize;
136 assert_eq!(m.pow(2), l);
137 let img_processor_padding = if m % 2 != 0 {
138 m += 1;
139 Some(ReflectionPad2d::new((0, 1, 0, 1)))
140 } else {
141 None
142 };
143 let image_dim_out = d;
144 let num_img_tokens = (m / 2).pow(2);
145 let base_feat_height_target = m;
146
147 let use_hd_transform = img_embd_config.use_hd_transform.unwrap_or(false);
149 let with_learnable_separator = img_embd_config.with_learnable_separator.unwrap_or(false);
150 let hd_transform_order = img_embd_config
151 .hd_transform_order
152 .clone()
153 .unwrap_or("glb_sub".to_string());
154 let crop_size = img_embd_config.crop_size.unwrap_or(336);
155
156 let (image_token_compression, base_feat_height_reduction, base_feat_height_target) =
157 match &img_embd_config.image_token_compression_cls {
158 Some(x) if x == "avg_pool_2d" => (
159 Some(AvgPool2d::new(2, 2)),
160 1_usize,
161 Some(base_feat_height_target / 2),
162 ),
163 None => (None, 2_usize, None),
164 _ => candle_core::bail!("Unexpected image_token_compression_cls"),
165 };
166
167 assert_eq!(use_hd_transform, with_learnable_separator);
168 let (glb_gn, sub_gn) = if with_learnable_separator {
169 let glb_gn = vb.get(
170 (1, 1, image_dim_out * base_feat_height_reduction.pow(2)),
171 "glb_GN",
172 )?;
173 let sub_gn = vb.get(
174 (1, 1, 1, image_dim_out * base_feat_height_reduction.pow(2)),
175 "sub_GN",
176 )?;
177 (Some(glb_gn), Some(sub_gn))
178 } else {
179 (None, None)
180 };
181
182 let projection_cls = img_embd_config
184 .projection_cls
185 .clone()
186 .unwrap_or("linear".to_string());
187
188 let mut tensors = Vec::new();
189 let layers: Vec<Box<dyn ModuleWithMetadata>> =
190 match (projection_cls.as_str(), use_hd_transform) {
191 ("linear", _) => {
192 let a = mistralrs_quant::linear_b(
193 image_dim_out,
194 hidden_size,
195 true,
196 &None,
197 vb.pp("img_projection"),
198 )?;
199 let (a_w, a_b) = a.unquant_weight_bias().unwrap();
200 tensors.push(("img_projection.weight".to_string(), a_w));
201 if let Some(b) = a_b {
202 tensors.push(("img_projection.bias".to_string(), b));
203 }
204 vec![Box::new(QuantMethodWrapper(a))]
205 }
206 ("mlp", true) => {
207 let dim_proj = hidden_size;
208 let a = mistralrs_quant::linear_b(
209 image_dim_out * base_feat_height_reduction.pow(2),
210 dim_proj,
211 true,
212 &None,
213 vb.pp("img_projection.0"),
214 )?;
215 let (a_w, a_b) = a.unquant_weight_bias().unwrap();
216 tensors.push(("img_projection.0.weight".to_string(), a_w));
217 if let Some(b) = a_b {
218 tensors.push(("img_projection.0.bias".to_string(), b));
219 }
220 let b = mistralrs_quant::linear_b(
221 dim_proj,
222 dim_proj,
223 true,
224 &None,
225 vb.pp("img_projection.2"),
226 )?;
227 let (b_w, b_b) = b.unquant_weight_bias().unwrap();
228 tensors.push(("img_projection.2.weight".to_string(), b_w));
229 if let Some(b) = b_b {
230 tensors.push(("img_projection.2.bias".to_string(), b));
231 }
232 vec![
233 Box::new(QuantMethodWrapper(a)),
234 Box::new(candle_nn::Activation::Gelu),
235 Box::new(QuantMethodWrapper(b)),
236 ]
237 }
238 ("mlp", false) => {
239 let dim_proj = hidden_size;
240 let a = mistralrs_quant::linear_b(
241 image_dim_out,
242 dim_proj,
243 true,
244 &None,
245 vb.pp("img_projection.0"),
246 )?;
247 let (a_w, a_b) = a.unquant_weight_bias().unwrap();
248 tensors.push(("img_projection.0.weight".to_string(), a_w));
249 if let Some(b) = a_b {
250 tensors.push(("img_projection.0.bias".to_string(), b));
251 }
252 let b = mistralrs_quant::linear_b(
253 dim_proj,
254 dim_proj,
255 true,
256 &None,
257 vb.pp("img_projection.2"),
258 )?;
259 let (b_w, b_b) = b.unquant_weight_bias().unwrap();
260 tensors.push(("img_projection.2.weight".to_string(), b_w));
261 if let Some(b) = b_b {
262 tensors.push(("img_projection.2.bias".to_string(), b));
263 }
264 vec![
265 Box::new(QuantMethodWrapper(a)),
266 Box::new(candle_nn::Activation::Gelu),
267 Box::new(QuantMethodWrapper(b)),
268 ]
269 }
270 _ => {
271 candle_core::bail!("projection_cls=`{projection_cls}` not implemented.");
272 }
273 };
274
275 let (layer_idx, type_feature) = match &cfg.img_processor {
276 Some(Phi4MMImgProcessorConfig {
277 layer_idx,
278 type_feature,
279 }) => (
280 layer_idx.unwrap_or(-2),
281 type_feature.clone().unwrap_or("patch".to_string()),
282 ),
283
284 None => (-2, "patch".to_string()),
285 };
286
287 Ok(Self {
288 wte,
289 image_dim_out,
290 num_img_tokens,
291 glb_gn,
292 sub_gn,
293 layer_idx,
294 type_feature,
295 image_processor,
296 layers: EmbeddingLayers(layers),
297 hd_transform_order,
298 use_hd_transform,
299 tensors,
300 img_processor_padding,
301 crop_size,
302 image_token_compression,
303 base_feat_height_reduction,
304 base_feat_height_target,
305 })
306 }
307
308 fn get_image_features(
309 &self,
310 img_embeds: &Tensor,
311 attention_mask: Option<&Tensor>,
312 ) -> Result<Tensor> {
313 assert!(self.layer_idx < 0);
314 let img_feature = self.image_processor.forward_get_hidden_states(
315 &img_embeds.to_dtype(self.image_processor.dtype())?,
316 attention_mask,
317 None,
318 self.layer_idx,
319 )?;
320
321 if self.type_feature == "patch" {
322 let mut patch_feature = img_feature.clone();
323 if let Some(image_token_compression) = &self.image_token_compression {
324 let width = (patch_feature.dim(1)? as f64).sqrt() as usize;
326 patch_feature =
327 patch_feature.reshape(((), width, width, patch_feature.dim(D::Minus1)?))?;
328 patch_feature = patch_feature.permute((0, 3, 1, 2))?;
330 if let Some(img_processor_padding) = &self.img_processor_padding {
331 patch_feature = patch_feature.apply(img_processor_padding)?;
332 }
333 patch_feature = image_token_compression.forward(&patch_feature)?;
334 patch_feature = patch_feature.permute((0, 2, 3, 1))?;
336 patch_feature = patch_feature.reshape((
337 (),
338 patch_feature.dim(1)? * patch_feature.dim(2)?,
339 patch_feature.dim(D::Minus1)?,
340 ))?;
341 } else if let Some(img_processor_padding) = &self.img_processor_padding {
342 let width = (patch_feature.dim(1)? as f64).sqrt() as usize;
344 patch_feature =
345 patch_feature.reshape(((), width, width, patch_feature.dim(D::Minus1)?))?;
346 patch_feature = patch_feature.permute((0, 3, 1, 2))?;
348 patch_feature = patch_feature.apply(img_processor_padding)?;
349 patch_feature = patch_feature.permute((0, 2, 3, 1))?;
351 patch_feature = patch_feature.reshape((
352 (),
353 patch_feature.dim(1)? * patch_feature.dim(2)?,
354 patch_feature.dim(D::Minus1)?,
355 ))?;
356 };
357 Ok(patch_feature)
358 } else if self.type_feature == "cls_patch" {
359 let mut img_feature = img_feature.clone();
360 if let Some(image_token_compression) = &self.image_token_compression {
361 let mut patch_feature = img_feature.i((.., 1..))?;
363 let cls_feature = img_feature.i((.., 0))?;
364 let width = (patch_feature.dim(1)? as f64).sqrt() as usize;
365 patch_feature =
366 patch_feature.reshape(((), width, width, patch_feature.dim(D::Minus1)?))?;
367 patch_feature = image_token_compression.forward(&patch_feature)?;
368 patch_feature = patch_feature.reshape((
369 (),
370 patch_feature.dim(D::Minus2)? * patch_feature.dim(D::Minus1)?,
371 ))?;
372 img_feature = Tensor::cat(&[cls_feature, patch_feature], 1)?;
373 }
374 Ok(img_feature)
375 } else {
376 candle_core::bail!("Unsupported image feature type {}", self.type_feature)
377 }
378 }
379
380 #[allow(non_snake_case)]
381 pub fn forward(
382 &self,
383 input_ids: &Tensor,
384 input_embeds: &Tensor,
385 image_attention_mask: Option<&Tensor>,
386 image_sizes: Option<Vec<(u32, u32)>>,
387 ) -> Result<Tensor> {
388 let input_ids = input_ids.reshape(((), input_ids.dim(D::Minus1)?))?;
389
390 let positions = input_ids.eq(IMAGE_SPECIAL_TOKEN_ID)?.nonzero()?;
391
392 let target_dev = self.layers.0[0].device();
393 let target_dtype = self.layers.0[0].dtype();
394
395 let mut select = false;
396 let mut image_set_tensor = None;
397 if positions.dim(0)? > 0 {
398 select = true;
399
400 if self.use_hd_transform && image_sizes.is_some() {
401 assert_eq!(input_embeds.dims().len(), 5);
402 let bs = input_embeds.dim(0)?;
403 let img_features = match image_attention_mask {
404 Some(attn_mask) => self.get_image_features(
405 &input_embeds.flatten(0, 1)?,
406 Some(&attn_mask.ne(0.)?.flatten(0, 1)?),
407 )?,
408 None => self.get_image_features(input_embeds, None)?,
409 };
410
411 let base_feat_height_target = self.base_feat_height_target.unwrap();
412 let base_resolution = self.crop_size;
413 let base_feat_height_reduction = self.base_feat_height_reduction;
414
415 let base_feat_height = (img_features.dim(1)? as f64).sqrt() as usize;
416 let base_feat_width = base_feat_height;
417
418 assert_eq!(base_feat_height, base_feat_height_target);
419 assert_eq!(base_feat_width, base_feat_height_target);
420
421 let img_features = img_features.reshape((
422 bs,
423 (),
424 base_feat_height * base_feat_width,
425 self.image_dim_out,
426 ))?;
427 let C = self.image_dim_out;
428 let H = base_feat_height;
429
430 let mut output_imgs = Vec::new();
431 for bs_ in 0..bs {
432 let (h, w) = image_sizes.as_ref().unwrap()[bs_];
433 let h = h as usize / base_resolution;
434 let w = w as usize / base_resolution;
435 let B_ = h * w;
436
437 let global_img_feature = img_features.i((bs_, ..1))?;
439
440 let glb_img = global_img_feature
442 .reshape((1, H, H, C))?
443 .reshape((
444 1,
445 H / base_feat_height_reduction,
446 base_feat_height_reduction,
447 H / base_feat_height_reduction,
448 base_feat_height_reduction,
449 C,
450 ))?
451 .contiguous()?
452 .permute((0, 1, 3, 2, 4, 5))?
453 .reshape((
454 1,
455 H / base_feat_height_reduction,
456 H / base_feat_height_reduction,
457 base_feat_height_reduction * base_feat_height_reduction * C,
458 ))?
459 .contiguous()?;
460 let temp_glbl_gn = self
461 .sub_gn
462 .as_ref()
463 .expect("Need `sub_gn` if `use_hd_transform`")
464 .repeat((1, H / base_feat_height_reduction, 1, 1))?;
465
466 let glb_img = Tensor::cat(&[glb_img, temp_glbl_gn], 2)?.reshape((
468 1,
469 (),
470 base_feat_height_reduction * base_feat_height_reduction * C,
471 ))?;
472
473 let mut sub_img = img_features.i((bs_, 1..))?;
475
476 sub_img = sub_img.i(..B_)?;
479
480 sub_img = sub_img
482 .reshape((B_, H, H, C))?
483 .reshape((
484 B_,
485 H / base_feat_height_reduction,
486 base_feat_height_reduction,
487 H / base_feat_height_reduction,
488 base_feat_height_reduction,
489 C,
490 ))?
491 .contiguous()?
492 .permute((0, 1, 3, 2, 4, 5))?
493 .reshape((
494 B_,
495 (),
496 base_feat_height_reduction * base_feat_height_reduction * C,
497 ))?
498 .contiguous()?;
499 sub_img = sub_img
500 .reshape(BigShapeWithOneHole((
501 1usize,
502 h,
503 w,
504 base_feat_height / base_feat_height_reduction,
505 base_feat_width / base_feat_height_reduction,
506 (),
507 )))?
508 .permute((0, 1, 3, 2, 4, 5))?
509 .reshape((
510 1,
511 h * base_feat_height / base_feat_height_reduction,
512 w * base_feat_width / base_feat_height_reduction,
513 base_feat_height_reduction * base_feat_height_reduction * C,
514 ))?;
515
516 let (temp_sub_GN, temp_len) = if let Some(image_attention_mask) =
517 image_attention_mask
518 {
519 let h_indices = Tensor::arange_step(
520 0,
521 image_attention_mask.dim(2)? as u32,
522 2,
523 &target_dev,
524 )?;
525 let w_indices = Tensor::arange_step(
526 0,
527 image_attention_mask.dim(3)? as u32,
528 2,
529 &target_dev,
530 )?;
531
532 let reshaped_image_attention_mask = {
533 let mut selected = image_attention_mask.i((bs_, 1..B_ + 1))?;
534 selected = selected.index_select(&h_indices, 1)?;
535 selected = selected.index_select(&w_indices, 2)?;
536 selected
537 .reshape((
538 1,
539 h,
540 w,
541 base_feat_height / base_feat_height_reduction,
542 base_feat_width / base_feat_height_reduction,
543 ))?
544 .permute((0, 1, 3, 2, 4))?
545 .reshape((
546 1,
547 h * base_feat_height / base_feat_height_reduction,
548 w * base_feat_width / base_feat_height_reduction,
549 ))?
550 };
551
552 let useful_height = reshaped_image_attention_mask
553 .i((0, .., 0))?
554 .sum_all()?
555 .to_scalar::<u32>()?;
556 let useful_width = reshaped_image_attention_mask
557 .i((0, 0, ..))?
558 .sum_all()?
559 .to_scalar::<u32>()?;
560
561 sub_img =
562 sub_img.i((.., ..useful_height as usize, ..useful_width as usize))?;
563
564 let temp_len = {
565 let mut selected = image_attention_mask.i((bs_, ..B_ + 1))?;
566 selected = selected.index_select(&h_indices, 1)?;
567 selected = selected.index_select(&w_indices, 2)?;
568 selected.sum_all()?.to_scalar::<u32>()?
569 };
570 let temp_len = temp_len as usize
571 + useful_height as usize
572 + 1
573 + base_feat_height / base_feat_height_reduction;
574
575 (
576 self.sub_gn
577 .as_ref()
578 .expect("Need `sub_gn` if `use_hd_transform`")
579 .repeat((1, useful_height as usize, 1, 1))?,
580 temp_len,
581 )
582 } else {
583 let temp_len = (h * w + 1) * self.num_img_tokens
584 + 1
585 + (h + 1) * base_feat_height / base_feat_height_reduction;
586
587 (
588 self.sub_gn
589 .as_ref()
590 .expect("Need `sub_gn` if `use_hd_transform`")
591 .repeat((
592 1,
593 h * base_feat_height / base_feat_height_reduction,
594 1,
595 1,
596 ))?,
597 temp_len,
598 )
599 };
600
601 let sub_img = Tensor::cat(&[sub_img, temp_sub_GN], 2)?.reshape((
602 1,
603 (),
604 base_feat_height_reduction * base_feat_height_reduction * C,
605 ))?;
606
607 match self.hd_transform_order.as_str() {
611 "glb_sub" => {
612 output_imgs.push(Tensor::cat(
613 &[
614 glb_img,
615 self.glb_gn
616 .as_ref()
617 .expect("Need `glb_gn` if `use_hd_transform`")
618 .clone(),
619 sub_img,
620 ],
621 1,
622 )?);
623 }
624 "sub_glb" => {
625 output_imgs.push(Tensor::cat(
626 &[
627 sub_img,
628 self.glb_gn
629 .as_ref()
630 .expect("Need `glb_gn` if `use_hd_transform`")
631 .clone(),
632 glb_img,
633 ],
634 1,
635 )?);
636 }
637 other => {
638 candle_core::bail!("Invalid hd_transform_order=`{other}`");
639 }
640 }
641
642 assert_eq!(temp_len, output_imgs.last().unwrap().dims()[1]);
644 }
645
646 let mut image_set_tensor_inner = Vec::new();
647 for img in output_imgs {
648 let layerout = self
649 .layers
650 .forward(&img.to_device(&target_dev)?.to_dtype(target_dtype)?)?;
651 image_set_tensor_inner.push(layerout);
652 }
653 image_set_tensor = Some(image_set_tensor_inner);
654 } else {
655 unreachable!()
656 }
657 }
658
659 let mut hidden_states = self.wte.forward(&input_ids)?;
660 if select && self.use_hd_transform {
661 match image_set_tensor {
662 Some(image_set_tensors) => {
663 let merged_img_set_tensor = Tensor::cat(&image_set_tensors, 1)?.squeeze(0)?;
664
665 let original_shape = hidden_states.shape().clone();
666 let (hs_b, hs_l, hs_d) = hidden_states.dims3()?;
667 let mut hidden_states_flat = hidden_states.reshape(((), hs_d))?;
668
669 let positions_transposed = positions.to_dtype(DType::F32)?;
671 let positions_transposed_0 = positions_transposed.i((.., 0))?;
672 let positions_transposed_1 = positions_transposed.i((.., 1))?;
673
674 let mut linear_index = ((positions_transposed_0 * (hs_l * hs_b) as f64)?
675 + positions_transposed_1)?;
676 linear_index = linear_index.to_dtype(DType::U32)?;
677 linear_index = linear_index.unsqueeze(1)?.repeat((1, hs_d))?;
678
679 let current_vals = hidden_states_flat.gather(&linear_index, 0)?;
680 let delta = merged_img_set_tensor.broadcast_sub(¤t_vals)?;
681
682 hidden_states_flat =
683 hidden_states_flat.scatter_add(&linear_index, &delta, 0)?;
684
685 hidden_states = hidden_states_flat.reshape(original_shape)?;
686 }
687 _ => unreachable!(),
688 }
689 }
690
691 Ok(hidden_states)
692 }
693
694 pub fn residual_tensors(&self) -> Vec<(String, Tensor)> {
695 let uvb = UnVarBuilder::new();
696
697 if let Some(glb_gn) = self.glb_gn.clone() {
698 uvb.add_tensor("glb_GN", glb_gn);
699 }
700 if let Some(sub_gn) = self.sub_gn.clone() {
701 uvb.add_tensor("sub_GN", sub_gn);
702 }
703 uvb.extend(self.tensors.clone());
704 uvb.pp("img_processor.vision_model")
705 .extend(self.image_processor.residual_tensors());
706
707 uvb.to_safetensors()
708 }
709}