use crate::core::{DType, Result, Tensor, Var};
#[derive(Debug, Clone, Copy, PartialEq)]
pub struct BatchNormConfig {
pub eps: f64,
pub remove_mean: bool,
pub affine: bool,
pub momentum: f64,
}
impl Default for BatchNormConfig {
fn default() -> Self {
Self {
eps: 1e-5,
remove_mean: true,
affine: true,
momentum: 0.1,
}
}
}
impl From<f64> for BatchNormConfig {
fn from(eps: f64) -> Self {
Self {
eps,
..Default::default()
}
}
}
#[derive(Clone, Debug)]
pub struct BatchNorm {
running_mean: Var,
running_var: Var,
weight_and_bias: Option<(Tensor, Tensor)>,
remove_mean: bool,
eps: f64,
momentum: f64,
}
impl BatchNorm {
fn check_validity(&self, num_features: usize) -> Result<()> {
if self.eps < 0. {
crate::bail!("batch-norm eps cannot be negative {}", self.eps)
}
if !(0.0..=1.0).contains(&self.momentum) {
crate::bail!(
"batch-norm momentum must be between 0 and 1, is {}",
self.momentum
)
}
if self.running_mean.dims() != [num_features] {
crate::bail!(
"batch-norm running mean has unexpected shape {:?} should have shape [{num_features}]",
self.running_mean.shape(),
)
}
if self.running_var.dims() != [num_features] {
crate::bail!(
"batch-norm running variance has unexpected shape {:?} should have shape [{num_features}]",
self.running_var.shape(),
)
}
if let Some((ref weight, ref bias)) = self.weight_and_bias.as_ref() {
if weight.dims() != [num_features] {
crate::bail!(
"batch-norm weight has unexpected shape {:?} should have shape [{num_features}]",
weight.shape(),
)
}
if bias.dims() != [num_features] {
crate::bail!(
"batch-norm weight has unexpected shape {:?} should have shape [{num_features}]",
bias.shape(),
)
}
}
Ok(())
}
pub fn new(
num_features: usize,
running_mean: Tensor,
running_var: Tensor,
weight: Tensor,
bias: Tensor,
eps: f64,
) -> Result<Self> {
let out = Self {
running_mean: Var::from_tensor(&running_mean)?,
running_var: Var::from_tensor(&running_var)?,
weight_and_bias: Some((weight, bias)),
remove_mean: true,
eps,
momentum: 0.1,
};
out.check_validity(num_features)?;
Ok(out)
}
pub fn new_no_bias(
num_features: usize,
running_mean: Tensor,
running_var: Tensor,
eps: f64,
) -> Result<Self> {
let out = Self {
running_mean: Var::from_tensor(&running_mean)?,
running_var: Var::from_tensor(&running_var)?,
weight_and_bias: None,
remove_mean: true,
eps,
momentum: 0.1,
};
out.check_validity(num_features)?;
Ok(out)
}
pub fn new_with_momentum(
num_features: usize,
running_mean: Tensor,
running_var: Tensor,
weight: Tensor,
bias: Tensor,
eps: f64,
momentum: f64,
) -> Result<Self> {
let out = Self {
running_mean: Var::from_tensor(&running_mean)?,
running_var: Var::from_tensor(&running_var)?,
weight_and_bias: Some((weight, bias)),
remove_mean: true,
eps,
momentum,
};
out.check_validity(num_features)?;
Ok(out)
}
pub fn new_no_bias_with_momentum(
num_features: usize,
running_mean: Tensor,
running_var: Tensor,
eps: f64,
momentum: f64,
) -> Result<Self> {
let out = Self {
running_mean: Var::from_tensor(&running_mean)?,
running_var: Var::from_tensor(&running_var)?,
weight_and_bias: None,
remove_mean: true,
eps,
momentum,
};
out.check_validity(num_features)?;
Ok(out)
}
pub fn running_mean(&self) -> &Tensor {
self.running_mean.as_tensor()
}
pub fn running_var(&self) -> &Tensor {
self.running_var.as_tensor()
}
pub fn eps(&self) -> f64 {
self.eps
}
pub fn weight_and_bias(&self) -> Option<(&Tensor, &Tensor)> {
self.weight_and_bias.as_ref().map(|v| (&v.0, &v.1))
}
pub fn momentum(&self) -> f64 {
self.momentum
}
pub fn forward_train(&self, x: &Tensor) -> Result<Tensor> {
let num_features = self.running_mean.as_tensor().dim(0)?;
let x_dtype = x.dtype();
let internal_dtype = match x_dtype {
DType::F16 | DType::BF16 => DType::F32,
d => d,
};
if x.rank() < 2 {
crate::bail!(
"batch-norm input tensor must have at least two dimensions ({:?})",
x.shape()
)
}
if x.dim(1)? != num_features {
crate::bail!(
"batch-norm input doesn't have the expected number of features ({:?} <> {})",
x.shape(),
num_features
)
}
let x = x.to_dtype(internal_dtype)?;
let x = x.transpose(0, 1)?;
let x_dims_post_transpose = x.dims();
let x = x.flatten_from(1)?.contiguous()?;
let x = if self.remove_mean {
let mean_x = x.mean_keepdim(1)?;
let updated_running_mean = ((self.running_mean.as_tensor() * (1.0 - self.momentum))?
+ (mean_x.flatten_all()? * self.momentum)?)?;
self.running_mean.set(&updated_running_mean)?;
x.broadcast_sub(&mean_x)?
} else {
x
};
let norm_x = x.sqr()?.mean_keepdim(1)?;
let updated_running_var = {
let batch_size = x.dim(1)? as f64;
let running_var_weight = 1.0 - self.momentum;
let norm_x_weight = self.momentum * batch_size / (batch_size - 1.0);
((self.running_var.as_tensor() * running_var_weight)?
+ (&norm_x.flatten_all()? * norm_x_weight)?)?
};
self.running_var.set(&updated_running_var)?;
let x = x
.broadcast_div(&(norm_x + self.eps)?.sqrt()?)?
.to_dtype(x_dtype)?;
let x = match &self.weight_and_bias {
None => x,
Some((weight, bias)) => {
let weight = weight.reshape(((), 1))?;
let bias = bias.reshape(((), 1))?;
x.broadcast_mul(&weight)?.broadcast_add(&bias)?
}
};
x.reshape(x_dims_post_transpose)?.transpose(0, 1)
}
fn forward_eval(&self, x: &Tensor) -> Result<Tensor> {
let target_shape: Vec<usize> = x
.dims()
.iter()
.enumerate()
.map(|(idx, v)| if idx == 1 { *v } else { 1 })
.collect();
let target_shape = target_shape.as_slice();
let x = x
.broadcast_sub(
&self
.running_mean
.as_detached_tensor()
.reshape(target_shape)?,
)?
.broadcast_div(
&(self
.running_var
.as_detached_tensor()
.reshape(target_shape)?
+ self.eps)?
.sqrt()?,
)?;
match &self.weight_and_bias {
None => Ok(x),
Some((weight, bias)) => {
let weight = weight.reshape(target_shape)?;
let bias = bias.reshape(target_shape)?;
x.broadcast_mul(&weight)?.broadcast_add(&bias)
}
}
}
}
impl crate::nn::ModuleT for BatchNorm {
fn forward_t(&self, x: &Tensor, train: bool) -> Result<Tensor> {
if train {
self.forward_train(x)
} else {
self.forward_eval(x)
}
}
}
pub fn batch_norm<C: Into<BatchNormConfig>>(
num_features: usize,
config: C,
vb: crate::nn::VarBuilder,
) -> Result<BatchNorm> {
use crate::nn::Init;
let config = config.into();
if config.eps < 0. {
crate::bail!("batch-norm eps cannot be negative {}", config.eps)
}
let running_mean = vb.get_with_hints(num_features, "running_mean", Init::Const(0.))?;
let running_var = vb.get_with_hints(num_features, "running_var", Init::Const(1.))?;
let weight_and_bias = if config.affine {
let weight = vb.get_with_hints(num_features, "weight", Init::Const(1.))?;
let bias = vb.get_with_hints(num_features, "bias", Init::Const(0.))?;
Some((weight, bias))
} else {
None
};
Ok(BatchNorm {
running_mean: Var::from_tensor(&running_mean)?,
running_var: Var::from_tensor(&running_var)?,
weight_and_bias,
remove_mean: config.remove_mean,
eps: config.eps,
momentum: config.momentum,
})
}