mistralrs_core/attention/
mod.rs

1#![allow(clippy::cast_possible_truncation, clippy::cast_precision_loss)]
2
3use crate::pipeline::text_models_inputs_processor::FlashParams;
4
5use candle_core::{Device, Result, Tensor};
6
7mod backends;
8
9#[allow(unused)]
10pub(crate) use backends::{flash_attn, maybe_synchronize, naive_sdpa};
11
12fn repeat_kv(x: Tensor, n_rep: usize) -> Result<Tensor> {
13    if n_rep == 1 {
14        Ok(x)
15    } else {
16        let (b_sz, n_kv_head, seq_len, head_dim) = x.dims4()?;
17        Tensor::cat(&vec![&x; n_rep], 2)?.reshape((b_sz, n_kv_head * n_rep, seq_len, head_dim))
18    }
19}
20
21pub struct SdpaParams {
22    pub n_kv_groups: usize,
23    pub softcap: Option<f32>,
24    pub softmax_scale: f32,
25    pub sliding_window: Option<usize>,
26}
27
28pub struct Sdpa;
29
30impl Sdpa {
31    /// Computes softmax(QK^T*sqrt(d_k))V
32    ///
33    /// Inputs:
34    /// - q: (b_sz, n_attn_heads, q_len, head_dim)
35    /// - k: (b_sz, n_kv_heads, q_len, head_dim)
36    /// - v: (b_sz, n_kv_heads, q_len, head_dim)
37    ///
38    /// The attention implementation is dispatched as follows:
39    /// 1) If using flash attn (CUDA), use a flash attention V2/V3 kernel
40    /// 2) If decoding and using a Metal device, use a fused kkernel
41    /// 2) Otherwise, use the "naive" SDPA implementation (with optimized mask+softmax+scale application)
42    #[allow(unused_variables, clippy::too_many_arguments)]
43    pub fn run_attention(
44        &self,
45        q: &Tensor,
46        k: &Tensor,
47        v: &Tensor,
48        mask: Option<&Tensor>,
49        flash_params: Option<&FlashParams>,
50        sdpa_params: &SdpaParams,
51    ) -> Result<Tensor> {
52        let (b_sz, n_attn_heads, seq_len, head_dim) = q.dims4()?;
53        let (_, _, _, k_head_dim) = k.dims4()?;
54        let (_, _, _, v_head_dim) = v.dims4()?;
55        if crate::using_flash_attn() && q.device().is_cuda() {
56            // flash-attn expects (b_sz, seq_len, nheads, head_dim)
57            let q = q.transpose(1, 2)?;
58            let k = k.transpose(1, 2)?;
59            let v = v.transpose(1, 2)?;
60            return flash_attn(&q, &k, &v, flash_params, sdpa_params)?.transpose(1, 2);
61        }
62
63        // We can use Metal SDPA (vector/full) if the mask is the correct size and head dims match.
64        // If the mask is provided, then softcapping isn't allowed - default back to naive SDPA
65        // Softcapping is implemented for vector SDPA.
66        let all_head_dims_match = head_dim == k_head_dim && k_head_dim == v_head_dim;
67        let tgt_mask_shape = vec![b_sz, n_attn_heads, seq_len, k.dim(2)?];
68        let can_use_mask = mask.is_none_or(|mask| {
69            mask.layout().broadcast_as(tgt_mask_shape.clone()).is_ok()
70                && sdpa_params.softcap.is_none_or(|x| x == 1.0)
71        });
72        let valid_head_dims: &[usize] = if seq_len == 1 {
73            &[32, 64, 72, 80, 96, 128, 256]
74        } else {
75            // Not sure why the full kernel doesn't like 256.
76            // [32, 64, 72, 80, 96, 128, 256]
77            &[32, 64, 72, 80, 96, 128]
78        };
79        if [q, k, v].into_iter().all(|x| x.device().is_metal())
80            && all_head_dims_match
81            && valid_head_dims.contains(&head_dim)
82            && can_use_mask
83        {
84            let mask = match mask {
85                Some(mask) => Some(mask.broadcast_as(tgt_mask_shape)?),
86                None => None,
87            };
88            return candle_nn::ops::sdpa(
89                q,
90                k,
91                v,
92                mask.as_ref(),
93                false,
94                sdpa_params.softmax_scale,
95                sdpa_params.softcap.unwrap_or(1.0),
96            );
97        }
98
99        let k = repeat_kv(k.clone(), sdpa_params.n_kv_groups)?;
100        let v = repeat_kv(v.clone(), sdpa_params.n_kv_groups)?;
101
102        if mask.is_some_and(|x| x.rank() == 2) || mistralrs_quant::distributed::use_nccl() {
103            return naive_sdpa(
104                &q.contiguous()?,
105                &k.contiguous()?,
106                &v.contiguous()?,
107                mask,
108                sdpa_params,
109            );
110        }
111
112        // TODO: bench?
113        #[allow(unused)]
114        if let (Device::Cuda(_), Some(cublaslt)) = (
115            q.device(),
116            mistralrs_quant::cublaslt::CUBLASLT_CONTROLLER.get(),
117        ) {
118            #[cfg(feature = "cuda")]
119            {
120                maybe_synchronize(q.device())?;
121
122                // cuBLASLt batch matmul implementation requires inputs to be dims3
123                let k = k.flatten(0, 1)?;
124                let q = q.flatten(0, 1)?;
125                let v = v.flatten(0, 1)?;
126                let attention_bias = match mask {
127                    Some(mask) if mask.rank() == 3 && mask.dims()[0] == 1 => {
128                        Some(mask.repeat((n_attn_heads, 1, 1))?)
129                    }
130                    Some(mask) if mask.rank() == 3 => Some(mask.clone()),
131                    Some(mask) if mask.rank() == 4 => {
132                        Some(mask.broadcast_as(tgt_mask_shape)?.flatten(0, 1)?)
133                    }
134                    Some(mask) => {
135                        candle_core::bail!("cublaslt attn mask: rank must be 3 or 4")
136                    }
137                    None => None,
138                };
139
140                // If attention_bias is set, we fuse the add by giving it as the output matrix
141                // and setting beta to 1.0
142                let beta = match attention_bias.is_some() {
143                    true => Some(1.0),
144                    false => None,
145                };
146
147                // Batch matrix multiplication
148                // Fuse softmax scale and attention_bias add
149                let mut attention_scores = cublaslt.batch_matmul(
150                    &k,
151                    &q,
152                    attention_bias.as_ref(),
153                    Some(sdpa_params.softmax_scale / sdpa_params.softcap.unwrap_or(1.0)),
154                    beta,
155                    None,
156                    None,
157                )?;
158                if let Some(softcap) = sdpa_params.softcap {
159                    attention_scores = (attention_scores.tanh()? * softcap as f64)?;
160                }
161                candle_nn::ops::inplace_softmax_last_dim(&mut attention_scores)?;
162
163                let context_layer = cublaslt.batch_matmul(
164                    &v.t()?.contiguous()?,
165                    &attention_scores,
166                    // We save one allocation
167                    Some(&q),
168                    None,
169                    None,
170                    None,
171                    None,
172                )?;
173
174                // Reshape to dims4
175                context_layer.reshape((b_sz, n_attn_heads, seq_len, v_head_dim))
176            }
177            #[cfg(not(feature = "cuda"))]
178            {
179                candle_core::bail!("`cuda` feature is not enabled")
180            }
181        } else {
182            naive_sdpa(q, &k, &v, mask, sdpa_params)
183        }
184    }
185}