diffusion_rs_common::nn::ops

Function sdpa

source
pub fn sdpa(
    q: &Tensor,
    k: &Tensor,
    v: &Tensor,
    scale: f32,
    softcapping: f32,
) -> Result<Tensor>
Expand description

Scaled dot product attention with a fused kernel.

Computes softmax(qk^T*scale)v.

Inputs shapes:

  • q: (bs, qhead, seq, hidden)
  • k: (bs, kv_head, kv_seq, hidden)
  • k: (bs, kv_head, kv_seq, v_hidden)
  • scale is applied before softmax.
  • If softcapping != 1.0:
    • Computation is: softmax(tanh(qk^T*scale/cap)*cap)v

Output shape: (bs, qhead, seq, v_hidden)

Supported head dims: 32, 64, 96, 128, 256.

ยงOn Metal:

  • If seq == 1:
    • Use a vectorized kernel
    • Supports seq != kv_seq (cross attn. support)
    • Supports GQA when qhead is a multiple of kv_head
  • Otherwise:
    • Use an alternate kernel
    • Requires seq == kv_seq
    • GQA is not supported (requires qhead == kv_head)