diffusion_rs_common::nn::rnn

Trait RNN

source
pub trait RNN {
    type State: Clone;

    // Required methods
    fn zero_state(&self, batch_dim: usize) -> Result<Self::State>;
    fn step(&self, input: &Tensor, state: &Self::State) -> Result<Self::State>;
    fn states_to_tensor(&self, states: &[Self::State]) -> Result<Tensor>;

    // Provided methods
    fn seq(&self, input: &Tensor) -> Result<Vec<Self::State>> { ... }
    fn seq_init(
        &self,
        input: &Tensor,
        init_state: &Self::State,
    ) -> Result<Vec<Self::State>> { ... }
}
Expand description

Trait for Recurrent Neural Networks.

Required Associated Types§

Required Methods§

source

fn zero_state(&self, batch_dim: usize) -> Result<Self::State>

A zero state from which the recurrent network is usually initialized.

source

fn step(&self, input: &Tensor, state: &Self::State) -> Result<Self::State>

Applies a single step of the recurrent network.

The input should have dimensions [batch_size, features].

source

fn states_to_tensor(&self, states: &[Self::State]) -> Result<Tensor>

Converts a sequence of state to a tensor.

Provided Methods§

source

fn seq(&self, input: &Tensor) -> Result<Vec<Self::State>>

Applies multiple steps of the recurrent network.

The input should have dimensions [batch_size, seq_len, features]. The initial state is the result of applying zero_state.

source

fn seq_init( &self, input: &Tensor, init_state: &Self::State, ) -> Result<Vec<Self::State>>

Applies multiple steps of the recurrent network.

The input should have dimensions [batch_size, seq_len, features].

Implementors§