pub type VarBuilder<'a> = VarBuilderArgs<'a, Box<dyn SimpleBackend + 'a>>;
Expand description
A simple VarBuilder
, this is less generic than VarBuilderArgs
but should cover most common
use cases.
Aliased Type§
struct VarBuilder<'a> {
pub dtype: DType,
/* private fields */
}
Fields§
§dtype: DType
Implementations§
source§impl<'a> VarBuilder<'a>
impl<'a> VarBuilder<'a>
sourcepub fn from_backend(
backend: Box<dyn SimpleBackend + 'a>,
dtype: DType,
device: Device,
) -> Self
pub fn from_backend( backend: Box<dyn SimpleBackend + 'a>, dtype: DType, device: Device, ) -> Self
Initializes a VarBuilder
using a custom backend.
It is preferred to use one of the more specific constructors. This constructor is provided to allow downstream users to define their own backends.
sourcepub fn zeros(dtype: DType, dev: &Device) -> Self
pub fn zeros(dtype: DType, dev: &Device) -> Self
Initializes a VarBuilder
that uses zeros for any tensor.
sourcepub fn from_tensors(
ts: HashMap<String, Tensor>,
dtype: DType,
dev: &Device,
) -> Self
pub fn from_tensors( ts: HashMap<String, Tensor>, dtype: DType, dev: &Device, ) -> Self
Initializes a VarBuilder
that retrieves tensors stored in a hashtable. An error is
returned if no tensor is available under the requested path or on shape mismatches.
sourcepub fn from_varmap(varmap: &VarMap, dtype: DType, dev: &Device) -> Self
pub fn from_varmap(varmap: &VarMap, dtype: DType, dev: &Device) -> Self
Initializes a VarBuilder
using a VarMap
. The requested tensors are created and
initialized on new paths, the same tensor is used if the same path is requested multiple
times. This is commonly used when initializing a model before training.
Note that it is possible to load the tensor values after model creation using the load
method on varmap
, this can be used to start model training from an existing checkpoint.
sourcepub unsafe fn from_mmaped_safetensors<P: AsRef<Path>>(
paths: &[P],
dtype: DType,
dev: &Device,
) -> Result<Self>
pub unsafe fn from_mmaped_safetensors<P: AsRef<Path>>( paths: &[P], dtype: DType, dev: &Device, ) -> Result<Self>
Initializes a VarBuilder
that retrieves tensors stored in a collection of safetensors
files.
§Safety
The unsafe is inherited from [memmap2::MmapOptions
].
sourcepub fn from_buffered_safetensors(
data: Vec<u8>,
dtype: DType,
dev: &Device,
) -> Result<Self>
pub fn from_buffered_safetensors( data: Vec<u8>, dtype: DType, dev: &Device, ) -> Result<Self>
Initializes a VarBuilder
from a binary buffer in the safetensor format.
sourcepub fn from_slice_safetensors(
data: &'a [u8],
dtype: DType,
dev: &Device,
) -> Result<Self>
pub fn from_slice_safetensors( data: &'a [u8], dtype: DType, dev: &Device, ) -> Result<Self>
Initializes a VarBuilder
from a binary slice in the safetensor format.
sourcepub fn from_npz<P: AsRef<Path>>(
p: P,
dtype: DType,
dev: &Device,
) -> Result<Self>
pub fn from_npz<P: AsRef<Path>>( p: P, dtype: DType, dev: &Device, ) -> Result<Self>
Initializes a VarBuilder
that retrieves tensors stored in a numpy npz file.
sourcepub fn from_pth<P: AsRef<Path>>(
p: P,
dtype: DType,
dev: &Device,
) -> Result<Self>
pub fn from_pth<P: AsRef<Path>>( p: P, dtype: DType, dev: &Device, ) -> Result<Self>
Initializes a VarBuilder
that retrieves tensors stored in a pytorch pth file.
sourcepub fn rename_f<F: Fn(&str) -> String + Sync + Send + 'static>(
self,
f: F,
) -> Self
pub fn rename_f<F: Fn(&str) -> String + Sync + Send + 'static>( self, f: F, ) -> Self
Gets a VarBuilder that applies some renaming function on tensor it gets queried for before passing the new names to the inner VarBuilder.
use diffusion_rs_common::core::{Tensor, DType, Device};
let a = Tensor::arange(0f32, 6f32, &Device::Cpu)?.reshape((2, 3))?;
let tensors: std::collections::HashMap<_, _> = [
("foo".to_string(), a),
]
.into_iter()
.collect();
let vb = diffusion_rs_common::nn::VarBuilder::from_tensors(tensors, DType::F32, &Device::Cpu);
assert!(vb.contains_tensor("foo"));
assert!(vb.get((2, 3), "foo").is_ok());
assert!(!vb.contains_tensor("bar"));
let vb = vb.rename_f(|f: &str| if f == "bar" { "foo".to_string() } else { f.to_string() });
assert!(vb.contains_tensor("bar"));
assert!(vb.contains_tensor("foo"));
assert!(vb.get((2, 3), "bar").is_ok());
assert!(vb.get((2, 3), "foo").is_ok());
assert!(!vb.contains_tensor("baz"));