diffusion_rs_common::nn::loss

Function binary_cross_entropy_with_logit

source
pub fn binary_cross_entropy_with_logit(
    inp: &Tensor,
    target: &Tensor,
) -> Result<Tensor>
Expand description

The binary cross-entropy with logit loss.

Arguments

  • [inp]: The input tensor of dimensions N, C where N is the batch size and C the number of categories. This is expected to raw logits.
  • [target]: The ground truth labels as a tensor of u32 of dimension N, C where N is the batch size and C the number of categories.

The resulting tensor is a scalar containing the average value over the batch.