mistralrs_core/
attention.rs

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
#![allow(clippy::cast_precision_loss)]

#[cfg(feature = "metal")]
use std::sync::atomic::AtomicUsize;

use crate::{
    cublaslt::CUBLASLT_HANDLE,
    layers::{get_use_matmul_via_f16, MatMul},
    pipeline::text_models_inputs_processor::FlashParams,
};

use candle_core::{Device, Result, Tensor};

#[cfg(feature = "metal")]
/// Initial, sentinel value is usize::MAX
static METAL_VERSION_CACHE: AtomicUsize = AtomicUsize::new(usize::MAX);

#[cfg(feature = "flash-attn")]
fn flash_attn(
    q: &Tensor,
    k: &Tensor,
    v: &Tensor,
    flash_params: Option<&crate::pipeline::text_models_inputs_processor::FlashParams>,
    sdpa_params: &SdpaParams,
) -> Result<Tensor> {
    let (_b_sz, _n_attn_heads, seq_len, _head_dim) = q.dims4()?;
    let causal = seq_len > 1;

    use crate::pipeline::text_models_inputs_processor::FlashParams;

    if let Some(FlashParams {
        max_q,
        max_k,
        cumulative_seqlens_q,
        cumulative_seqlens_k,
    }) = flash_params
    {
        let qshape = q.shape();
        let q = q.flatten_to(1)?;
        let k = k.flatten_to(1)?;
        let v = v.flatten_to(1)?;

        let window_size_left = sdpa_params.sliding_window;
        let window_size_right = if causal { Some(0) } else { None };

        //dbg!(&qshape);
        candle_flash_attn::flash_attn_varlen_windowed_softcap(
            &q,
            &k,
            &v,
            cumulative_seqlens_q,
            cumulative_seqlens_k,
            *max_q as usize,
            *max_k as usize,
            sdpa_params.softmax_scale,
            sdpa_params.softcap,
            window_size_left,
            window_size_right,
        )?
        .reshape(qshape)
    } else {
        candle_flash_attn::flash_attn_softcap(
            q,
            k,
            v,
            sdpa_params.softmax_scale,
            sdpa_params.softcap,
            causal,
        )
    }
}

#[cfg(not(feature = "flash-attn"))]
fn flash_attn(
    _: &Tensor,
    _: &Tensor,
    _: &Tensor,
    _: Option<&crate::pipeline::text_models_inputs_processor::FlashParams>,
    _: &SdpaParams,
) -> Result<Tensor> {
    unimplemented!("Compile with '--features flash-attn'")
}

fn repeat_kv(x: Tensor, n_rep: usize) -> Result<Tensor> {
    if n_rep == 1 {
        Ok(x)
    } else {
        let (b_sz, n_kv_head, seq_len, head_dim) = x.dims4()?;
        Tensor::cat(&vec![&x; n_rep], 2)?.reshape((b_sz, n_kv_head * n_rep, seq_len, head_dim))
    }
}

/// Computes softmax(QK^T*sqrt(d_k))V
fn naive_sdpa(
    q: &Tensor,
    k: &Tensor,
    v: &Tensor,
    mask: Option<&Tensor>,
    head_dim: usize,
    sdpa_params: &SdpaParams,
) -> Result<Tensor> {
    #[cfg(feature = "metal")]
    let supports_attn_softmax = {
        use std::sync::atomic::Ordering;
        let cache = METAL_VERSION_CACHE.load(Ordering::Relaxed);

        let version = if cache != usize::MAX {
            cache
        } else {
            // echo "__METAL_VERSION__" | xcrun -sdk macosx metal -E -x metal -P -

            use std::process::{Command, Stdio};

            // Create the `echo` command and pipe its output into `xcrun`
            let mut echo = Command::new("echo")
                .arg("__METAL_VERSION__")
                .stdout(Stdio::piped())
                .spawn()
                .expect("Failed to start echo command");

            echo.wait()?;

            // Run the `xcrun` command, taking input from the `echo` command's output
            let output = Command::new("xcrun")
                .arg("-sdk")
                .arg("macosx")
                .arg("metal")
                .arg("-E")
                .arg("-x")
                .arg("metal")
                .arg("-P")
                .arg("-")
                .stdin(echo.stdout.unwrap())
                .output()
                .expect("Failed to run xcrun command");

            // Handle the output
            if output.status.success() {
                let version = String::from_utf8_lossy(&output.stdout)
                    .split('\n')
                    .nth(1)
                    .unwrap()
                    .trim()
                    .to_string()
                    .parse::<usize>()
                    .unwrap();
                METAL_VERSION_CACHE.store(version, Ordering::Relaxed);
                version
            } else {
                let stderr = String::from_utf8_lossy(&output.stderr);
                panic!("Error:\n{}", stderr);
            }
        };
        // Attn softmax is only supported for metal >= 310
        version >= 310
    };

    #[cfg(not(feature = "metal"))]
    let supports_attn_softmax = true;

    // Use faster softmax if mask is rank 2 or it's rank 3 and bs 1
    if mask.is_some_and(|mask| mask.rank() == 2 || (mask.rank() == 3 && mask.dims()[0] == 1))
        && supports_attn_softmax
    {
        let n_attn_heads = q.dim(1)?;
        let bs = q.dim(0)?;
        let attention_bias = match mask {
            Some(mask) if mask.rank() == 3 && mask.dims()[0] == 1 => {
                mask.unsqueeze(0)?.repeat((bs, n_attn_heads, 1, 1))?
            }
            Some(mask) if mask.rank() == 3 => mask.unsqueeze(0)?,
            Some(mask) if mask.rank() == 2 => {
                mask.unsqueeze(0)?
                    .unsqueeze(0)?
                    .repeat((bs, n_attn_heads, 1, 1))?
            }
            Some(mask) if mask.rank() == 4 => mask.clone(),
            _ => candle_core::bail!("unsupported mask {mask:?}"),
        };
        let mut att = attention_bias;

        q.matmul_with_alpha_beta(
            &k.t()?,
            &mut att,
            Some((sdpa_params.softmax_scale / sdpa_params.softcap.unwrap_or(1.0)) as f64),
        )?;

        if let Some(softcap) = sdpa_params.softcap {
            att = (att.tanh()? * softcap as f64)?;
        }

        candle_nn::ops::inplace_softmax_last_dim(&mut att)?;

        MatMul.matmul(&att, v)
    } else if let Some(mask) = mask {
        let mut att = MatMul.matmul_affine_div(q, &k.t()?, (head_dim as f64).sqrt())?;
        if let Some(softcap) = sdpa_params.softcap {
            att = (att / softcap as f64)?;
            att = att.tanh()?;
            att = (att * softcap as f64)?;
        }

        att = att.broadcast_add(mask)?;
        candle_nn::ops::inplace_softmax_last_dim(&mut att)?;
        MatMul.matmul(&att, v)
    } else {
        let mut att = MatMul.matmul_affine_div(q, &k.t()?, (head_dim as f64).sqrt())?;
        if let Some(softcap) = sdpa_params.softcap {
            att = (att / softcap as f64)?;
            att = att.tanh()?;
            att = (att * softcap as f64)?;
        }

        candle_nn::ops::inplace_softmax_last_dim(&mut att)?;
        MatMul.matmul(&att, v)
    }
}

pub struct SdpaParams {
    pub n_kv_groups: usize,
    pub use_flash_attn: bool,
    pub softcap: Option<f32>,
    pub softmax_scale: f32,
    pub sliding_window: Option<usize>,
}

pub struct Sdpa;

impl Sdpa {
    /// Computes softmax(QK^T*sqrt(d_k))V
    ///
    /// Inputs:
    /// - q: (b_sz, n_attn_heads, q_len, head_dim)
    /// - k: (b_sz, n_kv_heads, q_len, head_dim)
    /// - v: (b_sz, n_kv_heads, q_len, head_dim)
    ///
    /// The attention implementation is dispatched as follows:
    /// 1) If `use_flash_attn == true`, use a flash attention V2 kernel
    /// 2) If using CUDA and the cuBLASLt kernel is initialized, then it will use an optimized version.
    /// 3) Otherwise, use the "naive" SDPA implementation.
    #[allow(unused_variables, clippy::too_many_arguments)]
    pub fn run_attention(
        &self,
        q: &Tensor,
        k: &Tensor,
        v: &Tensor,
        mask: Option<&Tensor>,
        flash_params: Option<&FlashParams>,
        sdpa_params: &SdpaParams,
    ) -> Result<Tensor> {
        let (b_sz, n_attn_heads, seq_len, head_dim) = q.dims4()?;
        if sdpa_params.use_flash_attn {
            // flash-attn expects (b_sz, seq_len, nheads, head_dim)
            let q = q.transpose(1, 2)?;
            let k = k.transpose(1, 2)?;
            let v = v.transpose(1, 2)?;
            return flash_attn(&q, &k, &v, flash_params, sdpa_params)?.transpose(1, 2);
        }

        if q.device().is_metal() && seq_len == 1 {
            return candle_nn::ops::sdpa(
                q,
                k,
                v,
                sdpa_params.softmax_scale,
                sdpa_params.softcap.unwrap_or(1.0),
            );
        }

        let k = repeat_kv(k.clone(), sdpa_params.n_kv_groups)?;
        let v = repeat_kv(v.clone(), sdpa_params.n_kv_groups)?;
        if let (Device::Cuda(_), Some(cublaslt)) = (q.device(), *CUBLASLT_HANDLE.lock().unwrap()) {
            if !get_use_matmul_via_f16() {
                #[cfg(feature = "cuda")]
                {
                    // cuBLASLt batch matmul implementation requires inputs to be dims3
                    let k = k.flatten(0, 1)?;
                    let q = q.flatten(0, 1)?;
                    let v = v.flatten(0, 1)?;
                    let attention_bias = match mask {
                        Some(mask) if mask.rank() == 3 && mask.dims()[0] == 1 => {
                            Some(mask.repeat((n_attn_heads, 1, 1))?)
                        }
                        Some(mask) if mask.rank() == 3 => Some(mask.clone()),
                        Some(mask) if mask.rank() == 4 => Some(mask.flatten(0, 1)?),
                        Some(mask) => {
                            candle_core::bail!("cublaslt attn mask: rank must be 3 or 4")
                        }
                        None => None,
                    };

                    // If attention_bias is set, we fuse the add by giving it as the output matrix
                    // and setting beta to 1.0
                    let beta = match attention_bias.is_some() {
                        true => Some(1.0),
                        false => None,
                    };

                    // Batch matrix multiplication
                    // Fuse softmax scale and attention_bias add
                    let mut attention_scores = cublaslt.batch_matmul(
                        &k,
                        &q,
                        attention_bias.as_ref(),
                        Some(sdpa_params.softmax_scale / sdpa_params.softcap.unwrap_or(1.0)),
                        beta,
                        None,
                        None,
                    )?;
                    if let Some(softcap) = sdpa_params.softcap {
                        attention_scores = (attention_scores.tanh()? * softcap as f64)?;
                    }
                    candle_nn::ops::inplace_softmax_last_dim(&mut attention_scores)?;

                    let context_layer = cublaslt.batch_matmul(
                        &v.t()?.contiguous()?,
                        &attention_scores,
                        // We save one allocation
                        Some(&q),
                        None,
                        None,
                        None,
                        None,
                    )?;

                    // Reshape to dims4
                    context_layer.reshape((b_sz, n_attn_heads, seq_len, head_dim))
                }
                #[cfg(not(feature = "cuda"))]
                {
                    candle_core::bail!("`cuda` feature is not enabled")
                }
            } else {
                // Use the f16 kernels here if quantized (ISQ or GGML), and a large enough prompt
                naive_sdpa(q, &k, &v, mask, head_dim, sdpa_params)
            }
        } else {
            naive_sdpa(q, &k, &v, mask, head_dim, sdpa_params)
        }
    }
}