use crate::core::{Result, Tensor};
use crate::nn::BatchNorm;
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub struct Conv1dConfig {
pub padding: usize,
pub stride: usize,
pub dilation: usize,
pub groups: usize,
}
impl Default for Conv1dConfig {
fn default() -> Self {
Self {
padding: 0,
stride: 1,
dilation: 1,
groups: 1,
}
}
}
#[derive(Clone, Debug)]
pub struct Conv1d {
weight: Tensor,
bias: Option<Tensor>,
config: Conv1dConfig,
}
impl Conv1d {
pub fn new(weight: Tensor, bias: Option<Tensor>, config: Conv1dConfig) -> Self {
Self {
weight,
bias,
config,
}
}
pub fn config(&self) -> &Conv1dConfig {
&self.config
}
pub fn weight(&self) -> &Tensor {
&self.weight
}
pub fn bias(&self) -> Option<&Tensor> {
self.bias.as_ref()
}
}
impl crate::nn::Module for Conv1d {
fn forward(&self, x: &Tensor) -> Result<Tensor> {
let x = x.conv1d(
&self.weight,
self.config.padding,
self.config.stride,
self.config.dilation,
self.config.groups,
)?;
match &self.bias {
None => Ok(x),
Some(bias) => {
let b = bias.dims1()?;
let bias = bias.reshape((1, b, 1))?;
Ok(x.broadcast_add(&bias)?)
}
}
}
}
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub struct ConvTranspose1dConfig {
pub padding: usize,
pub output_padding: usize,
pub stride: usize,
pub dilation: usize,
pub groups: usize,
}
impl Default for ConvTranspose1dConfig {
fn default() -> Self {
Self {
padding: 0,
output_padding: 0,
stride: 1,
dilation: 1,
groups: 1,
}
}
}
#[derive(Clone, Debug)]
pub struct ConvTranspose1d {
weight: Tensor,
bias: Option<Tensor>,
config: ConvTranspose1dConfig,
}
impl ConvTranspose1d {
pub fn new(weight: Tensor, bias: Option<Tensor>, config: ConvTranspose1dConfig) -> Self {
Self {
weight,
bias,
config,
}
}
pub fn config(&self) -> &ConvTranspose1dConfig {
&self.config
}
pub fn weight(&self) -> &Tensor {
&self.weight
}
pub fn bias(&self) -> Option<&Tensor> {
self.bias.as_ref()
}
}
impl crate::nn::Module for ConvTranspose1d {
fn forward(&self, x: &Tensor) -> Result<Tensor> {
let x = x.conv_transpose1d(
&self.weight,
self.config.padding,
self.config.output_padding,
self.config.stride,
self.config.dilation,
self.config.groups,
)?;
match &self.bias {
None => Ok(x),
Some(bias) => {
let b = bias.dims1()?;
let bias = bias.reshape((1, b, 1))?;
Ok(x.broadcast_add(&bias)?)
}
}
}
}
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub struct Conv2dConfig {
pub padding: usize,
pub stride: usize,
pub dilation: usize,
pub groups: usize,
}
impl Default for Conv2dConfig {
fn default() -> Self {
Self {
padding: 0,
stride: 1,
dilation: 1,
groups: 1,
}
}
}
#[derive(Clone, Debug)]
pub struct Conv2d {
weight: Tensor,
bias: Option<Tensor>,
config: Conv2dConfig,
}
impl Conv2d {
pub fn new(weight: Tensor, bias: Option<Tensor>, config: Conv2dConfig) -> Self {
Self {
weight,
bias,
config,
}
}
pub fn config(&self) -> &Conv2dConfig {
&self.config
}
pub fn weight(&self) -> &Tensor {
&self.weight
}
pub fn bias(&self) -> Option<&Tensor> {
self.bias.as_ref()
}
pub fn absorb_bn(&self, bn: &BatchNorm) -> Result<Self> {
if let Some((w_bn, b_bn)) = bn.weight_and_bias() {
let std_ = w_bn.div(&((bn.running_var() + bn.eps())?.sqrt()?))?;
let weight = self
.weight()
.broadcast_mul(&(std_.reshape((self.weight().dims4()?.0, 1, 1, 1))?))?;
let bias = match &self.bias {
None => b_bn.sub(&(std_.mul(bn.running_mean())?))?,
Some(bias) => b_bn.add(&(std_.mul(&bias.sub(bn.running_mean())?)?))?,
};
Ok(Self {
weight,
bias: Some(bias),
config: self.config,
})
} else {
crate::bail!("batch norm does not have weight_and_bias")
}
}
}
impl crate::nn::Module for Conv2d {
fn forward(&self, x: &Tensor) -> Result<Tensor> {
let x = x.conv2d(
&self.weight,
self.config.padding,
self.config.stride,
self.config.dilation,
self.config.groups,
)?;
match &self.bias {
None => Ok(x),
Some(bias) => {
let b = bias.dims1()?;
let bias = bias.reshape((1, b, 1, 1))?;
Ok(x.broadcast_add(&bias)?)
}
}
}
}
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub struct ConvTranspose2dConfig {
pub padding: usize,
pub output_padding: usize,
pub stride: usize,
pub dilation: usize,
}
impl Default for ConvTranspose2dConfig {
fn default() -> Self {
Self {
padding: 0,
output_padding: 0,
stride: 1,
dilation: 1,
}
}
}
#[derive(Clone, Debug)]
pub struct ConvTranspose2d {
weight: Tensor,
bias: Option<Tensor>,
config: ConvTranspose2dConfig,
}
impl ConvTranspose2d {
pub fn new(weight: Tensor, bias: Option<Tensor>, config: ConvTranspose2dConfig) -> Self {
Self {
weight,
bias,
config,
}
}
pub fn config(&self) -> &ConvTranspose2dConfig {
&self.config
}
pub fn weight(&self) -> &Tensor {
&self.weight
}
pub fn bias(&self) -> Option<&Tensor> {
self.bias.as_ref()
}
}
impl crate::nn::Module for ConvTranspose2d {
fn forward(&self, x: &Tensor) -> Result<Tensor> {
let x = x.conv_transpose2d(
&self.weight,
self.config.padding,
self.config.output_padding,
self.config.stride,
self.config.dilation,
)?;
match &self.bias {
None => Ok(x),
Some(bias) => {
let b = bias.dims1()?;
let bias = bias.reshape((1, b, 1, 1))?;
Ok(x.broadcast_add(&bias)?)
}
}
}
}
pub fn conv1d(
in_channels: usize,
out_channels: usize,
kernel_size: usize,
cfg: Conv1dConfig,
vb: crate::nn::VarBuilder,
) -> Result<Conv1d> {
let init_ws = crate::nn::init::DEFAULT_KAIMING_NORMAL;
let ws = vb.get_with_hints(
(out_channels, in_channels / cfg.groups, kernel_size),
"weight",
init_ws,
)?;
let bound = 1. / (in_channels as f64).sqrt();
let init_bs = crate::nn::Init::Uniform {
lo: -bound,
up: bound,
};
let bs = vb.get_with_hints(out_channels, "bias", init_bs)?;
Ok(Conv1d::new(ws, Some(bs), cfg))
}
pub fn conv1d_no_bias(
in_channels: usize,
out_channels: usize,
kernel_size: usize,
cfg: Conv1dConfig,
vb: crate::nn::VarBuilder,
) -> Result<Conv1d> {
let init_ws = crate::nn::init::DEFAULT_KAIMING_NORMAL;
let ws = vb.get_with_hints(
(out_channels, in_channels / cfg.groups, kernel_size),
"weight",
init_ws,
)?;
Ok(Conv1d::new(ws, None, cfg))
}
pub fn conv_transpose1d(
in_channels: usize,
out_channels: usize,
kernel_size: usize,
cfg: ConvTranspose1dConfig,
vb: crate::nn::VarBuilder,
) -> Result<ConvTranspose1d> {
let bound = 1. / (out_channels as f64 * kernel_size as f64).sqrt();
let init = crate::nn::Init::Uniform {
lo: -bound,
up: bound,
};
let ws = vb.get_with_hints(
(in_channels, out_channels / cfg.groups, kernel_size),
"weight",
init,
)?;
let bs = vb.get_with_hints(out_channels, "bias", init)?;
Ok(ConvTranspose1d::new(ws, Some(bs), cfg))
}
pub fn conv_transpose1d_no_bias(
in_channels: usize,
out_channels: usize,
kernel_size: usize,
cfg: ConvTranspose1dConfig,
vb: crate::nn::VarBuilder,
) -> Result<ConvTranspose1d> {
let bound = 1. / (out_channels as f64 * kernel_size as f64).sqrt();
let init = crate::nn::Init::Uniform {
lo: -bound,
up: bound,
};
let ws = vb.get_with_hints(
(in_channels, out_channels / cfg.groups, kernel_size),
"weight",
init,
)?;
Ok(ConvTranspose1d::new(ws, None, cfg))
}
pub fn conv2d(
in_channels: usize,
out_channels: usize,
kernel_size: usize,
cfg: Conv2dConfig,
vb: crate::nn::VarBuilder,
) -> Result<Conv2d> {
let init_ws = crate::nn::init::DEFAULT_KAIMING_NORMAL;
let ws = vb.get_with_hints(
(
out_channels,
in_channels / cfg.groups,
kernel_size,
kernel_size,
),
"weight",
init_ws,
)?;
let bound = 1. / (in_channels as f64).sqrt();
let init_bs = crate::nn::Init::Uniform {
lo: -bound,
up: bound,
};
let bs = vb.get_with_hints(out_channels, "bias", init_bs)?;
Ok(Conv2d::new(ws, Some(bs), cfg))
}
pub fn conv2d_no_bias(
in_channels: usize,
out_channels: usize,
kernel_size: usize,
cfg: Conv2dConfig,
vb: crate::nn::VarBuilder,
) -> Result<Conv2d> {
let init_ws = crate::nn::init::DEFAULT_KAIMING_NORMAL;
let ws = vb.get_with_hints(
(
out_channels,
in_channels / cfg.groups,
kernel_size,
kernel_size,
),
"weight",
init_ws,
)?;
Ok(Conv2d::new(ws, None, cfg))
}
pub fn conv_transpose2d(
in_channels: usize,
out_channels: usize,
kernel_size: usize,
cfg: ConvTranspose2dConfig,
vb: crate::nn::VarBuilder,
) -> Result<ConvTranspose2d> {
let bound = 1. / (out_channels as f64).sqrt() / kernel_size as f64;
let init = crate::nn::Init::Uniform {
lo: -bound,
up: bound,
};
let ws = vb.get_with_hints(
(in_channels, out_channels, kernel_size, kernel_size),
"weight",
init,
)?;
let bs = vb.get_with_hints(out_channels, "bias", init)?;
Ok(ConvTranspose2d::new(ws, Some(bs), cfg))
}
pub fn conv_transpose2d_no_bias(
in_channels: usize,
out_channels: usize,
kernel_size: usize,
cfg: ConvTranspose2dConfig,
vb: crate::nn::VarBuilder,
) -> Result<ConvTranspose2d> {
let bound = 1. / (out_channels as f64).sqrt() / kernel_size as f64;
let init = crate::nn::Init::Uniform {
lo: -bound,
up: bound,
};
let ws = vb.get_with_hints(
(in_channels, out_channels, kernel_size, kernel_size),
"weight",
init,
)?;
Ok(ConvTranspose2d::new(ws, None, cfg))
}