mistralrs_core/attention/
mod.rs

1#![allow(clippy::cast_possible_truncation, clippy::cast_precision_loss)]
2
3use crate::pipeline::text_models_inputs_processor::FlashParams;
4
5use candle_core::{Device, Result, Tensor};
6
7mod backends;
8
9#[allow(unused)]
10pub(crate) use backends::{flash_attn, maybe_synchronize, naive_sdpa};
11
12fn repeat_kv(x: Tensor, n_rep: usize) -> Result<Tensor> {
13    if n_rep == 1 {
14        Ok(x)
15    } else {
16        let (b_sz, n_kv_head, seq_len, head_dim) = x.dims4()?;
17        Tensor::cat(&vec![&x; n_rep], 2)?.reshape((b_sz, n_kv_head * n_rep, seq_len, head_dim))
18    }
19}
20
21pub struct SdpaParams {
22    pub n_kv_groups: usize,
23    pub softcap: Option<f32>,
24    pub softmax_scale: f32,
25    pub sliding_window: Option<usize>,
26}
27
28pub struct Sdpa;
29
30impl Sdpa {
31    /// Computes softmax(QK^T*sqrt(d_k))V
32    ///
33    /// Inputs:
34    /// - q: (b_sz, n_attn_heads, q_len, head_dim)
35    /// - k: (b_sz, n_kv_heads, q_len, head_dim)
36    /// - v: (b_sz, n_kv_heads, q_len, head_dim)
37    ///
38    /// The attention implementation is dispatched as follows:
39    /// 1) If using flash attn (CUDA), use a flash attention V2/V3 kernel
40    /// 2) If decoding and using a Metal device, use a fused kkernel
41    /// 2) Otherwise, use the "naive" SDPA implementation (with optimized mask+softmax+scale application)
42    #[allow(unused_variables, clippy::too_many_arguments)]
43    pub fn run_attention(
44        &self,
45        q: &Tensor,
46        k: &Tensor,
47        v: &Tensor,
48        mask: Option<&Tensor>,
49        flash_params: Option<&FlashParams>,
50        sdpa_params: &SdpaParams,
51    ) -> Result<Tensor> {
52        let (b_sz, n_attn_heads, seq_len, head_dim) = q.dims4()?;
53        let (_, _, _, k_head_dim) = k.dims4()?;
54        let (_, _, _, v_head_dim) = v.dims4()?;
55        if crate::using_flash_attn() && q.device().is_cuda() {
56            // flash-attn expects (b_sz, seq_len, nheads, head_dim)
57            let q = q.transpose(1, 2)?;
58            let k = k.transpose(1, 2)?;
59            let v = v.transpose(1, 2)?;
60            return flash_attn(&q, &k, &v, flash_params, sdpa_params)?.transpose(1, 2);
61        }
62
63        // We can use Metal SDPA (vector/full) if the mask is the correct size and head dims match.
64        // If the mask is provided, then softcapping isn't allowed - default back to naive SDPA
65        // Softcapping is implemented for vector SDPA.
66        let all_head_dims_match = head_dim == k_head_dim && k_head_dim == v_head_dim;
67        let tgt_mask_shape = vec![b_sz, n_attn_heads, seq_len, k.dim(2)?];
68        let can_use_mask = mask.is_none_or(|mask| {
69            mask.layout().broadcast_as(tgt_mask_shape.clone()).is_ok()
70                && sdpa_params.softcap.is_none_or(|x| x == 1.0)
71        });
72        let valid_head_dims: &[usize] = if seq_len == 1 {
73            &[32, 64, 72, 80, 96, 128, 256]
74        } else {
75            // Not sure why the full kernel doesn't like 256.
76            // [32, 64, 72, 80, 96, 128, 256]
77            &[32, 64, 72, 80, 96, 128]
78        };
79        if [q, k, v].into_iter().all(|x| x.device().is_metal())
80            && all_head_dims_match
81            && valid_head_dims.contains(&head_dim)
82            && can_use_mask
83        {
84            let mask = match mask {
85                Some(mask) => Some(mask.broadcast_as(tgt_mask_shape)?),
86                None => None,
87            };
88            return candle_nn::ops::sdpa(
89                q,
90                k,
91                v,
92                mask.as_ref(),
93                false,
94                sdpa_params.softmax_scale,
95                sdpa_params.softcap.unwrap_or(1.0),
96            );
97        }
98
99        let k = repeat_kv(k.clone(), sdpa_params.n_kv_groups)?;
100        let v = repeat_kv(v.clone(), sdpa_params.n_kv_groups)?;
101
102        if mask.is_some_and(|x| x.rank() == 2) || mistralrs_quant::distributed::use_nccl() {
103            return naive_sdpa(q, &k, &v, mask, sdpa_params);
104        }
105
106        // TODO: bench?
107        #[allow(unused)]
108        if let (Device::Cuda(_), Some(cublaslt)) = (
109            q.device(),
110            mistralrs_quant::cublaslt::CUBLASLT_CONTROLLER.get(),
111        ) {
112            #[cfg(feature = "cuda")]
113            {
114                maybe_synchronize(q.device())?;
115
116                // cuBLASLt batch matmul implementation requires inputs to be dims3
117                let k = k.flatten(0, 1)?;
118                let q = q.flatten(0, 1)?;
119                let v = v.flatten(0, 1)?;
120                let attention_bias = match mask {
121                    Some(mask) if mask.rank() == 3 && mask.dims()[0] == 1 => {
122                        Some(mask.repeat((n_attn_heads, 1, 1))?)
123                    }
124                    Some(mask) if mask.rank() == 3 => Some(mask.clone()),
125                    Some(mask) if mask.rank() == 4 => {
126                        Some(mask.broadcast_as(tgt_mask_shape)?.flatten(0, 1)?)
127                    }
128                    Some(mask) => {
129                        candle_core::bail!("cublaslt attn mask: rank must be 3 or 4")
130                    }
131                    None => None,
132                };
133
134                // If attention_bias is set, we fuse the add by giving it as the output matrix
135                // and setting beta to 1.0
136                let beta = match attention_bias.is_some() {
137                    true => Some(1.0),
138                    false => None,
139                };
140
141                // Batch matrix multiplication
142                // Fuse softmax scale and attention_bias add
143                let mut attention_scores = cublaslt.batch_matmul(
144                    &k,
145                    &q,
146                    attention_bias.as_ref(),
147                    Some(sdpa_params.softmax_scale / sdpa_params.softcap.unwrap_or(1.0)),
148                    beta,
149                    None,
150                    None,
151                )?;
152                if let Some(softcap) = sdpa_params.softcap {
153                    attention_scores = (attention_scores.tanh()? * softcap as f64)?;
154                }
155                candle_nn::ops::inplace_softmax_last_dim(&mut attention_scores)?;
156
157                let context_layer = cublaslt.batch_matmul(
158                    &v.t()?.contiguous()?,
159                    &attention_scores,
160                    // We save one allocation
161                    Some(&q),
162                    None,
163                    None,
164                    None,
165                    None,
166                )?;
167
168                // Reshape to dims4
169                context_layer.reshape((b_sz, n_attn_heads, seq_len, v_head_dim))
170            }
171            #[cfg(not(feature = "cuda"))]
172            {
173                candle_core::bail!("`cuda` feature is not enabled")
174            }
175        } else {
176            naive_sdpa(q, &k, &v, mask, sdpa_params)
177        }
178    }
179}