pub fn sdpa(
q: &Tensor,
k: &Tensor,
v: &Tensor,
scale: f32,
softcapping: f32,
) -> Result<Tensor>
Expand description
Scaled dot product attention with a fused kernel.
Computes softmax(qk^T*scale)v.
Inputs shapes:
q
: (bs, qhead, seq, hidden)k
: (bs, kv_head, kv_seq, hidden)k
: (bs, kv_head, kv_seq, v_hidden)scale
is applied before softmax.- If
softcapping
!= 1.0:- Computation is: softmax(tanh(qk^T*scale/cap)*cap)v
Output shape: (bs, qhead, seq, v_hidden)
Supported head dims: 32, 64, 96, 128, 256.
ยงOn Metal:
- If
seq
== 1:- Use a vectorized kernel
- Supports
seq
!=kv_seq
(cross attn. support) - Supports GQA when
qhead
is a multiple ofkv_head
- Otherwise:
- Use an alternate kernel
- Requires
seq
==kv_seq
- GQA is not supported (requires
qhead
==kv_head
)