pub struct VarBuilderArgs<'a, B: Backend> {
pub dtype: DType,
/* private fields */
}
Expand description
A structure used to retrieve variables, these variables can either come from storage or be generated via some form of initialization.
The way to retrieve variables is defined in the backend embedded in the VarBuilder
.
Fields§
§dtype: DType
Implementations§
source§impl<B: Backend> VarBuilderArgs<'_, B>
impl<B: Backend> VarBuilderArgs<'_, B>
pub fn new_with_args(backend: B, dtype: DType, dev: &Device) -> Self
sourcepub fn set_prefix(&self, prefix: impl ToString) -> Self
pub fn set_prefix(&self, prefix: impl ToString) -> Self
Returns a new VarBuilder
with the prefix set to prefix
.
sourcepub fn push_prefix<S: ToString>(&self, s: S) -> Self
pub fn push_prefix<S: ToString>(&self, s: S) -> Self
Return a new VarBuilder
adding s
to the current prefix. This can be think of as cd
into a directory.
sourcepub fn contains_tensor(&self, tensor_name: &str) -> bool
pub fn contains_tensor(&self, tensor_name: &str) -> bool
This returns true only if a tensor with the passed in name is available. E.g. when passed
a
, true is returned if prefix.a
exists but false is returned if only prefix.a.b
exists.
sourcepub fn get_with_hints<S: Into<Shape>>(
&self,
s: S,
name: &str,
hints: B::Hints,
) -> Result<Tensor>
pub fn get_with_hints<S: Into<Shape>>( &self, s: S, name: &str, hints: B::Hints, ) -> Result<Tensor>
Retrieve the tensor associated with the given name at the current path.
sourcepub fn get<S: Into<Shape>>(&self, s: S, name: &str) -> Result<Tensor>
pub fn get<S: Into<Shape>>(&self, s: S, name: &str) -> Result<Tensor>
Retrieve the tensor associated with the given name at the current path.
sourcepub fn get_unchecked(&self, name: &str) -> Result<Tensor>
pub fn get_unchecked(&self, name: &str) -> Result<Tensor>
Retrieve the tensor associated with the given name at the current path.
sourcepub fn get_unchecked_dtype(&self, name: &str, dtype: DType) -> Result<Tensor>
pub fn get_unchecked_dtype(&self, name: &str, dtype: DType) -> Result<Tensor>
Retrieve the tensor associated with the given name & dtype at the current path.
sourcepub fn get_with_hints_dtype<S: Into<Shape>>(
&self,
s: S,
name: &str,
hints: B::Hints,
dtype: DType,
) -> Result<Tensor>
pub fn get_with_hints_dtype<S: Into<Shape>>( &self, s: S, name: &str, hints: B::Hints, dtype: DType, ) -> Result<Tensor>
Retrieve the tensor associated with the given name & dtype at the current path.
sourcepub fn set_device(self, device: Device) -> Self
pub fn set_device(self, device: Device) -> Self
Set the device of the VarBuilder.
source§impl<'a> VarBuilderArgs<'a, Box<dyn SimpleBackend + 'a>>
impl<'a> VarBuilderArgs<'a, Box<dyn SimpleBackend + 'a>>
sourcepub fn from_backend(
backend: Box<dyn SimpleBackend + 'a>,
dtype: DType,
device: Device,
) -> Self
pub fn from_backend( backend: Box<dyn SimpleBackend + 'a>, dtype: DType, device: Device, ) -> Self
Initializes a VarBuilder
using a custom backend.
It is preferred to use one of the more specific constructors. This constructor is provided to allow downstream users to define their own backends.
sourcepub fn zeros(dtype: DType, dev: &Device) -> Self
pub fn zeros(dtype: DType, dev: &Device) -> Self
Initializes a VarBuilder
that uses zeros for any tensor.
sourcepub fn from_tensors(
ts: HashMap<String, Tensor>,
dtype: DType,
dev: &Device,
) -> Self
pub fn from_tensors( ts: HashMap<String, Tensor>, dtype: DType, dev: &Device, ) -> Self
Initializes a VarBuilder
that retrieves tensors stored in a hashtable. An error is
returned if no tensor is available under the requested path or on shape mismatches.
sourcepub fn from_varmap(varmap: &VarMap, dtype: DType, dev: &Device) -> Self
pub fn from_varmap(varmap: &VarMap, dtype: DType, dev: &Device) -> Self
Initializes a VarBuilder
using a VarMap
. The requested tensors are created and
initialized on new paths, the same tensor is used if the same path is requested multiple
times. This is commonly used when initializing a model before training.
Note that it is possible to load the tensor values after model creation using the load
method on varmap
, this can be used to start model training from an existing checkpoint.
sourcepub unsafe fn from_mmaped_safetensors<P: AsRef<Path>>(
paths: &[P],
dtype: DType,
dev: &Device,
) -> Result<Self>
pub unsafe fn from_mmaped_safetensors<P: AsRef<Path>>( paths: &[P], dtype: DType, dev: &Device, ) -> Result<Self>
Initializes a VarBuilder
that retrieves tensors stored in a collection of safetensors
files.
§Safety
The unsafe is inherited from [memmap2::MmapOptions
].
sourcepub fn from_buffered_safetensors(
data: Vec<u8>,
dtype: DType,
dev: &Device,
) -> Result<Self>
pub fn from_buffered_safetensors( data: Vec<u8>, dtype: DType, dev: &Device, ) -> Result<Self>
Initializes a VarBuilder
from a binary buffer in the safetensor format.
sourcepub fn from_slice_safetensors(
data: &'a [u8],
dtype: DType,
dev: &Device,
) -> Result<Self>
pub fn from_slice_safetensors( data: &'a [u8], dtype: DType, dev: &Device, ) -> Result<Self>
Initializes a VarBuilder
from a binary slice in the safetensor format.
sourcepub fn from_npz<P: AsRef<Path>>(
p: P,
dtype: DType,
dev: &Device,
) -> Result<Self>
pub fn from_npz<P: AsRef<Path>>( p: P, dtype: DType, dev: &Device, ) -> Result<Self>
Initializes a VarBuilder
that retrieves tensors stored in a numpy npz file.
sourcepub fn from_pth<P: AsRef<Path>>(
p: P,
dtype: DType,
dev: &Device,
) -> Result<Self>
pub fn from_pth<P: AsRef<Path>>( p: P, dtype: DType, dev: &Device, ) -> Result<Self>
Initializes a VarBuilder
that retrieves tensors stored in a pytorch pth file.
sourcepub fn rename_f<F: Fn(&str) -> String + Sync + Send + 'static>(
self,
f: F,
) -> Self
pub fn rename_f<F: Fn(&str) -> String + Sync + Send + 'static>( self, f: F, ) -> Self
Gets a VarBuilder that applies some renaming function on tensor it gets queried for before passing the new names to the inner VarBuilder.
use diffusion_rs_common::core::{Tensor, DType, Device};
let a = Tensor::arange(0f32, 6f32, &Device::Cpu)?.reshape((2, 3))?;
let tensors: std::collections::HashMap<_, _> = [
("foo".to_string(), a),
]
.into_iter()
.collect();
let vb = diffusion_rs_common::nn::VarBuilder::from_tensors(tensors, DType::F32, &Device::Cpu);
assert!(vb.contains_tensor("foo"));
assert!(vb.get((2, 3), "foo").is_ok());
assert!(!vb.contains_tensor("bar"));
let vb = vb.rename_f(|f: &str| if f == "bar" { "foo".to_string() } else { f.to_string() });
assert!(vb.contains_tensor("bar"));
assert!(vb.contains_tensor("foo"));
assert!(vb.get((2, 3), "bar").is_ok());
assert!(vb.get((2, 3), "foo").is_ok());
assert!(!vb.contains_tensor("baz"));
pub fn rename<R: Renamer + Send + Sync + 'a>(self, renamer: R) -> Self
Trait Implementations§
Auto Trait Implementations§
impl<'a, B> Freeze for VarBuilderArgs<'a, B>
impl<'a, B> RefUnwindSafe for VarBuilderArgs<'a, B>where
B: RefUnwindSafe,
impl<'a, B> Send for VarBuilderArgs<'a, B>
impl<'a, B> Sync for VarBuilderArgs<'a, B>
impl<'a, B> Unpin for VarBuilderArgs<'a, B>
impl<'a, B> UnwindSafe for VarBuilderArgs<'a, B>where
B: RefUnwindSafe,
Blanket Implementations§
source§impl<T> BorrowMut<T> for Twhere
T: ?Sized,
impl<T> BorrowMut<T> for Twhere
T: ?Sized,
source§fn borrow_mut(&mut self) -> &mut T
fn borrow_mut(&mut self) -> &mut T
source§impl<T> CloneToUninit for Twhere
T: Clone,
impl<T> CloneToUninit for Twhere
T: Clone,
source§unsafe fn clone_to_uninit(&self, dst: *mut T)
unsafe fn clone_to_uninit(&self, dst: *mut T)
clone_to_uninit
)source§impl<T> IntoEither for T
impl<T> IntoEither for T
source§fn into_either(self, into_left: bool) -> Either<Self, Self>
fn into_either(self, into_left: bool) -> Either<Self, Self>
self
into a Left
variant of Either<Self, Self>
if into_left
is true
.
Converts self
into a Right
variant of Either<Self, Self>
otherwise. Read moresource§fn into_either_with<F>(self, into_left: F) -> Either<Self, Self>
fn into_either_with<F>(self, into_left: F) -> Either<Self, Self>
self
into a Left
variant of Either<Self, Self>
if into_left(&self)
returns true
.
Converts self
into a Right
variant of Either<Self, Self>
otherwise. Read more