pub fn attn_softmax_last_dim(
xs: &Tensor,
mask: &Tensor,
scale: f32,
) -> Result<Tensor>
Expand description
Softmax with fused broadcast addition of a mask and scale. Equivalent to:
ⓘ
diffusion_rs_common::nn::ops::softmax_last_dim(&(xs.broadcast_add(&mask)? * scale as f64)?)?
xs
must be a rank-4 tensormask
must be a rank-2 matrix- The last 2 dimensions of
xs
must match the dimensions ofmask
.
Note: if the last dim of xs
is a multiple of 4, a vectorized implementation will be used.