diffusion_rs_common::nn::ops

Function attn_softmax_last_dim

source
pub fn attn_softmax_last_dim(
    xs: &Tensor,
    mask: &Tensor,
    scale: f32,
) -> Result<Tensor>
Expand description

Softmax with fused broadcast addition of a mask and scale. Equivalent to:

diffusion_rs_common::nn::ops::softmax_last_dim(&(xs.broadcast_add(&mask)? * scale as f64)?)?
  • xs must be a rank-4 tensor
  • mask must be a rank-2 matrix
  • The last 2 dimensions of xs must match the dimensions of mask.

Note: if the last dim of xs is a multiple of 4, a vectorized implementation will be used.