diffusion_rs_common/
safetensors.rs

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
use crate::core::safetensors::Load;
use crate::core::{Device, Error, Result, Tensor};
use safetensors::tensor as st;
use safetensors::tensor::SafeTensors;

pub struct BytesSafetensors<'a> {
    safetensors: SafeTensors<'a>,
}

impl<'a> BytesSafetensors<'a> {
    pub fn new(bytes: &'a [u8]) -> Result<BytesSafetensors<'a>> {
        let st = safetensors::SafeTensors::deserialize(bytes).map_err(Error::from)?;
        Ok(Self { safetensors: st })
    }

    pub fn load(&self, name: &str, dev: &Device) -> Result<Tensor> {
        self.get(name)?.load(dev)
    }

    pub fn tensors(&self) -> Vec<(String, st::TensorView<'_>)> {
        self.safetensors.tensors()
    }

    pub fn get(&self, name: &str) -> Result<st::TensorView<'_>> {
        Ok(self.safetensors.tensor(name)?)
    }
}