pub fn sdpa(
q: &Tensor,
k: &Tensor,
v: &Tensor,
scale: f32,
softcapping: f32,
) -> Result<Tensor>Expand description
Scaled dot product attention with a fused kernel.
Computes softmax(qk^T*scale)v.
Inputs shapes:
q: (bs, qhead, seq, hidden)k: (bs, kv_head, kv_seq, hidden)k: (bs, kv_head, kv_seq, v_hidden)scaleis applied before softmax.- If
softcapping!= 1.0:- Computation is: softmax(tanh(qk^T*scale/cap)*cap)v
Output shape: (bs, qhead, seq, v_hidden)
Supported head dims: 32, 64, 96, 128, 256.
ยงOn Metal:
- If
seq== 1:- Use a vectorized kernel
- Supports
seq!=kv_seq(cross attn. support) - Supports GQA when
qheadis a multiple ofkv_head
- Otherwise:
- Use an alternate kernel
- Requires
seq==kv_seq - GQA is not supported (requires
qhead==kv_head)