mistralrs_core/
layers_masker.rs

1#![allow(clippy::cast_possible_truncation, clippy::cast_precision_loss)]
2
3use std::ops::Add;
4
5use candle_core::{DType, Device, Result, Tensor, WithDType, D};
6
7use crate::pipeline::KvCache;
8
9// https://github.com/huggingface/transformers/blob/main/src/transformers/modeling_attn_mask_utils.py
10pub struct CausalMasker;
11
12// https://github.com/mokeyish/candle-ext/blob/main/src/masked_fill.rs
13/// xs are on false (0), value is on true (1)
14pub fn masked_fill<D: WithDType>(xs: &Tensor, mask: &Tensor, value: D) -> Result<Tensor> {
15    let on_true = Tensor::full(value, xs.shape(), xs.device())?.to_dtype(xs.dtype())?;
16    let on_false = xs;
17    let res = mask
18        .broadcast_as(xs.shape())?
19        .where_cond(&on_true, on_false)?;
20    Ok(res)
21}
22
23pub trait PastKvLenCache {
24    fn get_past_kv_len(&self) -> Result<usize>;
25}
26
27impl PastKvLenCache for Vec<KvCache> {
28    fn get_past_kv_len(&self) -> Result<usize> {
29        let kv_cache_1 = &self[0];
30        Ok(kv_cache_1.current_seq_len())
31    }
32}
33
34impl PastKvLenCache for &[usize] {
35    fn get_past_kv_len(&self) -> Result<usize> {
36        if self.windows(2).all(|w| w[0] == w[1]) {
37            Ok(self[0])
38        } else {
39            Ok(0)
40        }
41    }
42}
43
44impl PastKvLenCache for Vec<Option<(Tensor, Tensor)>> {
45    fn get_past_kv_len(&self) -> Result<usize> {
46        let kv_cache_1 = &self[0];
47        if kv_cache_1.is_none() {
48            return Ok(0);
49        }
50        let k_cache_1 = &kv_cache_1.as_ref().unwrap().0;
51        Ok(k_cache_1.dims()[2])
52    }
53}
54
55impl CausalMasker {
56    fn make_mask(&self, tgt_len: usize, past_kv_len: usize, device: &Device) -> Result<Tensor> {
57        let offset = tgt_len + past_kv_len;
58        let mask: Vec<_> = (0..tgt_len)
59            .flat_map(|i| (0..offset).map(move |j| u8::from(j + tgt_len > i + offset)))
60            .collect();
61        Tensor::from_slice(&mask, (tgt_len, offset), device)
62    }
63
64    fn make_mask_chunked(
65        &self,
66        tgt_len: usize,
67        past_kv_len: usize,
68        chunk_size: usize,
69        device: &Device,
70    ) -> Result<Tensor> {
71        let offset = tgt_len + past_kv_len;
72        let mask: Vec<_> = (0..tgt_len)
73            .flat_map(|i| {
74                (0..offset).map(move |j| {
75                    // For past key-value positions
76                    if j < past_kv_len {
77                        return 0;
78                    }
79
80                    // Adjust j to account for past_kv_len
81                    let j_adj = j - past_kv_len;
82
83                    // Calculate block position (equivalent to block_pos)
84                    let i_block = i / chunk_size;
85                    let j_block = j_adj / chunk_size;
86                    let block_pos = (i_block as isize - j_block as isize).abs();
87
88                    // Calculate token position (equivalent to token_pos)
89                    let token_pos = j_adj as isize - i as isize;
90
91                    // Apply mask conditions: same block and causal
92                    1 - u8::from((block_pos == 0) && (token_pos <= 0))
93                })
94            })
95            .collect();
96
97        Tensor::from_slice(&mask, (tgt_len, offset), device)
98    }
99
100    fn make_swa_mask(
101        &self,
102        tgt_len: usize,
103        seqlen_offset: usize,
104        sliding_window: usize,
105        device: &Device,
106        dtype: DType,
107    ) -> Result<Tensor> {
108        let mask: Vec<_> = (0..tgt_len)
109            .flat_map(|i| {
110                (0..tgt_len).map(move |j| {
111                    if i < j || j + sliding_window < i {
112                        f32::NEG_INFINITY
113                    } else {
114                        0.
115                    }
116                })
117            })
118            .collect();
119        let mask = Tensor::from_slice(&mask, (tgt_len, tgt_len), device)?;
120        let mask = if seqlen_offset > 0 {
121            let mask0 = Tensor::zeros((tgt_len, seqlen_offset), DType::F32, device)?;
122            Tensor::cat(&[&mask0, &mask], D::Minus1)?
123        } else {
124            mask
125        };
126        mask.to_dtype(dtype)
127    }
128
129    /// Expands a mask from (bs, seq_len) to (bs, 1, tgt_len, seq_len)
130    /// If tgt_len is None, use seq_len
131    pub fn expand_mask(
132        &self,
133        mask: &Tensor,
134        dtype: DType,
135        tgt_len: Option<usize>,
136    ) -> Result<Tensor> {
137        let (bs, src_len) = mask.dims2()?;
138
139        let expanded_mask = mask.unsqueeze(1)?.unsqueeze(1)?;
140        let expanded_mask = expanded_mask
141            .expand((bs, 1, tgt_len.unwrap_or(src_len), src_len))?
142            .to_dtype(dtype)?;
143
144        let inverted_mask = expanded_mask.neg()?.add(1.0f64)?;
145        masked_fill(
146            &inverted_mask,
147            &inverted_mask.to_dtype(DType::U8)?,
148            f32::MIN,
149        )
150    }
151
152    pub fn calculate_past_kv_len(
153        &self,
154        cache: &[Option<(Tensor, Tensor)>],
155    ) -> candle_core::Result<usize> {
156        let kv_cache_1 = &cache[0];
157        if kv_cache_1.is_none() {
158            return Ok(0);
159        }
160        let k_cache_1 = &kv_cache_1.as_ref().unwrap().0;
161        Ok(k_cache_1.dims()[2])
162    }
163
164    pub fn make_causal_mask_matrix(
165        &self,
166        input_ids: &Tensor,
167        cache: &dyn PastKvLenCache,
168        dtype: DType,
169        _n_attn_heads: usize,
170    ) -> Result<Option<Tensor>> {
171        let past_kv_len = cache.get_past_kv_len()?;
172        let (_b_sz, tgt_len) = input_ids.dims2()?;
173        if tgt_len == 1 {
174            return Ok(None);
175        }
176
177        // Avoid materializing large sliding-window masks when flash-attn on CUDA.
178        if crate::using_flash_attn() && input_ids.device().is_cuda() {
179            return Ok(Some(Tensor::zeros((1, 1), dtype, input_ids.device())?));
180        }
181
182        let mut causal_mask = self
183            .make_mask(tgt_len, past_kv_len, input_ids.device())?
184            .to_dtype(DType::U8)?;
185
186        let zero = Tensor::new(0.0f32, input_ids.device())?;
187        causal_mask = {
188            let mut mask =
189                causal_mask.broadcast_as((causal_mask.dims()[0], causal_mask.dims()[1]))?;
190            // Mask: 1 means use from x (add 0.0), 0 means mask out (add -inf)
191            mask = masked_fill(
192                &zero.to_dtype(dtype)?.broadcast_as(mask.shape())?,
193                &mask,
194                f32::NEG_INFINITY,
195            )?;
196            mask
197        };
198
199        Ok(Some(causal_mask))
200    }
201
202    pub fn make_chunked_mask_matrix(
203        &self,
204        input_ids: &Tensor,
205        chunk_size: usize,
206        cache: &dyn PastKvLenCache,
207        dtype: DType,
208        _n_attn_heads: usize,
209    ) -> Result<Option<Tensor>> {
210        let past_kv_len = cache.get_past_kv_len()?;
211        let (_b_sz, tgt_len) = input_ids.dims2()?;
212        if tgt_len == 1 {
213            return Ok(None);
214        }
215
216        let mut causal_mask = self
217            .make_mask_chunked(tgt_len, past_kv_len, chunk_size, input_ids.device())?
218            .to_dtype(DType::U8)?;
219
220        let zero = Tensor::new(0.0f32, input_ids.device())?;
221        causal_mask = {
222            let mut mask =
223                causal_mask.broadcast_as((causal_mask.dims()[0], causal_mask.dims()[1]))?;
224            // Mask: 1 means use from x (add 0.0), 0 means mask out (add -inf)
225            mask = masked_fill(
226                &zero.to_dtype(dtype)?.broadcast_as(mask.shape())?,
227                &mask,
228                f32::NEG_INFINITY,
229            )?;
230            mask
231        };
232
233        Ok(Some(causal_mask))
234    }
235
236    pub fn make_sliding_window_causal_mask_matrix(
237        &self,
238        input_ids: &Tensor,
239        cache: &dyn PastKvLenCache,
240        sliding_window: Option<usize>,
241        dtype: DType,
242        n_attn_heads: usize,
243    ) -> Result<Option<Tensor>> {
244        if sliding_window.is_none() {
245            return self.make_causal_mask_matrix(input_ids, cache, dtype, n_attn_heads);
246        }
247        let (_b_sz, tgt_len) = input_ids.dims2()?;
248        let sliding_window = sliding_window.unwrap();
249
250        // Avoid materializing large sliding-window masks when flash-attn on CUDA.
251        if tgt_len > 1 && crate::using_flash_attn() && input_ids.device().is_cuda() {
252            return Ok(Some(Tensor::zeros((1, 1), dtype, input_ids.device())?));
253        }
254
255        // Compare the past KV len to the sliding window size. If the past kv len is 0 (no prefix cache), then this will be 0.
256        // Otherwise, this will be the number required such that the mask fits the size of the k/v seqlen (usually sliding window)
257        let past_kv_len = cache
258            .get_past_kv_len()?
259            .min(sliding_window.saturating_sub(tgt_len));
260        if tgt_len == 1 {
261            return Ok(None);
262        }
263
264        Ok(Some(self.make_swa_mask(
265            tgt_len,
266            past_kv_len,
267            sliding_window,
268            input_ids.device(),
269            dtype,
270        )?))
271    }
272
273    pub fn apply_mask_one_and_zero(
274        &self,
275        mask: &Option<Tensor>,
276        att: Tensor,
277        neg_inf: &Tensor,
278    ) -> Result<Tensor> {
279        match mask {
280            None => Ok(att),
281            Some(mask) => {
282                let mask = mask.broadcast_as(att.shape())?;
283                mask.where_cond(
284                    &neg_inf
285                        .to_device(att.device())?
286                        .to_dtype(att.dtype())?
287                        .broadcast_as(att.dims())?,
288                    &att,
289                )
290            }
291        }
292    }
293}