pub trait RNN {
type State: Clone;
// Required methods
fn zero_state(&self, batch_dim: usize) -> Result<Self::State>;
fn step(&self, input: &Tensor, state: &Self::State) -> Result<Self::State>;
fn states_to_tensor(&self, states: &[Self::State]) -> Result<Tensor>;
// Provided methods
fn seq(&self, input: &Tensor) -> Result<Vec<Self::State>> { ... }
fn seq_init(
&self,
input: &Tensor,
init_state: &Self::State,
) -> Result<Vec<Self::State>> { ... }
}
Expand description
Trait for Recurrent Neural Networks.
Required Associated Types§
Required Methods§
sourcefn zero_state(&self, batch_dim: usize) -> Result<Self::State>
fn zero_state(&self, batch_dim: usize) -> Result<Self::State>
A zero state from which the recurrent network is usually initialized.
sourcefn step(&self, input: &Tensor, state: &Self::State) -> Result<Self::State>
fn step(&self, input: &Tensor, state: &Self::State) -> Result<Self::State>
Applies a single step of the recurrent network.
The input should have dimensions [batch_size, features].
sourcefn states_to_tensor(&self, states: &[Self::State]) -> Result<Tensor>
fn states_to_tensor(&self, states: &[Self::State]) -> Result<Tensor>
Converts a sequence of state to a tensor.