pub fn scaled_dot_product_attention(
q: &Tensor,
k: &Tensor,
v: &Tensor,
scale: f64,
mask: Option<&Tensor>,
seq_len: usize,
) -> Result<Tensor>
Expand description
Computes (softmax(QK^T*sqrt(d_k)) + M)V. M
is the attention mask, and is a bias (0 for unmasked, -inf for masked).
The attention implementation is automatically accelerated and dispatched as follows:
- If
use_flash_attn == true
, use a Flash Attention V2 kernel - Otherwise, use SDPA with fusion of softmax scale and attention bias application
Note that there may be minute differences in output because floating point operations are not associative.