diffusion_rs_common::nn::var_builder

Trait Backend

source
pub trait Backend: Send + Sync {
    type Hints: Default;

    // Required methods
    fn get(
        &self,
        s: Shape,
        name: &str,
        h: Self::Hints,
        dtype: DType,
        dev: &Device,
    ) -> Result<Tensor>;
    fn get_unchecked(
        &self,
        name: &str,
        dtype: DType,
        dev: &Device,
    ) -> Result<Tensor>;
    fn contains_tensor(&self, name: &str) -> bool;
}
Expand description

A trait that defines how tensor data is retrieved.

Typically this would use disk storage in some specific format, or random initialization. Note that there is a specialized version of this trait (SimpleBackend) that can be used most of the time. The main restriction is that it doesn’t allow for specific args (besides initialization hints).

Required Associated Types§

Required Methods§

source

fn get( &self, s: Shape, name: &str, h: Self::Hints, dtype: DType, dev: &Device, ) -> Result<Tensor>

Retrieve a tensor with some target shape.

source

fn get_unchecked( &self, name: &str, dtype: DType, dev: &Device, ) -> Result<Tensor>

Retrieve a tensor based on the name.

source

fn contains_tensor(&self, name: &str) -> bool

Implementations on Foreign Types§

source§

impl Backend for Box<dyn SimpleBackend + '_>

source§

type Hints = Init

source§

fn get( &self, s: Shape, name: &str, h: Self::Hints, dtype: DType, dev: &Device, ) -> Result<Tensor>

source§

fn get_unchecked( &self, name: &str, dtype: DType, dev: &Device, ) -> Result<Tensor>

source§

fn contains_tensor(&self, name: &str) -> bool

Implementors§

source§

impl Backend for ShardedSafeTensors

Get part of a tensor, typically used to do Tensor Parallelism sharding.

If the tensor is of size (1024, 1024).

dim corresponds to the dimension to slice into rank is the rank of the current process world_size is the total number of ranks in the process group

get_sharded("tensor", 0, 0, 2) means tensor.i((..512)) get_sharded("tensor", 0, 1, 2) means tensor.i((512..)) get_sharded("tensor", 1, 0, 2) means tensor.i((.., ..512))