diffusion_rs_common::nn::var_builder

Struct VarBuilderArgs

source
pub struct VarBuilderArgs<'a, B: Backend> {
    pub dtype: DType,
    /* private fields */
}
Expand description

A structure used to retrieve variables, these variables can either come from storage or be generated via some form of initialization.

The way to retrieve variables is defined in the backend embedded in the VarBuilder.

Fields§

§dtype: DType

Implementations§

source§

impl<B: Backend> VarBuilderArgs<'_, B>

source

pub fn new_with_args(backend: B, dtype: DType, dev: &Device) -> Self

source

pub fn prefix(&self) -> String

Returns the prefix of the VarBuilder.

source

pub fn root(&self) -> Self

Returns a new VarBuilder using the root path.

source

pub fn set_prefix(&self, prefix: impl ToString) -> Self

Returns a new VarBuilder with the prefix set to prefix.

source

pub fn push_prefix<S: ToString>(&self, s: S) -> Self

Return a new VarBuilder adding s to the current prefix. This can be think of as cd into a directory.

source

pub fn pp<S: ToString>(&self, s: S) -> Self

Short alias for push_prefix.

source

pub fn device(&self) -> &Device

The device used by default.

source

pub fn dtype(&self) -> DType

The dtype used by default.

source

pub fn to_dtype(&self, dtype: DType) -> Self

Clone the VarBuilder tweaking its dtype

source

pub fn contains_tensor(&self, tensor_name: &str) -> bool

This returns true only if a tensor with the passed in name is available. E.g. when passed a, true is returned if prefix.a exists but false is returned if only prefix.a.b exists.

source

pub fn get_with_hints<S: Into<Shape>>( &self, s: S, name: &str, hints: B::Hints, ) -> Result<Tensor>

Retrieve the tensor associated with the given name at the current path.

source

pub fn get<S: Into<Shape>>(&self, s: S, name: &str) -> Result<Tensor>

Retrieve the tensor associated with the given name at the current path.

source

pub fn get_unchecked(&self, name: &str) -> Result<Tensor>

Retrieve the tensor associated with the given name at the current path.

source

pub fn get_unchecked_dtype(&self, name: &str, dtype: DType) -> Result<Tensor>

Retrieve the tensor associated with the given name & dtype at the current path.

source

pub fn get_with_hints_dtype<S: Into<Shape>>( &self, s: S, name: &str, hints: B::Hints, dtype: DType, ) -> Result<Tensor>

Retrieve the tensor associated with the given name & dtype at the current path.

source

pub fn set_device(self, device: Device) -> Self

Set the device of the VarBuilder.

source

pub fn set_dtype(self, dtype: DType) -> Self

Set the dtype of the VarBuilder.

source§

impl<'a> VarBuilderArgs<'a, Box<dyn SimpleBackend + 'a>>

source

pub fn from_backend( backend: Box<dyn SimpleBackend + 'a>, dtype: DType, device: Device, ) -> Self

Initializes a VarBuilder using a custom backend.

It is preferred to use one of the more specific constructors. This constructor is provided to allow downstream users to define their own backends.

source

pub fn zeros(dtype: DType, dev: &Device) -> Self

Initializes a VarBuilder that uses zeros for any tensor.

source

pub fn from_tensors( ts: HashMap<String, Tensor>, dtype: DType, dev: &Device, ) -> Self

Initializes a VarBuilder that retrieves tensors stored in a hashtable. An error is returned if no tensor is available under the requested path or on shape mismatches.

source

pub fn from_varmap(varmap: &VarMap, dtype: DType, dev: &Device) -> Self

Initializes a VarBuilder using a VarMap. The requested tensors are created and initialized on new paths, the same tensor is used if the same path is requested multiple times. This is commonly used when initializing a model before training.

Note that it is possible to load the tensor values after model creation using the load method on varmap, this can be used to start model training from an existing checkpoint.

source

pub unsafe fn from_mmaped_safetensors<P: AsRef<Path>>( paths: &[P], dtype: DType, dev: &Device, ) -> Result<Self>

Initializes a VarBuilder that retrieves tensors stored in a collection of safetensors files.

§Safety

The unsafe is inherited from [memmap2::MmapOptions].

source

pub fn from_buffered_safetensors( data: Vec<u8>, dtype: DType, dev: &Device, ) -> Result<Self>

Initializes a VarBuilder from a binary buffer in the safetensor format.

source

pub fn from_slice_safetensors( data: &'a [u8], dtype: DType, dev: &Device, ) -> Result<Self>

Initializes a VarBuilder from a binary slice in the safetensor format.

source

pub fn from_npz<P: AsRef<Path>>( p: P, dtype: DType, dev: &Device, ) -> Result<Self>

Initializes a VarBuilder that retrieves tensors stored in a numpy npz file.

source

pub fn from_pth<P: AsRef<Path>>( p: P, dtype: DType, dev: &Device, ) -> Result<Self>

Initializes a VarBuilder that retrieves tensors stored in a pytorch pth file.

source

pub fn rename_f<F: Fn(&str) -> String + Sync + Send + 'static>( self, f: F, ) -> Self

Gets a VarBuilder that applies some renaming function on tensor it gets queried for before passing the new names to the inner VarBuilder.

use diffusion_rs_common::core::{Tensor, DType, Device};

let a = Tensor::arange(0f32, 6f32, &Device::Cpu)?.reshape((2, 3))?;
let tensors: std::collections::HashMap<_, _> = [
    ("foo".to_string(), a),
]
.into_iter()
.collect();
let vb = diffusion_rs_common::nn::VarBuilder::from_tensors(tensors, DType::F32, &Device::Cpu);
assert!(vb.contains_tensor("foo"));
assert!(vb.get((2, 3), "foo").is_ok());
assert!(!vb.contains_tensor("bar"));
let vb = vb.rename_f(|f: &str| if f == "bar" { "foo".to_string() } else { f.to_string() });
assert!(vb.contains_tensor("bar"));
assert!(vb.contains_tensor("foo"));
assert!(vb.get((2, 3), "bar").is_ok());
assert!(vb.get((2, 3), "foo").is_ok());
assert!(!vb.contains_tensor("baz"));
source

pub fn rename<R: Renamer + Send + Sync + 'a>(self, renamer: R) -> Self

Trait Implementations§

source§

impl<B: Backend> Clone for VarBuilderArgs<'_, B>

source§

fn clone(&self) -> Self

Returns a copy of the value. Read more
1.0.0 · source§

fn clone_from(&mut self, source: &Self)

Performs copy-assignment from source. Read more

Auto Trait Implementations§

§

impl<'a, B> Freeze for VarBuilderArgs<'a, B>

§

impl<'a, B> RefUnwindSafe for VarBuilderArgs<'a, B>
where B: RefUnwindSafe,

§

impl<'a, B> Send for VarBuilderArgs<'a, B>

§

impl<'a, B> Sync for VarBuilderArgs<'a, B>

§

impl<'a, B> Unpin for VarBuilderArgs<'a, B>

§

impl<'a, B> UnwindSafe for VarBuilderArgs<'a, B>
where B: RefUnwindSafe,

Blanket Implementations§

source§

impl<T> Any for T
where T: 'static + ?Sized,

source§

fn type_id(&self) -> TypeId

Gets the TypeId of self. Read more
source§

impl<T> Borrow<T> for T
where T: ?Sized,

source§

fn borrow(&self) -> &T

Immutably borrows from an owned value. Read more
source§

impl<T> BorrowMut<T> for T
where T: ?Sized,

source§

fn borrow_mut(&mut self) -> &mut T

Mutably borrows from an owned value. Read more
source§

impl<T> CloneToUninit for T
where T: Clone,

source§

unsafe fn clone_to_uninit(&self, dst: *mut T)

🔬This is a nightly-only experimental API. (clone_to_uninit)
Performs copy-assignment from self to dst. Read more
source§

impl<T> From<T> for T

source§

fn from(t: T) -> T

Returns the argument unchanged.

source§

impl<T, U> Into<U> for T
where U: From<T>,

source§

fn into(self) -> U

Calls U::from(self).

That is, this conversion is whatever the implementation of From<T> for U chooses to do.

source§

impl<T> IntoEither for T

source§

fn into_either(self, into_left: bool) -> Either<Self, Self>

Converts self into a Left variant of Either<Self, Self> if into_left is true. Converts self into a Right variant of Either<Self, Self> otherwise. Read more
source§

fn into_either_with<F>(self, into_left: F) -> Either<Self, Self>
where F: FnOnce(&Self) -> bool,

Converts self into a Left variant of Either<Self, Self> if into_left(&self) returns true. Converts self into a Right variant of Either<Self, Self> otherwise. Read more
§

impl<T> Pointable for T

§

const ALIGN: usize = _

The alignment of pointer.
§

type Init = T

The type for initializers.
§

unsafe fn init(init: <T as Pointable>::Init) -> usize

Initializes a with the given initializer. Read more
§

unsafe fn deref<'a>(ptr: usize) -> &'a T

Dereferences the given pointer. Read more
§

unsafe fn deref_mut<'a>(ptr: usize) -> &'a mut T

Mutably dereferences the given pointer. Read more
§

unsafe fn drop(ptr: usize)

Drops the object pointed to by the given pointer. Read more
source§

impl<T> Same for T

source§

type Output = T

Should always be Self
source§

impl<T> ToOwned for T
where T: Clone,

source§

type Owned = T

The resulting type after obtaining ownership.
source§

fn to_owned(&self) -> T

Creates owned data from borrowed data, usually by cloning. Read more
source§

fn clone_into(&self, target: &mut T)

Uses borrowed data to replace owned data, usually by cloning. Read more
source§

impl<T, U> TryFrom<U> for T
where U: Into<T>,

source§

type Error = Infallible

The type returned in the event of a conversion error.
source§

fn try_from(value: U) -> Result<T, <T as TryFrom<U>>::Error>

Performs the conversion.
source§

impl<T, U> TryInto<U> for T
where U: TryFrom<T>,

source§

type Error = <U as TryFrom<T>>::Error

The type returned in the event of a conversion error.
source§

fn try_into(self) -> Result<U, <U as TryFrom<T>>::Error>

Performs the conversion.
§

impl<V, T> VZip<V> for T
where V: MultiLane<T>,

§

fn vzip(self) -> V

§

impl<T> ErasedDestructor for T
where T: 'static,

§

impl<T> MaybeSendSync for T