diffusion_rs_common/nn/linear.rs
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90
//! Linear layer
//!
//! This layer applies a linear transformation to the incoming data, `y = x@w.t() + b`.
//! The bias is optional. The `forward` method can be used to apply the layer, it supports input
//! with a batch dimension (so of shape `(b_sz, in_c)`) or without (of shape `(in_c,)`), the
//! output has shape `(b_sz, out_c)` and `(out_c,)` respectively.
//!
//! ```rust
//! use diffusion_rs_common::core::{Tensor, Device::Cpu};
//! use diffusion_rs_common::nn::{Linear, Module};
//! # fn main() -> diffusion_rs_common::core::Result<()> {
//!
//! let w = Tensor::new(&[[1f32, 2.], [3., 4.], [5., 6.]], &Cpu)?;
//! let layer = Linear::new(w, None); // Use no bias.
//! let xs = Tensor::new(&[[10f32, 100.]], &Cpu)?;
//! let ys = layer.forward(&xs)?;
//! assert_eq!(ys.to_vec2::<f32>()?, &[[210.0, 430.0, 650.0]]);
//! # Ok(()) }
//! ```
use crate::core::{Result, Tensor};
#[derive(Clone, Debug)]
pub struct Linear {
weight: Tensor,
bias: Option<Tensor>,
}
impl Linear {
pub fn new(weight: Tensor, bias: Option<Tensor>) -> Self {
Self { weight, bias }
}
pub fn weight(&self) -> &Tensor {
&self.weight
}
pub fn bias(&self) -> Option<&Tensor> {
self.bias.as_ref()
}
}
impl super::Module for Linear {
fn forward(&self, x: &Tensor) -> crate::core::Result<Tensor> {
let w = match *x.dims() {
[b1, b2, _, _] => self.weight.broadcast_left((b1, b2))?.t()?,
[bsize, _, _] => self.weight.broadcast_left(bsize)?.t()?,
_ => self.weight.t()?,
};
let x = x.matmul(&w)?;
match &self.bias {
None => Ok(x),
Some(bias) => x.broadcast_add(bias),
}
}
}
/// Create or initialize a new linear layer.
///
/// This uses some default names for weights and biases, namely `"weight"` and `"bias"`.
pub fn linear(in_dim: usize, out_dim: usize, vb: crate::nn::VarBuilder) -> Result<Linear> {
let init_ws = crate::nn::init::DEFAULT_KAIMING_NORMAL;
let ws = vb.get_with_hints((out_dim, in_dim), "weight", init_ws)?;
let bound = 1. / (in_dim as f64).sqrt();
let init_bs = crate::nn::Init::Uniform {
lo: -bound,
up: bound,
};
let bs = vb.get_with_hints(out_dim, "bias", init_bs)?;
Ok(Linear::new(ws, Some(bs)))
}
/// Create or initialize a new linear layer without biases.
pub fn linear_no_bias(in_dim: usize, out_dim: usize, vb: crate::nn::VarBuilder) -> Result<Linear> {
let init_ws = crate::nn::init::DEFAULT_KAIMING_NORMAL;
let ws = vb.get_with_hints((out_dim, in_dim), "weight", init_ws)?;
Ok(Linear::new(ws, None))
}
pub fn linear_b(
in_dim: usize,
out_dim: usize,
bias: bool,
vb: crate::nn::VarBuilder,
) -> Result<Linear> {
if bias {
linear(in_dim, out_dim, vb)
} else {
linear_no_bias(in_dim, out_dim, vb)
}
}