pub fn attn_softmax_last_dim(
xs: &Tensor,
mask: &Tensor,
scale: f32,
) -> Result<Tensor>Expand description
Softmax with fused broadcast addition of a mask and scale. Equivalent to:
ⓘ
diffusion_rs_common::nn::ops::softmax_last_dim(&(xs.broadcast_add(&mask)? * scale as f64)?)?xsmust be a rank-4 tensormaskmust be a rank-2 matrix- The last 2 dimensions of
xsmust match the dimensions ofmask.
Note: if the last dim of xs is a multiple of 4, a vectorized implementation will be used.