1#![allow(clippy::cast_possible_truncation, clippy::cast_precision_loss)]
2
3use std::ops::Add;
4
5use candle_core::{DType, Device, Result, Tensor, WithDType, D};
6
7use crate::pipeline::KvCache;
8
9pub struct CausalMasker;
11
12pub fn masked_fill<D: WithDType>(xs: &Tensor, mask: &Tensor, value: D) -> Result<Tensor> {
15 let on_true = Tensor::full(value, xs.shape(), xs.device())?.to_dtype(xs.dtype())?;
16 let on_false = xs;
17 let res = mask
18 .broadcast_as(xs.shape())?
19 .where_cond(&on_true, on_false)?;
20 Ok(res)
21}
22
23pub trait PastKvLenCache {
24 fn get_past_kv_len(&self) -> Result<usize>;
25}
26
27impl PastKvLenCache for Vec<KvCache> {
28 fn get_past_kv_len(&self) -> Result<usize> {
29 let kv_cache_1 = &self[0];
30 Ok(kv_cache_1.current_seq_len())
31 }
32}
33
34impl PastKvLenCache for &[usize] {
35 fn get_past_kv_len(&self) -> Result<usize> {
36 if self.windows(2).all(|w| w[0] == w[1]) {
37 Ok(self[0])
38 } else {
39 Ok(0)
40 }
41 }
42}
43
44impl PastKvLenCache for Vec<Option<(Tensor, Tensor)>> {
45 fn get_past_kv_len(&self) -> Result<usize> {
46 let kv_cache_1 = &self[0];
47 if kv_cache_1.is_none() {
48 return Ok(0);
49 }
50 let k_cache_1 = &kv_cache_1.as_ref().unwrap().0;
51 Ok(k_cache_1.dims()[2])
52 }
53}
54
55impl CausalMasker {
56 fn make_mask(&self, tgt_len: usize, past_kv_len: usize, device: &Device) -> Result<Tensor> {
57 let offset = tgt_len + past_kv_len;
58 let mask: Vec<_> = (0..tgt_len)
59 .flat_map(|i| (0..offset).map(move |j| u8::from(j + tgt_len > i + offset)))
60 .collect();
61 Tensor::from_slice(&mask, (tgt_len, offset), device)
62 }
63
64 fn make_mask_chunked(
65 &self,
66 tgt_len: usize,
67 past_kv_len: usize,
68 chunk_size: usize,
69 device: &Device,
70 ) -> Result<Tensor> {
71 let offset = tgt_len + past_kv_len;
72 let mask: Vec<_> = (0..tgt_len)
73 .flat_map(|i| {
74 (0..offset).map(move |j| {
75 if j < past_kv_len {
77 return 0;
78 }
79
80 let j_adj = j - past_kv_len;
82
83 let i_block = i / chunk_size;
85 let j_block = j_adj / chunk_size;
86 let block_pos = (i_block as isize - j_block as isize).abs();
87
88 let token_pos = j_adj as isize - i as isize;
90
91 1 - u8::from((block_pos == 0) && (token_pos <= 0))
93 })
94 })
95 .collect();
96
97 Tensor::from_slice(&mask, (tgt_len, offset), device)
98 }
99
100 fn make_swa_mask(
101 &self,
102 tgt_len: usize,
103 seqlen_offset: usize,
104 sliding_window: usize,
105 device: &Device,
106 dtype: DType,
107 ) -> Result<Tensor> {
108 let mask: Vec<_> = (0..tgt_len)
109 .flat_map(|i| {
110 (0..tgt_len).map(move |j| {
111 if i < j || j + sliding_window < i {
112 f32::NEG_INFINITY
113 } else {
114 0.
115 }
116 })
117 })
118 .collect();
119 let mask = Tensor::from_slice(&mask, (tgt_len, tgt_len), device)?;
120 let mask = if seqlen_offset > 0 {
121 let mask0 = Tensor::zeros((tgt_len, seqlen_offset), DType::F32, device)?;
122 Tensor::cat(&[&mask0, &mask], D::Minus1)?
123 } else {
124 mask
125 };
126 mask.to_dtype(dtype)
127 }
128
129 pub fn expand_mask(
132 &self,
133 mask: &Tensor,
134 dtype: DType,
135 tgt_len: Option<usize>,
136 ) -> Result<Tensor> {
137 let (bs, src_len) = mask.dims2()?;
138
139 let expanded_mask = mask.unsqueeze(1)?.unsqueeze(1)?;
140 let expanded_mask = expanded_mask
141 .expand((bs, 1, tgt_len.unwrap_or(src_len), src_len))?
142 .to_dtype(dtype)?;
143
144 let inverted_mask = expanded_mask.neg()?.add(1.0f64)?;
145 masked_fill(
146 &inverted_mask,
147 &inverted_mask.to_dtype(DType::U8)?,
148 f32::MIN,
149 )
150 }
151
152 pub fn calculate_past_kv_len(
153 &self,
154 cache: &[Option<(Tensor, Tensor)>],
155 ) -> candle_core::Result<usize> {
156 let kv_cache_1 = &cache[0];
157 if kv_cache_1.is_none() {
158 return Ok(0);
159 }
160 let k_cache_1 = &kv_cache_1.as_ref().unwrap().0;
161 Ok(k_cache_1.dims()[2])
162 }
163
164 pub fn make_causal_mask_matrix(
165 &self,
166 input_ids: &Tensor,
167 cache: &dyn PastKvLenCache,
168 dtype: DType,
169 _n_attn_heads: usize,
170 ) -> Result<Option<Tensor>> {
171 let past_kv_len = cache.get_past_kv_len()?;
172 let (_b_sz, tgt_len) = input_ids.dims2()?;
173 if tgt_len == 1 {
174 return Ok(None);
175 }
176
177 if crate::using_flash_attn() && input_ids.device().is_cuda() {
179 return Ok(Some(Tensor::zeros((1, 1), dtype, input_ids.device())?));
180 }
181
182 let mut causal_mask = self
183 .make_mask(tgt_len, past_kv_len, input_ids.device())?
184 .to_dtype(DType::U8)?;
185
186 let zero = Tensor::new(0.0f32, input_ids.device())?;
187 causal_mask = {
188 let mut mask =
189 causal_mask.broadcast_as((causal_mask.dims()[0], causal_mask.dims()[1]))?;
190 mask = masked_fill(
192 &zero.to_dtype(dtype)?.broadcast_as(mask.shape())?,
193 &mask,
194 f32::NEG_INFINITY,
195 )?;
196 mask
197 };
198
199 Ok(Some(causal_mask))
200 }
201
202 pub fn make_chunked_mask_matrix(
203 &self,
204 input_ids: &Tensor,
205 chunk_size: usize,
206 cache: &dyn PastKvLenCache,
207 dtype: DType,
208 _n_attn_heads: usize,
209 ) -> Result<Option<Tensor>> {
210 let past_kv_len = cache.get_past_kv_len()?;
211 let (_b_sz, tgt_len) = input_ids.dims2()?;
212 if tgt_len == 1 {
213 return Ok(None);
214 }
215
216 let mut causal_mask = self
217 .make_mask_chunked(tgt_len, past_kv_len, chunk_size, input_ids.device())?
218 .to_dtype(DType::U8)?;
219
220 let zero = Tensor::new(0.0f32, input_ids.device())?;
221 causal_mask = {
222 let mut mask =
223 causal_mask.broadcast_as((causal_mask.dims()[0], causal_mask.dims()[1]))?;
224 mask = masked_fill(
226 &zero.to_dtype(dtype)?.broadcast_as(mask.shape())?,
227 &mask,
228 f32::NEG_INFINITY,
229 )?;
230 mask
231 };
232
233 Ok(Some(causal_mask))
234 }
235
236 pub fn make_sliding_window_causal_mask_matrix(
237 &self,
238 input_ids: &Tensor,
239 cache: &dyn PastKvLenCache,
240 sliding_window: Option<usize>,
241 dtype: DType,
242 n_attn_heads: usize,
243 ) -> Result<Option<Tensor>> {
244 if sliding_window.is_none() {
245 return self.make_causal_mask_matrix(input_ids, cache, dtype, n_attn_heads);
246 }
247 let (_b_sz, tgt_len) = input_ids.dims2()?;
248 let sliding_window = sliding_window.unwrap();
249
250 if tgt_len > 1 && crate::using_flash_attn() && input_ids.device().is_cuda() {
252 return Ok(Some(Tensor::zeros((1, 1), dtype, input_ids.device())?));
253 }
254
255 let past_kv_len = cache
258 .get_past_kv_len()?
259 .min(sliding_window.saturating_sub(tgt_len));
260 if tgt_len == 1 {
261 return Ok(None);
262 }
263
264 Ok(Some(self.make_swa_mask(
265 tgt_len,
266 past_kv_len,
267 sliding_window,
268 input_ids.device(),
269 dtype,
270 )?))
271 }
272
273 pub fn apply_mask_one_and_zero(
274 &self,
275 mask: &Option<Tensor>,
276 att: Tensor,
277 neg_inf: &Tensor,
278 ) -> Result<Tensor> {
279 match mask {
280 None => Ok(att),
281 Some(mask) => {
282 let mask = mask.broadcast_as(att.shape())?;
283 mask.where_cond(
284 &neg_inf
285 .to_device(att.device())?
286 .to_dtype(att.dtype())?
287 .broadcast_as(att.dims())?,
288 &att,
289 )
290 }
291 }
292 }
293}