diffusion_rs_common/nn/attention.rs
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38
use crate::core::{Result, Tensor};
/// Computes (softmax(QK^T*sqrt(d_k)) + M)V. `M` is the attention mask, and is a bias (0 for unmasked, -inf for masked).
///
/// The attention implementation is automatically accelerated and dispatched as follows:
/// 1) If `use_flash_attn == true`, use a Flash Attention V2 kernel
/// 2) Otherwise, use SDPA with fusion of softmax scale and attention bias application
///
/// Note that there may be minute differences in output because floating point operations are not associative.
#[allow(unused_variables, clippy::too_many_arguments)]
pub fn scaled_dot_product_attention(
q: &Tensor,
k: &Tensor,
v: &Tensor,
scale: f64,
mask: Option<&Tensor>,
seq_len: usize,
) -> Result<Tensor> {
let att = match mask {
Some(mask) => {
let (b, n, s, _h) = q.dims4()?;
let mut mask_and_output = mask.broadcast_as((b, n, s, s))?.contiguous()?;
q.contiguous()?.matmul_with_alpha_beta(
&k.t()?.contiguous()?,
&mut mask_and_output,
Some(scale),
)?;
mask_and_output
}
None => q
.contiguous()?
.matmul_with_alpha(&k.t()?.contiguous()?, Some(scale))?,
};
let att = crate::nn::ops::softmax_last_dim(&att)?;
// Convert to contiguous as matmul doesn't support strided vs for now.
att.matmul(&v.contiguous()?)
}