mistralrs_core/attention/
mod.rs

1#![allow(clippy::cast_possible_truncation, clippy::cast_precision_loss)]
2
3use crate::{attention::backends::cpu, pipeline::text_models_inputs_processor::FlashParams};
4
5use candle_core::{DType, Device, Result, Tensor};
6
7mod backends;
8
9#[allow(unused)]
10pub(crate) use backends::{flash_attn, maybe_synchronize, naive_sdpa};
11
12/// Chunk size for attention computation to avoid OOM on long sequences
13pub(crate) const ATTENTION_CHUNK_SIZE: usize = 1024;
14
15/// Generic chunked attention computation that can be used by different backends
16pub(crate) fn chunked_attention<F>(
17    q: &Tensor,
18    k: &Tensor,
19    v: &Tensor,
20    mask: Option<&Tensor>,
21    attention_fn: F,
22) -> Result<Tensor>
23where
24    F: Fn(&Tensor, &Tensor, &Tensor, Option<&Tensor>) -> Result<Tensor>,
25{
26    let seq_len = q.dim(2)?;
27
28    if seq_len <= ATTENTION_CHUNK_SIZE {
29        // For short sequences, use the regular path
30        return attention_fn(q, k, v, mask);
31    }
32
33    // Chunk the query to avoid OOM on long sequences
34    let num_chunks = seq_len.div_ceil(ATTENTION_CHUNK_SIZE);
35    let mut attn_chunks = Vec::with_capacity(num_chunks);
36
37    for chunk_idx in 0..num_chunks {
38        let offset = chunk_idx * ATTENTION_CHUNK_SIZE;
39        let chunk_len = ATTENTION_CHUNK_SIZE.min(seq_len - offset);
40
41        // Extract query chunk
42        let q_chunk = q.narrow(2, offset, chunk_len)?;
43
44        // Extract mask chunk if present
45        let mask_chunk = mask
46            .map(|m| {
47                match m.rank() {
48                    2 => {
49                        // For 2D masks (seq_len, seq_len), narrow along dimension 0
50                        m.narrow(0, offset, chunk_len)
51                    }
52                    3 => {
53                        // For 3D masks (batch, seq_len, seq_len), narrow along dimension 1
54                        m.narrow(1, offset, chunk_len)
55                    }
56                    4 => {
57                        // For 4D masks (batch, heads, seq_len, seq_len), narrow along dimension 2
58                        m.narrow(2, offset, chunk_len)
59                    }
60                    _ => m.narrow(2, offset, chunk_len), // Default to dimension 2
61                }
62            })
63            .transpose()?;
64
65        // Compute attention for this chunk
66        let att_chunk = attention_fn(&q_chunk, k, v, mask_chunk.as_ref())?;
67
68        attn_chunks.push(att_chunk);
69    }
70
71    // Concatenate all chunks along the sequence dimension
72    Tensor::cat(&attn_chunks, 2)
73}
74
75fn repeat_kv(x: Tensor, n_rep: usize) -> Result<Tensor> {
76    if n_rep == 1 {
77        Ok(x)
78    } else {
79        let (b_sz, n_kv_head, seq_len, head_dim) = x.dims4()?;
80        Tensor::cat(&vec![&x; n_rep], 2)?.reshape((b_sz, n_kv_head * n_rep, seq_len, head_dim))
81    }
82}
83
84pub struct SdpaParams {
85    pub n_kv_groups: usize,
86    pub softcap: Option<f32>,
87    pub softmax_scale: f32,
88    pub sliding_window: Option<usize>,
89}
90
91pub struct Sdpa;
92
93impl Sdpa {
94    /// Computes softmax(QK^T*sqrt(d_k))V
95    ///
96    /// Inputs:
97    /// - q: (b_sz, n_attn_heads, q_len, head_dim)
98    /// - k: (b_sz, n_kv_heads, q_len, head_dim)
99    /// - v: (b_sz, n_kv_heads, q_len, head_dim)
100    ///
101    /// The attention implementation is dispatched as follows:
102    /// 1) If using flash attn (CUDA), use a flash attention V2/V3 kernel
103    /// 2) If decoding and using a Metal device, use a fused kkernel
104    /// 2) Otherwise, use the "naive" SDPA implementation (with optimized mask+softmax+scale application)
105    #[allow(unused_variables, clippy::too_many_arguments)]
106    pub fn run_attention(
107        &self,
108        q: &Tensor,
109        k: &Tensor,
110        v: &Tensor,
111        mask: Option<&Tensor>,
112        flash_params: Option<&FlashParams>,
113        sdpa_params: &SdpaParams,
114    ) -> Result<Tensor> {
115        let (b_sz, n_attn_heads, seq_len, head_dim) = q.dims4()?;
116        let (_, _, _, k_head_dim) = k.dims4()?;
117        let (_, _, _, v_head_dim) = v.dims4()?;
118
119        let can_use_flash =
120            q.device().is_cpu() || q.device().is_cuda() && crate::using_flash_attn();
121
122        if can_use_flash {
123            // flash-attn expects (b_sz, seq_len, nheads, head_dim)
124            let q = q.transpose(1, 2)?;
125            let k = k.transpose(1, 2)?;
126            let v = v.transpose(1, 2)?;
127
128            if q.device().is_cpu() {
129                match q.dtype() {
130                    DType::F32 => {
131                        return cpu::run_flash_attn_cpu::<f32>(&q, &k, &v, mask, sdpa_params);
132                    }
133                    DType::F16 => {
134                        return cpu::run_flash_attn_cpu::<half::f16>(&q, &k, &v, mask, sdpa_params)
135                    }
136                    DType::BF16 => {
137                        return cpu::run_flash_attn_cpu::<half::bf16>(
138                            &q,
139                            &k,
140                            &v,
141                            mask,
142                            sdpa_params,
143                        );
144                    }
145                    _ => {
146                        return Err(candle_core::Error::Msg("Unsupported data type".into()));
147                    }
148                }
149            } else {
150                return flash_attn(&q, &k, &v, flash_params, sdpa_params)?.transpose(1, 2);
151            }
152        }
153
154        self.run_attention_noflash(q, k, v, mask, sdpa_params)
155    }
156
157    /// Same as `run_attention`, but no flash attention
158    #[allow(unused_variables, clippy::too_many_arguments)]
159    pub fn run_attention_noflash(
160        &self,
161        q: &Tensor,
162        k: &Tensor,
163        v: &Tensor,
164        mask: Option<&Tensor>,
165        sdpa_params: &SdpaParams,
166    ) -> Result<Tensor> {
167        let (b_sz, n_attn_heads, seq_len, head_dim) = q.dims4()?;
168        let (_, _, _, k_head_dim) = k.dims4()?;
169        let (_, _, _, v_head_dim) = v.dims4()?;
170
171        // We can use Metal SDPA (vector/full) if the mask is the correct size and head dims match.
172        // If the mask is provided, then softcapping isn't allowed - default back to naive SDPA
173        // Softcapping is implemented for vector SDPA.
174        let all_head_dims_match = head_dim == k_head_dim && k_head_dim == v_head_dim;
175        let tgt_mask_shape = vec![b_sz, n_attn_heads, seq_len, k.dim(2)?];
176        let can_use_mask = mask.is_none_or(|mask| {
177            mask.layout().broadcast_as(tgt_mask_shape.clone()).is_ok()
178                && sdpa_params.softcap.is_none_or(|x| x == 1.0)
179        });
180        let valid_head_dims: &[usize] = if seq_len == 1 {
181            &[32, 64, 72, 80, 96, 128, 256]
182        } else {
183            // Not sure why the full kernel doesn't like 256.
184            // [32, 64, 72, 80, 96, 128, 256]
185            &[32, 64, 72, 80, 96, 128]
186        };
187        if [q, k, v].into_iter().all(|x| x.device().is_metal())
188            && all_head_dims_match
189            && valid_head_dims.contains(&head_dim)
190            && can_use_mask
191        {
192            let mask = match mask {
193                Some(mask) => Some(mask.broadcast_as(tgt_mask_shape)?),
194                None => None,
195            };
196            return candle_nn::ops::sdpa(
197                q,
198                k,
199                v,
200                mask.as_ref(),
201                false,
202                sdpa_params.softmax_scale,
203                sdpa_params.softcap.unwrap_or(1.0),
204            );
205        }
206
207        let k = repeat_kv(k.clone(), sdpa_params.n_kv_groups)?;
208        let v = repeat_kv(v.clone(), sdpa_params.n_kv_groups)?;
209
210        if mask.is_some_and(|x| x.rank() == 2) || mistralrs_quant::distributed::use_nccl() {
211            return naive_sdpa(
212                &q.contiguous()?,
213                &k.contiguous()?,
214                &v.contiguous()?,
215                mask,
216                sdpa_params,
217            );
218        }
219
220        // TODO: bench?
221        #[allow(unused)]
222        if let (Device::Cuda(_), Some(cublaslt)) = (
223            q.device(),
224            mistralrs_quant::cublaslt::CUBLASLT_CONTROLLER.get(),
225        ) {
226            #[cfg(feature = "cuda")]
227            {
228                maybe_synchronize(q.device())?;
229
230                // Use chunked attention for cuBLASLt path
231                let k_flat = k.flatten(0, 1)?;
232                let v_flat = v.flatten(0, 1)?;
233
234                chunked_attention(q, &k, &v, mask, |q_chunk, _k, _v, mask_chunk| {
235                    // cuBLASLt batch matmul implementation requires inputs to be dims3
236                    let (chunk_b_sz, chunk_n_heads, chunk_seq_len, chunk_head_dim) =
237                        q_chunk.dims4()?;
238                    let q_flat = q_chunk.flatten(0, 1)?;
239
240                    let attention_bias = match mask_chunk {
241                        Some(mask) if mask.rank() == 3 && mask.dims()[0] == 1 => {
242                            Some(mask.repeat((chunk_n_heads, 1, 1))?)
243                        }
244                        Some(mask) if mask.rank() == 3 => Some(mask.clone()),
245                        Some(mask) if mask.rank() == 4 => {
246                            let tgt_shape =
247                                vec![chunk_b_sz, chunk_n_heads, chunk_seq_len, k.dim(2)?];
248                            Some(mask.broadcast_as(tgt_shape)?.flatten(0, 1)?)
249                        }
250                        Some(mask) => {
251                            candle_core::bail!("cublaslt attn mask: rank must be 3 or 4")
252                        }
253                        None => None,
254                    };
255
256                    // If attention_bias is set, we fuse the add by giving it as the output matrix
257                    // and setting beta to 1.0
258                    let beta = match attention_bias.is_some() {
259                        true => Some(1.0),
260                        false => None,
261                    };
262
263                    // Batch matrix multiplication
264                    // Fuse softmax scale and attention_bias add
265                    let mut attention_scores = cublaslt.batch_matmul(
266                        &k_flat,
267                        &q_flat,
268                        attention_bias.as_ref(),
269                        Some(sdpa_params.softmax_scale / sdpa_params.softcap.unwrap_or(1.0)),
270                        beta,
271                        None,
272                        None,
273                    )?;
274                    if let Some(softcap) = sdpa_params.softcap {
275                        attention_scores = (attention_scores.tanh()? * softcap as f64)?;
276                    }
277                    attention_scores = candle_nn::ops::softmax_last_dim(&attention_scores)?;
278
279                    let context_layer = cublaslt.batch_matmul(
280                        &v_flat.t()?.contiguous()?,
281                        &attention_scores,
282                        // We save one allocation
283                        Some(&q_flat),
284                        None,
285                        None,
286                        None,
287                        None,
288                    )?;
289
290                    // Reshape to dims4
291                    context_layer.reshape((chunk_b_sz, chunk_n_heads, chunk_seq_len, v_head_dim))
292                })
293            }
294            #[cfg(not(feature = "cuda"))]
295            {
296                candle_core::bail!("`cuda` feature is not enabled")
297            }
298        } else {
299            naive_sdpa(q, &k, &v, mask, sdpa_params)
300        }
301    }
302}