mistralrs_core/vision_models/phi4/
image_embedding.rs

1use std::{
2    fmt::Debug,
3    sync::{Arc, LazyLock},
4};
5
6use candle_core::{shape::ShapeWithOneHole, DType, Device, IndexOp, Result, Shape, Tensor, D};
7use candle_nn::Module;
8use mistralrs_quant::{NonZeroOp, QuantMethod, ShardedVarBuilder};
9
10use crate::{
11    layers::{AvgPool2d, ReflectionPad2d},
12    utils::unvarbuilder::UnVarBuilder,
13    vision_models::{
14        phi4::config::Phi4MMImgProcessorConfig,
15        siglip::{SiglipVisionConfig, SiglipVisionTransformer},
16    },
17};
18
19use super::{config::Phi4MMImageEmbedConfig, Phi4MMConfig};
20
21pub(super) const IMAGE_SPECIAL_TOKEN_ID: f64 = 200010.;
22
23trait ModuleWithMetadata: Module + Debug + Send + Sync {
24    fn device(&self) -> Device;
25    fn dtype(&self) -> DType;
26}
27
28#[derive(Debug)]
29struct QuantMethodWrapper(Arc<dyn QuantMethod>);
30
31impl Module for QuantMethodWrapper {
32    fn forward(&self, xs: &Tensor) -> Result<Tensor> {
33        self.0.forward(xs)
34    }
35}
36
37impl ModuleWithMetadata for QuantMethodWrapper {
38    fn device(&self) -> Device {
39        self.0.unquant_weight_bias().unwrap().0.device().clone()
40    }
41    fn dtype(&self) -> DType {
42        self.0.unquant_weight_bias().unwrap().0.dtype()
43    }
44}
45
46impl ModuleWithMetadata for candle_nn::Activation {
47    fn device(&self) -> Device {
48        unreachable!()
49    }
50    fn dtype(&self) -> DType {
51        unreachable!()
52    }
53}
54
55#[derive(Debug)]
56struct BigShapeWithOneHole((usize, usize, usize, usize, usize, ()));
57
58fn hole_size(el_count: usize, prod_d: usize, s: &dyn std::fmt::Debug) -> Result<usize> {
59    if prod_d == 0 {
60        candle_core::bail!("cannot reshape tensor of {el_count} elements to {s:?}")
61    }
62    if el_count % prod_d != 0 {
63        candle_core::bail!("cannot reshape tensor with {el_count} elements to {s:?}")
64    }
65    Ok(el_count / prod_d)
66}
67
68impl ShapeWithOneHole for BigShapeWithOneHole {
69    fn into_shape(self, el_count: usize) -> Result<Shape> {
70        let (d1, d2, d3, d4, d5, ()) = self.0;
71        let d = hole_size(el_count, d1 * d2 * d3 * d4 * d5, &self)?;
72        Ok((d1, d2, d3, d4, d5, d).into())
73    }
74}
75
76#[derive(Debug)]
77struct EmbeddingLayers(Vec<Box<dyn ModuleWithMetadata>>);
78
79impl Module for EmbeddingLayers {
80    fn forward(&self, xs: &Tensor) -> Result<Tensor> {
81        let mut xs = xs.clone();
82        for layer in &self.0 {
83            xs = layer.forward(&xs)?;
84        }
85        Ok(xs)
86    }
87}
88
89pub(crate) static PHI4_MM_VISION_CFG: LazyLock<SiglipVisionConfig> =
90    LazyLock::new(|| SiglipVisionConfig {
91        hidden_size: 1152,
92        image_size: 448,
93        intermediate_size: 4304,
94        num_attention_heads: 16,
95        num_hidden_layers: 27,
96        patch_size: 14,
97        ..Default::default()
98    });
99
100pub struct ImageEmbedding {
101    wte: candle_nn::Embedding,
102    image_dim_out: usize,
103    num_img_tokens: usize,
104    glb_gn: Option<Tensor>,
105    sub_gn: Option<Tensor>,
106    layers: EmbeddingLayers,
107    type_feature: String,
108    layer_idx: isize,
109    image_processor: SiglipVisionTransformer,
110    hd_transform_order: String,
111    use_hd_transform: bool,
112    tensors: Vec<(String, Tensor)>,
113    img_processor_padding: Option<ReflectionPad2d>,
114    crop_size: usize,
115    image_token_compression: Option<AvgPool2d>,
116    base_feat_height_reduction: usize,
117    base_feat_height_target: Option<usize>,
118}
119
120impl ImageEmbedding {
121    pub fn new(
122        cfg: &Phi4MMConfig,
123        img_embd_config: &Phi4MMImageEmbedConfig,
124        wte: candle_nn::Embedding,
125        vb: ShardedVarBuilder,
126    ) -> Result<Self> {
127        let hidden_size = img_embd_config.n_embd.unwrap_or(cfg.hidden_size);
128
129        let siglip_vision_config = &PHI4_MM_VISION_CFG;
130        let image_processor =
131            SiglipVisionTransformer::new(siglip_vision_config, vb.pp("img_processor"))?;
132
133        let pe_weight = image_processor.embeddings.position_embedding.embeddings();
134        let (l, d) = pe_weight.dims2()?;
135        let mut m = (l as f64).sqrt() as usize;
136        assert_eq!(m.pow(2), l);
137        let img_processor_padding = if m % 2 != 0 {
138            m += 1;
139            Some(ReflectionPad2d::new((0, 1, 0, 1)))
140        } else {
141            None
142        };
143        let image_dim_out = d;
144        let num_img_tokens = (m / 2).pow(2);
145        let base_feat_height_target = m;
146
147        // High dim transform
148        let use_hd_transform = img_embd_config.use_hd_transform.unwrap_or(false);
149        let with_learnable_separator = img_embd_config.with_learnable_separator.unwrap_or(false);
150        let hd_transform_order = img_embd_config
151            .hd_transform_order
152            .clone()
153            .unwrap_or("glb_sub".to_string());
154        let crop_size = img_embd_config.crop_size.unwrap_or(336);
155
156        let (image_token_compression, base_feat_height_reduction, base_feat_height_target) =
157            match &img_embd_config.image_token_compression_cls {
158                Some(x) if x == "avg_pool_2d" => (
159                    Some(AvgPool2d::new(2, 2)),
160                    1_usize,
161                    Some(base_feat_height_target / 2),
162                ),
163                None => (None, 2_usize, None),
164                _ => candle_core::bail!("Unexpected image_token_compression_cls"),
165            };
166
167        assert_eq!(use_hd_transform, with_learnable_separator);
168        let (glb_gn, sub_gn) = if with_learnable_separator {
169            let glb_gn = vb.get(
170                (1, 1, image_dim_out * base_feat_height_reduction.pow(2)),
171                "glb_GN",
172            )?;
173            let sub_gn = vb.get(
174                (1, 1, 1, image_dim_out * base_feat_height_reduction.pow(2)),
175                "sub_GN",
176            )?;
177            (Some(glb_gn), Some(sub_gn))
178        } else {
179            (None, None)
180        };
181
182        // Inner projection
183        let projection_cls = img_embd_config
184            .projection_cls
185            .clone()
186            .unwrap_or("linear".to_string());
187
188        let mut tensors = Vec::new();
189        let layers: Vec<Box<dyn ModuleWithMetadata>> =
190            match (projection_cls.as_str(), use_hd_transform) {
191                ("linear", _) => {
192                    let a = mistralrs_quant::linear_b(
193                        image_dim_out,
194                        hidden_size,
195                        true,
196                        &None,
197                        vb.pp("img_projection"),
198                    )?;
199                    let (a_w, a_b) = a.unquant_weight_bias().unwrap();
200                    tensors.push(("img_projection.weight".to_string(), a_w));
201                    if let Some(b) = a_b {
202                        tensors.push(("img_projection.bias".to_string(), b));
203                    }
204                    vec![Box::new(QuantMethodWrapper(a))]
205                }
206                ("mlp", true) => {
207                    let dim_proj = hidden_size;
208                    let a = mistralrs_quant::linear_b(
209                        image_dim_out * base_feat_height_reduction.pow(2),
210                        dim_proj,
211                        true,
212                        &None,
213                        vb.pp("img_projection.0"),
214                    )?;
215                    let (a_w, a_b) = a.unquant_weight_bias().unwrap();
216                    tensors.push(("img_projection.0.weight".to_string(), a_w));
217                    if let Some(b) = a_b {
218                        tensors.push(("img_projection.0.bias".to_string(), b));
219                    }
220                    let b = mistralrs_quant::linear_b(
221                        dim_proj,
222                        dim_proj,
223                        true,
224                        &None,
225                        vb.pp("img_projection.2"),
226                    )?;
227                    let (b_w, b_b) = b.unquant_weight_bias().unwrap();
228                    tensors.push(("img_projection.2.weight".to_string(), b_w));
229                    if let Some(b) = b_b {
230                        tensors.push(("img_projection.2.bias".to_string(), b));
231                    }
232                    vec![
233                        Box::new(QuantMethodWrapper(a)),
234                        Box::new(candle_nn::Activation::Gelu),
235                        Box::new(QuantMethodWrapper(b)),
236                    ]
237                }
238                ("mlp", false) => {
239                    let dim_proj = hidden_size;
240                    let a = mistralrs_quant::linear_b(
241                        image_dim_out,
242                        dim_proj,
243                        true,
244                        &None,
245                        vb.pp("img_projection.0"),
246                    )?;
247                    let (a_w, a_b) = a.unquant_weight_bias().unwrap();
248                    tensors.push(("img_projection.0.weight".to_string(), a_w));
249                    if let Some(b) = a_b {
250                        tensors.push(("img_projection.0.bias".to_string(), b));
251                    }
252                    let b = mistralrs_quant::linear_b(
253                        dim_proj,
254                        dim_proj,
255                        true,
256                        &None,
257                        vb.pp("img_projection.2"),
258                    )?;
259                    let (b_w, b_b) = b.unquant_weight_bias().unwrap();
260                    tensors.push(("img_projection.2.weight".to_string(), b_w));
261                    if let Some(b) = b_b {
262                        tensors.push(("img_projection.2.bias".to_string(), b));
263                    }
264                    vec![
265                        Box::new(QuantMethodWrapper(a)),
266                        Box::new(candle_nn::Activation::Gelu),
267                        Box::new(QuantMethodWrapper(b)),
268                    ]
269                }
270                _ => {
271                    candle_core::bail!("projection_cls=`{projection_cls}` not implemented.");
272                }
273            };
274
275        let (layer_idx, type_feature) = match &cfg.img_processor {
276            Some(Phi4MMImgProcessorConfig {
277                layer_idx,
278                type_feature,
279            }) => (
280                layer_idx.unwrap_or(-2),
281                type_feature.clone().unwrap_or("patch".to_string()),
282            ),
283
284            None => (-2, "patch".to_string()),
285        };
286
287        Ok(Self {
288            wte,
289            image_dim_out,
290            num_img_tokens,
291            glb_gn,
292            sub_gn,
293            layer_idx,
294            type_feature,
295            image_processor,
296            layers: EmbeddingLayers(layers),
297            hd_transform_order,
298            use_hd_transform,
299            tensors,
300            img_processor_padding,
301            crop_size,
302            image_token_compression,
303            base_feat_height_reduction,
304            base_feat_height_target,
305        })
306    }
307
308    fn get_image_features(
309        &self,
310        img_embeds: &Tensor,
311        attention_mask: Option<&Tensor>,
312    ) -> Result<Tensor> {
313        assert!(self.layer_idx < 0);
314        let img_feature = self.image_processor.forward_get_hidden_states(
315            &img_embeds.to_dtype(self.image_processor.dtype())?,
316            attention_mask,
317            None,
318            self.layer_idx,
319        )?;
320
321        if self.type_feature == "patch" {
322            let mut patch_feature = img_feature.clone();
323            if let Some(image_token_compression) = &self.image_token_compression {
324                // reshape to 2D tensor
325                let width = (patch_feature.dim(1)? as f64).sqrt() as usize;
326                patch_feature =
327                    patch_feature.reshape(((), width, width, patch_feature.dim(D::Minus1)?))?;
328                // Convert to NCHW
329                patch_feature = patch_feature.permute((0, 3, 1, 2))?;
330                if let Some(img_processor_padding) = &self.img_processor_padding {
331                    patch_feature = patch_feature.apply(img_processor_padding)?;
332                }
333                patch_feature = image_token_compression.forward(&patch_feature)?;
334                // Convert to NHWC
335                patch_feature = patch_feature.permute((0, 2, 3, 1))?;
336                patch_feature = patch_feature.reshape((
337                    (),
338                    patch_feature.dim(1)? * patch_feature.dim(2)?,
339                    patch_feature.dim(D::Minus1)?,
340                ))?;
341            } else if let Some(img_processor_padding) = &self.img_processor_padding {
342                // reshape to 2D tensor
343                let width = (patch_feature.dim(1)? as f64).sqrt() as usize;
344                patch_feature =
345                    patch_feature.reshape(((), width, width, patch_feature.dim(D::Minus1)?))?;
346                // Convert to NCHW
347                patch_feature = patch_feature.permute((0, 3, 1, 2))?;
348                patch_feature = patch_feature.apply(img_processor_padding)?;
349                // Convert to NHWC
350                patch_feature = patch_feature.permute((0, 2, 3, 1))?;
351                patch_feature = patch_feature.reshape((
352                    (),
353                    patch_feature.dim(1)? * patch_feature.dim(2)?,
354                    patch_feature.dim(D::Minus1)?,
355                ))?;
356            };
357            Ok(patch_feature)
358        } else if self.type_feature == "cls_patch" {
359            let mut img_feature = img_feature.clone();
360            if let Some(image_token_compression) = &self.image_token_compression {
361                // reshape to 2D tensor
362                let mut patch_feature = img_feature.i((.., 1..))?;
363                let cls_feature = img_feature.i((.., 0))?;
364                let width = (patch_feature.dim(1)? as f64).sqrt() as usize;
365                patch_feature =
366                    patch_feature.reshape(((), width, width, patch_feature.dim(D::Minus1)?))?;
367                patch_feature = image_token_compression.forward(&patch_feature)?;
368                patch_feature = patch_feature.reshape((
369                    (),
370                    patch_feature.dim(D::Minus2)? * patch_feature.dim(D::Minus1)?,
371                ))?;
372                img_feature = Tensor::cat(&[cls_feature, patch_feature], 1)?;
373            }
374            Ok(img_feature)
375        } else {
376            candle_core::bail!("Unsupported image feature type {}", self.type_feature)
377        }
378    }
379
380    #[allow(non_snake_case)]
381    pub fn forward(
382        &self,
383        input_ids: &Tensor,
384        input_embeds: &Tensor,
385        image_attention_mask: Option<&Tensor>,
386        image_sizes: Option<Vec<(u32, u32)>>,
387    ) -> Result<Tensor> {
388        let input_ids = input_ids.reshape(((), input_ids.dim(D::Minus1)?))?;
389
390        let positions = input_ids.eq(IMAGE_SPECIAL_TOKEN_ID)?.nonzero()?;
391
392        let target_dev = self.layers.0[0].device();
393        let target_dtype = self.layers.0[0].dtype();
394
395        let mut select = false;
396        let mut image_set_tensor = None;
397        if positions.dim(0)? > 0 {
398            select = true;
399
400            if self.use_hd_transform {
401                if let Some(image_sizes_ref) = image_sizes.as_ref() {
402                    assert_eq!(input_embeds.dims().len(), 5);
403                    let bs = input_embeds.dim(0)?;
404                    let img_features = match image_attention_mask {
405                        Some(attn_mask) => self.get_image_features(
406                            &input_embeds.flatten(0, 1)?,
407                            Some(&attn_mask.ne(0.)?.flatten(0, 1)?),
408                        )?,
409                        None => self.get_image_features(input_embeds, None)?,
410                    };
411
412                    let base_feat_height_target = self.base_feat_height_target.unwrap();
413                    let base_resolution = self.crop_size;
414                    let base_feat_height_reduction = self.base_feat_height_reduction;
415
416                    let base_feat_height = (img_features.dim(1)? as f64).sqrt() as usize;
417                    let base_feat_width = base_feat_height;
418
419                    assert_eq!(base_feat_height, base_feat_height_target);
420                    assert_eq!(base_feat_width, base_feat_height_target);
421
422                    let img_features = img_features.reshape((
423                        bs,
424                        (),
425                        base_feat_height * base_feat_width,
426                        self.image_dim_out,
427                    ))?;
428                    let C = self.image_dim_out;
429                    let H = base_feat_height;
430
431                    let mut output_imgs = Vec::new();
432                    for (bs_, &(h, w)) in image_sizes_ref.iter().enumerate().take(bs) {
433                        let h = h as usize / base_resolution;
434                        let w = w as usize / base_resolution;
435                        let B_ = h * w;
436
437                        // 1 x (24x24) x 1024
438                        let global_img_feature = img_features.i((bs_, ..1))?;
439
440                        // 1 x 12 x 12 x 4096
441                        let glb_img = global_img_feature
442                            .reshape((1, H, H, C))?
443                            .reshape((
444                                1,
445                                H / base_feat_height_reduction,
446                                base_feat_height_reduction,
447                                H / base_feat_height_reduction,
448                                base_feat_height_reduction,
449                                C,
450                            ))?
451                            .contiguous()?
452                            .permute((0, 1, 3, 2, 4, 5))?
453                            .reshape((
454                                1,
455                                H / base_feat_height_reduction,
456                                H / base_feat_height_reduction,
457                                base_feat_height_reduction * base_feat_height_reduction * C,
458                            ))?
459                            .contiguous()?;
460                        let temp_glbl_gn = self
461                            .sub_gn
462                            .as_ref()
463                            .expect("Need `sub_gn` if `use_hd_transform`")
464                            .repeat((1, H / base_feat_height_reduction, 1, 1))?;
465
466                        // 1 x 156 x 4096
467                        let glb_img = Tensor::cat(&[glb_img, temp_glbl_gn], 2)?.reshape((
468                            1,
469                            (),
470                            base_feat_height_reduction * base_feat_height_reduction * C,
471                        ))?;
472
473                        // (max_num_crops-1) x (12x12) x C
474                        let mut sub_img = img_features.i((bs_, 1..))?;
475
476                        // 16x574x1024
477                        // Get rid of padding sub_img
478                        sub_img = sub_img.i(..B_)?;
479
480                        // (num_crops, 12, 2, 12, 2, 1024) -> (num_crops, 12, 12, 2, 2, 1024) -> (num_crops, 12*12, 4*1024)
481                        sub_img = sub_img
482                            .reshape((B_, H, H, C))?
483                            .reshape((
484                                B_,
485                                H / base_feat_height_reduction,
486                                base_feat_height_reduction,
487                                H / base_feat_height_reduction,
488                                base_feat_height_reduction,
489                                C,
490                            ))?
491                            .contiguous()?
492                            .permute((0, 1, 3, 2, 4, 5))?
493                            .reshape((
494                                B_,
495                                (),
496                                base_feat_height_reduction * base_feat_height_reduction * C,
497                            ))?
498                            .contiguous()?;
499                        sub_img = sub_img
500                            .reshape(BigShapeWithOneHole((
501                                1usize,
502                                h,
503                                w,
504                                base_feat_height / base_feat_height_reduction,
505                                base_feat_width / base_feat_height_reduction,
506                                (),
507                            )))?
508                            .permute((0, 1, 3, 2, 4, 5))?
509                            .reshape((
510                                1,
511                                h * base_feat_height / base_feat_height_reduction,
512                                w * base_feat_width / base_feat_height_reduction,
513                                base_feat_height_reduction * base_feat_height_reduction * C,
514                            ))?;
515
516                        let (temp_sub_GN, temp_len) =
517                            if let Some(image_attention_mask) = image_attention_mask {
518                                let h_indices = Tensor::arange_step(
519                                    0,
520                                    image_attention_mask.dim(2)? as u32,
521                                    2,
522                                    &target_dev,
523                                )?;
524                                let w_indices = Tensor::arange_step(
525                                    0,
526                                    image_attention_mask.dim(3)? as u32,
527                                    2,
528                                    &target_dev,
529                                )?;
530
531                                let reshaped_image_attention_mask = {
532                                    let mut selected = image_attention_mask.i((bs_, 1..B_ + 1))?;
533                                    selected = selected.index_select(&h_indices, 1)?;
534                                    selected = selected.index_select(&w_indices, 2)?;
535                                    selected
536                                        .reshape((
537                                            1,
538                                            h,
539                                            w,
540                                            base_feat_height / base_feat_height_reduction,
541                                            base_feat_width / base_feat_height_reduction,
542                                        ))?
543                                        .permute((0, 1, 3, 2, 4))?
544                                        .reshape((
545                                            1,
546                                            h * base_feat_height / base_feat_height_reduction,
547                                            w * base_feat_width / base_feat_height_reduction,
548                                        ))?
549                                };
550
551                                let useful_height = reshaped_image_attention_mask
552                                    .i((0, .., 0))?
553                                    .sum_all()?
554                                    .to_scalar::<u32>()?;
555                                let useful_width = reshaped_image_attention_mask
556                                    .i((0, 0, ..))?
557                                    .sum_all()?
558                                    .to_scalar::<u32>()?;
559
560                                sub_img = sub_img.i((
561                                    ..,
562                                    ..useful_height as usize,
563                                    ..useful_width as usize,
564                                ))?;
565
566                                let temp_len = {
567                                    let mut selected = image_attention_mask.i((bs_, ..B_ + 1))?;
568                                    selected = selected.index_select(&h_indices, 1)?;
569                                    selected = selected.index_select(&w_indices, 2)?;
570                                    selected.sum_all()?.to_scalar::<u32>()?
571                                };
572                                let temp_len = temp_len as usize
573                                    + useful_height as usize
574                                    + 1
575                                    + base_feat_height / base_feat_height_reduction;
576
577                                (
578                                    self.sub_gn
579                                        .as_ref()
580                                        .expect("Need `sub_gn` if `use_hd_transform`")
581                                        .repeat((1, useful_height as usize, 1, 1))?,
582                                    temp_len,
583                                )
584                            } else {
585                                let temp_len = (h * w + 1) * self.num_img_tokens
586                                    + 1
587                                    + (h + 1) * base_feat_height / base_feat_height_reduction;
588
589                                (
590                                    self.sub_gn
591                                        .as_ref()
592                                        .expect("Need `sub_gn` if `use_hd_transform`")
593                                        .repeat((
594                                            1,
595                                            h * base_feat_height / base_feat_height_reduction,
596                                            1,
597                                            1,
598                                        ))?,
599                                    temp_len,
600                                )
601                            };
602
603                        let sub_img = Tensor::cat(&[sub_img, temp_sub_GN], 2)?.reshape((
604                            1,
605                            (),
606                            base_feat_height_reduction * base_feat_height_reduction * C,
607                        ))?;
608
609                        // (1, num_img_tokens, 1024*4)
610
611                        // glb + sub
612                        match self.hd_transform_order.as_str() {
613                            "glb_sub" => {
614                                output_imgs.push(Tensor::cat(
615                                    &[
616                                        glb_img,
617                                        self.glb_gn
618                                            .as_ref()
619                                            .expect("Need `glb_gn` if `use_hd_transform`")
620                                            .clone(),
621                                        sub_img,
622                                    ],
623                                    1,
624                                )?);
625                            }
626                            "sub_glb" => {
627                                output_imgs.push(Tensor::cat(
628                                    &[
629                                        sub_img,
630                                        self.glb_gn
631                                            .as_ref()
632                                            .expect("Need `glb_gn` if `use_hd_transform`")
633                                            .clone(),
634                                        glb_img,
635                                    ],
636                                    1,
637                                )?);
638                            }
639                            other => {
640                                candle_core::bail!("Invalid hd_transform_order=`{other}`");
641                            }
642                        }
643
644                        // (h*w+1)*144 + 1 + (h+1)*12
645                        assert_eq!(temp_len, output_imgs.last().unwrap().dims()[1]);
646                    }
647
648                    let mut image_set_tensor_inner = Vec::new();
649                    for img in output_imgs {
650                        let layerout = self
651                            .layers
652                            .forward(&img.to_device(&target_dev)?.to_dtype(target_dtype)?)?;
653                        image_set_tensor_inner.push(layerout);
654                    }
655                    image_set_tensor = Some(image_set_tensor_inner);
656                }
657            } else {
658                unreachable!()
659            }
660        }
661
662        let mut hidden_states = self.wte.forward(&input_ids)?;
663        if select && self.use_hd_transform {
664            match image_set_tensor {
665                Some(image_set_tensors) => {
666                    let merged_img_set_tensor = Tensor::cat(&image_set_tensors, 1)?.squeeze(0)?;
667
668                    let original_shape = hidden_states.shape().clone();
669                    let (hs_b, hs_l, hs_d) = hidden_states.dims3()?;
670                    let mut hidden_states_flat = hidden_states.reshape(((), hs_d))?;
671
672                    // Get the equiv 0th and 1th rows of the positions_tuple
673                    let positions_transposed = positions.to_dtype(DType::F32)?;
674                    let positions_transposed_0 = positions_transposed.i((.., 0))?;
675                    let positions_transposed_1 = positions_transposed.i((.., 1))?;
676
677                    let mut linear_index = ((positions_transposed_0 * (hs_l * hs_b) as f64)?
678                        + positions_transposed_1)?;
679                    linear_index = linear_index.to_dtype(DType::U32)?;
680                    linear_index = linear_index.unsqueeze(1)?.repeat((1, hs_d))?;
681
682                    let current_vals = hidden_states_flat.gather(&linear_index, 0)?;
683                    let delta = merged_img_set_tensor.broadcast_sub(&current_vals)?;
684
685                    hidden_states_flat =
686                        hidden_states_flat.scatter_add(&linear_index, &delta, 0)?;
687
688                    hidden_states = hidden_states_flat.reshape(original_shape)?;
689                }
690                _ => unreachable!(),
691            }
692        }
693
694        Ok(hidden_states)
695    }
696
697    pub fn residual_tensors(&self) -> Vec<(String, Tensor)> {
698        let uvb = UnVarBuilder::new();
699
700        if let Some(glb_gn) = self.glb_gn.clone() {
701            uvb.add_tensor("glb_GN", glb_gn);
702        }
703        if let Some(sub_gn) = self.sub_gn.clone() {
704            uvb.add_tensor("sub_GN", sub_gn);
705        }
706        uvb.extend(self.tensors.clone());
707        uvb.pp("img_processor.vision_model")
708            .extend(self.image_processor.residual_tensors());
709
710        uvb.to_safetensors()
711    }
712}