mistralrs_core/
attention.rs

1#![allow(clippy::cast_precision_loss)]
2
3#[cfg(feature = "metal")]
4use std::sync::atomic::AtomicUsize;
5
6use crate::{pipeline::text_models_inputs_processor::FlashParams, MemoryUsage};
7
8use candle_core::{Device, Result, Tensor};
9use mistralrs_quant::MatMul;
10
11#[cfg(feature = "metal")]
12/// Initial, sentinel value is usize::MAX
13static METAL_VERSION_CACHE: AtomicUsize = AtomicUsize::new(usize::MAX);
14
15#[cfg(feature = "flash-attn")]
16fn flash_attn(
17    q: &Tensor,
18    k: &Tensor,
19    v: &Tensor,
20    flash_params: Option<&crate::pipeline::text_models_inputs_processor::FlashParams>,
21    sdpa_params: &SdpaParams,
22) -> Result<Tensor> {
23    let (_b_sz, _n_attn_heads, seq_len, _head_dim) = q.dims4()?;
24    let causal = seq_len > 1;
25
26    use crate::pipeline::text_models_inputs_processor::FlashParams;
27
28    if let Some(FlashParams {
29        max_q,
30        max_k,
31        cumulative_seqlens_q,
32        cumulative_seqlens_k,
33    }) = flash_params
34    {
35        let qshape = q.shape();
36        let q = q.flatten_to(1)?;
37        let k = k.flatten_to(1)?;
38        let v = v.flatten_to(1)?;
39
40        let window_size_left = sdpa_params.sliding_window;
41        let window_size_right = if causal { Some(0) } else { None };
42
43        let cumulative_seqlens_q = &cumulative_seqlens_q[&q.device().location()];
44        let cumulative_seqlens_k = &cumulative_seqlens_k[&q.device().location()];
45
46        candle_flash_attn::flash_attn_varlen_windowed_softcap(
47            &q,
48            &k,
49            &v,
50            cumulative_seqlens_q,
51            cumulative_seqlens_k,
52            *max_q as usize,
53            *max_k as usize,
54            sdpa_params.softmax_scale,
55            sdpa_params.softcap,
56            window_size_left,
57            window_size_right,
58        )?
59        .reshape(qshape)
60    } else {
61        candle_flash_attn::flash_attn_softcap(
62            q,
63            k,
64            v,
65            sdpa_params.softmax_scale,
66            sdpa_params.softcap,
67            causal,
68        )
69    }
70}
71
72#[cfg(feature = "flash-attn-v3")]
73fn flash_attn(
74    q: &Tensor,
75    k: &Tensor,
76    v: &Tensor,
77    flash_params: Option<&crate::pipeline::text_models_inputs_processor::FlashParams>,
78    sdpa_params: &SdpaParams,
79) -> Result<Tensor> {
80    let (_b_sz, _n_attn_heads, seq_len, _head_dim) = q.dims4()?;
81    let causal = seq_len > 1;
82
83    use crate::pipeline::text_models_inputs_processor::FlashParams;
84
85    if let Some(FlashParams {
86        max_q,
87        max_k,
88        cumulative_seqlens_q,
89        cumulative_seqlens_k,
90    }) = flash_params
91    {
92        let qshape = q.shape();
93        let q = q.flatten_to(1)?;
94        let k = k.flatten_to(1)?;
95        let v = v.flatten_to(1)?;
96
97        let window_size_left = sdpa_params.sliding_window;
98        let window_size_right = if causal { Some(0) } else { None };
99
100        let cumulative_seqlens_q = &cumulative_seqlens_q[&q.device().location()];
101        let cumulative_seqlens_k = &cumulative_seqlens_k[&q.device().location()];
102
103        candle_flash_attn_v3::flash_attn_varlen_windowed(
104            &q,
105            &k,
106            &v,
107            cumulative_seqlens_q,
108            cumulative_seqlens_k,
109            *max_q as usize,
110            *max_k as usize,
111            sdpa_params.softmax_scale,
112            window_size_left,
113            window_size_right,
114            true,
115        )?
116        .reshape(qshape)
117    } else {
118        candle_flash_attn_v3::flash_attn(q, k, v, sdpa_params.softmax_scale, causal, true)
119    }
120}
121
122#[cfg(not(any(feature = "flash-attn", feature = "flash-attn-v3")))]
123fn flash_attn(
124    _: &Tensor,
125    _: &Tensor,
126    _: &Tensor,
127    _: Option<&crate::pipeline::text_models_inputs_processor::FlashParams>,
128    _: &SdpaParams,
129) -> Result<Tensor> {
130    unimplemented!("Compile with `--features flash-attn` or `--features flash-attn-v3`.")
131}
132
133fn repeat_kv(x: Tensor, n_rep: usize) -> Result<Tensor> {
134    if n_rep == 1 {
135        Ok(x)
136    } else {
137        let (b_sz, n_kv_head, seq_len, head_dim) = x.dims4()?;
138        Tensor::cat(&vec![&x; n_rep], 2)?.reshape((b_sz, n_kv_head * n_rep, seq_len, head_dim))
139    }
140}
141
142fn supports_attn_softmax() -> Result<bool> {
143    #[cfg(feature = "metal")]
144    {
145        use std::sync::atomic::Ordering;
146        let cache = METAL_VERSION_CACHE.load(Ordering::Relaxed);
147
148        let version = if cache != usize::MAX {
149            cache
150        } else {
151            // echo "__METAL_VERSION__" | xcrun -sdk macosx metal -E -x metal -P -
152
153            use std::process::{Command, Stdio};
154
155            // Create the `echo` command and pipe its output into `xcrun`
156            let mut echo = Command::new("echo")
157                .arg("__METAL_VERSION__")
158                .stdout(Stdio::piped())
159                .spawn()
160                .expect("Failed to start echo command");
161
162            echo.wait()?;
163
164            // Run the `xcrun` command, taking input from the `echo` command's output
165            let output = Command::new("xcrun")
166                .arg("-sdk")
167                .arg("macosx")
168                .arg("metal")
169                .arg("-E")
170                .arg("-x")
171                .arg("metal")
172                .arg("-P")
173                .arg("-")
174                .stdin(echo.stdout.unwrap())
175                .output()
176                .expect("Failed to run xcrun command");
177
178            // Handle the output
179            if output.status.success() {
180                let version = String::from_utf8_lossy(&output.stdout)
181                    .split('\n')
182                    .nth(1)
183                    .unwrap()
184                    .trim()
185                    .to_string()
186                    .parse::<usize>()
187                    .unwrap();
188                METAL_VERSION_CACHE.store(version, Ordering::Relaxed);
189                version
190            } else {
191                let stderr = String::from_utf8_lossy(&output.stderr);
192                panic!("Error:\n{}", stderr);
193            }
194        };
195        // Attn softmax is only supported for metal >= 310
196        Ok(version >= 310)
197    }
198
199    #[cfg(not(feature = "metal"))]
200    Ok(true)
201}
202
203/// Not *really* sure why this is necessary but it is.
204fn maybe_synchronize(device: &Device) -> Result<()> {
205    // If less that 4 GB available, synchronize
206    if MemoryUsage.get_memory_available(device)? < 4 * 1024 * (1024 * 1024) {
207        device.synchronize()?;
208    }
209    Ok(())
210}
211
212/// Computes softmax(QK^T*sqrt(d_k))V
213fn naive_sdpa(
214    q: &Tensor,
215    k: &Tensor,
216    v: &Tensor,
217    mask: Option<&Tensor>,
218    sdpa_params: &SdpaParams,
219) -> Result<Tensor> {
220    maybe_synchronize(q.device())?;
221
222    // Use faster softmax if mask is rank 2 or it's rank 3
223    if mask.is_some_and(|mask| mask.rank() == 2 || mask.rank() == 3) && supports_attn_softmax()? {
224        let mask = match mask {
225            Some(mask) if mask.rank() == 3 || mask.rank() == 2 => mask.clone(),
226            _ => candle_core::bail!("unsupported mask {mask:?}"),
227        };
228
229        let mut att = MatMul.matmul(q, &k.t()?)?;
230
231        candle_nn::ops::inplace_attn_softmax_last_dim(
232            &mut att,
233            &mask.contiguous()?,
234            sdpa_params.softmax_scale / sdpa_params.softcap.unwrap_or(1.0),
235        )?;
236
237        if let Some(softcap) = sdpa_params.softcap {
238            att = (att.tanh()? * softcap as f64)?;
239        }
240
241        MatMul.matmul(&att, v)
242    } else if let Some(mask) = mask {
243        let mut att = MatMul.matmul_affine_mul(q, &k.t()?, sdpa_params.softmax_scale.into())?;
244        if let Some(softcap) = sdpa_params.softcap {
245            att = (att / softcap as f64)?;
246            att = att.tanh()?;
247            att = (att * softcap as f64)?;
248        }
249
250        att = att.broadcast_add(mask)?;
251        candle_nn::ops::inplace_softmax_last_dim(&mut att)?;
252
253        MatMul.matmul(&att, v)
254    } else {
255        let mut att = MatMul.matmul_affine_mul(q, &k.t()?, sdpa_params.softmax_scale.into())?;
256        if let Some(softcap) = sdpa_params.softcap {
257            att = (att / softcap as f64)?;
258            att = att.tanh()?;
259            att = (att * softcap as f64)?;
260        }
261
262        candle_nn::ops::inplace_softmax_last_dim(&mut att)?;
263        MatMul.matmul(&att, v)
264    }
265}
266
267pub struct SdpaParams {
268    pub n_kv_groups: usize,
269    pub use_flash_attn: bool,
270    pub softcap: Option<f32>,
271    pub softmax_scale: f32,
272    pub sliding_window: Option<usize>,
273}
274
275pub struct Sdpa;
276
277impl Sdpa {
278    /// Computes softmax(QK^T*sqrt(d_k))V
279    ///
280    /// Inputs:
281    /// - q: (b_sz, n_attn_heads, q_len, head_dim)
282    /// - k: (b_sz, n_kv_heads, q_len, head_dim)
283    /// - v: (b_sz, n_kv_heads, q_len, head_dim)
284    ///
285    /// The attention implementation is dispatched as follows:
286    /// 1) If `use_flash_attn == true` (CUDA), use a flash attention V2 kernel
287    /// 2) If decoding and using a Metal device, use a fused kkernel
288    /// 2) Otherwise, use the "naive" SDPA implementation (with optimized mask+softmax+scale application)
289    #[allow(unused_variables, clippy::too_many_arguments)]
290    pub fn run_attention(
291        &self,
292        q: &Tensor,
293        k: &Tensor,
294        v: &Tensor,
295        mask: Option<&Tensor>,
296        flash_params: Option<&FlashParams>,
297        sdpa_params: &SdpaParams,
298    ) -> Result<Tensor> {
299        let (b_sz, n_attn_heads, seq_len, head_dim) = q.dims4()?;
300        let (_, _, _, k_head_dim) = k.dims4()?;
301        let (_, _, _, v_head_dim) = v.dims4()?;
302        if sdpa_params.use_flash_attn && q.device().is_cuda() {
303            // flash-attn expects (b_sz, seq_len, nheads, head_dim)
304            let q = q.transpose(1, 2)?;
305            let k = k.transpose(1, 2)?;
306            let v = v.transpose(1, 2)?;
307            return flash_attn(&q, &k, &v, flash_params, sdpa_params)?.transpose(1, 2);
308        }
309
310        // We can use Metal SDPA (vector/full) if the mask is the correct size and head dims match.
311        // If the mask is provided, then softcapping isn't allowed - default back to naive SDPA
312        // Softcapping is implemented for vector SDPA.
313        let all_head_dims_match = head_dim == k_head_dim && k_head_dim == v_head_dim;
314        let tgt_mask_shape = vec![b_sz, n_attn_heads, seq_len, k.dim(2)?];
315        let can_use_mask = mask.is_none_or(|mask| {
316            mask.layout().broadcast_as(tgt_mask_shape.clone()).is_ok()
317                && sdpa_params.softcap.is_none_or(|x| x == 1.0)
318        });
319        let valid_head_dims: &[usize] = if can_use_mask && mask.is_some() {
320            &[64, 80, 128]
321        } else {
322            &[32, 64, 96, 128, 256]
323        };
324        if [q, k, v].into_iter().all(|x| x.device().is_metal())
325            && all_head_dims_match
326            && valid_head_dims.contains(&head_dim)
327            && can_use_mask
328        {
329            let mask = match mask {
330                Some(mask) => Some(mask.broadcast_as(tgt_mask_shape)?),
331                None => None,
332            };
333            return candle_nn::ops::sdpa(
334                q,
335                k,
336                v,
337                mask.as_ref(),
338                false,
339                sdpa_params.softmax_scale,
340                sdpa_params.softcap.unwrap_or(1.0),
341            );
342        }
343
344        let k = repeat_kv(k.clone(), sdpa_params.n_kv_groups)?;
345        let v = repeat_kv(v.clone(), sdpa_params.n_kv_groups)?;
346
347        if mask.is_some_and(|x| x.rank() == 2) || mistralrs_quant::distributed::use_nccl() {
348            return naive_sdpa(q, &k, &v, mask, sdpa_params);
349        }
350
351        // TODO: bench?
352        #[allow(unused)]
353        if let (Device::Cuda(_), Some(cublaslt)) = (
354            q.device(),
355            *mistralrs_quant::cublaslt::CUBLASLT_HANDLE.lock().unwrap(),
356        ) {
357            #[cfg(feature = "cuda")]
358            {
359                maybe_synchronize(q.device())?;
360
361                // cuBLASLt batch matmul implementation requires inputs to be dims3
362                let k = k.flatten(0, 1)?;
363                let q = q.flatten(0, 1)?;
364                let v = v.flatten(0, 1)?;
365                let attention_bias = match mask {
366                    Some(mask) if mask.rank() == 3 && mask.dims()[0] == 1 => {
367                        Some(mask.repeat((n_attn_heads, 1, 1))?)
368                    }
369                    Some(mask) if mask.rank() == 3 => Some(mask.clone()),
370                    Some(mask) if mask.rank() == 4 => Some(mask.flatten(0, 1)?),
371                    Some(mask) => {
372                        candle_core::bail!("cublaslt attn mask: rank must be 3 or 4")
373                    }
374                    None => None,
375                };
376
377                // If attention_bias is set, we fuse the add by giving it as the output matrix
378                // and setting beta to 1.0
379                let beta = match attention_bias.is_some() {
380                    true => Some(1.0),
381                    false => None,
382                };
383
384                // Batch matrix multiplication
385                // Fuse softmax scale and attention_bias add
386                let mut attention_scores = cublaslt.batch_matmul(
387                    &k,
388                    &q,
389                    attention_bias.as_ref(),
390                    Some(sdpa_params.softmax_scale / sdpa_params.softcap.unwrap_or(1.0)),
391                    beta,
392                    None,
393                    None,
394                )?;
395                if let Some(softcap) = sdpa_params.softcap {
396                    attention_scores = (attention_scores.tanh()? * softcap as f64)?;
397                }
398                candle_nn::ops::inplace_softmax_last_dim(&mut attention_scores)?;
399
400                let context_layer = cublaslt.batch_matmul(
401                    &v.t()?.contiguous().unwrap(),
402                    &attention_scores,
403                    // We save one allocation
404                    Some(&q),
405                    None,
406                    None,
407                    None,
408                    None,
409                )?;
410
411                // Reshape to dims4
412                context_layer.reshape((b_sz, n_attn_heads, seq_len, v_head_dim))
413            }
414            #[cfg(not(feature = "cuda"))]
415            {
416                candle_core::bail!("`cuda` feature is not enabled")
417            }
418        } else {
419            naive_sdpa(q, &k, &v, mask, sdpa_params)
420        }
421    }
422}