diffusion_rs_common::nn::var_builder

Trait SimpleBackend

source
pub trait SimpleBackend: Send + Sync {
    // Required methods
    fn get(
        &self,
        s: Shape,
        name: &str,
        h: Init,
        dtype: DType,
        dev: &Device,
    ) -> Result<Tensor>;
    fn get_unchecked(
        &self,
        name: &str,
        dtype: DType,
        dev: &Device,
    ) -> Result<Tensor>;
    fn contains_tensor(&self, name: &str) -> bool;
}

Required Methods§

source

fn get( &self, s: Shape, name: &str, h: Init, dtype: DType, dev: &Device, ) -> Result<Tensor>

Retrieve a tensor based on a target name and shape.

source

fn get_unchecked( &self, name: &str, dtype: DType, dev: &Device, ) -> Result<Tensor>

Retrieve a tensor based on the name.

source

fn contains_tensor(&self, name: &str) -> bool

Trait Implementations§

source§

impl Backend for Box<dyn SimpleBackend + '_>

source§

type Hints = Init

source§

fn get( &self, s: Shape, name: &str, h: Self::Hints, dtype: DType, dev: &Device, ) -> Result<Tensor>

Retrieve a tensor with some target shape.
source§

fn get_unchecked( &self, name: &str, dtype: DType, dev: &Device, ) -> Result<Tensor>

Retrieve a tensor based on the name.
source§

fn contains_tensor(&self, name: &str) -> bool

Implementations on Foreign Types§

source§

impl SimpleBackend for HashMap<String, Tensor>

source§

fn get( &self, s: Shape, name: &str, _: Init, dtype: DType, dev: &Device, ) -> Result<Tensor>

source§

fn get_unchecked( &self, name: &str, dtype: DType, dev: &Device, ) -> Result<Tensor>

source§

fn contains_tensor(&self, name: &str) -> bool

Implementors§