mistralrs_core/vision_models/phi4/
image_embedding.rs

1use std::{
2    fmt::Debug,
3    sync::{Arc, LazyLock},
4};
5
6use candle_core::{shape::ShapeWithOneHole, DType, Device, IndexOp, Result, Shape, Tensor, D};
7use candle_nn::Module;
8use mistralrs_quant::{NonZeroOp, QuantMethod, ShardedVarBuilder};
9
10use crate::{
11    layers::{AvgPool2d, ReflectionPad2d},
12    utils::unvarbuilder::UnVarBuilder,
13    vision_models::{
14        phi4::config::Phi4MMImgProcessorConfig,
15        siglip::{SiglipVisionConfig, SiglipVisionTransformer},
16    },
17};
18
19use super::{config::Phi4MMImageEmbedConfig, Phi4MMConfig};
20
21pub(super) const IMAGE_SPECIAL_TOKEN_ID: f64 = 200010.;
22
23trait ModuleWithMetadata: Module + Debug + Send + Sync {
24    fn device(&self) -> Device;
25    fn dtype(&self) -> DType;
26}
27
28#[derive(Debug)]
29struct QuantMethodWrapper(Arc<dyn QuantMethod>);
30
31impl Module for QuantMethodWrapper {
32    fn forward(&self, xs: &Tensor) -> Result<Tensor> {
33        self.0.forward(xs)
34    }
35}
36
37impl ModuleWithMetadata for QuantMethodWrapper {
38    fn device(&self) -> Device {
39        self.0.unquant_weight_bias().unwrap().0.device().clone()
40    }
41    fn dtype(&self) -> DType {
42        self.0.unquant_weight_bias().unwrap().0.dtype()
43    }
44}
45
46impl ModuleWithMetadata for candle_nn::Activation {
47    fn device(&self) -> Device {
48        unreachable!()
49    }
50    fn dtype(&self) -> DType {
51        unreachable!()
52    }
53}
54
55#[derive(Debug)]
56struct BigShapeWithOneHole((usize, usize, usize, usize, usize, ()));
57
58fn hole_size(el_count: usize, prod_d: usize, s: &dyn std::fmt::Debug) -> Result<usize> {
59    if prod_d == 0 {
60        candle_core::bail!("cannot reshape tensor of {el_count} elements to {s:?}")
61    }
62    if el_count % prod_d != 0 {
63        candle_core::bail!("cannot reshape tensor with {el_count} elements to {s:?}")
64    }
65    Ok(el_count / prod_d)
66}
67
68impl ShapeWithOneHole for BigShapeWithOneHole {
69    fn into_shape(self, el_count: usize) -> Result<Shape> {
70        let (d1, d2, d3, d4, d5, ()) = self.0;
71        let d = hole_size(el_count, d1 * d2 * d3 * d4 * d5, &self)?;
72        Ok((d1, d2, d3, d4, d5, d).into())
73    }
74}
75
76#[derive(Debug)]
77struct EmbeddingLayers(Vec<Box<dyn ModuleWithMetadata>>);
78
79impl Module for EmbeddingLayers {
80    fn forward(&self, xs: &Tensor) -> Result<Tensor> {
81        let mut xs = xs.clone();
82        for layer in &self.0 {
83            xs = layer.forward(&xs)?;
84        }
85        Ok(xs)
86    }
87}
88
89pub(crate) static PHI4_MM_VISION_CFG: LazyLock<SiglipVisionConfig> =
90    LazyLock::new(|| SiglipVisionConfig {
91        hidden_size: 1152,
92        image_size: 448,
93        intermediate_size: 4304,
94        num_attention_heads: 16,
95        num_hidden_layers: 27,
96        patch_size: 14,
97        ..Default::default()
98    });
99
100pub struct ImageEmbedding {
101    wte: candle_nn::Embedding,
102    image_dim_out: usize,
103    num_img_tokens: usize,
104    glb_gn: Option<Tensor>,
105    sub_gn: Option<Tensor>,
106    layers: EmbeddingLayers,
107    type_feature: String,
108    layer_idx: isize,
109    image_processor: SiglipVisionTransformer,
110    hd_transform_order: String,
111    use_hd_transform: bool,
112    tensors: Vec<(String, Tensor)>,
113    img_processor_padding: Option<ReflectionPad2d>,
114    crop_size: usize,
115    image_token_compression: Option<AvgPool2d>,
116    base_feat_height_reduction: usize,
117    base_feat_height_target: Option<usize>,
118}
119
120impl ImageEmbedding {
121    pub fn new(
122        cfg: &Phi4MMConfig,
123        img_embd_config: &Phi4MMImageEmbedConfig,
124        wte: candle_nn::Embedding,
125        vb: ShardedVarBuilder,
126    ) -> Result<Self> {
127        let hidden_size = img_embd_config.n_embd.unwrap_or(cfg.hidden_size);
128
129        let siglip_vision_config = &PHI4_MM_VISION_CFG;
130        let image_processor =
131            SiglipVisionTransformer::new(siglip_vision_config, vb.pp("img_processor"))?;
132
133        let pe_weight = image_processor.embeddings.position_embedding.embeddings();
134        let (l, d) = pe_weight.dims2()?;
135        let mut m = (l as f64).sqrt() as usize;
136        assert_eq!(m.pow(2), l);
137        let img_processor_padding = if m % 2 != 0 {
138            m += 1;
139            Some(ReflectionPad2d::new((0, 1, 0, 1)))
140        } else {
141            None
142        };
143        let image_dim_out = d;
144        let num_img_tokens = (m / 2).pow(2);
145        let base_feat_height_target = m;
146
147        // High dim transform
148        let use_hd_transform = img_embd_config.use_hd_transform.unwrap_or(false);
149        let with_learnable_separator = img_embd_config.with_learnable_separator.unwrap_or(false);
150        let hd_transform_order = img_embd_config
151            .hd_transform_order
152            .clone()
153            .unwrap_or("glb_sub".to_string());
154        let crop_size = img_embd_config.crop_size.unwrap_or(336);
155
156        let (image_token_compression, base_feat_height_reduction, base_feat_height_target) =
157            match &img_embd_config.image_token_compression_cls {
158                Some(x) if x == "avg_pool_2d" => (
159                    Some(AvgPool2d::new(2, 2)),
160                    1_usize,
161                    Some(base_feat_height_target / 2),
162                ),
163                None => (None, 2_usize, None),
164                _ => candle_core::bail!("Unexpected image_token_compression_cls"),
165            };
166
167        assert_eq!(use_hd_transform, with_learnable_separator);
168        let (glb_gn, sub_gn) = if with_learnable_separator {
169            let glb_gn = vb.get(
170                (1, 1, image_dim_out * base_feat_height_reduction.pow(2)),
171                "glb_GN",
172            )?;
173            let sub_gn = vb.get(
174                (1, 1, 1, image_dim_out * base_feat_height_reduction.pow(2)),
175                "sub_GN",
176            )?;
177            (Some(glb_gn), Some(sub_gn))
178        } else {
179            (None, None)
180        };
181
182        // Inner projection
183        let projection_cls = img_embd_config
184            .projection_cls
185            .clone()
186            .unwrap_or("linear".to_string());
187
188        let mut tensors = Vec::new();
189        let layers: Vec<Box<dyn ModuleWithMetadata>> =
190            match (projection_cls.as_str(), use_hd_transform) {
191                ("linear", _) => {
192                    let a = mistralrs_quant::linear_b(
193                        image_dim_out,
194                        hidden_size,
195                        true,
196                        &None,
197                        vb.pp("img_projection"),
198                    )?;
199                    let (a_w, a_b) = a.unquant_weight_bias().unwrap();
200                    tensors.push(("img_projection.weight".to_string(), a_w));
201                    if let Some(b) = a_b {
202                        tensors.push(("img_projection.bias".to_string(), b));
203                    }
204                    vec![Box::new(QuantMethodWrapper(a))]
205                }
206                ("mlp", true) => {
207                    let dim_proj = hidden_size;
208                    let a = mistralrs_quant::linear_b(
209                        image_dim_out * base_feat_height_reduction.pow(2),
210                        dim_proj,
211                        true,
212                        &None,
213                        vb.pp("img_projection.0"),
214                    )?;
215                    let (a_w, a_b) = a.unquant_weight_bias().unwrap();
216                    tensors.push(("img_projection.0.weight".to_string(), a_w));
217                    if let Some(b) = a_b {
218                        tensors.push(("img_projection.0.bias".to_string(), b));
219                    }
220                    let b = mistralrs_quant::linear_b(
221                        dim_proj,
222                        dim_proj,
223                        true,
224                        &None,
225                        vb.pp("img_projection.2"),
226                    )?;
227                    let (b_w, b_b) = b.unquant_weight_bias().unwrap();
228                    tensors.push(("img_projection.2.weight".to_string(), b_w));
229                    if let Some(b) = b_b {
230                        tensors.push(("img_projection.2.bias".to_string(), b));
231                    }
232                    vec![
233                        Box::new(QuantMethodWrapper(a)),
234                        Box::new(candle_nn::Activation::Gelu),
235                        Box::new(QuantMethodWrapper(b)),
236                    ]
237                }
238                ("mlp", false) => {
239                    let dim_proj = hidden_size;
240                    let a = mistralrs_quant::linear_b(
241                        image_dim_out,
242                        dim_proj,
243                        true,
244                        &None,
245                        vb.pp("img_projection.0"),
246                    )?;
247                    let (a_w, a_b) = a.unquant_weight_bias().unwrap();
248                    tensors.push(("img_projection.0.weight".to_string(), a_w));
249                    if let Some(b) = a_b {
250                        tensors.push(("img_projection.0.bias".to_string(), b));
251                    }
252                    let b = mistralrs_quant::linear_b(
253                        dim_proj,
254                        dim_proj,
255                        true,
256                        &None,
257                        vb.pp("img_projection.2"),
258                    )?;
259                    let (b_w, b_b) = b.unquant_weight_bias().unwrap();
260                    tensors.push(("img_projection.2.weight".to_string(), b_w));
261                    if let Some(b) = b_b {
262                        tensors.push(("img_projection.2.bias".to_string(), b));
263                    }
264                    vec![
265                        Box::new(QuantMethodWrapper(a)),
266                        Box::new(candle_nn::Activation::Gelu),
267                        Box::new(QuantMethodWrapper(b)),
268                    ]
269                }
270                _ => {
271                    candle_core::bail!("projection_cls=`{projection_cls}` not implemented.");
272                }
273            };
274
275        let (layer_idx, type_feature) = match &cfg.img_processor {
276            Some(Phi4MMImgProcessorConfig {
277                layer_idx,
278                type_feature,
279            }) => (
280                layer_idx.unwrap_or(-2),
281                type_feature.clone().unwrap_or("patch".to_string()),
282            ),
283
284            None => (-2, "patch".to_string()),
285        };
286
287        Ok(Self {
288            wte,
289            image_dim_out,
290            num_img_tokens,
291            glb_gn,
292            sub_gn,
293            layer_idx,
294            type_feature,
295            image_processor,
296            layers: EmbeddingLayers(layers),
297            hd_transform_order,
298            use_hd_transform,
299            tensors,
300            img_processor_padding,
301            crop_size,
302            image_token_compression,
303            base_feat_height_reduction,
304            base_feat_height_target,
305        })
306    }
307
308    fn get_image_features(
309        &self,
310        img_embeds: &Tensor,
311        attention_mask: Option<&Tensor>,
312    ) -> Result<Tensor> {
313        assert!(self.layer_idx < 0);
314        let img_feature = self.image_processor.forward_get_hidden_states(
315            &img_embeds.to_dtype(self.image_processor.dtype())?,
316            attention_mask,
317            None,
318            self.layer_idx,
319        )?;
320
321        if self.type_feature == "patch" {
322            let mut patch_feature = img_feature.clone();
323            if let Some(image_token_compression) = &self.image_token_compression {
324                // reshape to 2D tensor
325                let width = (patch_feature.dim(1)? as f64).sqrt() as usize;
326                patch_feature =
327                    patch_feature.reshape(((), width, width, patch_feature.dim(D::Minus1)?))?;
328                // Convert to NCHW
329                patch_feature = patch_feature.permute((0, 3, 1, 2))?;
330                if let Some(img_processor_padding) = &self.img_processor_padding {
331                    patch_feature = patch_feature.apply(img_processor_padding)?;
332                }
333                patch_feature = image_token_compression.forward(&patch_feature)?;
334                // Convert to NHWC
335                patch_feature = patch_feature.permute((0, 2, 3, 1))?;
336                patch_feature = patch_feature.reshape((
337                    (),
338                    patch_feature.dim(1)? * patch_feature.dim(2)?,
339                    patch_feature.dim(D::Minus1)?,
340                ))?;
341            } else if let Some(img_processor_padding) = &self.img_processor_padding {
342                // reshape to 2D tensor
343                let width = (patch_feature.dim(1)? as f64).sqrt() as usize;
344                patch_feature =
345                    patch_feature.reshape(((), width, width, patch_feature.dim(D::Minus1)?))?;
346                // Convert to NCHW
347                patch_feature = patch_feature.permute((0, 3, 1, 2))?;
348                patch_feature = patch_feature.apply(img_processor_padding)?;
349                // Convert to NHWC
350                patch_feature = patch_feature.permute((0, 2, 3, 1))?;
351                patch_feature = patch_feature.reshape((
352                    (),
353                    patch_feature.dim(1)? * patch_feature.dim(2)?,
354                    patch_feature.dim(D::Minus1)?,
355                ))?;
356            };
357            Ok(patch_feature)
358        } else if self.type_feature == "cls_patch" {
359            let mut img_feature = img_feature.clone();
360            if let Some(image_token_compression) = &self.image_token_compression {
361                // reshape to 2D tensor
362                let mut patch_feature = img_feature.i((.., 1..))?;
363                let cls_feature = img_feature.i((.., 0))?;
364                let width = (patch_feature.dim(1)? as f64).sqrt() as usize;
365                patch_feature =
366                    patch_feature.reshape(((), width, width, patch_feature.dim(D::Minus1)?))?;
367                patch_feature = image_token_compression.forward(&patch_feature)?;
368                patch_feature = patch_feature.reshape((
369                    (),
370                    patch_feature.dim(D::Minus2)? * patch_feature.dim(D::Minus1)?,
371                ))?;
372                img_feature = Tensor::cat(&[cls_feature, patch_feature], 1)?;
373            }
374            Ok(img_feature)
375        } else {
376            candle_core::bail!("Unsupported image feature type {}", self.type_feature)
377        }
378    }
379
380    #[allow(non_snake_case)]
381    pub fn forward(
382        &self,
383        input_ids: &Tensor,
384        input_embeds: &Tensor,
385        image_attention_mask: Option<&Tensor>,
386        image_sizes: Option<Vec<(u32, u32)>>,
387    ) -> Result<Tensor> {
388        let input_ids = input_ids.reshape(((), input_ids.dim(D::Minus1)?))?;
389
390        let positions = input_ids.eq(IMAGE_SPECIAL_TOKEN_ID)?.nonzero()?;
391
392        let target_dev = self.layers.0[0].device();
393        let target_dtype = self.layers.0[0].dtype();
394
395        let mut select = false;
396        let mut image_set_tensor = None;
397        if positions.dim(0)? > 0 {
398            select = true;
399
400            if self.use_hd_transform && image_sizes.is_some() {
401                assert_eq!(input_embeds.dims().len(), 5);
402                let bs = input_embeds.dim(0)?;
403                let img_features = match image_attention_mask {
404                    Some(attn_mask) => self.get_image_features(
405                        &input_embeds.flatten(0, 1)?,
406                        Some(&attn_mask.ne(0.)?.flatten(0, 1)?),
407                    )?,
408                    None => self.get_image_features(input_embeds, None)?,
409                };
410
411                let base_feat_height_target = self.base_feat_height_target.unwrap();
412                let base_resolution = self.crop_size;
413                let base_feat_height_reduction = self.base_feat_height_reduction;
414
415                let base_feat_height = (img_features.dim(1)? as f64).sqrt() as usize;
416                let base_feat_width = base_feat_height;
417
418                assert_eq!(base_feat_height, base_feat_height_target);
419                assert_eq!(base_feat_width, base_feat_height_target);
420
421                let img_features = img_features.reshape((
422                    bs,
423                    (),
424                    base_feat_height * base_feat_width,
425                    self.image_dim_out,
426                ))?;
427                let C = self.image_dim_out;
428                let H = base_feat_height;
429
430                let mut output_imgs = Vec::new();
431                for bs_ in 0..bs {
432                    let (h, w) = image_sizes.as_ref().unwrap()[bs_];
433                    let h = h as usize / base_resolution;
434                    let w = w as usize / base_resolution;
435                    let B_ = h * w;
436
437                    // 1 x (24x24) x 1024
438                    let global_img_feature = img_features.i((bs_, ..1))?;
439
440                    // 1 x 12 x 12 x 4096
441                    let glb_img = global_img_feature
442                        .reshape((1, H, H, C))?
443                        .reshape((
444                            1,
445                            H / base_feat_height_reduction,
446                            base_feat_height_reduction,
447                            H / base_feat_height_reduction,
448                            base_feat_height_reduction,
449                            C,
450                        ))?
451                        .contiguous()?
452                        .permute((0, 1, 3, 2, 4, 5))?
453                        .reshape((
454                            1,
455                            H / base_feat_height_reduction,
456                            H / base_feat_height_reduction,
457                            base_feat_height_reduction * base_feat_height_reduction * C,
458                        ))?
459                        .contiguous()?;
460                    let temp_glbl_gn = self
461                        .sub_gn
462                        .as_ref()
463                        .expect("Need `sub_gn` if `use_hd_transform`")
464                        .repeat((1, H / base_feat_height_reduction, 1, 1))?;
465
466                    // 1 x 156 x 4096
467                    let glb_img = Tensor::cat(&[glb_img, temp_glbl_gn], 2)?.reshape((
468                        1,
469                        (),
470                        base_feat_height_reduction * base_feat_height_reduction * C,
471                    ))?;
472
473                    // (max_num_crops-1) x (12x12) x C
474                    let mut sub_img = img_features.i((bs_, 1..))?;
475
476                    // 16x574x1024
477                    // Get rid of padding sub_img
478                    sub_img = sub_img.i(..B_)?;
479
480                    // (num_crops, 12, 2, 12, 2, 1024) -> (num_crops, 12, 12, 2, 2, 1024) -> (num_crops, 12*12, 4*1024)
481                    sub_img = sub_img
482                        .reshape((B_, H, H, C))?
483                        .reshape((
484                            B_,
485                            H / base_feat_height_reduction,
486                            base_feat_height_reduction,
487                            H / base_feat_height_reduction,
488                            base_feat_height_reduction,
489                            C,
490                        ))?
491                        .contiguous()?
492                        .permute((0, 1, 3, 2, 4, 5))?
493                        .reshape((
494                            B_,
495                            (),
496                            base_feat_height_reduction * base_feat_height_reduction * C,
497                        ))?
498                        .contiguous()?;
499                    sub_img = sub_img
500                        .reshape(BigShapeWithOneHole((
501                            1usize,
502                            h,
503                            w,
504                            base_feat_height / base_feat_height_reduction,
505                            base_feat_width / base_feat_height_reduction,
506                            (),
507                        )))?
508                        .permute((0, 1, 3, 2, 4, 5))?
509                        .reshape((
510                            1,
511                            h * base_feat_height / base_feat_height_reduction,
512                            w * base_feat_width / base_feat_height_reduction,
513                            base_feat_height_reduction * base_feat_height_reduction * C,
514                        ))?;
515
516                    let (temp_sub_GN, temp_len) = if let Some(image_attention_mask) =
517                        image_attention_mask
518                    {
519                        let h_indices = Tensor::arange_step(
520                            0,
521                            image_attention_mask.dim(2)? as u32,
522                            2,
523                            &target_dev,
524                        )?;
525                        let w_indices = Tensor::arange_step(
526                            0,
527                            image_attention_mask.dim(3)? as u32,
528                            2,
529                            &target_dev,
530                        )?;
531
532                        let reshaped_image_attention_mask = {
533                            let mut selected = image_attention_mask.i((bs_, 1..B_ + 1))?;
534                            selected = selected.index_select(&h_indices, 1)?;
535                            selected = selected.index_select(&w_indices, 2)?;
536                            selected
537                                .reshape((
538                                    1,
539                                    h,
540                                    w,
541                                    base_feat_height / base_feat_height_reduction,
542                                    base_feat_width / base_feat_height_reduction,
543                                ))?
544                                .permute((0, 1, 3, 2, 4))?
545                                .reshape((
546                                    1,
547                                    h * base_feat_height / base_feat_height_reduction,
548                                    w * base_feat_width / base_feat_height_reduction,
549                                ))?
550                        };
551
552                        let useful_height = reshaped_image_attention_mask
553                            .i((0, .., 0))?
554                            .sum_all()?
555                            .to_scalar::<u32>()?;
556                        let useful_width = reshaped_image_attention_mask
557                            .i((0, 0, ..))?
558                            .sum_all()?
559                            .to_scalar::<u32>()?;
560
561                        sub_img =
562                            sub_img.i((.., ..useful_height as usize, ..useful_width as usize))?;
563
564                        let temp_len = {
565                            let mut selected = image_attention_mask.i((bs_, ..B_ + 1))?;
566                            selected = selected.index_select(&h_indices, 1)?;
567                            selected = selected.index_select(&w_indices, 2)?;
568                            selected.sum_all()?.to_scalar::<u32>()?
569                        };
570                        let temp_len = temp_len as usize
571                            + useful_height as usize
572                            + 1
573                            + base_feat_height / base_feat_height_reduction;
574
575                        (
576                            self.sub_gn
577                                .as_ref()
578                                .expect("Need `sub_gn` if `use_hd_transform`")
579                                .repeat((1, useful_height as usize, 1, 1))?,
580                            temp_len,
581                        )
582                    } else {
583                        let temp_len = (h * w + 1) * self.num_img_tokens
584                            + 1
585                            + (h + 1) * base_feat_height / base_feat_height_reduction;
586
587                        (
588                            self.sub_gn
589                                .as_ref()
590                                .expect("Need `sub_gn` if `use_hd_transform`")
591                                .repeat((
592                                    1,
593                                    h * base_feat_height / base_feat_height_reduction,
594                                    1,
595                                    1,
596                                ))?,
597                            temp_len,
598                        )
599                    };
600
601                    let sub_img = Tensor::cat(&[sub_img, temp_sub_GN], 2)?.reshape((
602                        1,
603                        (),
604                        base_feat_height_reduction * base_feat_height_reduction * C,
605                    ))?;
606
607                    // (1, num_img_tokens, 1024*4)
608
609                    // glb + sub
610                    match self.hd_transform_order.as_str() {
611                        "glb_sub" => {
612                            output_imgs.push(Tensor::cat(
613                                &[
614                                    glb_img,
615                                    self.glb_gn
616                                        .as_ref()
617                                        .expect("Need `glb_gn` if `use_hd_transform`")
618                                        .clone(),
619                                    sub_img,
620                                ],
621                                1,
622                            )?);
623                        }
624                        "sub_glb" => {
625                            output_imgs.push(Tensor::cat(
626                                &[
627                                    sub_img,
628                                    self.glb_gn
629                                        .as_ref()
630                                        .expect("Need `glb_gn` if `use_hd_transform`")
631                                        .clone(),
632                                    glb_img,
633                                ],
634                                1,
635                            )?);
636                        }
637                        other => {
638                            candle_core::bail!("Invalid hd_transform_order=`{other}`");
639                        }
640                    }
641
642                    // (h*w+1)*144 + 1 + (h+1)*12
643                    assert_eq!(temp_len, output_imgs.last().unwrap().dims()[1]);
644                }
645
646                let mut image_set_tensor_inner = Vec::new();
647                for img in output_imgs {
648                    let layerout = self
649                        .layers
650                        .forward(&img.to_device(&target_dev)?.to_dtype(target_dtype)?)?;
651                    image_set_tensor_inner.push(layerout);
652                }
653                image_set_tensor = Some(image_set_tensor_inner);
654            } else {
655                unreachable!()
656            }
657        }
658
659        let mut hidden_states = self.wte.forward(&input_ids)?;
660        if select && self.use_hd_transform {
661            match image_set_tensor {
662                Some(image_set_tensors) => {
663                    let merged_img_set_tensor = Tensor::cat(&image_set_tensors, 1)?.squeeze(0)?;
664
665                    let original_shape = hidden_states.shape().clone();
666                    let (hs_b, hs_l, hs_d) = hidden_states.dims3()?;
667                    let mut hidden_states_flat = hidden_states.reshape(((), hs_d))?;
668
669                    // Get the equiv 0th and 1th rows of the positions_tuple
670                    let positions_transposed = positions.to_dtype(DType::F32)?;
671                    let positions_transposed_0 = positions_transposed.i((.., 0))?;
672                    let positions_transposed_1 = positions_transposed.i((.., 1))?;
673
674                    let mut linear_index = ((positions_transposed_0 * (hs_l * hs_b) as f64)?
675                        + positions_transposed_1)?;
676                    linear_index = linear_index.to_dtype(DType::U32)?;
677                    linear_index = linear_index.unsqueeze(1)?.repeat((1, hs_d))?;
678
679                    let current_vals = hidden_states_flat.gather(&linear_index, 0)?;
680                    let delta = merged_img_set_tensor.broadcast_sub(&current_vals)?;
681
682                    hidden_states_flat =
683                        hidden_states_flat.scatter_add(&linear_index, &delta, 0)?;
684
685                    hidden_states = hidden_states_flat.reshape(original_shape)?;
686                }
687                _ => unreachable!(),
688            }
689        }
690
691        Ok(hidden_states)
692    }
693
694    pub fn residual_tensors(&self) -> Vec<(String, Tensor)> {
695        let uvb = UnVarBuilder::new();
696
697        if let Some(glb_gn) = self.glb_gn.clone() {
698            uvb.add_tensor("glb_GN", glb_gn);
699        }
700        if let Some(sub_gn) = self.sub_gn.clone() {
701            uvb.add_tensor("sub_GN", sub_gn);
702        }
703        uvb.extend(self.tensors.clone());
704        uvb.pp("img_processor.vision_model")
705            .extend(self.image_processor.residual_tensors());
706
707        uvb.to_safetensors()
708    }
709}