fused_glu

Function fused_glu 

Source
pub fn fused_glu(
    a: &Tensor,
    b: &Tensor,
    activation: GluActivationType,
) -> Result<Tensor>
Expand description

Fused GLU activation: output = activation(a) * b

This fuses the activation function application and element-wise multiplication into a single pass, reducing memory bandwidth and eliminating intermediate tensor allocation.

Supported on CUDA (optimized kernel), Metal (optimized kernel), and CPU (rayon parallelism).

Args: a: Input tensor to apply activation to b: Tensor to multiply with activated values activation: The activation function to apply (SiLU, GELU, or ReLU)

Returns: Tensor with same shape as inputs