diffusion_rs_common::nn::attention

Function scaled_dot_product_attention

source
pub fn scaled_dot_product_attention(
    q: &Tensor,
    k: &Tensor,
    v: &Tensor,
    scale: f64,
    mask: Option<&Tensor>,
    seq_len: usize,
) -> Result<Tensor>
Expand description

Computes (softmax(QK^T*sqrt(d_k)) + M)V. M is the attention mask, and is a bias (0 for unmasked, -inf for masked).

The attention implementation is automatically accelerated and dispatched as follows:

  1. If use_flash_attn == true, use a Flash Attention V2 kernel
  2. Otherwise, use SDPA with fusion of softmax scale and attention bias application

Note that there may be minute differences in output because floating point operations are not associative.