pub trait Backend: Send + Sync {
type Hints: Default;
// Required methods
fn get(
&self,
s: Shape,
name: &str,
h: Self::Hints,
dtype: DType,
dev: &Device,
) -> Result<Tensor>;
fn get_unchecked(
&self,
name: &str,
dtype: DType,
dev: &Device,
) -> Result<Tensor>;
fn contains_tensor(&self, name: &str) -> bool;
}
Expand description
A trait that defines how tensor data is retrieved.
Typically this would use disk storage in some specific format, or random initialization.
Note that there is a specialized version of this trait (SimpleBackend
) that can be used most
of the time. The main restriction is that it doesn’t allow for specific args (besides
initialization hints).
Required Associated Types§
Required Methods§
sourcefn get(
&self,
s: Shape,
name: &str,
h: Self::Hints,
dtype: DType,
dev: &Device,
) -> Result<Tensor>
fn get( &self, s: Shape, name: &str, h: Self::Hints, dtype: DType, dev: &Device, ) -> Result<Tensor>
Retrieve a tensor with some target shape.
sourcefn get_unchecked(
&self,
name: &str,
dtype: DType,
dev: &Device,
) -> Result<Tensor>
fn get_unchecked( &self, name: &str, dtype: DType, dev: &Device, ) -> Result<Tensor>
Retrieve a tensor based on the name.
fn contains_tensor(&self, name: &str) -> bool
Implementations on Foreign Types§
source§impl Backend for Box<dyn SimpleBackend + '_>
impl Backend for Box<dyn SimpleBackend + '_>
Implementors§
source§impl Backend for ShardedSafeTensors
impl Backend for ShardedSafeTensors
Get part of a tensor, typically used to do Tensor Parallelism sharding.
If the tensor is of size (1024, 1024).
dim
corresponds to the dimension to slice into
rank
is the rank of the current process
world_size
is the total number of ranks in the process group
get_sharded("tensor", 0, 0, 2)
means tensor.i((..512))
get_sharded("tensor", 0, 1, 2)
means tensor.i((512..))
get_sharded("tensor", 1, 0, 2)
means tensor.i((.., ..512))