1#![allow(clippy::cast_possible_truncation, clippy::cast_precision_loss)]
2
3use std::sync::{Arc, Mutex};
4
5use candle_core::{DType, Device, IndexOp, Result, Tensor, D};
6use candle_nn::{LayerNorm, Linear};
7use mistralrs_quant::{MatMul, ShardedVarBuilder};
8
9use crate::{
10 layers::{self, layer_norm, GetFloatInfo},
11 layers_masker::masked_fill,
12 utils::unvarbuilder::UnVarBuilder,
13};
14
15const DEFAULT_MAX_SIZE: (usize, usize) = (70, 70);
16
17fn get_2d_sincos_pos_embed(
18 embed_dim: usize,
19 image_size: (usize, usize),
20 device: &Device,
21 dtype: DType,
22) -> Result<Tensor> {
23 let (grid_h_size, grid_w_size) = image_size;
24 let grid_h = Tensor::arange(0f32, grid_h_size as f32, device)?;
25 let grid_w = Tensor::arange(0f32, grid_w_size as f32, device)?;
26 let grid = Tensor::meshgrid(&[grid_w, grid_h], true)?;
28 let grid = Tensor::stack(&grid, 0)?;
29
30 get_2d_sincos_pos_embed_from_grid(embed_dim, &grid)?.to_dtype(dtype)
31}
32
33fn get_2d_sincos_pos_embed_from_grid(embed_dim: usize, grid: &Tensor) -> Result<Tensor> {
34 assert_eq!(embed_dim % 2, 0);
35
36 let emb_h = get_1d_sincos_pos_embed_from_grid_new(embed_dim / 2, &grid.i(0)?)?;
37 let emb_w = get_1d_sincos_pos_embed_from_grid_new(embed_dim / 2, &grid.i(1)?)?;
38
39 Tensor::cat(&[emb_h, emb_w], D::Minus1)
40}
41
42fn get_1d_sincos_pos_embed_from_grid_new(embed_dim: usize, pos: &Tensor) -> Result<Tensor> {
43 let inv_freq: Vec<_> = (0..embed_dim)
44 .step_by(2)
45 .map(|i| 1f32 / 10_000f32.powf(i as f32 / embed_dim as f32))
46 .collect();
47 let inv_freq_len = inv_freq.len();
48 let omega = Tensor::from_vec(inv_freq, (1, inv_freq_len), pos.device())?;
49
50 let (h, w) = pos.dims2()?;
51
52 let mut out = pos
53 .reshape(((), 1))?
54 .matmul(&omega.reshape((1, ()))?)
55 .unwrap();
56
57 out = out.reshape((h, w, ()))?;
58
59 let emb_sin = out.sin()?;
60 let emb_cos = out.cos()?;
61
62 Tensor::cat(&[emb_sin, emb_cos], D::Minus1)
63}
64
65struct SinCos2dPosEmbed {
66 pos_embed: Tensor,
67 max_size: (usize, usize),
68}
69
70pub struct Resampler {
71 query: Tensor,
72 kv_proj: Option<Linear>,
73 proj: Tensor,
74 ln_q: LayerNorm,
75 ln_kv: LayerNorm,
76 ln_post: LayerNorm,
77 attn: MultiheadAttention,
78 sincos_pos_embed: Arc<Mutex<SinCos2dPosEmbed>>,
79 embed_dim: usize,
80}
81
82impl Resampler {
83 pub fn new(
84 num_queries: usize,
85 embed_dim: usize,
86 num_heads: usize,
87 kv_dim: usize,
88 _adaptive: bool,
89 max_size: Option<(usize, usize)>,
90 vb: ShardedVarBuilder,
91 ) -> Result<Self> {
92 let max_size = max_size.unwrap_or(DEFAULT_MAX_SIZE);
93
94 let query = vb.get((num_queries, embed_dim), "query")?;
95 let kv_proj = if kv_dim != embed_dim {
96 Some(layers::linear_no_bias(kv_dim, embed_dim, vb.pp("kv_proj"))?)
97 } else {
98 None
99 };
100 let ln_q = layer_norm(embed_dim, 1e-6, vb.pp("ln_q"))?;
101 let ln_kv = layer_norm(embed_dim, 1e-6, vb.pp("ln_kv"))?;
102 let ln_post = layer_norm(embed_dim, 1e-6, vb.pp("ln_post"))?;
103 let proj = vb.get((embed_dim, embed_dim), "proj")?;
104 let attn = MultiheadAttention::new(embed_dim, num_heads, vb.pp("attn"))?;
105
106 let pos_embed = Arc::new(Mutex::new(SinCos2dPosEmbed {
107 pos_embed: get_2d_sincos_pos_embed(embed_dim, max_size, vb.device(), vb.dtype())?,
108 max_size,
109 }));
110
111 Ok(Self {
112 query,
113 kv_proj,
114 proj,
115 ln_q,
116 ln_kv,
117 ln_post,
118 attn,
119 sincos_pos_embed: pos_embed,
120 embed_dim,
121 })
122 }
123
124 pub fn forward(&self, x: &Tensor, tgt_sizes_vec: &[Vec<u32>]) -> Result<Tensor> {
125 let mut pos_embed_cache = self.sincos_pos_embed.lock().unwrap();
126
127 let bs = x.dim(0)?;
128 let device = x.device();
129
130 assert_eq!(bs, tgt_sizes_vec.len());
131
132 let tgt_sizes_vec_0 = tgt_sizes_vec.iter().map(|x| x[0]).collect::<Vec<_>>();
133 let tgt_sizes_vec_1 = tgt_sizes_vec.iter().map(|x| x[1]).collect::<Vec<_>>();
134 let patch_len = tgt_sizes_vec_0
135 .iter()
136 .zip(&tgt_sizes_vec_1)
137 .map(|(x, y)| x * y)
138 .collect::<Vec<_>>();
139
140 {
142 let max_h = *tgt_sizes_vec_0.iter().max().unwrap() as usize;
143 let max_w = *tgt_sizes_vec_1.iter().max().unwrap() as usize;
144
145 if max_h > pos_embed_cache.max_size.0 || max_w > pos_embed_cache.max_size.1 {
146 pos_embed_cache.max_size = (
147 max_h.max(pos_embed_cache.max_size.0),
148 max_w.max(pos_embed_cache.max_size.1),
149 );
150 pos_embed_cache.pos_embed = get_2d_sincos_pos_embed(
151 self.embed_dim,
152 pos_embed_cache.max_size,
153 device,
154 x.dtype(),
155 )?;
156 }
157 }
158
159 let max_patch_len = *patch_len.iter().max().unwrap() as usize;
160
161 let mut key_padding_mask = Tensor::zeros((bs, max_patch_len), DType::U8, device)?;
162
163 let mut pos_embed = Vec::new();
164 for (i, tgt_sizes_vec_i) in tgt_sizes_vec.iter().enumerate().take(bs) {
165 let (tgt_h, tgt_w) = (tgt_sizes_vec_i[0] as usize, tgt_sizes_vec_i[1] as usize);
166 pos_embed.push(
167 pos_embed_cache
168 .pos_embed
169 .i((..tgt_h, ..tgt_w, ..))?
170 .reshape((tgt_h * tgt_w, ()))?,
171 );
172
173 let n = patch_len[i] as usize;
174 if n != max_patch_len {
175 key_padding_mask = key_padding_mask.slice_assign(
176 &[&i, &(n..)],
177 &Tensor::ones((1, max_patch_len - n), DType::U8, device)?,
178 )?;
179 }
180 }
181
182 let lens = pos_embed
183 .iter()
184 .map(|emb| emb.dim(0))
185 .collect::<Result<Vec<_>>>()?;
186 let max_len = lens.into_iter().max().expect("No pixe values somehow?");
187 pos_embed = pos_embed
188 .into_iter()
189 .map(|emb| emb.pad_with_zeros(0, 0, max_len - emb.dim(0)?))
190 .collect::<Result<Vec<_>>>()?;
191 let pos_embed = Tensor::stack(&pos_embed, 0)?;
192
193 let mut x = if let Some(kv_proj) = &self.kv_proj {
194 x.apply(kv_proj)?
195 } else {
196 x.clone()
197 };
198 x = x.apply(&self.ln_kv)?;
199
200 let q = self.query.apply(&self.ln_q)?;
201
202 let mut out = self.attn.forward(
203 &self.repeat_q_bs(&q, bs)?,
204 &(&x + &pos_embed)?,
205 &x,
206 Some(key_padding_mask),
207 None,
208 )?;
209
210 out = out.apply(&self.ln_post)?;
211 out.broadcast_matmul(&self.proj)
212 }
213
214 fn repeat_q_bs(&self, q: &Tensor, n: usize) -> Result<Tensor> {
215 q.unsqueeze(0)?.repeat((n, 1, 1))
216 }
217
218 pub fn residual_tensors(&self) -> Vec<(String, Tensor)> {
219 let uvb = UnVarBuilder::new();
220
221 let uvb_attn = uvb.pp("attn");
222 uvb_attn.pp("out_proj").add(&self.attn.out_proj);
223 uvb_attn.add_tensor("in_proj_weight", self.attn.in_proj_weight.clone());
224 uvb_attn.add_tensor("in_proj_bias", self.attn.in_proj_bias.clone());
225
226 uvb.pp("ln_kv").add(&self.ln_kv);
227 uvb.pp("ln_post").add(&self.ln_post);
228 uvb.pp("ln_q").add(&self.ln_q);
229 uvb.add_tensor("proj", self.proj.clone());
230 uvb.add_tensor("query", self.query.clone());
231
232 uvb.to_safetensors()
233 }
234}
235
236struct MultiheadAttention {
237 q_proj: Linear,
238 k_proj: Linear,
239 v_proj: Linear,
240 out_proj: Linear,
241 num_heads: usize,
242 head_dim: usize,
243 in_proj_weight: Tensor,
244 in_proj_bias: Tensor,
245}
246
247impl MultiheadAttention {
248 fn new(embed_dim: usize, num_heads: usize, vb: ShardedVarBuilder) -> Result<Self> {
249 let in_proj_bias = vb.get(embed_dim * 3, "in_proj_bias")?;
250 let in_proj_weight = vb.get((embed_dim * 3, embed_dim), "in_proj_weight")?;
251 let q_proj = Linear::new(
252 in_proj_weight.i(0..embed_dim)?,
253 Some(in_proj_bias.i(0..embed_dim)?),
254 );
255 let k_proj = Linear::new(
256 in_proj_weight.i(embed_dim..embed_dim * 2)?,
257 Some(in_proj_bias.i(embed_dim..embed_dim * 2)?),
258 );
259 let v_proj = Linear::new(
260 in_proj_weight.i(embed_dim * 2..embed_dim * 3)?,
261 Some(in_proj_bias.i(embed_dim * 2..embed_dim * 3)?),
262 );
263 let out_proj = layers::linear(embed_dim, embed_dim, vb.pp("out_proj"))?;
264 Ok(Self {
265 q_proj,
266 k_proj,
267 v_proj,
268 out_proj,
269 num_heads,
270 head_dim: embed_dim / num_heads,
271 in_proj_weight,
272 in_proj_bias,
273 })
274 }
275
276 fn forward(
277 &self,
278 q: &Tensor,
279 k: &Tensor,
280 v: &Tensor,
281 key_padding_mask: Option<Tensor>,
282 mut attn_mask: Option<Tensor>,
283 ) -> Result<Tensor> {
284 let (bs, q_seq, _) = q.dims3()?;
285 let (_, kv_seq, _) = k.dims3()?;
286
287 let mut q = q.apply(&self.q_proj)?;
288 let mut k = k.apply(&self.k_proj)?;
289 let mut v = v.apply(&self.v_proj)?;
290
291 if let Some(mut key_padding_mask) = key_padding_mask {
293 key_padding_mask = key_padding_mask
294 .reshape((bs, 1, 1, kv_seq))?
295 .repeat((1, self.num_heads, 1, 1))?
296 .reshape((bs * self.num_heads, 1, kv_seq))?;
297 if let Some(attn_mask) = attn_mask.as_mut() {
298 *attn_mask = attn_mask.broadcast_add(&key_padding_mask)?;
299 } else {
300 attn_mask = Some(key_padding_mask);
301 }
302 }
303
304 q = q
305 .reshape((bs, q_seq, self.num_heads, self.head_dim))?
306 .transpose(1, 2)?
307 .contiguous()?;
308 k = k
309 .reshape((bs, kv_seq, self.num_heads, self.head_dim))?
310 .transpose(1, 2)?
311 .contiguous()?;
312 v = v
313 .reshape((bs, kv_seq, self.num_heads, self.head_dim))?
314 .transpose(1, 2)?
315 .contiguous()?;
316
317 let mut y = {
318 let mut att =
319 MatMul.matmul_affine_mul(&q, &k.t()?, (1. / self.head_dim as f64).sqrt())?;
320
321 att = match attn_mask {
322 Some(mask) => {
323 let mask = mask.reshape((bs, self.num_heads, (), kv_seq))?;
324 masked_fill(&att, &mask, att.dtype().finfo()?.min as f32)?
325 }
326 None => att,
327 };
328 candle_nn::ops::inplace_softmax_last_dim(&mut att)?;
329 MatMul.matmul(&att, &v)?
330 };
331
332 y = y.transpose(1, 2)?.reshape((bs, q_seq, ()))?;
333 y.apply(&self.out_proj)
334 }
335}