1#![allow(clippy::cast_precision_loss)]
2
3#[cfg(feature = "metal")]
4use std::sync::atomic::AtomicUsize;
5
6use crate::{pipeline::text_models_inputs_processor::FlashParams, MemoryUsage};
7
8use candle_core::{Device, Result, Tensor};
9use mistralrs_quant::MatMul;
10
11#[cfg(feature = "metal")]
12static METAL_VERSION_CACHE: AtomicUsize = AtomicUsize::new(usize::MAX);
14
15#[cfg(feature = "flash-attn")]
16fn flash_attn(
17 q: &Tensor,
18 k: &Tensor,
19 v: &Tensor,
20 flash_params: Option<&crate::pipeline::text_models_inputs_processor::FlashParams>,
21 sdpa_params: &SdpaParams,
22) -> Result<Tensor> {
23 let (_b_sz, _n_attn_heads, seq_len, _head_dim) = q.dims4()?;
24 let causal = seq_len > 1;
25
26 use crate::pipeline::text_models_inputs_processor::FlashParams;
27
28 if let Some(FlashParams {
29 max_q,
30 max_k,
31 cumulative_seqlens_q,
32 cumulative_seqlens_k,
33 }) = flash_params
34 {
35 let qshape = q.shape();
36 let q = q.flatten_to(1)?;
37 let k = k.flatten_to(1)?;
38 let v = v.flatten_to(1)?;
39
40 let window_size_left = sdpa_params.sliding_window;
41 let window_size_right = if causal { Some(0) } else { None };
42
43 let cumulative_seqlens_q = &cumulative_seqlens_q[&q.device().location()];
44 let cumulative_seqlens_k = &cumulative_seqlens_k[&q.device().location()];
45
46 candle_flash_attn::flash_attn_varlen_windowed_softcap(
47 &q,
48 &k,
49 &v,
50 cumulative_seqlens_q,
51 cumulative_seqlens_k,
52 *max_q as usize,
53 *max_k as usize,
54 sdpa_params.softmax_scale,
55 sdpa_params.softcap,
56 window_size_left,
57 window_size_right,
58 )?
59 .reshape(qshape)
60 } else {
61 candle_flash_attn::flash_attn_softcap(
62 q,
63 k,
64 v,
65 sdpa_params.softmax_scale,
66 sdpa_params.softcap,
67 causal,
68 )
69 }
70}
71
72#[cfg(feature = "flash-attn-v3")]
73fn flash_attn(
74 q: &Tensor,
75 k: &Tensor,
76 v: &Tensor,
77 flash_params: Option<&crate::pipeline::text_models_inputs_processor::FlashParams>,
78 sdpa_params: &SdpaParams,
79) -> Result<Tensor> {
80 let (_b_sz, _n_attn_heads, seq_len, _head_dim) = q.dims4()?;
81 let causal = seq_len > 1;
82
83 use crate::pipeline::text_models_inputs_processor::FlashParams;
84
85 if let Some(FlashParams {
86 max_q,
87 max_k,
88 cumulative_seqlens_q,
89 cumulative_seqlens_k,
90 }) = flash_params
91 {
92 let qshape = q.shape();
93 let q = q.flatten_to(1)?;
94 let k = k.flatten_to(1)?;
95 let v = v.flatten_to(1)?;
96
97 let window_size_left = sdpa_params.sliding_window;
98 let window_size_right = if causal { Some(0) } else { None };
99
100 let cumulative_seqlens_q = &cumulative_seqlens_q[&q.device().location()];
101 let cumulative_seqlens_k = &cumulative_seqlens_k[&q.device().location()];
102
103 candle_flash_attn_v3::flash_attn_varlen_windowed(
104 &q,
105 &k,
106 &v,
107 cumulative_seqlens_q,
108 cumulative_seqlens_k,
109 *max_q as usize,
110 *max_k as usize,
111 sdpa_params.softmax_scale,
112 window_size_left,
113 window_size_right,
114 true,
115 )?
116 .reshape(qshape)
117 } else {
118 candle_flash_attn_v3::flash_attn(q, k, v, sdpa_params.softmax_scale, causal, true)
119 }
120}
121
122#[cfg(not(any(feature = "flash-attn", feature = "flash-attn-v3")))]
123fn flash_attn(
124 _: &Tensor,
125 _: &Tensor,
126 _: &Tensor,
127 _: Option<&crate::pipeline::text_models_inputs_processor::FlashParams>,
128 _: &SdpaParams,
129) -> Result<Tensor> {
130 unimplemented!("Compile with `--features flash-attn` or `--features flash-attn-v3`.")
131}
132
133fn repeat_kv(x: Tensor, n_rep: usize) -> Result<Tensor> {
134 if n_rep == 1 {
135 Ok(x)
136 } else {
137 let (b_sz, n_kv_head, seq_len, head_dim) = x.dims4()?;
138 Tensor::cat(&vec![&x; n_rep], 2)?.reshape((b_sz, n_kv_head * n_rep, seq_len, head_dim))
139 }
140}
141
142fn supports_attn_softmax() -> Result<bool> {
143 #[cfg(feature = "metal")]
144 {
145 use std::sync::atomic::Ordering;
146 let cache = METAL_VERSION_CACHE.load(Ordering::Relaxed);
147
148 let version = if cache != usize::MAX {
149 cache
150 } else {
151 use std::process::{Command, Stdio};
154
155 let mut echo = Command::new("echo")
157 .arg("__METAL_VERSION__")
158 .stdout(Stdio::piped())
159 .spawn()
160 .expect("Failed to start echo command");
161
162 echo.wait()?;
163
164 let output = Command::new("xcrun")
166 .arg("-sdk")
167 .arg("macosx")
168 .arg("metal")
169 .arg("-E")
170 .arg("-x")
171 .arg("metal")
172 .arg("-P")
173 .arg("-")
174 .stdin(echo.stdout.unwrap())
175 .output()
176 .expect("Failed to run xcrun command");
177
178 if output.status.success() {
180 let version = String::from_utf8_lossy(&output.stdout)
181 .split('\n')
182 .nth(1)
183 .unwrap()
184 .trim()
185 .to_string()
186 .parse::<usize>()
187 .unwrap();
188 METAL_VERSION_CACHE.store(version, Ordering::Relaxed);
189 version
190 } else {
191 let stderr = String::from_utf8_lossy(&output.stderr);
192 panic!("Error:\n{}", stderr);
193 }
194 };
195 Ok(version >= 310)
197 }
198
199 #[cfg(not(feature = "metal"))]
200 Ok(true)
201}
202
203fn maybe_synchronize(device: &Device) -> Result<()> {
205 if MemoryUsage.get_memory_available(device)? < 4 * 1024 * (1024 * 1024) {
207 device.synchronize()?;
208 }
209 Ok(())
210}
211
212fn naive_sdpa(
214 q: &Tensor,
215 k: &Tensor,
216 v: &Tensor,
217 mask: Option<&Tensor>,
218 sdpa_params: &SdpaParams,
219) -> Result<Tensor> {
220 maybe_synchronize(q.device())?;
221
222 if mask.is_some_and(|mask| mask.rank() == 2 || mask.rank() == 3) && supports_attn_softmax()? {
224 let mask = match mask {
225 Some(mask) if mask.rank() == 3 || mask.rank() == 2 => mask.clone(),
226 _ => candle_core::bail!("unsupported mask {mask:?}"),
227 };
228
229 let mut att = MatMul.matmul(q, &k.t()?)?;
230
231 candle_nn::ops::inplace_attn_softmax_last_dim(
232 &mut att,
233 &mask.contiguous()?,
234 sdpa_params.softmax_scale / sdpa_params.softcap.unwrap_or(1.0),
235 )?;
236
237 if let Some(softcap) = sdpa_params.softcap {
238 att = (att.tanh()? * softcap as f64)?;
239 }
240
241 MatMul.matmul(&att, v)
242 } else if let Some(mask) = mask {
243 let mut att = MatMul.matmul_affine_mul(q, &k.t()?, sdpa_params.softmax_scale.into())?;
244 if let Some(softcap) = sdpa_params.softcap {
245 att = (att / softcap as f64)?;
246 att = att.tanh()?;
247 att = (att * softcap as f64)?;
248 }
249
250 att = att.broadcast_add(mask)?;
251 candle_nn::ops::inplace_softmax_last_dim(&mut att)?;
252
253 MatMul.matmul(&att, v)
254 } else {
255 let mut att = MatMul.matmul_affine_mul(q, &k.t()?, sdpa_params.softmax_scale.into())?;
256 if let Some(softcap) = sdpa_params.softcap {
257 att = (att / softcap as f64)?;
258 att = att.tanh()?;
259 att = (att * softcap as f64)?;
260 }
261
262 candle_nn::ops::inplace_softmax_last_dim(&mut att)?;
263 MatMul.matmul(&att, v)
264 }
265}
266
267pub struct SdpaParams {
268 pub n_kv_groups: usize,
269 pub use_flash_attn: bool,
270 pub softcap: Option<f32>,
271 pub softmax_scale: f32,
272 pub sliding_window: Option<usize>,
273}
274
275pub struct Sdpa;
276
277impl Sdpa {
278 #[allow(unused_variables, clippy::too_many_arguments)]
290 pub fn run_attention(
291 &self,
292 q: &Tensor,
293 k: &Tensor,
294 v: &Tensor,
295 mask: Option<&Tensor>,
296 flash_params: Option<&FlashParams>,
297 sdpa_params: &SdpaParams,
298 ) -> Result<Tensor> {
299 let (b_sz, n_attn_heads, seq_len, head_dim) = q.dims4()?;
300 let (_, _, _, k_head_dim) = k.dims4()?;
301 let (_, _, _, v_head_dim) = v.dims4()?;
302 if sdpa_params.use_flash_attn && q.device().is_cuda() {
303 let q = q.transpose(1, 2)?;
305 let k = k.transpose(1, 2)?;
306 let v = v.transpose(1, 2)?;
307 return flash_attn(&q, &k, &v, flash_params, sdpa_params)?.transpose(1, 2);
308 }
309
310 let all_head_dims_match = head_dim == k_head_dim && k_head_dim == v_head_dim;
314 let tgt_mask_shape = vec![b_sz, n_attn_heads, seq_len, k.dim(2)?];
315 let can_use_mask = mask.is_none_or(|mask| {
316 mask.layout().broadcast_as(tgt_mask_shape.clone()).is_ok()
317 && sdpa_params.softcap.is_none_or(|x| x == 1.0)
318 });
319 let valid_head_dims: &[usize] = if can_use_mask && mask.is_some() {
320 &[64, 80, 128]
321 } else {
322 &[32, 64, 96, 128, 256]
323 };
324 if [q, k, v].into_iter().all(|x| x.device().is_metal())
325 && all_head_dims_match
326 && valid_head_dims.contains(&head_dim)
327 && can_use_mask
328 {
329 let mask = match mask {
330 Some(mask) => Some(mask.broadcast_as(tgt_mask_shape)?),
331 None => None,
332 };
333 return candle_nn::ops::sdpa(
334 q,
335 k,
336 v,
337 mask.as_ref(),
338 false,
339 sdpa_params.softmax_scale,
340 sdpa_params.softcap.unwrap_or(1.0),
341 );
342 }
343
344 let k = repeat_kv(k.clone(), sdpa_params.n_kv_groups)?;
345 let v = repeat_kv(v.clone(), sdpa_params.n_kv_groups)?;
346
347 if mask.is_some_and(|x| x.rank() == 2) || mistralrs_quant::distributed::use_nccl() {
348 return naive_sdpa(q, &k, &v, mask, sdpa_params);
349 }
350
351 #[allow(unused)]
353 if let (Device::Cuda(_), Some(cublaslt)) = (
354 q.device(),
355 *mistralrs_quant::cublaslt::CUBLASLT_HANDLE.lock().unwrap(),
356 ) {
357 #[cfg(feature = "cuda")]
358 {
359 maybe_synchronize(q.device())?;
360
361 let k = k.flatten(0, 1)?;
363 let q = q.flatten(0, 1)?;
364 let v = v.flatten(0, 1)?;
365 let attention_bias = match mask {
366 Some(mask) if mask.rank() == 3 && mask.dims()[0] == 1 => {
367 Some(mask.repeat((n_attn_heads, 1, 1))?)
368 }
369 Some(mask) if mask.rank() == 3 => Some(mask.clone()),
370 Some(mask) if mask.rank() == 4 => Some(mask.flatten(0, 1)?),
371 Some(mask) => {
372 candle_core::bail!("cublaslt attn mask: rank must be 3 or 4")
373 }
374 None => None,
375 };
376
377 let beta = match attention_bias.is_some() {
380 true => Some(1.0),
381 false => None,
382 };
383
384 let mut attention_scores = cublaslt.batch_matmul(
387 &k,
388 &q,
389 attention_bias.as_ref(),
390 Some(sdpa_params.softmax_scale / sdpa_params.softcap.unwrap_or(1.0)),
391 beta,
392 None,
393 None,
394 )?;
395 if let Some(softcap) = sdpa_params.softcap {
396 attention_scores = (attention_scores.tanh()? * softcap as f64)?;
397 }
398 candle_nn::ops::inplace_softmax_last_dim(&mut attention_scores)?;
399
400 let context_layer = cublaslt.batch_matmul(
401 &v.t()?.contiguous().unwrap(),
402 &attention_scores,
403 Some(&q),
405 None,
406 None,
407 None,
408 None,
409 )?;
410
411 context_layer.reshape((b_sz, n_attn_heads, seq_len, v_head_dim))
413 }
414 #[cfg(not(feature = "cuda"))]
415 {
416 candle_core::bail!("`cuda` feature is not enabled")
417 }
418 } else {
419 naive_sdpa(q, &k, &v, mask, sdpa_params)
420 }
421 }
422}