1#![allow(clippy::cast_possible_truncation, clippy::cast_precision_loss)]
2
3use std::ops::Add;
4
5use candle_core::{DType, Device, Result, Tensor, WithDType, D};
6
7use crate::pipeline::KvCache;
8
9pub struct CausalMasker;
11
12pub fn masked_fill<D: WithDType>(xs: &Tensor, mask: &Tensor, value: D) -> Result<Tensor> {
15 let on_true = Tensor::full(value, xs.shape(), xs.device())?.to_dtype(xs.dtype())?;
16 let on_false = xs;
17 let res = mask
18 .broadcast_as(xs.shape())?
19 .where_cond(&on_true, on_false)?;
20 Ok(res)
21}
22
23pub struct NotACache;
24
25pub trait PastKvLenCache {
26 fn get_past_kv_len(&self) -> Result<usize>;
27}
28
29impl PastKvLenCache for NotACache {
30 fn get_past_kv_len(&self) -> Result<usize> {
31 Ok(0)
32 }
33}
34
35impl PastKvLenCache for Vec<KvCache> {
36 fn get_past_kv_len(&self) -> Result<usize> {
37 let kv_cache_1 = &self[0];
38 Ok(kv_cache_1.current_seq_len())
39 }
40}
41
42impl PastKvLenCache for &[usize] {
43 fn get_past_kv_len(&self) -> Result<usize> {
44 if self.windows(2).all(|w| w[0] == w[1]) {
45 Ok(self[0])
46 } else {
47 Ok(0)
48 }
49 }
50}
51
52impl PastKvLenCache for Vec<Option<(Tensor, Tensor)>> {
53 fn get_past_kv_len(&self) -> Result<usize> {
54 let kv_cache_1 = &self[0];
55 if kv_cache_1.is_none() {
56 return Ok(0);
57 }
58 let k_cache_1 = &kv_cache_1.as_ref().unwrap().0;
59 Ok(k_cache_1.dims()[2])
60 }
61}
62
63impl CausalMasker {
64 fn make_mask(&self, tgt_len: usize, past_kv_len: usize, device: &Device) -> Result<Tensor> {
65 let offset = tgt_len + past_kv_len;
66 let mask: Vec<_> = (0..tgt_len)
67 .flat_map(|i| (0..offset).map(move |j| u8::from(j + tgt_len > i + offset)))
68 .collect();
69 Tensor::from_slice(&mask, (tgt_len, offset), device)
70 }
71
72 fn make_mask_chunked(
73 &self,
74 tgt_len: usize,
75 past_kv_len: usize,
76 chunk_size: usize,
77 device: &Device,
78 ) -> Result<Tensor> {
79 let offset = tgt_len + past_kv_len;
80 let mask: Vec<_> = (0..tgt_len)
81 .flat_map(|i| {
82 (0..offset).map(move |j| {
83 if j < past_kv_len {
85 return 0;
86 }
87
88 let j_adj = j - past_kv_len;
90
91 let i_block = i / chunk_size;
93 let j_block = j_adj / chunk_size;
94 let block_pos = (i_block as isize - j_block as isize).abs();
95
96 let token_pos = j_adj as isize - i as isize;
98
99 1 - u8::from((block_pos == 0) && (token_pos <= 0))
101 })
102 })
103 .collect();
104
105 Tensor::from_slice(&mask, (tgt_len, offset), device)
106 }
107
108 fn make_swa_mask(
109 &self,
110 tgt_len: usize,
111 seqlen_offset: usize,
112 sliding_window: usize,
113 device: &Device,
114 dtype: DType,
115 ) -> Result<Tensor> {
116 let mask: Vec<_> = (0..tgt_len)
117 .flat_map(|i| {
118 (0..tgt_len).map(move |j| {
119 if i < j || j + sliding_window < i {
120 f32::NEG_INFINITY
121 } else {
122 0.
123 }
124 })
125 })
126 .collect();
127 let mask = Tensor::from_slice(&mask, (tgt_len, tgt_len), device)?;
128 let mask = if seqlen_offset > 0 {
129 let mask0 = Tensor::zeros((tgt_len, seqlen_offset), DType::F32, device)?;
130 Tensor::cat(&[&mask0, &mask], D::Minus1)?
131 } else {
132 mask
133 };
134 mask.to_dtype(dtype)
135 }
136
137 pub fn expand_mask(
140 &self,
141 mask: &Tensor,
142 dtype: DType,
143 tgt_len: Option<usize>,
144 ) -> Result<Tensor> {
145 let (bs, src_len) = mask.dims2()?;
146
147 let expanded_mask = mask.unsqueeze(1)?.unsqueeze(1)?;
148 let expanded_mask = expanded_mask
149 .expand((bs, 1, tgt_len.unwrap_or(src_len), src_len))?
150 .to_dtype(dtype)?;
151
152 let inverted_mask = expanded_mask.neg()?.add(1.0f64)?;
153 masked_fill(
154 &inverted_mask,
155 &inverted_mask.to_dtype(DType::U8)?,
156 f32::MIN,
157 )
158 }
159
160 pub fn calculate_past_kv_len(
161 &self,
162 cache: &[Option<(Tensor, Tensor)>],
163 ) -> candle_core::Result<usize> {
164 let kv_cache_1 = &cache[0];
165 if kv_cache_1.is_none() {
166 return Ok(0);
167 }
168 let k_cache_1 = &kv_cache_1.as_ref().unwrap().0;
169 Ok(k_cache_1.dims()[2])
170 }
171
172 pub fn make_causal_mask_matrix(
173 &self,
174 input_ids: &Tensor,
175 cache: &dyn PastKvLenCache,
176 dtype: DType,
177 _n_attn_heads: usize,
178 ) -> Result<Option<Tensor>> {
179 let past_kv_len = cache.get_past_kv_len()?;
180 let (_b_sz, tgt_len) = input_ids.dims2()?;
181 if tgt_len == 1 {
182 return Ok(None);
183 }
184
185 if crate::using_flash_attn() && input_ids.device().is_cuda() {
187 return Ok(Some(Tensor::zeros((1, 1), dtype, input_ids.device())?));
188 }
189
190 let mut causal_mask = self
191 .make_mask(tgt_len, past_kv_len, input_ids.device())?
192 .to_dtype(DType::U8)?;
193
194 let zero = Tensor::new(0.0f32, input_ids.device())?;
195 causal_mask = {
196 let mut mask =
197 causal_mask.broadcast_as((causal_mask.dims()[0], causal_mask.dims()[1]))?;
198 mask = masked_fill(
200 &zero.to_dtype(dtype)?.broadcast_as(mask.shape())?,
201 &mask,
202 f32::NEG_INFINITY,
203 )?;
204 mask
205 };
206
207 Ok(Some(causal_mask))
208 }
209
210 pub fn make_chunked_mask_matrix(
211 &self,
212 input_ids: &Tensor,
213 chunk_size: usize,
214 cache: &dyn PastKvLenCache,
215 dtype: DType,
216 _n_attn_heads: usize,
217 ) -> Result<Option<Tensor>> {
218 let past_kv_len = cache.get_past_kv_len()?;
219 let (_b_sz, tgt_len) = input_ids.dims2()?;
220 if tgt_len == 1 {
221 return Ok(None);
222 }
223
224 let mut causal_mask = self
225 .make_mask_chunked(tgt_len, past_kv_len, chunk_size, input_ids.device())?
226 .to_dtype(DType::U8)?;
227
228 let zero = Tensor::new(0.0f32, input_ids.device())?;
229 causal_mask = {
230 let mut mask =
231 causal_mask.broadcast_as((causal_mask.dims()[0], causal_mask.dims()[1]))?;
232 mask = masked_fill(
234 &zero.to_dtype(dtype)?.broadcast_as(mask.shape())?,
235 &mask,
236 f32::NEG_INFINITY,
237 )?;
238 mask
239 };
240
241 Ok(Some(causal_mask))
242 }
243
244 pub fn make_sliding_window_causal_mask_matrix(
245 &self,
246 input_ids: &Tensor,
247 cache: &dyn PastKvLenCache,
248 sliding_window: Option<usize>,
249 dtype: DType,
250 n_attn_heads: usize,
251 ) -> Result<Option<Tensor>> {
252 if sliding_window.is_none() {
253 return self.make_causal_mask_matrix(input_ids, cache, dtype, n_attn_heads);
254 }
255 let (_b_sz, tgt_len) = input_ids.dims2()?;
256 let sliding_window = sliding_window.unwrap();
257
258 if tgt_len > 1 && crate::using_flash_attn() && input_ids.device().is_cuda() {
260 return Ok(Some(Tensor::zeros((1, 1), dtype, input_ids.device())?));
261 }
262
263 let past_kv_len = cache
266 .get_past_kv_len()?
267 .min(sliding_window.saturating_sub(tgt_len));
268 if tgt_len == 1 {
269 return Ok(None);
270 }
271
272 Ok(Some(self.make_swa_mask(
273 tgt_len,
274 past_kv_len,
275 sliding_window,
276 input_ids.device(),
277 dtype,
278 )?))
279 }
280
281 pub fn apply_mask_one_and_zero(
282 &self,
283 mask: &Option<Tensor>,
284 att: Tensor,
285 neg_inf: &Tensor,
286 ) -> Result<Tensor> {
287 match mask {
288 None => Ok(att),
289 Some(mask) => {
290 let mask = mask.broadcast_as(att.shape())?;
291 mask.where_cond(
292 &neg_inf
293 .to_device(att.device())?
294 .to_dtype(att.dtype())?
295 .broadcast_as(att.dims())?,
296 &att,
297 )
298 }
299 }
300 }
301}
302
303pub struct BidirectionalMasker;
304
305impl BidirectionalMasker {
306 fn make_swa_mask(
307 &self,
308 tgt_len: usize,
309 sliding_window: usize,
310 device: &Device,
311 dtype: DType,
312 ) -> Result<Tensor> {
313 let mask: Vec<_> = (0..tgt_len)
314 .flat_map(|i| {
315 (0..tgt_len).map(move |j| {
316 if (i as isize - j as isize).unsigned_abs() >= sliding_window {
319 f32::NEG_INFINITY
320 } else {
321 0.
322 }
323 })
324 })
325 .collect();
326 let mask = Tensor::from_slice(&mask, (tgt_len, tgt_len), device)?;
327 mask.to_dtype(dtype)
328 }
329
330 pub fn make_mask(&self, input_ids: &Tensor, dtype: DType) -> Result<Tensor> {
331 let (_b_sz, tgt_len) = input_ids.dims2()?;
332
333 if crate::using_flash_attn() && input_ids.device().is_cuda() {
335 return Tensor::zeros((1, 1), dtype, input_ids.device());
336 }
337
338 let mask = Tensor::zeros((tgt_len, tgt_len), dtype, input_ids.device())?;
340
341 Ok(mask)
342 }
343 pub fn make_sliding_mask(
344 &self,
345 input_ids: &Tensor,
346 dtype: DType,
347 sliding_window: usize,
348 ) -> Result<Tensor> {
349 let (_b_sz, tgt_len) = input_ids.dims2()?;
350
351 if crate::using_flash_attn() && input_ids.device().is_cuda() {
353 return Tensor::zeros((1, 1), dtype, input_ids.device());
354 }
355
356 let mask = self.make_swa_mask(tgt_len, sliding_window, input_ids.device(), dtype)?;
357
358 Ok(mask)
359 }
360}