mistralrs_core/vision_models/phi4/
image_embedding.rs

1use std::{
2    fmt::Debug,
3    sync::{Arc, LazyLock},
4};
5
6use candle_core::{shape::ShapeWithOneHole, DType, Device, IndexOp, Result, Shape, Tensor, D};
7use candle_nn::Module;
8use mistralrs_quant::{QuantMethod, ShardedVarBuilder};
9
10use crate::{
11    layers::{AvgPool2d, ReflectionPad2d},
12    ops::NonZeroOp,
13    utils::unvarbuilder::UnVarBuilder,
14    vision_models::{
15        phi4::config::Phi4MMImgProcessorConfig,
16        siglip::{SiglipVisionConfig, SiglipVisionTransformer},
17    },
18};
19
20use super::{config::Phi4MMImageEmbedConfig, Phi4MMConfig};
21
22pub(super) const IMAGE_SPECIAL_TOKEN_ID: f64 = 200010.;
23
24trait ModuleWithMetadata: Module + Debug + Send + Sync {
25    fn device(&self) -> Device;
26    fn dtype(&self) -> DType;
27}
28
29#[derive(Debug)]
30struct QuantMethodWrapper(Arc<dyn QuantMethod>);
31
32impl Module for QuantMethodWrapper {
33    fn forward(&self, xs: &Tensor) -> Result<Tensor> {
34        self.0.forward(xs)
35    }
36}
37
38impl ModuleWithMetadata for QuantMethodWrapper {
39    fn device(&self) -> Device {
40        self.0.unquant_weight_bias().unwrap().0.device().clone()
41    }
42    fn dtype(&self) -> DType {
43        self.0.unquant_weight_bias().unwrap().0.dtype()
44    }
45}
46
47impl ModuleWithMetadata for candle_nn::Activation {
48    fn device(&self) -> Device {
49        unreachable!()
50    }
51    fn dtype(&self) -> DType {
52        unreachable!()
53    }
54}
55
56#[derive(Debug)]
57struct BigShapeWithOneHole((usize, usize, usize, usize, usize, ()));
58
59fn hole_size(el_count: usize, prod_d: usize, s: &dyn std::fmt::Debug) -> Result<usize> {
60    if prod_d == 0 {
61        candle_core::bail!("cannot reshape tensor of {el_count} elements to {s:?}")
62    }
63    if el_count % prod_d != 0 {
64        candle_core::bail!("cannot reshape tensor with {el_count} elements to {s:?}")
65    }
66    Ok(el_count / prod_d)
67}
68
69impl ShapeWithOneHole for BigShapeWithOneHole {
70    fn into_shape(self, el_count: usize) -> Result<Shape> {
71        let (d1, d2, d3, d4, d5, ()) = self.0;
72        let d = hole_size(el_count, d1 * d2 * d3 * d4 * d5, &self)?;
73        Ok((d1, d2, d3, d4, d5, d).into())
74    }
75}
76
77#[derive(Debug)]
78struct EmbeddingLayers(Vec<Box<dyn ModuleWithMetadata>>);
79
80impl Module for EmbeddingLayers {
81    fn forward(&self, xs: &Tensor) -> Result<Tensor> {
82        let mut xs = xs.clone();
83        for layer in &self.0 {
84            xs = layer.forward(&xs)?;
85        }
86        Ok(xs)
87    }
88}
89
90pub(crate) static PHI4_MM_VISION_CFG: LazyLock<SiglipVisionConfig> =
91    LazyLock::new(|| SiglipVisionConfig {
92        hidden_size: 1152,
93        image_size: 448,
94        intermediate_size: 4304,
95        num_attention_heads: 16,
96        num_hidden_layers: 27,
97        patch_size: 14,
98        ..Default::default()
99    });
100
101pub struct ImageEmbedding {
102    wte: candle_nn::Embedding,
103    image_dim_out: usize,
104    num_img_tokens: usize,
105    glb_gn: Option<Tensor>,
106    sub_gn: Option<Tensor>,
107    layers: EmbeddingLayers,
108    type_feature: String,
109    layer_idx: isize,
110    image_processor: SiglipVisionTransformer,
111    hd_transform_order: String,
112    use_hd_transform: bool,
113    tensors: Vec<(String, Tensor)>,
114    img_processor_padding: Option<ReflectionPad2d>,
115    crop_size: usize,
116    image_token_compression: Option<AvgPool2d>,
117    base_feat_height_reduction: usize,
118    base_feat_height_target: Option<usize>,
119}
120
121impl ImageEmbedding {
122    pub fn new(
123        cfg: &Phi4MMConfig,
124        img_embd_config: &Phi4MMImageEmbedConfig,
125        wte: candle_nn::Embedding,
126        vb: ShardedVarBuilder,
127    ) -> Result<Self> {
128        let hidden_size = img_embd_config.n_embd.unwrap_or(cfg.hidden_size);
129
130        let siglip_vision_config = &PHI4_MM_VISION_CFG;
131        let image_processor =
132            SiglipVisionTransformer::new(siglip_vision_config, vb.pp("img_processor"))?;
133
134        let pe_weight = image_processor.embeddings.position_embedding.embeddings();
135        let (l, d) = pe_weight.dims2()?;
136        let mut m = (l as f64).sqrt() as usize;
137        assert_eq!(m.pow(2), l);
138        let img_processor_padding = if m % 2 != 0 {
139            m += 1;
140            Some(ReflectionPad2d::new((0, 1, 0, 1)))
141        } else {
142            None
143        };
144        let image_dim_out = d;
145        let num_img_tokens = (m / 2).pow(2);
146        let base_feat_height_target = m;
147
148        // High dim transform
149        let use_hd_transform = img_embd_config.use_hd_transform.unwrap_or(false);
150        let with_learnable_separator = img_embd_config.with_learnable_separator.unwrap_or(false);
151        let hd_transform_order = img_embd_config
152            .hd_transform_order
153            .clone()
154            .unwrap_or("glb_sub".to_string());
155        let crop_size = img_embd_config.crop_size.unwrap_or(336);
156
157        let (image_token_compression, base_feat_height_reduction, base_feat_height_target) =
158            match &img_embd_config.image_token_compression_cls {
159                Some(x) if x == "avg_pool_2d" => (
160                    Some(AvgPool2d::new(2, 2)),
161                    1_usize,
162                    Some(base_feat_height_target / 2),
163                ),
164                None => (None, 2_usize, None),
165                _ => candle_core::bail!("Unexpected image_token_compression_cls"),
166            };
167
168        assert_eq!(use_hd_transform, with_learnable_separator);
169        let (glb_gn, sub_gn) = if with_learnable_separator {
170            let glb_gn = vb.get(
171                (1, 1, image_dim_out * base_feat_height_reduction.pow(2)),
172                "glb_GN",
173            )?;
174            let sub_gn = vb.get(
175                (1, 1, 1, image_dim_out * base_feat_height_reduction.pow(2)),
176                "sub_GN",
177            )?;
178            (Some(glb_gn), Some(sub_gn))
179        } else {
180            (None, None)
181        };
182
183        // Inner projection
184        let projection_cls = img_embd_config
185            .projection_cls
186            .clone()
187            .unwrap_or("linear".to_string());
188
189        let mut tensors = Vec::new();
190        let layers: Vec<Box<dyn ModuleWithMetadata>> =
191            match (projection_cls.as_str(), use_hd_transform) {
192                ("linear", _) => {
193                    let a = mistralrs_quant::linear_b(
194                        image_dim_out,
195                        hidden_size,
196                        true,
197                        &None,
198                        vb.pp("img_projection"),
199                    )?;
200                    let (a_w, a_b) = a.unquant_weight_bias().unwrap();
201                    tensors.push(("img_projection.weight".to_string(), a_w));
202                    if let Some(b) = a_b {
203                        tensors.push(("img_projection.bias".to_string(), b));
204                    }
205                    vec![Box::new(QuantMethodWrapper(a))]
206                }
207                ("mlp", true) => {
208                    let dim_proj = hidden_size;
209                    let a = mistralrs_quant::linear_b(
210                        image_dim_out * base_feat_height_reduction.pow(2),
211                        dim_proj,
212                        true,
213                        &None,
214                        vb.pp("img_projection.0"),
215                    )?;
216                    let (a_w, a_b) = a.unquant_weight_bias().unwrap();
217                    tensors.push(("img_projection.0.weight".to_string(), a_w));
218                    if let Some(b) = a_b {
219                        tensors.push(("img_projection.0.bias".to_string(), b));
220                    }
221                    let b = mistralrs_quant::linear_b(
222                        dim_proj,
223                        dim_proj,
224                        true,
225                        &None,
226                        vb.pp("img_projection.2"),
227                    )?;
228                    let (b_w, b_b) = b.unquant_weight_bias().unwrap();
229                    tensors.push(("img_projection.2.weight".to_string(), b_w));
230                    if let Some(b) = b_b {
231                        tensors.push(("img_projection.2.bias".to_string(), b));
232                    }
233                    vec![
234                        Box::new(QuantMethodWrapper(a)),
235                        Box::new(candle_nn::Activation::Gelu),
236                        Box::new(QuantMethodWrapper(b)),
237                    ]
238                }
239                ("mlp", false) => {
240                    let dim_proj = hidden_size;
241                    let a = mistralrs_quant::linear_b(
242                        image_dim_out,
243                        dim_proj,
244                        true,
245                        &None,
246                        vb.pp("img_projection.0"),
247                    )?;
248                    let (a_w, a_b) = a.unquant_weight_bias().unwrap();
249                    tensors.push(("img_projection.0.weight".to_string(), a_w));
250                    if let Some(b) = a_b {
251                        tensors.push(("img_projection.0.bias".to_string(), b));
252                    }
253                    let b = mistralrs_quant::linear_b(
254                        dim_proj,
255                        dim_proj,
256                        true,
257                        &None,
258                        vb.pp("img_projection.2"),
259                    )?;
260                    let (b_w, b_b) = b.unquant_weight_bias().unwrap();
261                    tensors.push(("img_projection.2.weight".to_string(), b_w));
262                    if let Some(b) = b_b {
263                        tensors.push(("img_projection.2.bias".to_string(), b));
264                    }
265                    vec![
266                        Box::new(QuantMethodWrapper(a)),
267                        Box::new(candle_nn::Activation::Gelu),
268                        Box::new(QuantMethodWrapper(b)),
269                    ]
270                }
271                _ => {
272                    candle_core::bail!("projection_cls=`{projection_cls}` not implemented.");
273                }
274            };
275
276        let (layer_idx, type_feature) = match &cfg.img_processor {
277            Some(Phi4MMImgProcessorConfig {
278                layer_idx,
279                type_feature,
280            }) => (
281                layer_idx.unwrap_or(-2),
282                type_feature.clone().unwrap_or("patch".to_string()),
283            ),
284
285            None => (-2, "patch".to_string()),
286        };
287
288        Ok(Self {
289            wte,
290            image_dim_out,
291            num_img_tokens,
292            glb_gn,
293            sub_gn,
294            layer_idx,
295            type_feature,
296            image_processor,
297            layers: EmbeddingLayers(layers),
298            hd_transform_order,
299            use_hd_transform,
300            tensors,
301            img_processor_padding,
302            crop_size,
303            image_token_compression,
304            base_feat_height_reduction,
305            base_feat_height_target,
306        })
307    }
308
309    fn get_image_features(
310        &self,
311        img_embeds: &Tensor,
312        attention_mask: Option<&Tensor>,
313    ) -> Result<Tensor> {
314        assert!(self.layer_idx < 0);
315        let img_feature = self.image_processor.forward_get_hidden_states(
316            &img_embeds.to_dtype(self.image_processor.dtype())?,
317            attention_mask,
318            None,
319            self.layer_idx,
320        )?;
321
322        if self.type_feature == "patch" {
323            let mut patch_feature = img_feature.clone();
324            if let Some(image_token_compression) = &self.image_token_compression {
325                // reshape to 2D tensor
326                let width = (patch_feature.dim(1)? as f64).sqrt() as usize;
327                patch_feature =
328                    patch_feature.reshape(((), width, width, patch_feature.dim(D::Minus1)?))?;
329                // Convert to NCHW
330                patch_feature = patch_feature.permute((0, 3, 1, 2))?;
331                if let Some(img_processor_padding) = &self.img_processor_padding {
332                    patch_feature = patch_feature.apply(img_processor_padding)?;
333                }
334                patch_feature = image_token_compression.forward(&patch_feature)?;
335                // Convert to NHWC
336                patch_feature = patch_feature.permute((0, 2, 3, 1))?;
337                patch_feature = patch_feature.reshape((
338                    (),
339                    patch_feature.dim(1)? * patch_feature.dim(2)?,
340                    patch_feature.dim(D::Minus1)?,
341                ))?;
342            } else if let Some(img_processor_padding) = &self.img_processor_padding {
343                // reshape to 2D tensor
344                let width = (patch_feature.dim(1)? as f64).sqrt() as usize;
345                patch_feature =
346                    patch_feature.reshape(((), width, width, patch_feature.dim(D::Minus1)?))?;
347                // Convert to NCHW
348                patch_feature = patch_feature.permute((0, 3, 1, 2))?;
349                patch_feature = patch_feature.apply(img_processor_padding)?;
350                // Convert to NHWC
351                patch_feature = patch_feature.permute((0, 2, 3, 1))?;
352                patch_feature = patch_feature.reshape((
353                    (),
354                    patch_feature.dim(1)? * patch_feature.dim(2)?,
355                    patch_feature.dim(D::Minus1)?,
356                ))?;
357            };
358            Ok(patch_feature)
359        } else if self.type_feature == "cls_patch" {
360            let mut img_feature = img_feature.clone();
361            if let Some(image_token_compression) = &self.image_token_compression {
362                // reshape to 2D tensor
363                let mut patch_feature = img_feature.i((.., 1..))?;
364                let cls_feature = img_feature.i((.., 0))?;
365                let width = (patch_feature.dim(1)? as f64).sqrt() as usize;
366                patch_feature =
367                    patch_feature.reshape(((), width, width, patch_feature.dim(D::Minus1)?))?;
368                patch_feature = image_token_compression.forward(&patch_feature)?;
369                patch_feature = patch_feature.reshape((
370                    (),
371                    patch_feature.dim(D::Minus2)? * patch_feature.dim(D::Minus1)?,
372                ))?;
373                img_feature = Tensor::cat(&[cls_feature, patch_feature], 1)?;
374            }
375            Ok(img_feature)
376        } else {
377            candle_core::bail!("Unsupported image feature type {}", self.type_feature)
378        }
379    }
380
381    #[allow(non_snake_case)]
382    pub fn forward(
383        &self,
384        input_ids: &Tensor,
385        input_embeds: &Tensor,
386        image_attention_mask: Option<&Tensor>,
387        image_sizes: Option<Vec<(u32, u32)>>,
388    ) -> Result<Tensor> {
389        let input_ids = input_ids.reshape(((), input_ids.dim(D::Minus1)?))?;
390
391        let positions = input_ids.eq(IMAGE_SPECIAL_TOKEN_ID)?.nonzero()?;
392
393        let target_dev = self.layers.0[0].device();
394        let target_dtype = self.layers.0[0].dtype();
395
396        let mut select = false;
397        let mut image_set_tensor = None;
398        if positions.dim(0)? > 0 {
399            select = true;
400
401            if self.use_hd_transform && image_sizes.is_some() {
402                assert_eq!(input_embeds.dims().len(), 5);
403                let bs = input_embeds.dim(0)?;
404                let img_features = match image_attention_mask {
405                    Some(attn_mask) => self.get_image_features(
406                        &input_embeds.flatten(0, 1)?,
407                        Some(&attn_mask.ne(0.)?.flatten(0, 1)?),
408                    )?,
409                    None => self.get_image_features(input_embeds, None)?,
410                };
411
412                let base_feat_height_target = self.base_feat_height_target.unwrap();
413                let base_resolution = self.crop_size;
414                let base_feat_height_reduction = self.base_feat_height_reduction;
415
416                let base_feat_height = (img_features.dim(1)? as f64).sqrt() as usize;
417                let base_feat_width = base_feat_height;
418
419                assert_eq!(base_feat_height, base_feat_height_target);
420                assert_eq!(base_feat_width, base_feat_height_target);
421
422                let img_features = img_features.reshape((
423                    bs,
424                    (),
425                    base_feat_height * base_feat_width,
426                    self.image_dim_out,
427                ))?;
428                let C = self.image_dim_out;
429                let H = base_feat_height;
430
431                let mut output_imgs = Vec::new();
432                for bs_ in 0..bs {
433                    let (h, w) = image_sizes.as_ref().unwrap()[bs_];
434                    let h = h as usize / base_resolution;
435                    let w = w as usize / base_resolution;
436                    let B_ = h * w;
437
438                    // 1 x (24x24) x 1024
439                    let global_img_feature = img_features.i((bs_, ..1))?;
440
441                    // 1 x 12 x 12 x 4096
442                    let glb_img = global_img_feature
443                        .reshape((1, H, H, C))?
444                        .reshape((
445                            1,
446                            H / base_feat_height_reduction,
447                            base_feat_height_reduction,
448                            H / base_feat_height_reduction,
449                            base_feat_height_reduction,
450                            C,
451                        ))?
452                        .contiguous()?
453                        .permute((0, 1, 3, 2, 4, 5))?
454                        .reshape((
455                            1,
456                            H / base_feat_height_reduction,
457                            H / base_feat_height_reduction,
458                            base_feat_height_reduction * base_feat_height_reduction * C,
459                        ))?
460                        .contiguous()?;
461                    let temp_glbl_gn = self
462                        .sub_gn
463                        .as_ref()
464                        .expect("Need `sub_gn` if `use_hd_transform`")
465                        .repeat((1, H / base_feat_height_reduction, 1, 1))?;
466
467                    // 1 x 156 x 4096
468                    let glb_img = Tensor::cat(&[glb_img, temp_glbl_gn], 2)?.reshape((
469                        1,
470                        (),
471                        base_feat_height_reduction * base_feat_height_reduction * C,
472                    ))?;
473
474                    // (max_num_crops-1) x (12x12) x C
475                    let mut sub_img = img_features.i((bs_, 1..))?;
476
477                    // 16x574x1024
478                    // Get rid of padding sub_img
479                    sub_img = sub_img.i(..B_)?;
480
481                    // (num_crops, 12, 2, 12, 2, 1024) -> (num_crops, 12, 12, 2, 2, 1024) -> (num_crops, 12*12, 4*1024)
482                    sub_img = sub_img
483                        .reshape((B_, H, H, C))?
484                        .reshape((
485                            B_,
486                            H / base_feat_height_reduction,
487                            base_feat_height_reduction,
488                            H / base_feat_height_reduction,
489                            base_feat_height_reduction,
490                            C,
491                        ))?
492                        .contiguous()?
493                        .permute((0, 1, 3, 2, 4, 5))?
494                        .reshape((
495                            B_,
496                            (),
497                            base_feat_height_reduction * base_feat_height_reduction * C,
498                        ))?
499                        .contiguous()?;
500                    sub_img = sub_img
501                        .reshape(BigShapeWithOneHole((
502                            1usize,
503                            h,
504                            w,
505                            base_feat_height / base_feat_height_reduction,
506                            base_feat_width / base_feat_height_reduction,
507                            (),
508                        )))?
509                        .permute((0, 1, 3, 2, 4, 5))?
510                        .reshape((
511                            1,
512                            h * base_feat_height / base_feat_height_reduction,
513                            w * base_feat_width / base_feat_height_reduction,
514                            base_feat_height_reduction * base_feat_height_reduction * C,
515                        ))?;
516
517                    let (temp_sub_GN, temp_len) = if let Some(image_attention_mask) =
518                        image_attention_mask
519                    {
520                        let h_indices = Tensor::arange_step(
521                            0,
522                            image_attention_mask.dim(2)? as u32,
523                            2,
524                            &target_dev,
525                        )?;
526                        let w_indices = Tensor::arange_step(
527                            0,
528                            image_attention_mask.dim(3)? as u32,
529                            2,
530                            &target_dev,
531                        )?;
532
533                        let reshaped_image_attention_mask = {
534                            let mut selected = image_attention_mask.i((bs_, 1..B_ + 1))?;
535                            selected = selected.index_select(&h_indices, 1)?;
536                            selected = selected.index_select(&w_indices, 2)?;
537                            selected
538                                .reshape((
539                                    1,
540                                    h,
541                                    w,
542                                    base_feat_height / base_feat_height_reduction,
543                                    base_feat_width / base_feat_height_reduction,
544                                ))?
545                                .permute((0, 1, 3, 2, 4))?
546                                .reshape((
547                                    1,
548                                    h * base_feat_height / base_feat_height_reduction,
549                                    w * base_feat_width / base_feat_height_reduction,
550                                ))?
551                        };
552
553                        let useful_height = reshaped_image_attention_mask
554                            .i((0, .., 0))?
555                            .sum_all()?
556                            .to_scalar::<u32>()?;
557                        let useful_width = reshaped_image_attention_mask
558                            .i((0, 0, ..))?
559                            .sum_all()?
560                            .to_scalar::<u32>()?;
561
562                        sub_img =
563                            sub_img.i((.., ..useful_height as usize, ..useful_width as usize))?;
564
565                        let temp_len = {
566                            let mut selected = image_attention_mask.i((bs_, ..B_ + 1))?;
567                            selected = selected.index_select(&h_indices, 1)?;
568                            selected = selected.index_select(&w_indices, 2)?;
569                            selected.sum_all()?.to_scalar::<u32>()?
570                        };
571                        let temp_len = temp_len as usize
572                            + useful_height as usize
573                            + 1
574                            + base_feat_height / base_feat_height_reduction;
575
576                        (
577                            self.sub_gn
578                                .as_ref()
579                                .expect("Need `sub_gn` if `use_hd_transform`")
580                                .repeat((1, useful_height as usize, 1, 1))?,
581                            temp_len,
582                        )
583                    } else {
584                        let temp_len = (h * w + 1) * self.num_img_tokens
585                            + 1
586                            + (h + 1) * base_feat_height / base_feat_height_reduction;
587
588                        (
589                            self.sub_gn
590                                .as_ref()
591                                .expect("Need `sub_gn` if `use_hd_transform`")
592                                .repeat((
593                                    1,
594                                    h * base_feat_height / base_feat_height_reduction,
595                                    1,
596                                    1,
597                                ))?,
598                            temp_len,
599                        )
600                    };
601
602                    let sub_img = Tensor::cat(&[sub_img, temp_sub_GN], 2)?.reshape((
603                        1,
604                        (),
605                        base_feat_height_reduction * base_feat_height_reduction * C,
606                    ))?;
607
608                    // (1, num_img_tokens, 1024*4)
609
610                    // glb + sub
611                    match self.hd_transform_order.as_str() {
612                        "glb_sub" => {
613                            output_imgs.push(Tensor::cat(
614                                &[
615                                    glb_img,
616                                    self.glb_gn
617                                        .as_ref()
618                                        .expect("Need `glb_gn` if `use_hd_transform`")
619                                        .clone(),
620                                    sub_img,
621                                ],
622                                1,
623                            )?);
624                        }
625                        "sub_glb" => {
626                            output_imgs.push(Tensor::cat(
627                                &[
628                                    sub_img,
629                                    self.glb_gn
630                                        .as_ref()
631                                        .expect("Need `glb_gn` if `use_hd_transform`")
632                                        .clone(),
633                                    glb_img,
634                                ],
635                                1,
636                            )?);
637                        }
638                        other => {
639                            candle_core::bail!("Invalid hd_transform_order=`{other}`");
640                        }
641                    }
642
643                    // (h*w+1)*144 + 1 + (h+1)*12
644                    assert_eq!(temp_len, output_imgs.last().unwrap().dims()[1]);
645                }
646
647                let mut image_set_tensor_inner = Vec::new();
648                for img in output_imgs {
649                    let layerout = self
650                        .layers
651                        .forward(&img.to_device(&target_dev)?.to_dtype(target_dtype)?)?;
652                    image_set_tensor_inner.push(layerout);
653                }
654                image_set_tensor = Some(image_set_tensor_inner);
655            } else {
656                unreachable!()
657            }
658        }
659
660        let mut hidden_states = self.wte.forward(&input_ids)?;
661        if select && self.use_hd_transform {
662            match image_set_tensor {
663                Some(image_set_tensors) => {
664                    let merged_img_set_tensor = Tensor::cat(&image_set_tensors, 1)?.squeeze(0)?;
665
666                    let original_shape = hidden_states.shape().clone();
667                    let (hs_b, hs_l, hs_d) = hidden_states.dims3()?;
668                    let mut hidden_states_flat = hidden_states.reshape(((), hs_d))?;
669
670                    // Get the equiv 0th and 1th rows of the positions_tuple
671                    let positions_transposed = positions.to_dtype(DType::F32)?;
672                    let positions_transposed_0 = positions_transposed.i((.., 0))?;
673                    let positions_transposed_1 = positions_transposed.i((.., 1))?;
674
675                    let mut linear_index = ((positions_transposed_0 * (hs_l * hs_b) as f64)?
676                        + positions_transposed_1)?;
677                    linear_index = linear_index.to_dtype(DType::U32)?;
678                    linear_index = linear_index.unsqueeze(1)?.repeat((1, hs_d))?;
679
680                    let current_vals = hidden_states_flat.gather(&linear_index, 0)?;
681                    let delta = merged_img_set_tensor.broadcast_sub(&current_vals)?;
682
683                    hidden_states_flat =
684                        hidden_states_flat.scatter_add(&linear_index, &delta, 0)?;
685
686                    hidden_states = hidden_states_flat.reshape(original_shape)?;
687                }
688                _ => unreachable!(),
689            }
690        }
691
692        Ok(hidden_states)
693    }
694
695    pub fn residual_tensors(&self) -> Vec<(String, Tensor)> {
696        let uvb = UnVarBuilder::new();
697
698        if let Some(glb_gn) = self.glb_gn.clone() {
699            uvb.add_tensor("glb_GN", glb_gn);
700        }
701        if let Some(sub_gn) = self.sub_gn.clone() {
702            uvb.add_tensor("sub_GN", sub_gn);
703        }
704        uvb.extend(self.tensors.clone());
705        uvb.pp("img_processor.vision_model")
706            .extend(self.image_processor.residual_tensors());
707
708        uvb.to_safetensors()
709    }
710}