diffusion_rs_common::nn::var_builder

Trait SimpleBackend

Source
pub trait SimpleBackend: Send + Sync {
    // Required methods
    fn get(
        &self,
        s: Shape,
        name: &str,
        h: Init,
        dtype: DType,
        dev: &Device,
    ) -> Result<Tensor>;
    fn get_unchecked(
        &self,
        name: &str,
        dtype: DType,
        dev: &Device,
    ) -> Result<Tensor>;
    fn contains_tensor(&self, name: &str) -> bool;
}

Required Methods§

Source

fn get( &self, s: Shape, name: &str, h: Init, dtype: DType, dev: &Device, ) -> Result<Tensor>

Retrieve a tensor based on a target name and shape.

Source

fn get_unchecked( &self, name: &str, dtype: DType, dev: &Device, ) -> Result<Tensor>

Retrieve a tensor based on the name.

Source

fn contains_tensor(&self, name: &str) -> bool

Trait Implementations§

Source§

impl Backend for Box<dyn SimpleBackend + '_>

Source§

type Hints = Init

Source§

fn get( &self, s: Shape, name: &str, h: Self::Hints, dtype: DType, dev: &Device, ) -> Result<Tensor>

Retrieve a tensor with some target shape.
Source§

fn get_unchecked( &self, name: &str, dtype: DType, dev: &Device, ) -> Result<Tensor>

Retrieve a tensor based on the name.
Source§

fn contains_tensor(&self, name: &str) -> bool

Implementations on Foreign Types§

Source§

impl SimpleBackend for HashMap<String, Tensor>

Source§

fn get( &self, s: Shape, name: &str, _: Init, dtype: DType, dev: &Device, ) -> Result<Tensor>

Source§

fn get_unchecked( &self, name: &str, dtype: DType, dev: &Device, ) -> Result<Tensor>

Source§

fn contains_tensor(&self, name: &str) -> bool

Implementors§