mistralrs_core/vision_models/minicpmo/
resampler.rs

1#![allow(clippy::cast_possible_truncation, clippy::cast_precision_loss)]
2
3use std::sync::{Arc, Mutex};
4
5use candle_core::{DType, Device, IndexOp, Result, Tensor, D};
6use candle_nn::{LayerNorm, Linear};
7use mistralrs_quant::{MatMul, ShardedVarBuilder};
8
9use crate::{
10    layers::{self, layer_norm, GetFloatInfo},
11    layers_masker::masked_fill,
12    utils::unvarbuilder::UnVarBuilder,
13};
14
15const DEFAULT_MAX_SIZE: (usize, usize) = (70, 70);
16
17fn get_2d_sincos_pos_embed(
18    embed_dim: usize,
19    image_size: (usize, usize),
20    device: &Device,
21    dtype: DType,
22) -> Result<Tensor> {
23    let (grid_h_size, grid_w_size) = image_size;
24    let grid_h = Tensor::arange(0f32, grid_h_size as f32, device)?;
25    let grid_w = Tensor::arange(0f32, grid_w_size as f32, device)?;
26    // Original code uses np.meshgrid, xy is default
27    let grid = Tensor::meshgrid(&[grid_w, grid_h], true)?;
28    let grid = Tensor::stack(&grid, 0)?;
29
30    get_2d_sincos_pos_embed_from_grid(embed_dim, &grid)?.to_dtype(dtype)
31}
32
33fn get_2d_sincos_pos_embed_from_grid(embed_dim: usize, grid: &Tensor) -> Result<Tensor> {
34    assert_eq!(embed_dim % 2, 0);
35
36    let emb_h = get_1d_sincos_pos_embed_from_grid_new(embed_dim / 2, &grid.i(0)?)?;
37    let emb_w = get_1d_sincos_pos_embed_from_grid_new(embed_dim / 2, &grid.i(1)?)?;
38
39    Tensor::cat(&[emb_h, emb_w], D::Minus1)
40}
41
42fn get_1d_sincos_pos_embed_from_grid_new(embed_dim: usize, pos: &Tensor) -> Result<Tensor> {
43    let inv_freq: Vec<_> = (0..embed_dim)
44        .step_by(2)
45        .map(|i| 1f32 / 10_000f32.powf(i as f32 / embed_dim as f32))
46        .collect();
47    let inv_freq_len = inv_freq.len();
48    let omega = Tensor::from_vec(inv_freq, (1, inv_freq_len), pos.device())?;
49
50    let (h, w) = pos.dims2()?;
51
52    let mut out = pos
53        .reshape(((), 1))?
54        .matmul(&omega.reshape((1, ()))?)
55        .unwrap();
56
57    out = out.reshape((h, w, ()))?;
58
59    let emb_sin = out.sin()?;
60    let emb_cos = out.cos()?;
61
62    Tensor::cat(&[emb_sin, emb_cos], D::Minus1)
63}
64
65struct SinCos2dPosEmbed {
66    pos_embed: Tensor,
67    max_size: (usize, usize),
68}
69
70pub struct Resampler {
71    query: Tensor,
72    kv_proj: Option<Linear>,
73    proj: Tensor,
74    ln_q: LayerNorm,
75    ln_kv: LayerNorm,
76    ln_post: LayerNorm,
77    attn: MultiheadAttention,
78    sincos_pos_embed: Arc<Mutex<SinCos2dPosEmbed>>,
79    embed_dim: usize,
80}
81
82impl Resampler {
83    pub fn new(
84        num_queries: usize,
85        embed_dim: usize,
86        num_heads: usize,
87        kv_dim: usize,
88        _adaptive: bool,
89        max_size: Option<(usize, usize)>,
90        vb: ShardedVarBuilder,
91    ) -> Result<Self> {
92        let max_size = max_size.unwrap_or(DEFAULT_MAX_SIZE);
93
94        let query = vb.get((num_queries, embed_dim), "query")?;
95        let kv_proj = if kv_dim != embed_dim {
96            Some(layers::linear_no_bias(kv_dim, embed_dim, vb.pp("kv_proj"))?)
97        } else {
98            None
99        };
100        let ln_q = layer_norm(embed_dim, 1e-6, vb.pp("ln_q"))?;
101        let ln_kv = layer_norm(embed_dim, 1e-6, vb.pp("ln_kv"))?;
102        let ln_post = layer_norm(embed_dim, 1e-6, vb.pp("ln_post"))?;
103        let proj = vb.get((embed_dim, embed_dim), "proj")?;
104        let attn = MultiheadAttention::new(embed_dim, num_heads, vb.pp("attn"))?;
105
106        let pos_embed = Arc::new(Mutex::new(SinCos2dPosEmbed {
107            pos_embed: get_2d_sincos_pos_embed(embed_dim, max_size, vb.device(), vb.dtype())?,
108            max_size,
109        }));
110
111        Ok(Self {
112            query,
113            kv_proj,
114            proj,
115            ln_q,
116            ln_kv,
117            ln_post,
118            attn,
119            sincos_pos_embed: pos_embed,
120            embed_dim,
121        })
122    }
123
124    pub fn forward(&self, x: &Tensor, tgt_sizes_vec: &[Vec<u32>]) -> Result<Tensor> {
125        let mut pos_embed_cache = self.sincos_pos_embed.lock().unwrap();
126
127        let bs = x.dim(0)?;
128        let device = x.device();
129
130        assert_eq!(bs, tgt_sizes_vec.len());
131
132        let tgt_sizes_vec_0 = tgt_sizes_vec.iter().map(|x| x[0]).collect::<Vec<_>>();
133        let tgt_sizes_vec_1 = tgt_sizes_vec.iter().map(|x| x[1]).collect::<Vec<_>>();
134        let patch_len = tgt_sizes_vec_0
135            .iter()
136            .zip(&tgt_sizes_vec_1)
137            .map(|(x, y)| x * y)
138            .collect::<Vec<_>>();
139
140        // Adjust/recompute pos embeds
141        {
142            let max_h = *tgt_sizes_vec_0.iter().max().unwrap() as usize;
143            let max_w = *tgt_sizes_vec_1.iter().max().unwrap() as usize;
144
145            if max_h > pos_embed_cache.max_size.0 || max_w > pos_embed_cache.max_size.1 {
146                pos_embed_cache.max_size = (
147                    max_h.max(pos_embed_cache.max_size.0),
148                    max_w.max(pos_embed_cache.max_size.1),
149                );
150                pos_embed_cache.pos_embed = get_2d_sincos_pos_embed(
151                    self.embed_dim,
152                    pos_embed_cache.max_size,
153                    device,
154                    x.dtype(),
155                )?;
156            }
157        }
158
159        let max_patch_len = *patch_len.iter().max().unwrap() as usize;
160
161        let mut key_padding_mask = Tensor::zeros((bs, max_patch_len), DType::U8, device)?;
162
163        let mut pos_embed = Vec::new();
164        for (i, tgt_sizes_vec_i) in tgt_sizes_vec.iter().enumerate().take(bs) {
165            let (tgt_h, tgt_w) = (tgt_sizes_vec_i[0] as usize, tgt_sizes_vec_i[1] as usize);
166            pos_embed.push(
167                pos_embed_cache
168                    .pos_embed
169                    .i((..tgt_h, ..tgt_w, ..))?
170                    .reshape((tgt_h * tgt_w, ()))?,
171            );
172
173            let n = patch_len[i] as usize;
174            if n != max_patch_len {
175                key_padding_mask = key_padding_mask.slice_assign(
176                    &[&i, &(n..)],
177                    &Tensor::ones((1, max_patch_len - n), DType::U8, device)?,
178                )?;
179            }
180        }
181
182        let lens = pos_embed
183            .iter()
184            .map(|emb| emb.dim(0))
185            .collect::<Result<Vec<_>>>()?;
186        let max_len = lens.into_iter().max().expect("No pixe values somehow?");
187        pos_embed = pos_embed
188            .into_iter()
189            .map(|emb| emb.pad_with_zeros(0, 0, max_len - emb.dim(0)?))
190            .collect::<Result<Vec<_>>>()?;
191        let pos_embed = Tensor::stack(&pos_embed, 0)?;
192
193        let mut x = if let Some(kv_proj) = &self.kv_proj {
194            x.apply(kv_proj)?
195        } else {
196            x.clone()
197        };
198        x = x.apply(&self.ln_kv)?;
199
200        let q = self.query.apply(&self.ln_q)?;
201
202        let mut out = self.attn.forward(
203            &self.repeat_q_bs(&q, bs)?,
204            &(&x + &pos_embed)?,
205            &x,
206            Some(key_padding_mask),
207            None,
208        )?;
209
210        out = out.apply(&self.ln_post)?;
211        out.broadcast_matmul(&self.proj)
212    }
213
214    fn repeat_q_bs(&self, q: &Tensor, n: usize) -> Result<Tensor> {
215        q.unsqueeze(0)?.repeat((n, 1, 1))
216    }
217
218    pub fn residual_tensors(&self) -> Vec<(String, Tensor)> {
219        let uvb = UnVarBuilder::new();
220
221        let uvb_attn = uvb.pp("attn");
222        uvb_attn.pp("out_proj").add(&self.attn.out_proj);
223        uvb_attn.add_tensor("in_proj_weight", self.attn.in_proj_weight.clone());
224        uvb_attn.add_tensor("in_proj_bias", self.attn.in_proj_bias.clone());
225
226        uvb.pp("ln_kv").add(&self.ln_kv);
227        uvb.pp("ln_post").add(&self.ln_post);
228        uvb.pp("ln_q").add(&self.ln_q);
229        uvb.add_tensor("proj", self.proj.clone());
230        uvb.add_tensor("query", self.query.clone());
231
232        uvb.to_safetensors()
233    }
234}
235
236struct MultiheadAttention {
237    q_proj: Linear,
238    k_proj: Linear,
239    v_proj: Linear,
240    out_proj: Linear,
241    num_heads: usize,
242    head_dim: usize,
243    in_proj_weight: Tensor,
244    in_proj_bias: Tensor,
245}
246
247impl MultiheadAttention {
248    fn new(embed_dim: usize, num_heads: usize, vb: ShardedVarBuilder) -> Result<Self> {
249        let in_proj_bias = vb.get(embed_dim * 3, "in_proj_bias")?;
250        let in_proj_weight = vb.get((embed_dim * 3, embed_dim), "in_proj_weight")?;
251        let q_proj = Linear::new(
252            in_proj_weight.i(0..embed_dim)?,
253            Some(in_proj_bias.i(0..embed_dim)?),
254        );
255        let k_proj = Linear::new(
256            in_proj_weight.i(embed_dim..embed_dim * 2)?,
257            Some(in_proj_bias.i(embed_dim..embed_dim * 2)?),
258        );
259        let v_proj = Linear::new(
260            in_proj_weight.i(embed_dim * 2..embed_dim * 3)?,
261            Some(in_proj_bias.i(embed_dim * 2..embed_dim * 3)?),
262        );
263        let out_proj = layers::linear(embed_dim, embed_dim, vb.pp("out_proj"))?;
264        Ok(Self {
265            q_proj,
266            k_proj,
267            v_proj,
268            out_proj,
269            num_heads,
270            head_dim: embed_dim / num_heads,
271            in_proj_weight,
272            in_proj_bias,
273        })
274    }
275
276    fn forward(
277        &self,
278        q: &Tensor,
279        k: &Tensor,
280        v: &Tensor,
281        key_padding_mask: Option<Tensor>,
282        mut attn_mask: Option<Tensor>,
283    ) -> Result<Tensor> {
284        let (bs, q_seq, _) = q.dims3()?;
285        let (_, kv_seq, _) = k.dims3()?;
286
287        let mut q = q.apply(&self.q_proj)?;
288        let mut k = k.apply(&self.k_proj)?;
289        let mut v = v.apply(&self.v_proj)?;
290
291        // Merge key padding and attention masks
292        if let Some(mut key_padding_mask) = key_padding_mask {
293            key_padding_mask = key_padding_mask
294                .reshape((bs, 1, 1, kv_seq))?
295                .repeat((1, self.num_heads, 1, 1))?
296                .reshape((bs * self.num_heads, 1, kv_seq))?;
297            if let Some(attn_mask) = attn_mask.as_mut() {
298                *attn_mask = attn_mask.broadcast_add(&key_padding_mask)?;
299            } else {
300                attn_mask = Some(key_padding_mask);
301            }
302        }
303
304        q = q
305            .reshape((bs, q_seq, self.num_heads, self.head_dim))?
306            .transpose(1, 2)?
307            .contiguous()?;
308        k = k
309            .reshape((bs, kv_seq, self.num_heads, self.head_dim))?
310            .transpose(1, 2)?
311            .contiguous()?;
312        v = v
313            .reshape((bs, kv_seq, self.num_heads, self.head_dim))?
314            .transpose(1, 2)?
315            .contiguous()?;
316
317        let mut y = {
318            let mut att =
319                MatMul.matmul_affine_mul(&q, &k.t()?, (1. / self.head_dim as f64).sqrt())?;
320
321            att = match attn_mask {
322                Some(mask) => {
323                    let mask = mask.reshape((bs, self.num_heads, (), kv_seq))?;
324                    masked_fill(&att, &mask, att.dtype().finfo()?.min as f32)?
325                }
326                None => att,
327            };
328            candle_nn::ops::inplace_softmax_last_dim(&mut att)?;
329            MatMul.matmul(&att, &v)?
330        };
331
332        y = y.transpose(1, 2)?.reshape((bs, q_seq, ()))?;
333        y.apply(&self.out_proj)
334    }
335}