diffusion_rs_common/nn/
var_map.rsuse crate::core::{DType, Device, Result, Shape, Tensor, Var};
use std::collections::HashMap;
use std::sync::{Arc, Mutex};
#[derive(Clone)]
pub struct VarMap {
data: Arc<Mutex<HashMap<String, Var>>>,
}
impl VarMap {
#[allow(clippy::new_without_default)]
pub fn new() -> Self {
let data = Arc::new(Mutex::new(HashMap::new()));
Self { data }
}
pub fn all_vars(&self) -> Vec<Var> {
let tensor_data = self.data.lock().unwrap();
#[allow(clippy::map_clone)]
tensor_data.values().map(|c| c.clone()).collect::<Vec<_>>()
}
pub fn save<P: AsRef<std::path::Path>>(&self, path: P) -> Result<()> {
let tensor_data = self.data.lock().unwrap();
let data = tensor_data.iter().map(|(k, v)| (k, v.as_tensor()));
safetensors::tensor::serialize_to_file(data, &None, path.as_ref())?;
Ok(())
}
pub fn load<P: AsRef<std::path::Path>>(&mut self, path: P) -> Result<()> {
let path = path.as_ref();
let data = unsafe { crate::core::safetensors::MmapedSafetensors::new(path)? };
let mut tensor_data = self.data.lock().unwrap();
for (name, var) in tensor_data.iter_mut() {
let data = data.load(name, var.device())?;
if let Err(err) = var.set(&data) {
crate::bail!("error setting {name} using data from {path:?}: {err}",)
}
}
Ok(())
}
pub fn set_one<K: AsRef<str>, V: AsRef<Tensor>>(&mut self, name: K, value: V) -> Result<()> {
let tensor_data = self.data.lock().unwrap();
let name = name.as_ref();
match tensor_data.get(name) {
None => crate::bail!("cannot find {name} in VarMap"),
Some(var) => {
if let Err(err) = var.set(value.as_ref()) {
crate::bail!("error setting {name}: {err}",)
}
}
}
Ok(())
}
pub fn set<I: Iterator<Item = (K, V)>, K: AsRef<str>, V: AsRef<Tensor>>(
&mut self,
iter: I,
) -> Result<()> {
let tensor_data = self.data.lock().unwrap();
for (name, value) in iter {
let name = name.as_ref();
match tensor_data.get(name) {
None => crate::bail!("cannot find {name} in VarMap"),
Some(var) => {
if let Err(err) = var.set(value.as_ref()) {
crate::bail!("error setting {name}: {err}",)
}
}
}
}
Ok(())
}
pub fn get<S: Into<Shape>>(
&self,
shape: S,
path: &str,
init: crate::nn::Init,
dtype: DType,
device: &Device,
) -> Result<Tensor> {
let shape = shape.into();
let mut tensor_data = self.data.lock().unwrap();
if let Some(tensor) = tensor_data.get(path) {
let tensor_shape = tensor.shape();
if &shape != tensor_shape {
crate::bail!("shape mismatch on {path}: {shape:?} <> {tensor_shape:?}")
}
return Ok(tensor.as_tensor().clone());
}
let var = init.var(shape, dtype, device)?;
let tensor = var.as_tensor().clone();
tensor_data.insert(path.to_string(), var);
Ok(tensor)
}
pub fn get_unchecked(&self, _path: &str, _dtype: DType, _device: &Device) -> Result<Tensor> {
crate::bail!("`get_unchecked` does not make sense for `VarMap`, use `get`.");
}
pub fn data(&self) -> &Mutex<HashMap<String, Var>> {
&self.data
}
}