mistralrs_core/
layers_masker.rs

1#![allow(clippy::cast_possible_truncation, clippy::cast_precision_loss)]
2
3use std::ops::Add;
4
5use candle_core::{DType, Device, Result, Tensor, WithDType, D};
6
7use crate::pipeline::KvCache;
8
9// https://github.com/huggingface/transformers/blob/main/src/transformers/modeling_attn_mask_utils.py
10pub struct CausalMasker;
11
12// https://github.com/mokeyish/candle-ext/blob/main/src/masked_fill.rs
13/// xs are on false (0), value is on true (1)
14pub fn masked_fill<D: WithDType>(xs: &Tensor, mask: &Tensor, value: D) -> Result<Tensor> {
15    let on_true = Tensor::full(value, xs.shape(), xs.device())?.to_dtype(xs.dtype())?;
16    let on_false = xs;
17    let res = mask
18        .broadcast_as(xs.shape())?
19        .where_cond(&on_true, on_false)?;
20    Ok(res)
21}
22
23pub trait PastKvLenCache {
24    fn get_past_kv_len(&self) -> Result<usize>;
25}
26
27impl PastKvLenCache for Vec<KvCache> {
28    fn get_past_kv_len(&self) -> Result<usize> {
29        let kv_cache_1 = &self[0];
30        Ok(kv_cache_1.current_seq_len())
31    }
32}
33
34impl PastKvLenCache for &[usize] {
35    fn get_past_kv_len(&self) -> Result<usize> {
36        if self.windows(2).all(|w| w[0] == w[1]) {
37            Ok(self[0])
38        } else {
39            Ok(0)
40        }
41    }
42}
43
44impl PastKvLenCache for Vec<Option<(Tensor, Tensor)>> {
45    fn get_past_kv_len(&self) -> Result<usize> {
46        let kv_cache_1 = &self[0];
47        if kv_cache_1.is_none() {
48            return Ok(0);
49        }
50        let k_cache_1 = &kv_cache_1.as_ref().unwrap().0;
51        Ok(k_cache_1.dims()[2])
52    }
53}
54
55impl CausalMasker {
56    fn make_mask(&self, tgt_len: usize, past_kv_len: usize, device: &Device) -> Result<Tensor> {
57        let offset = tgt_len + past_kv_len;
58        let mask: Vec<_> = (0..tgt_len)
59            .flat_map(|i| (0..offset).map(move |j| u8::from(j + tgt_len > i + offset)))
60            .collect();
61        Tensor::from_slice(&mask, (tgt_len, offset), device)
62    }
63
64    fn make_mask_chunked(
65        &self,
66        tgt_len: usize,
67        past_kv_len: usize,
68        chunk_size: usize,
69        device: &Device,
70    ) -> Result<Tensor> {
71        let offset = tgt_len + past_kv_len;
72        let mask: Vec<_> = (0..tgt_len)
73            .flat_map(|i| {
74                (0..offset).map(move |j| {
75                    // For past key-value positions
76                    if j < past_kv_len {
77                        return 0;
78                    }
79
80                    // Adjust j to account for past_kv_len
81                    let j_adj = j - past_kv_len;
82
83                    // Calculate block position (equivalent to block_pos)
84                    let i_block = i / chunk_size;
85                    let j_block = j_adj / chunk_size;
86                    let block_pos = (i_block as isize - j_block as isize).abs();
87
88                    // Calculate token position (equivalent to token_pos)
89                    let token_pos = j_adj as isize - i as isize;
90
91                    // Apply mask conditions: same block and causal
92                    1 - u8::from((block_pos == 0) && (token_pos <= 0))
93                })
94            })
95            .collect();
96
97        Tensor::from_slice(&mask, (tgt_len, offset), device)
98    }
99
100    fn make_swa_mask(
101        &self,
102        tgt_len: usize,
103        seqlen_offset: usize,
104        sliding_window: usize,
105        device: &Device,
106        dtype: DType,
107    ) -> Result<Tensor> {
108        let mask: Vec<_> = (0..tgt_len)
109            .flat_map(|i| {
110                (0..tgt_len).map(move |j| {
111                    if i < j || j + sliding_window < i {
112                        f32::NEG_INFINITY
113                    } else {
114                        0.
115                    }
116                })
117            })
118            .collect();
119        let mask = Tensor::from_slice(&mask, (tgt_len, tgt_len), device)?;
120        let mask = if seqlen_offset > 0 {
121            let mask0 = Tensor::zeros((tgt_len, seqlen_offset), DType::F32, device)?;
122            Tensor::cat(&[&mask0, &mask], D::Minus1)?
123        } else {
124            mask
125        };
126        mask.to_dtype(dtype)
127    }
128
129    /// Expands a mask from (bs, seq_len) to (bs, 1, tgt_len, seq_len)
130    /// If tgt_len is None, use seq_len
131    pub fn expand_mask(
132        &self,
133        mask: &Tensor,
134        dtype: DType,
135        tgt_len: Option<usize>,
136    ) -> Result<Tensor> {
137        let (bs, src_len) = mask.dims2()?;
138
139        let expanded_mask = mask.unsqueeze(1)?.unsqueeze(1)?;
140        let expanded_mask = expanded_mask
141            .expand((bs, 1, tgt_len.unwrap_or(src_len), src_len))?
142            .to_dtype(dtype)?;
143
144        let inverted_mask = expanded_mask.neg()?.add(1.0f64)?;
145        masked_fill(
146            &inverted_mask,
147            &inverted_mask.to_dtype(DType::U8)?,
148            f32::MIN,
149        )
150    }
151
152    pub fn calculate_past_kv_len(
153        &self,
154        cache: &[Option<(Tensor, Tensor)>],
155    ) -> candle_core::Result<usize> {
156        let kv_cache_1 = &cache[0];
157        if kv_cache_1.is_none() {
158            return Ok(0);
159        }
160        let k_cache_1 = &kv_cache_1.as_ref().unwrap().0;
161        Ok(k_cache_1.dims()[2])
162    }
163
164    pub fn make_causal_mask_matrix(
165        &self,
166        input_ids: &Tensor,
167        cache: &dyn PastKvLenCache,
168        dtype: DType,
169        _n_attn_heads: usize,
170    ) -> Result<Option<Tensor>> {
171        let past_kv_len = cache.get_past_kv_len()?;
172        let (_b_sz, tgt_len) = input_ids.dims2()?;
173        if tgt_len == 1 {
174            return Ok(None);
175        }
176
177        let mut causal_mask = self
178            .make_mask(tgt_len, past_kv_len, input_ids.device())?
179            .to_dtype(DType::U8)?;
180
181        let zero = Tensor::new(0.0f32, input_ids.device())?;
182        causal_mask = {
183            let mut mask =
184                causal_mask.broadcast_as((causal_mask.dims()[0], causal_mask.dims()[1]))?;
185            // Mask: 1 means use from x (add 0.0), 0 means mask out (add -inf)
186            mask = masked_fill(
187                &zero.to_dtype(dtype)?.broadcast_as(mask.shape())?,
188                &mask,
189                f32::NEG_INFINITY,
190            )?;
191            mask
192        };
193
194        Ok(Some(causal_mask))
195    }
196
197    pub fn make_chunked_mask_matrix(
198        &self,
199        input_ids: &Tensor,
200        chunk_size: usize,
201        cache: &dyn PastKvLenCache,
202        dtype: DType,
203        _n_attn_heads: usize,
204    ) -> Result<Option<Tensor>> {
205        let past_kv_len = cache.get_past_kv_len()?;
206        let (_b_sz, tgt_len) = input_ids.dims2()?;
207        if tgt_len == 1 {
208            return Ok(None);
209        }
210
211        let mut causal_mask = self
212            .make_mask_chunked(tgt_len, past_kv_len, chunk_size, input_ids.device())?
213            .to_dtype(DType::U8)?;
214
215        let zero = Tensor::new(0.0f32, input_ids.device())?;
216        causal_mask = {
217            let mut mask =
218                causal_mask.broadcast_as((causal_mask.dims()[0], causal_mask.dims()[1]))?;
219            // Mask: 1 means use from x (add 0.0), 0 means mask out (add -inf)
220            mask = masked_fill(
221                &zero.to_dtype(dtype)?.broadcast_as(mask.shape())?,
222                &mask,
223                f32::NEG_INFINITY,
224            )?;
225            mask
226        };
227
228        Ok(Some(causal_mask))
229    }
230
231    pub fn make_sliding_window_causal_mask_matrix(
232        &self,
233        input_ids: &Tensor,
234        cache: &dyn PastKvLenCache,
235        sliding_window: Option<usize>,
236        dtype: DType,
237        n_attn_heads: usize,
238    ) -> Result<Option<Tensor>> {
239        if sliding_window.is_none() {
240            return self.make_causal_mask_matrix(input_ids, cache, dtype, n_attn_heads);
241        }
242        let (_b_sz, tgt_len) = input_ids.dims2()?;
243        let sliding_window = sliding_window.unwrap();
244        // Compare the past KV len to the sliding window size. If the past kv len is 0 (no prefix cache), then this will be 0.
245        // Otherwise, this will be the number required such that the mask fits the size of the k/v seqlen (usually sliding window)
246        let past_kv_len = cache
247            .get_past_kv_len()?
248            .min(sliding_window.saturating_sub(tgt_len));
249        if tgt_len == 1 {
250            return Ok(None);
251        }
252
253        Ok(Some(self.make_swa_mask(
254            tgt_len,
255            past_kv_len,
256            sliding_window,
257            input_ids.device(),
258            dtype,
259        )?))
260    }
261
262    pub fn apply_mask_one_and_zero(
263        &self,
264        mask: &Option<Tensor>,
265        att: Tensor,
266        neg_inf: &Tensor,
267    ) -> Result<Tensor> {
268        match mask {
269            None => Ok(att),
270            Some(mask) => {
271                let mask = mask.broadcast_as(att.shape())?;
272                mask.where_cond(
273                    &neg_inf
274                        .to_device(att.device())?
275                        .to_dtype(att.dtype())?
276                        .broadcast_as(att.dims())?,
277                    &att,
278                )
279            }
280        }
281    }
282}