diffusion_rs_common::nn::var_builder

Type Alias VarBuilder

source
pub type VarBuilder<'a> = VarBuilderArgs<'a, Box<dyn SimpleBackend + 'a>>;
Expand description

A simple VarBuilder, this is less generic than VarBuilderArgs but should cover most common use cases.

Aliased Type§

struct VarBuilder<'a> {
    pub dtype: DType,
    /* private fields */
}

Fields§

§dtype: DType

Implementations§

source§

impl<'a> VarBuilder<'a>

source

pub fn from_backend( backend: Box<dyn SimpleBackend + 'a>, dtype: DType, device: Device, ) -> Self

Initializes a VarBuilder using a custom backend.

It is preferred to use one of the more specific constructors. This constructor is provided to allow downstream users to define their own backends.

source

pub fn zeros(dtype: DType, dev: &Device) -> Self

Initializes a VarBuilder that uses zeros for any tensor.

source

pub fn from_tensors( ts: HashMap<String, Tensor>, dtype: DType, dev: &Device, ) -> Self

Initializes a VarBuilder that retrieves tensors stored in a hashtable. An error is returned if no tensor is available under the requested path or on shape mismatches.

source

pub fn from_varmap(varmap: &VarMap, dtype: DType, dev: &Device) -> Self

Initializes a VarBuilder using a VarMap. The requested tensors are created and initialized on new paths, the same tensor is used if the same path is requested multiple times. This is commonly used when initializing a model before training.

Note that it is possible to load the tensor values after model creation using the load method on varmap, this can be used to start model training from an existing checkpoint.

source

pub unsafe fn from_mmaped_safetensors<P: AsRef<Path>>( paths: &[P], dtype: DType, dev: &Device, ) -> Result<Self>

Initializes a VarBuilder that retrieves tensors stored in a collection of safetensors files.

§Safety

The unsafe is inherited from [memmap2::MmapOptions].

source

pub fn from_buffered_safetensors( data: Vec<u8>, dtype: DType, dev: &Device, ) -> Result<Self>

Initializes a VarBuilder from a binary buffer in the safetensor format.

source

pub fn from_slice_safetensors( data: &'a [u8], dtype: DType, dev: &Device, ) -> Result<Self>

Initializes a VarBuilder from a binary slice in the safetensor format.

source

pub fn from_npz<P: AsRef<Path>>( p: P, dtype: DType, dev: &Device, ) -> Result<Self>

Initializes a VarBuilder that retrieves tensors stored in a numpy npz file.

source

pub fn from_pth<P: AsRef<Path>>( p: P, dtype: DType, dev: &Device, ) -> Result<Self>

Initializes a VarBuilder that retrieves tensors stored in a pytorch pth file.

source

pub fn rename_f<F: Fn(&str) -> String + Sync + Send + 'static>( self, f: F, ) -> Self

Gets a VarBuilder that applies some renaming function on tensor it gets queried for before passing the new names to the inner VarBuilder.

use diffusion_rs_common::core::{Tensor, DType, Device};

let a = Tensor::arange(0f32, 6f32, &Device::Cpu)?.reshape((2, 3))?;
let tensors: std::collections::HashMap<_, _> = [
    ("foo".to_string(), a),
]
.into_iter()
.collect();
let vb = diffusion_rs_common::nn::VarBuilder::from_tensors(tensors, DType::F32, &Device::Cpu);
assert!(vb.contains_tensor("foo"));
assert!(vb.get((2, 3), "foo").is_ok());
assert!(!vb.contains_tensor("bar"));
let vb = vb.rename_f(|f: &str| if f == "bar" { "foo".to_string() } else { f.to_string() });
assert!(vb.contains_tensor("bar"));
assert!(vb.contains_tensor("foo"));
assert!(vb.get((2, 3), "bar").is_ok());
assert!(vb.get((2, 3), "foo").is_ok());
assert!(!vb.contains_tensor("baz"));
source

pub fn rename<R: Renamer + Send + Sync + 'a>(self, renamer: R) -> Self