diffusion_rs_common/
varbuilder_loading.rsuse std::{
collections::HashMap,
sync::Arc,
thread::{self, JoinHandle},
};
use crate::{
core::{safetensors::MmapedSafetensors, DType, Device, Result, Tensor},
ModelSource,
};
use crate::{
safetensors::BytesSafetensors,
varbuilder::{SimpleBackend, VarBuilderArgs},
FileData, VarBuilder,
};
use super::progress::IterWithProgress;
trait TensorLoaderBackend {
fn get_names(&self) -> Vec<String>;
fn load_name(&self, name: &str, device: &Device, dtype: Option<DType>) -> Result<Tensor>;
}
struct SafetensorBackend(MmapedSafetensors);
impl TensorLoaderBackend for SafetensorBackend {
fn get_names(&self) -> Vec<String> {
self.0
.tensors()
.into_iter()
.map(|(name, _)| name)
.collect::<Vec<_>>()
}
fn load_name(&self, name: &str, device: &Device, _dtype: Option<DType>) -> Result<Tensor> {
self.0.load(name, device)
}
}
struct BytesSafetensorBackend<'a>(BytesSafetensors<'a>);
impl TensorLoaderBackend for BytesSafetensorBackend<'_> {
fn get_names(&self) -> Vec<String> {
self.0
.tensors()
.into_iter()
.map(|(name, _)| name)
.collect::<Vec<_>>()
}
fn load_name(&self, name: &str, device: &Device, _dtype: Option<DType>) -> Result<Tensor> {
self.0.load(name, device)
}
}
pub fn from_mmaped_safetensors<'a>(
paths: Vec<FileData>,
dtype: Option<DType>,
device: &Device,
silent: bool,
src: Arc<ModelSource>,
) -> Result<VarBuilderArgs<'a, Box<dyn SimpleBackend>>> {
#[allow(clippy::type_complexity)]
let mut handles: Vec<JoinHandle<Result<HashMap<String, Tensor>>>> = Vec::new();
for path in paths {
let device = device.clone();
let loader = Common;
let src_clone = src.clone();
handles.push(thread::spawn(Box::new(move || {
loader.load_tensors_from_path(&path, &device, dtype, silent, src_clone)
})));
}
let mut ws = HashMap::new();
while !handles.iter().all(|h| h.is_finished()) {}
for h in handles {
ws.extend(h.join().unwrap()?);
}
let first_dtype = DType::BF16; Ok(VarBuilder::from_tensors(
ws,
dtype.unwrap_or(first_dtype),
device,
))
}
trait LoadTensors {
fn load_tensors_from_path(
&self,
path: &FileData,
device: &Device,
dtype: Option<DType>,
silent: bool,
src: Arc<ModelSource>,
) -> Result<HashMap<String, Tensor>> {
let tensors: Box<dyn TensorLoaderBackend> = match path
.extension()
.expect("Expected extension")
.to_str()
.expect("Expected to convert")
{
"safetensors" => match path {
FileData::Dduf { name: _, start, end } => {
let ModelSource::Dduf { file, name: _ } = &*src else {
crate::bail!("expected dduf model source!");
};
Box::new(BytesSafetensorBackend(BytesSafetensors::new(&file.get_ref()[*start..*end])?))
}
FileData::DdufOwned { name: _, data } => {
Box::new(BytesSafetensorBackend(BytesSafetensors::new(data)?))
}
FileData::Path(path) => {Box::new(SafetensorBackend(unsafe {
crate::core::safetensors::MmapedSafetensors::new(path)?
}))}
},
other => crate::bail!("Unexpected extension `{other}`, this should have been handles by `get_model_paths`."),
};
let mut loaded_tensors = HashMap::new();
for name in tensors.get_names().into_iter().with_progress(silent) {
let tensor = tensors.load_name(&name, device, dtype)?;
loaded_tensors.insert(name, tensor);
}
Ok(loaded_tensors)
}
}
struct Common;
impl LoadTensors for Common {}