#![allow(clippy::cast_precision_loss)]
#[cfg(feature = "metal")]
use std::sync::atomic::AtomicUsize;
use crate::{
cublaslt::CUBLASLT_HANDLE,
layers::{get_use_matmul_via_f16, MatMul},
pipeline::text_models_inputs_processor::FlashParams,
};
use candle_core::{Device, Result, Tensor};
#[cfg(feature = "metal")]
static METAL_VERSION_CACHE: AtomicUsize = AtomicUsize::new(usize::MAX);
#[cfg(feature = "flash-attn")]
fn flash_attn(
q: &Tensor,
k: &Tensor,
v: &Tensor,
flash_params: Option<&crate::pipeline::text_models_inputs_processor::FlashParams>,
sdpa_params: &SdpaParams,
) -> Result<Tensor> {
let (_b_sz, _n_attn_heads, seq_len, _head_dim) = q.dims4()?;
let causal = seq_len > 1;
use crate::pipeline::text_models_inputs_processor::FlashParams;
if let Some(FlashParams {
max_q,
max_k,
cumulative_seqlens_q,
cumulative_seqlens_k,
}) = flash_params
{
let qshape = q.shape();
let q = q.flatten_to(1)?;
let k = k.flatten_to(1)?;
let v = v.flatten_to(1)?;
let window_size_left = sdpa_params.sliding_window;
let window_size_right = if causal { Some(0) } else { None };
candle_flash_attn::flash_attn_varlen_windowed_softcap(
&q,
&k,
&v,
cumulative_seqlens_q,
cumulative_seqlens_k,
*max_q as usize,
*max_k as usize,
sdpa_params.softmax_scale,
sdpa_params.softcap,
window_size_left,
window_size_right,
)?
.reshape(qshape)
} else {
candle_flash_attn::flash_attn_softcap(
q,
k,
v,
sdpa_params.softmax_scale,
sdpa_params.softcap,
causal,
)
}
}
#[cfg(not(feature = "flash-attn"))]
fn flash_attn(
_: &Tensor,
_: &Tensor,
_: &Tensor,
_: Option<&crate::pipeline::text_models_inputs_processor::FlashParams>,
_: &SdpaParams,
) -> Result<Tensor> {
unimplemented!("Compile with '--features flash-attn'")
}
fn repeat_kv(x: Tensor, n_rep: usize) -> Result<Tensor> {
if n_rep == 1 {
Ok(x)
} else {
let (b_sz, n_kv_head, seq_len, head_dim) = x.dims4()?;
Tensor::cat(&vec![&x; n_rep], 2)?.reshape((b_sz, n_kv_head * n_rep, seq_len, head_dim))
}
}
fn naive_sdpa(
q: &Tensor,
k: &Tensor,
v: &Tensor,
mask: Option<&Tensor>,
head_dim: usize,
sdpa_params: &SdpaParams,
) -> Result<Tensor> {
#[cfg(feature = "metal")]
let supports_attn_softmax = {
use std::sync::atomic::Ordering;
let cache = METAL_VERSION_CACHE.load(Ordering::Relaxed);
let version = if cache != usize::MAX {
cache
} else {
use std::process::{Command, Stdio};
let mut echo = Command::new("echo")
.arg("__METAL_VERSION__")
.stdout(Stdio::piped())
.spawn()
.expect("Failed to start echo command");
echo.wait()?;
let output = Command::new("xcrun")
.arg("-sdk")
.arg("macosx")
.arg("metal")
.arg("-E")
.arg("-x")
.arg("metal")
.arg("-P")
.arg("-")
.stdin(echo.stdout.unwrap())
.output()
.expect("Failed to run xcrun command");
if output.status.success() {
let version = String::from_utf8_lossy(&output.stdout)
.split('\n')
.nth(1)
.unwrap()
.trim()
.to_string()
.parse::<usize>()
.unwrap();
METAL_VERSION_CACHE.store(version, Ordering::Relaxed);
version
} else {
let stderr = String::from_utf8_lossy(&output.stderr);
panic!("Error:\n{}", stderr);
}
};
version >= 310
};
#[cfg(not(feature = "metal"))]
let supports_attn_softmax = true;
if mask.is_some_and(|mask| mask.rank() == 2 || (mask.rank() == 3 && mask.dims()[0] == 1))
&& supports_attn_softmax
{
let n_attn_heads = q.dim(1)?;
let bs = q.dim(0)?;
let attention_bias = match mask {
Some(mask) if mask.rank() == 3 && mask.dims()[0] == 1 => {
mask.unsqueeze(0)?.repeat((bs, n_attn_heads, 1, 1))?
}
Some(mask) if mask.rank() == 3 => mask.unsqueeze(0)?,
Some(mask) if mask.rank() == 2 => {
mask.unsqueeze(0)?
.unsqueeze(0)?
.repeat((bs, n_attn_heads, 1, 1))?
}
Some(mask) if mask.rank() == 4 => mask.clone(),
_ => candle_core::bail!("unsupported mask {mask:?}"),
};
let mut att = attention_bias;
q.matmul_with_alpha_beta(
&k.t()?,
&mut att,
Some((sdpa_params.softmax_scale / sdpa_params.softcap.unwrap_or(1.0)) as f64),
)?;
if let Some(softcap) = sdpa_params.softcap {
att = (att.tanh()? * softcap as f64)?;
}
candle_nn::ops::inplace_softmax_last_dim(&mut att)?;
MatMul.matmul(&att, v)
} else if let Some(mask) = mask {
let mut att = MatMul.matmul_affine_div(q, &k.t()?, (head_dim as f64).sqrt())?;
if let Some(softcap) = sdpa_params.softcap {
att = (att / softcap as f64)?;
att = att.tanh()?;
att = (att * softcap as f64)?;
}
att = att.broadcast_add(mask)?;
candle_nn::ops::inplace_softmax_last_dim(&mut att)?;
MatMul.matmul(&att, v)
} else {
let mut att = MatMul.matmul_affine_div(q, &k.t()?, (head_dim as f64).sqrt())?;
if let Some(softcap) = sdpa_params.softcap {
att = (att / softcap as f64)?;
att = att.tanh()?;
att = (att * softcap as f64)?;
}
candle_nn::ops::inplace_softmax_last_dim(&mut att)?;
MatMul.matmul(&att, v)
}
}
pub struct SdpaParams {
pub n_kv_groups: usize,
pub use_flash_attn: bool,
pub softcap: Option<f32>,
pub softmax_scale: f32,
pub sliding_window: Option<usize>,
}
pub struct Sdpa;
impl Sdpa {
#[allow(unused_variables, clippy::too_many_arguments)]
pub fn run_attention(
&self,
q: &Tensor,
k: &Tensor,
v: &Tensor,
mask: Option<&Tensor>,
flash_params: Option<&FlashParams>,
sdpa_params: &SdpaParams,
) -> Result<Tensor> {
let (b_sz, n_attn_heads, seq_len, head_dim) = q.dims4()?;
if sdpa_params.use_flash_attn {
let q = q.transpose(1, 2)?;
let k = k.transpose(1, 2)?;
let v = v.transpose(1, 2)?;
return flash_attn(&q, &k, &v, flash_params, sdpa_params)?.transpose(1, 2);
}
if q.device().is_metal() && seq_len == 1 {
return candle_nn::ops::sdpa(
q,
k,
v,
sdpa_params.softmax_scale,
sdpa_params.softcap.unwrap_or(1.0),
);
}
let k = repeat_kv(k.clone(), sdpa_params.n_kv_groups)?;
let v = repeat_kv(v.clone(), sdpa_params.n_kv_groups)?;
if let (Device::Cuda(_), Some(cublaslt)) = (q.device(), *CUBLASLT_HANDLE.lock().unwrap()) {
if !get_use_matmul_via_f16() {
#[cfg(feature = "cuda")]
{
let k = k.flatten(0, 1)?;
let q = q.flatten(0, 1)?;
let v = v.flatten(0, 1)?;
let attention_bias = match mask {
Some(mask) if mask.rank() == 3 && mask.dims()[0] == 1 => {
Some(mask.repeat((n_attn_heads, 1, 1))?)
}
Some(mask) if mask.rank() == 3 => Some(mask.clone()),
Some(mask) if mask.rank() == 4 => Some(mask.flatten(0, 1)?),
Some(mask) => {
candle_core::bail!("cublaslt attn mask: rank must be 3 or 4")
}
None => None,
};
let beta = match attention_bias.is_some() {
true => Some(1.0),
false => None,
};
let mut attention_scores = cublaslt.batch_matmul(
&k,
&q,
attention_bias.as_ref(),
Some(sdpa_params.softmax_scale / sdpa_params.softcap.unwrap_or(1.0)),
beta,
None,
None,
)?;
if let Some(softcap) = sdpa_params.softcap {
attention_scores = (attention_scores.tanh()? * softcap as f64)?;
}
candle_nn::ops::inplace_softmax_last_dim(&mut attention_scores)?;
let context_layer = cublaslt.batch_matmul(
&v.t()?.contiguous()?,
&attention_scores,
Some(&q),
None,
None,
None,
None,
)?;
context_layer.reshape((b_sz, n_attn_heads, seq_len, head_dim))
}
#[cfg(not(feature = "cuda"))]
{
candle_core::bail!("`cuda` feature is not enabled")
}
} else {
naive_sdpa(q, &k, &v, mask, head_dim, sdpa_params)
}
} else {
naive_sdpa(q, &k, &v, mask, head_dim, sdpa_params)
}
}
}