mistralrs_core/
layers_masker.rs

1#![allow(clippy::cast_possible_truncation, clippy::cast_precision_loss)]
2
3use std::ops::Add;
4
5use candle_core::{DType, Device, Result, Tensor, WithDType, D};
6
7use crate::pipeline::KvCache;
8
9// https://github.com/huggingface/transformers/blob/main/src/transformers/modeling_attn_mask_utils.py
10pub struct CausalMasker;
11
12// https://github.com/mokeyish/candle-ext/blob/main/src/masked_fill.rs
13/// xs are on false (0), value is on true (1)
14pub fn masked_fill<D: WithDType>(xs: &Tensor, mask: &Tensor, value: D) -> Result<Tensor> {
15    let on_true = Tensor::full(value, xs.shape(), xs.device())?.to_dtype(xs.dtype())?;
16    let on_false = xs;
17    let res = mask
18        .broadcast_as(xs.shape())?
19        .where_cond(&on_true, on_false)?;
20    Ok(res)
21}
22
23pub struct NotACache;
24
25pub trait PastKvLenCache {
26    fn get_past_kv_len(&self) -> Result<usize>;
27}
28
29impl PastKvLenCache for NotACache {
30    fn get_past_kv_len(&self) -> Result<usize> {
31        Ok(0)
32    }
33}
34
35impl PastKvLenCache for Vec<KvCache> {
36    fn get_past_kv_len(&self) -> Result<usize> {
37        let kv_cache_1 = &self[0];
38        Ok(kv_cache_1.current_seq_len())
39    }
40}
41
42impl PastKvLenCache for &[usize] {
43    fn get_past_kv_len(&self) -> Result<usize> {
44        if self.windows(2).all(|w| w[0] == w[1]) {
45            Ok(self[0])
46        } else {
47            Ok(0)
48        }
49    }
50}
51
52impl PastKvLenCache for Vec<Option<(Tensor, Tensor)>> {
53    fn get_past_kv_len(&self) -> Result<usize> {
54        let kv_cache_1 = &self[0];
55        if kv_cache_1.is_none() {
56            return Ok(0);
57        }
58        let k_cache_1 = &kv_cache_1.as_ref().unwrap().0;
59        Ok(k_cache_1.dims()[2])
60    }
61}
62
63impl CausalMasker {
64    fn make_mask(&self, tgt_len: usize, past_kv_len: usize, device: &Device) -> Result<Tensor> {
65        let offset = tgt_len + past_kv_len;
66        let mask: Vec<_> = (0..tgt_len)
67            .flat_map(|i| (0..offset).map(move |j| u8::from(j + tgt_len > i + offset)))
68            .collect();
69        Tensor::from_slice(&mask, (tgt_len, offset), device)
70    }
71
72    fn make_mask_chunked(
73        &self,
74        tgt_len: usize,
75        past_kv_len: usize,
76        chunk_size: usize,
77        device: &Device,
78    ) -> Result<Tensor> {
79        let offset = tgt_len + past_kv_len;
80        let mask: Vec<_> = (0..tgt_len)
81            .flat_map(|i| {
82                (0..offset).map(move |j| {
83                    // For past key-value positions
84                    if j < past_kv_len {
85                        return 0;
86                    }
87
88                    // Adjust j to account for past_kv_len
89                    let j_adj = j - past_kv_len;
90
91                    // Calculate block position (equivalent to block_pos)
92                    let i_block = i / chunk_size;
93                    let j_block = j_adj / chunk_size;
94                    let block_pos = (i_block as isize - j_block as isize).abs();
95
96                    // Calculate token position (equivalent to token_pos)
97                    let token_pos = j_adj as isize - i as isize;
98
99                    // Apply mask conditions: same block and causal
100                    1 - u8::from((block_pos == 0) && (token_pos <= 0))
101                })
102            })
103            .collect();
104
105        Tensor::from_slice(&mask, (tgt_len, offset), device)
106    }
107
108    fn make_swa_mask(
109        &self,
110        tgt_len: usize,
111        seqlen_offset: usize,
112        sliding_window: usize,
113        device: &Device,
114        dtype: DType,
115    ) -> Result<Tensor> {
116        let mask: Vec<_> = (0..tgt_len)
117            .flat_map(|i| {
118                (0..tgt_len).map(move |j| {
119                    if i < j || j + sliding_window < i {
120                        f32::NEG_INFINITY
121                    } else {
122                        0.
123                    }
124                })
125            })
126            .collect();
127        let mask = Tensor::from_slice(&mask, (tgt_len, tgt_len), device)?;
128        let mask = if seqlen_offset > 0 {
129            let mask0 = Tensor::zeros((tgt_len, seqlen_offset), DType::F32, device)?;
130            Tensor::cat(&[&mask0, &mask], D::Minus1)?
131        } else {
132            mask
133        };
134        mask.to_dtype(dtype)
135    }
136
137    /// Expands a mask from (bs, seq_len) to (bs, 1, tgt_len, seq_len)
138    /// If tgt_len is None, use seq_len
139    pub fn expand_mask(
140        &self,
141        mask: &Tensor,
142        dtype: DType,
143        tgt_len: Option<usize>,
144    ) -> Result<Tensor> {
145        let (bs, src_len) = mask.dims2()?;
146
147        let expanded_mask = mask.unsqueeze(1)?.unsqueeze(1)?;
148        let expanded_mask = expanded_mask
149            .expand((bs, 1, tgt_len.unwrap_or(src_len), src_len))?
150            .to_dtype(dtype)?;
151
152        let inverted_mask = expanded_mask.neg()?.add(1.0f64)?;
153        masked_fill(
154            &inverted_mask,
155            &inverted_mask.to_dtype(DType::U8)?,
156            f32::MIN,
157        )
158    }
159
160    pub fn calculate_past_kv_len(
161        &self,
162        cache: &[Option<(Tensor, Tensor)>],
163    ) -> candle_core::Result<usize> {
164        let kv_cache_1 = &cache[0];
165        if kv_cache_1.is_none() {
166            return Ok(0);
167        }
168        let k_cache_1 = &kv_cache_1.as_ref().unwrap().0;
169        Ok(k_cache_1.dims()[2])
170    }
171
172    pub fn make_causal_mask_matrix(
173        &self,
174        input_ids: &Tensor,
175        cache: &dyn PastKvLenCache,
176        dtype: DType,
177        _n_attn_heads: usize,
178    ) -> Result<Option<Tensor>> {
179        let past_kv_len = cache.get_past_kv_len()?;
180        let (_b_sz, tgt_len) = input_ids.dims2()?;
181        if tgt_len == 1 {
182            return Ok(None);
183        }
184
185        // Avoid materializing large sliding-window masks when flash-attn on CUDA.
186        if crate::using_flash_attn() && input_ids.device().is_cuda() {
187            return Ok(Some(Tensor::zeros((1, 1), dtype, input_ids.device())?));
188        }
189
190        let mut causal_mask = self
191            .make_mask(tgt_len, past_kv_len, input_ids.device())?
192            .to_dtype(DType::U8)?;
193
194        let zero = Tensor::new(0.0f32, input_ids.device())?;
195        causal_mask = {
196            let mut mask =
197                causal_mask.broadcast_as((causal_mask.dims()[0], causal_mask.dims()[1]))?;
198            // Mask: 1 means use from x (add 0.0), 0 means mask out (add -inf)
199            mask = masked_fill(
200                &zero.to_dtype(dtype)?.broadcast_as(mask.shape())?,
201                &mask,
202                f32::NEG_INFINITY,
203            )?;
204            mask
205        };
206
207        Ok(Some(causal_mask))
208    }
209
210    pub fn make_chunked_mask_matrix(
211        &self,
212        input_ids: &Tensor,
213        chunk_size: usize,
214        cache: &dyn PastKvLenCache,
215        dtype: DType,
216        _n_attn_heads: usize,
217    ) -> Result<Option<Tensor>> {
218        let past_kv_len = cache.get_past_kv_len()?;
219        let (_b_sz, tgt_len) = input_ids.dims2()?;
220        if tgt_len == 1 {
221            return Ok(None);
222        }
223
224        let mut causal_mask = self
225            .make_mask_chunked(tgt_len, past_kv_len, chunk_size, input_ids.device())?
226            .to_dtype(DType::U8)?;
227
228        let zero = Tensor::new(0.0f32, input_ids.device())?;
229        causal_mask = {
230            let mut mask =
231                causal_mask.broadcast_as((causal_mask.dims()[0], causal_mask.dims()[1]))?;
232            // Mask: 1 means use from x (add 0.0), 0 means mask out (add -inf)
233            mask = masked_fill(
234                &zero.to_dtype(dtype)?.broadcast_as(mask.shape())?,
235                &mask,
236                f32::NEG_INFINITY,
237            )?;
238            mask
239        };
240
241        Ok(Some(causal_mask))
242    }
243
244    pub fn make_sliding_window_causal_mask_matrix(
245        &self,
246        input_ids: &Tensor,
247        cache: &dyn PastKvLenCache,
248        sliding_window: Option<usize>,
249        dtype: DType,
250        n_attn_heads: usize,
251    ) -> Result<Option<Tensor>> {
252        if sliding_window.is_none() {
253            return self.make_causal_mask_matrix(input_ids, cache, dtype, n_attn_heads);
254        }
255        let (_b_sz, tgt_len) = input_ids.dims2()?;
256        let sliding_window = sliding_window.unwrap();
257
258        // Avoid materializing large sliding-window masks when flash-attn on CUDA.
259        if tgt_len > 1 && crate::using_flash_attn() && input_ids.device().is_cuda() {
260            return Ok(Some(Tensor::zeros((1, 1), dtype, input_ids.device())?));
261        }
262
263        // Compare the past KV len to the sliding window size. If the past kv len is 0 (no prefix cache), then this will be 0.
264        // Otherwise, this will be the number required such that the mask fits the size of the k/v seqlen (usually sliding window)
265        let past_kv_len = cache
266            .get_past_kv_len()?
267            .min(sliding_window.saturating_sub(tgt_len));
268        if tgt_len == 1 {
269            return Ok(None);
270        }
271
272        Ok(Some(self.make_swa_mask(
273            tgt_len,
274            past_kv_len,
275            sliding_window,
276            input_ids.device(),
277            dtype,
278        )?))
279    }
280
281    pub fn apply_mask_one_and_zero(
282        &self,
283        mask: &Option<Tensor>,
284        att: Tensor,
285        neg_inf: &Tensor,
286    ) -> Result<Tensor> {
287        match mask {
288            None => Ok(att),
289            Some(mask) => {
290                let mask = mask.broadcast_as(att.shape())?;
291                mask.where_cond(
292                    &neg_inf
293                        .to_device(att.device())?
294                        .to_dtype(att.dtype())?
295                        .broadcast_as(att.dims())?,
296                    &att,
297                )
298            }
299        }
300    }
301}
302
303pub struct BidirectionalMasker;
304
305impl BidirectionalMasker {
306    fn make_swa_mask(
307        &self,
308        tgt_len: usize,
309        sliding_window: usize,
310        device: &Device,
311        dtype: DType,
312    ) -> Result<Tensor> {
313        let mask: Vec<_> = (0..tgt_len)
314            .flat_map(|i| {
315                (0..tgt_len).map(move |j| {
316                    // https://github.com/huggingface/transformers/blob/a0bf5a82eebf88ee9f52145be427f6f1541329f6/src/transformers/models/gemma3/modeling_gemma3.py#L478
317                    // A token can attend to any other token if their absolute distance is within the (exclusive) sliding window size (distance < sliding_window)."
318                    if (i as isize - j as isize).unsigned_abs() >= sliding_window {
319                        f32::NEG_INFINITY
320                    } else {
321                        0.
322                    }
323                })
324            })
325            .collect();
326        let mask = Tensor::from_slice(&mask, (tgt_len, tgt_len), device)?;
327        mask.to_dtype(dtype)
328    }
329
330    pub fn make_mask(&self, input_ids: &Tensor, dtype: DType) -> Result<Tensor> {
331        let (_b_sz, tgt_len) = input_ids.dims2()?;
332
333        // Avoid materializing large sliding-window masks when flash-attn on CUDA.
334        if crate::using_flash_attn() && input_ids.device().is_cuda() {
335            return Tensor::zeros((1, 1), dtype, input_ids.device());
336        }
337
338        // Do not make any -inf
339        let mask = Tensor::zeros((tgt_len, tgt_len), dtype, input_ids.device())?;
340
341        Ok(mask)
342    }
343    pub fn make_sliding_mask(
344        &self,
345        input_ids: &Tensor,
346        dtype: DType,
347        sliding_window: usize,
348    ) -> Result<Tensor> {
349        let (_b_sz, tgt_len) = input_ids.dims2()?;
350
351        // Avoid materializing large sliding-window masks when flash-attn on CUDA.
352        if crate::using_flash_attn() && input_ids.device().is_cuda() {
353            return Tensor::zeros((1, 1), dtype, input_ids.device());
354        }
355
356        let mask = self.make_swa_mask(tgt_len, sliding_window, input_ids.device(), dtype)?;
357
358        Ok(mask)
359    }
360}