Function cross_entropy_loss
pub fn cross_entropy_loss(
inp: &Tensor,
target: &Tensor,
) -> Result<Tensor, Error>
Expand description
The cross-entropy loss.
Arguments
- [inp]: The input tensor of dimensions
N, C
whereN
is the batch size andC
the number of categories. This is expected to raw logits. - [target]: The ground truth labels as a tensor of u32 of dimension
N
.
The resulting tensor is a scalar containing the average value over the batch.