1use std::{
2 fmt::Debug,
3 sync::{Arc, LazyLock},
4};
5
6use candle_core::{shape::ShapeWithOneHole, DType, Device, IndexOp, Result, Shape, Tensor, D};
7use candle_nn::Module;
8use mistralrs_quant::{QuantMethod, ShardedVarBuilder};
9
10use crate::{
11 layers::{AvgPool2d, ReflectionPad2d},
12 ops::NonZeroOp,
13 utils::unvarbuilder::UnVarBuilder,
14 vision_models::{
15 phi4::config::Phi4MMImgProcessorConfig,
16 siglip::{SiglipVisionConfig, SiglipVisionTransformer},
17 },
18};
19
20use super::{config::Phi4MMImageEmbedConfig, Phi4MMConfig};
21
22pub(super) const IMAGE_SPECIAL_TOKEN_ID: f64 = 200010.;
23
24trait ModuleWithMetadata: Module + Debug + Send + Sync {
25 fn device(&self) -> Device;
26 fn dtype(&self) -> DType;
27}
28
29#[derive(Debug)]
30struct QuantMethodWrapper(Arc<dyn QuantMethod>);
31
32impl Module for QuantMethodWrapper {
33 fn forward(&self, xs: &Tensor) -> Result<Tensor> {
34 self.0.forward(xs)
35 }
36}
37
38impl ModuleWithMetadata for QuantMethodWrapper {
39 fn device(&self) -> Device {
40 self.0.unquant_weight_bias().unwrap().0.device().clone()
41 }
42 fn dtype(&self) -> DType {
43 self.0.unquant_weight_bias().unwrap().0.dtype()
44 }
45}
46
47impl ModuleWithMetadata for candle_nn::Activation {
48 fn device(&self) -> Device {
49 unreachable!()
50 }
51 fn dtype(&self) -> DType {
52 unreachable!()
53 }
54}
55
56#[derive(Debug)]
57struct BigShapeWithOneHole((usize, usize, usize, usize, usize, ()));
58
59fn hole_size(el_count: usize, prod_d: usize, s: &dyn std::fmt::Debug) -> Result<usize> {
60 if prod_d == 0 {
61 candle_core::bail!("cannot reshape tensor of {el_count} elements to {s:?}")
62 }
63 if el_count % prod_d != 0 {
64 candle_core::bail!("cannot reshape tensor with {el_count} elements to {s:?}")
65 }
66 Ok(el_count / prod_d)
67}
68
69impl ShapeWithOneHole for BigShapeWithOneHole {
70 fn into_shape(self, el_count: usize) -> Result<Shape> {
71 let (d1, d2, d3, d4, d5, ()) = self.0;
72 let d = hole_size(el_count, d1 * d2 * d3 * d4 * d5, &self)?;
73 Ok((d1, d2, d3, d4, d5, d).into())
74 }
75}
76
77#[derive(Debug)]
78struct EmbeddingLayers(Vec<Box<dyn ModuleWithMetadata>>);
79
80impl Module for EmbeddingLayers {
81 fn forward(&self, xs: &Tensor) -> Result<Tensor> {
82 let mut xs = xs.clone();
83 for layer in &self.0 {
84 xs = layer.forward(&xs)?;
85 }
86 Ok(xs)
87 }
88}
89
90pub(crate) static PHI4_MM_VISION_CFG: LazyLock<SiglipVisionConfig> =
91 LazyLock::new(|| SiglipVisionConfig {
92 hidden_size: 1152,
93 image_size: 448,
94 intermediate_size: 4304,
95 num_attention_heads: 16,
96 num_hidden_layers: 27,
97 patch_size: 14,
98 ..Default::default()
99 });
100
101pub struct ImageEmbedding {
102 wte: candle_nn::Embedding,
103 image_dim_out: usize,
104 num_img_tokens: usize,
105 glb_gn: Option<Tensor>,
106 sub_gn: Option<Tensor>,
107 layers: EmbeddingLayers,
108 type_feature: String,
109 layer_idx: isize,
110 image_processor: SiglipVisionTransformer,
111 hd_transform_order: String,
112 use_hd_transform: bool,
113 tensors: Vec<(String, Tensor)>,
114 img_processor_padding: Option<ReflectionPad2d>,
115 crop_size: usize,
116 image_token_compression: Option<AvgPool2d>,
117 base_feat_height_reduction: usize,
118 base_feat_height_target: Option<usize>,
119}
120
121impl ImageEmbedding {
122 pub fn new(
123 cfg: &Phi4MMConfig,
124 img_embd_config: &Phi4MMImageEmbedConfig,
125 wte: candle_nn::Embedding,
126 vb: ShardedVarBuilder,
127 ) -> Result<Self> {
128 let hidden_size = img_embd_config.n_embd.unwrap_or(cfg.hidden_size);
129
130 let siglip_vision_config = &PHI4_MM_VISION_CFG;
131 let image_processor =
132 SiglipVisionTransformer::new(siglip_vision_config, vb.pp("img_processor"))?;
133
134 let pe_weight = image_processor.embeddings.position_embedding.embeddings();
135 let (l, d) = pe_weight.dims2()?;
136 let mut m = (l as f64).sqrt() as usize;
137 assert_eq!(m.pow(2), l);
138 let img_processor_padding = if m % 2 != 0 {
139 m += 1;
140 Some(ReflectionPad2d::new((0, 1, 0, 1)))
141 } else {
142 None
143 };
144 let image_dim_out = d;
145 let num_img_tokens = (m / 2).pow(2);
146 let base_feat_height_target = m;
147
148 let use_hd_transform = img_embd_config.use_hd_transform.unwrap_or(false);
150 let with_learnable_separator = img_embd_config.with_learnable_separator.unwrap_or(false);
151 let hd_transform_order = img_embd_config
152 .hd_transform_order
153 .clone()
154 .unwrap_or("glb_sub".to_string());
155 let crop_size = img_embd_config.crop_size.unwrap_or(336);
156
157 let (image_token_compression, base_feat_height_reduction, base_feat_height_target) =
158 match &img_embd_config.image_token_compression_cls {
159 Some(x) if x == "avg_pool_2d" => (
160 Some(AvgPool2d::new(2, 2)),
161 1_usize,
162 Some(base_feat_height_target / 2),
163 ),
164 None => (None, 2_usize, None),
165 _ => candle_core::bail!("Unexpected image_token_compression_cls"),
166 };
167
168 assert_eq!(use_hd_transform, with_learnable_separator);
169 let (glb_gn, sub_gn) = if with_learnable_separator {
170 let glb_gn = vb.get(
171 (1, 1, image_dim_out * base_feat_height_reduction.pow(2)),
172 "glb_GN",
173 )?;
174 let sub_gn = vb.get(
175 (1, 1, 1, image_dim_out * base_feat_height_reduction.pow(2)),
176 "sub_GN",
177 )?;
178 (Some(glb_gn), Some(sub_gn))
179 } else {
180 (None, None)
181 };
182
183 let projection_cls = img_embd_config
185 .projection_cls
186 .clone()
187 .unwrap_or("linear".to_string());
188
189 let mut tensors = Vec::new();
190 let layers: Vec<Box<dyn ModuleWithMetadata>> =
191 match (projection_cls.as_str(), use_hd_transform) {
192 ("linear", _) => {
193 let a = mistralrs_quant::linear_b(
194 image_dim_out,
195 hidden_size,
196 true,
197 &None,
198 vb.pp("img_projection"),
199 )?;
200 let (a_w, a_b) = a.unquant_weight_bias().unwrap();
201 tensors.push(("img_projection.weight".to_string(), a_w));
202 if let Some(b) = a_b {
203 tensors.push(("img_projection.bias".to_string(), b));
204 }
205 vec![Box::new(QuantMethodWrapper(a))]
206 }
207 ("mlp", true) => {
208 let dim_proj = hidden_size;
209 let a = mistralrs_quant::linear_b(
210 image_dim_out * base_feat_height_reduction.pow(2),
211 dim_proj,
212 true,
213 &None,
214 vb.pp("img_projection.0"),
215 )?;
216 let (a_w, a_b) = a.unquant_weight_bias().unwrap();
217 tensors.push(("img_projection.0.weight".to_string(), a_w));
218 if let Some(b) = a_b {
219 tensors.push(("img_projection.0.bias".to_string(), b));
220 }
221 let b = mistralrs_quant::linear_b(
222 dim_proj,
223 dim_proj,
224 true,
225 &None,
226 vb.pp("img_projection.2"),
227 )?;
228 let (b_w, b_b) = b.unquant_weight_bias().unwrap();
229 tensors.push(("img_projection.2.weight".to_string(), b_w));
230 if let Some(b) = b_b {
231 tensors.push(("img_projection.2.bias".to_string(), b));
232 }
233 vec![
234 Box::new(QuantMethodWrapper(a)),
235 Box::new(candle_nn::Activation::Gelu),
236 Box::new(QuantMethodWrapper(b)),
237 ]
238 }
239 ("mlp", false) => {
240 let dim_proj = hidden_size;
241 let a = mistralrs_quant::linear_b(
242 image_dim_out,
243 dim_proj,
244 true,
245 &None,
246 vb.pp("img_projection.0"),
247 )?;
248 let (a_w, a_b) = a.unquant_weight_bias().unwrap();
249 tensors.push(("img_projection.0.weight".to_string(), a_w));
250 if let Some(b) = a_b {
251 tensors.push(("img_projection.0.bias".to_string(), b));
252 }
253 let b = mistralrs_quant::linear_b(
254 dim_proj,
255 dim_proj,
256 true,
257 &None,
258 vb.pp("img_projection.2"),
259 )?;
260 let (b_w, b_b) = b.unquant_weight_bias().unwrap();
261 tensors.push(("img_projection.2.weight".to_string(), b_w));
262 if let Some(b) = b_b {
263 tensors.push(("img_projection.2.bias".to_string(), b));
264 }
265 vec![
266 Box::new(QuantMethodWrapper(a)),
267 Box::new(candle_nn::Activation::Gelu),
268 Box::new(QuantMethodWrapper(b)),
269 ]
270 }
271 _ => {
272 candle_core::bail!("projection_cls=`{projection_cls}` not implemented.");
273 }
274 };
275
276 let (layer_idx, type_feature) = match &cfg.img_processor {
277 Some(Phi4MMImgProcessorConfig {
278 layer_idx,
279 type_feature,
280 }) => (
281 layer_idx.unwrap_or(-2),
282 type_feature.clone().unwrap_or("patch".to_string()),
283 ),
284
285 None => (-2, "patch".to_string()),
286 };
287
288 Ok(Self {
289 wte,
290 image_dim_out,
291 num_img_tokens,
292 glb_gn,
293 sub_gn,
294 layer_idx,
295 type_feature,
296 image_processor,
297 layers: EmbeddingLayers(layers),
298 hd_transform_order,
299 use_hd_transform,
300 tensors,
301 img_processor_padding,
302 crop_size,
303 image_token_compression,
304 base_feat_height_reduction,
305 base_feat_height_target,
306 })
307 }
308
309 fn get_image_features(
310 &self,
311 img_embeds: &Tensor,
312 attention_mask: Option<&Tensor>,
313 ) -> Result<Tensor> {
314 assert!(self.layer_idx < 0);
315 let img_feature = self.image_processor.forward_get_hidden_states(
316 &img_embeds.to_dtype(self.image_processor.dtype())?,
317 attention_mask,
318 None,
319 self.layer_idx,
320 )?;
321
322 if self.type_feature == "patch" {
323 let mut patch_feature = img_feature.clone();
324 if let Some(image_token_compression) = &self.image_token_compression {
325 let width = (patch_feature.dim(1)? as f64).sqrt() as usize;
327 patch_feature =
328 patch_feature.reshape(((), width, width, patch_feature.dim(D::Minus1)?))?;
329 patch_feature = patch_feature.permute((0, 3, 1, 2))?;
331 if let Some(img_processor_padding) = &self.img_processor_padding {
332 patch_feature = patch_feature.apply(img_processor_padding)?;
333 }
334 patch_feature = image_token_compression.forward(&patch_feature)?;
335 patch_feature = patch_feature.permute((0, 2, 3, 1))?;
337 patch_feature = patch_feature.reshape((
338 (),
339 patch_feature.dim(1)? * patch_feature.dim(2)?,
340 patch_feature.dim(D::Minus1)?,
341 ))?;
342 } else if let Some(img_processor_padding) = &self.img_processor_padding {
343 let width = (patch_feature.dim(1)? as f64).sqrt() as usize;
345 patch_feature =
346 patch_feature.reshape(((), width, width, patch_feature.dim(D::Minus1)?))?;
347 patch_feature = patch_feature.permute((0, 3, 1, 2))?;
349 patch_feature = patch_feature.apply(img_processor_padding)?;
350 patch_feature = patch_feature.permute((0, 2, 3, 1))?;
352 patch_feature = patch_feature.reshape((
353 (),
354 patch_feature.dim(1)? * patch_feature.dim(2)?,
355 patch_feature.dim(D::Minus1)?,
356 ))?;
357 };
358 Ok(patch_feature)
359 } else if self.type_feature == "cls_patch" {
360 let mut img_feature = img_feature.clone();
361 if let Some(image_token_compression) = &self.image_token_compression {
362 let mut patch_feature = img_feature.i((.., 1..))?;
364 let cls_feature = img_feature.i((.., 0))?;
365 let width = (patch_feature.dim(1)? as f64).sqrt() as usize;
366 patch_feature =
367 patch_feature.reshape(((), width, width, patch_feature.dim(D::Minus1)?))?;
368 patch_feature = image_token_compression.forward(&patch_feature)?;
369 patch_feature = patch_feature.reshape((
370 (),
371 patch_feature.dim(D::Minus2)? * patch_feature.dim(D::Minus1)?,
372 ))?;
373 img_feature = Tensor::cat(&[cls_feature, patch_feature], 1)?;
374 }
375 Ok(img_feature)
376 } else {
377 candle_core::bail!("Unsupported image feature type {}", self.type_feature)
378 }
379 }
380
381 #[allow(non_snake_case)]
382 pub fn forward(
383 &self,
384 input_ids: &Tensor,
385 input_embeds: &Tensor,
386 image_attention_mask: Option<&Tensor>,
387 image_sizes: Option<Vec<(u32, u32)>>,
388 ) -> Result<Tensor> {
389 let input_ids = input_ids.reshape(((), input_ids.dim(D::Minus1)?))?;
390
391 let positions = input_ids.eq(IMAGE_SPECIAL_TOKEN_ID)?.nonzero()?;
392
393 let target_dev = self.layers.0[0].device();
394 let target_dtype = self.layers.0[0].dtype();
395
396 let mut select = false;
397 let mut image_set_tensor = None;
398 if positions.dim(0)? > 0 {
399 select = true;
400
401 if self.use_hd_transform && image_sizes.is_some() {
402 assert_eq!(input_embeds.dims().len(), 5);
403 let bs = input_embeds.dim(0)?;
404 let img_features = match image_attention_mask {
405 Some(attn_mask) => self.get_image_features(
406 &input_embeds.flatten(0, 1)?,
407 Some(&attn_mask.ne(0.)?.flatten(0, 1)?),
408 )?,
409 None => self.get_image_features(input_embeds, None)?,
410 };
411
412 let base_feat_height_target = self.base_feat_height_target.unwrap();
413 let base_resolution = self.crop_size;
414 let base_feat_height_reduction = self.base_feat_height_reduction;
415
416 let base_feat_height = (img_features.dim(1)? as f64).sqrt() as usize;
417 let base_feat_width = base_feat_height;
418
419 assert_eq!(base_feat_height, base_feat_height_target);
420 assert_eq!(base_feat_width, base_feat_height_target);
421
422 let img_features = img_features.reshape((
423 bs,
424 (),
425 base_feat_height * base_feat_width,
426 self.image_dim_out,
427 ))?;
428 let C = self.image_dim_out;
429 let H = base_feat_height;
430
431 let mut output_imgs = Vec::new();
432 for bs_ in 0..bs {
433 let (h, w) = image_sizes.as_ref().unwrap()[bs_];
434 let h = h as usize / base_resolution;
435 let w = w as usize / base_resolution;
436 let B_ = h * w;
437
438 let global_img_feature = img_features.i((bs_, ..1))?;
440
441 let glb_img = global_img_feature
443 .reshape((1, H, H, C))?
444 .reshape((
445 1,
446 H / base_feat_height_reduction,
447 base_feat_height_reduction,
448 H / base_feat_height_reduction,
449 base_feat_height_reduction,
450 C,
451 ))?
452 .contiguous()?
453 .permute((0, 1, 3, 2, 4, 5))?
454 .reshape((
455 1,
456 H / base_feat_height_reduction,
457 H / base_feat_height_reduction,
458 base_feat_height_reduction * base_feat_height_reduction * C,
459 ))?
460 .contiguous()?;
461 let temp_glbl_gn = self
462 .sub_gn
463 .as_ref()
464 .expect("Need `sub_gn` if `use_hd_transform`")
465 .repeat((1, H / base_feat_height_reduction, 1, 1))?;
466
467 let glb_img = Tensor::cat(&[glb_img, temp_glbl_gn], 2)?.reshape((
469 1,
470 (),
471 base_feat_height_reduction * base_feat_height_reduction * C,
472 ))?;
473
474 let mut sub_img = img_features.i((bs_, 1..))?;
476
477 sub_img = sub_img.i(..B_)?;
480
481 sub_img = sub_img
483 .reshape((B_, H, H, C))?
484 .reshape((
485 B_,
486 H / base_feat_height_reduction,
487 base_feat_height_reduction,
488 H / base_feat_height_reduction,
489 base_feat_height_reduction,
490 C,
491 ))?
492 .contiguous()?
493 .permute((0, 1, 3, 2, 4, 5))?
494 .reshape((
495 B_,
496 (),
497 base_feat_height_reduction * base_feat_height_reduction * C,
498 ))?
499 .contiguous()?;
500 sub_img = sub_img
501 .reshape(BigShapeWithOneHole((
502 1usize,
503 h,
504 w,
505 base_feat_height / base_feat_height_reduction,
506 base_feat_width / base_feat_height_reduction,
507 (),
508 )))?
509 .permute((0, 1, 3, 2, 4, 5))?
510 .reshape((
511 1,
512 h * base_feat_height / base_feat_height_reduction,
513 w * base_feat_width / base_feat_height_reduction,
514 base_feat_height_reduction * base_feat_height_reduction * C,
515 ))?;
516
517 let (temp_sub_GN, temp_len) = if let Some(image_attention_mask) =
518 image_attention_mask
519 {
520 let h_indices = Tensor::arange_step(
521 0,
522 image_attention_mask.dim(2)? as u32,
523 2,
524 &target_dev,
525 )?;
526 let w_indices = Tensor::arange_step(
527 0,
528 image_attention_mask.dim(3)? as u32,
529 2,
530 &target_dev,
531 )?;
532
533 let reshaped_image_attention_mask = {
534 let mut selected = image_attention_mask.i((bs_, 1..B_ + 1))?;
535 selected = selected.index_select(&h_indices, 1)?;
536 selected = selected.index_select(&w_indices, 2)?;
537 selected
538 .reshape((
539 1,
540 h,
541 w,
542 base_feat_height / base_feat_height_reduction,
543 base_feat_width / base_feat_height_reduction,
544 ))?
545 .permute((0, 1, 3, 2, 4))?
546 .reshape((
547 1,
548 h * base_feat_height / base_feat_height_reduction,
549 w * base_feat_width / base_feat_height_reduction,
550 ))?
551 };
552
553 let useful_height = reshaped_image_attention_mask
554 .i((0, .., 0))?
555 .sum_all()?
556 .to_scalar::<u32>()?;
557 let useful_width = reshaped_image_attention_mask
558 .i((0, 0, ..))?
559 .sum_all()?
560 .to_scalar::<u32>()?;
561
562 sub_img =
563 sub_img.i((.., ..useful_height as usize, ..useful_width as usize))?;
564
565 let temp_len = {
566 let mut selected = image_attention_mask.i((bs_, ..B_ + 1))?;
567 selected = selected.index_select(&h_indices, 1)?;
568 selected = selected.index_select(&w_indices, 2)?;
569 selected.sum_all()?.to_scalar::<u32>()?
570 };
571 let temp_len = temp_len as usize
572 + useful_height as usize
573 + 1
574 + base_feat_height / base_feat_height_reduction;
575
576 (
577 self.sub_gn
578 .as_ref()
579 .expect("Need `sub_gn` if `use_hd_transform`")
580 .repeat((1, useful_height as usize, 1, 1))?,
581 temp_len,
582 )
583 } else {
584 let temp_len = (h * w + 1) * self.num_img_tokens
585 + 1
586 + (h + 1) * base_feat_height / base_feat_height_reduction;
587
588 (
589 self.sub_gn
590 .as_ref()
591 .expect("Need `sub_gn` if `use_hd_transform`")
592 .repeat((
593 1,
594 h * base_feat_height / base_feat_height_reduction,
595 1,
596 1,
597 ))?,
598 temp_len,
599 )
600 };
601
602 let sub_img = Tensor::cat(&[sub_img, temp_sub_GN], 2)?.reshape((
603 1,
604 (),
605 base_feat_height_reduction * base_feat_height_reduction * C,
606 ))?;
607
608 match self.hd_transform_order.as_str() {
612 "glb_sub" => {
613 output_imgs.push(Tensor::cat(
614 &[
615 glb_img,
616 self.glb_gn
617 .as_ref()
618 .expect("Need `glb_gn` if `use_hd_transform`")
619 .clone(),
620 sub_img,
621 ],
622 1,
623 )?);
624 }
625 "sub_glb" => {
626 output_imgs.push(Tensor::cat(
627 &[
628 sub_img,
629 self.glb_gn
630 .as_ref()
631 .expect("Need `glb_gn` if `use_hd_transform`")
632 .clone(),
633 glb_img,
634 ],
635 1,
636 )?);
637 }
638 other => {
639 candle_core::bail!("Invalid hd_transform_order=`{other}`");
640 }
641 }
642
643 assert_eq!(temp_len, output_imgs.last().unwrap().dims()[1]);
645 }
646
647 let mut image_set_tensor_inner = Vec::new();
648 for img in output_imgs {
649 let layerout = self
650 .layers
651 .forward(&img.to_device(&target_dev)?.to_dtype(target_dtype)?)?;
652 image_set_tensor_inner.push(layerout);
653 }
654 image_set_tensor = Some(image_set_tensor_inner);
655 } else {
656 unreachable!()
657 }
658 }
659
660 let mut hidden_states = self.wte.forward(&input_ids)?;
661 if select && self.use_hd_transform {
662 match image_set_tensor {
663 Some(image_set_tensors) => {
664 let merged_img_set_tensor = Tensor::cat(&image_set_tensors, 1)?.squeeze(0)?;
665
666 let original_shape = hidden_states.shape().clone();
667 let (hs_b, hs_l, hs_d) = hidden_states.dims3()?;
668 let mut hidden_states_flat = hidden_states.reshape(((), hs_d))?;
669
670 let positions_transposed = positions.to_dtype(DType::F32)?;
672 let positions_transposed_0 = positions_transposed.i((.., 0))?;
673 let positions_transposed_1 = positions_transposed.i((.., 1))?;
674
675 let mut linear_index = ((positions_transposed_0 * (hs_l * hs_b) as f64)?
676 + positions_transposed_1)?;
677 linear_index = linear_index.to_dtype(DType::U32)?;
678 linear_index = linear_index.unsqueeze(1)?.repeat((1, hs_d))?;
679
680 let current_vals = hidden_states_flat.gather(&linear_index, 0)?;
681 let delta = merged_img_set_tensor.broadcast_sub(¤t_vals)?;
682
683 hidden_states_flat =
684 hidden_states_flat.scatter_add(&linear_index, &delta, 0)?;
685
686 hidden_states = hidden_states_flat.reshape(original_shape)?;
687 }
688 _ => unreachable!(),
689 }
690 }
691
692 Ok(hidden_states)
693 }
694
695 pub fn residual_tensors(&self) -> Vec<(String, Tensor)> {
696 let uvb = UnVarBuilder::new();
697
698 if let Some(glb_gn) = self.glb_gn.clone() {
699 uvb.add_tensor("glb_GN", glb_gn);
700 }
701 if let Some(sub_gn) = self.sub_gn.clone() {
702 uvb.add_tensor("sub_GN", sub_gn);
703 }
704 uvb.extend(self.tensors.clone());
705 uvb.pp("img_processor.vision_model")
706 .extend(self.image_processor.residual_tensors());
707
708 uvb.to_safetensors()
709 }
710}