diffusion_rs_common::core

Trait CustomOp2

source
pub trait CustomOp2 {
    // Required methods
    fn name(&self) -> &'static str;
    fn cpu_fwd(
        &self,
        s1: &CpuStorage,
        l1: &Layout,
        s2: &CpuStorage,
        l2: &Layout,
    ) -> Result<(CpuStorage, Shape)>;

    // Provided methods
    fn cuda_fwd(
        &self,
        _: &CudaStorage,
        _: &Layout,
        _: &CudaStorage,
        _: &Layout,
    ) -> Result<(CudaStorage, Shape)> { ... }
    fn metal_fwd(
        &self,
        _: &MetalStorage,
        _: &Layout,
        _: &MetalStorage,
        _: &Layout,
    ) -> Result<(MetalStorage, Shape)> { ... }
    fn bwd(
        &self,
        _arg1: &Tensor,
        _arg2: &Tensor,
        _res: &Tensor,
        _grad_res: &Tensor,
    ) -> Result<(Option<Tensor>, Option<Tensor>)> { ... }
}

Required Methods§

source

fn name(&self) -> &'static str

source

fn cpu_fwd( &self, s1: &CpuStorage, l1: &Layout, s2: &CpuStorage, l2: &Layout, ) -> Result<(CpuStorage, Shape)>

The forward pass, as run on a cpu device. Note that the storage can use arbitrary strides, offsets etc so the associated layout should be used to access it.

Provided Methods§

source

fn cuda_fwd( &self, _: &CudaStorage, _: &Layout, _: &CudaStorage, _: &Layout, ) -> Result<(CudaStorage, Shape)>

The forward pass, as run on a gpu device. Note that the storage can use arbitrary strides, offsets etc so the associated layout should be used to access it.

source

fn metal_fwd( &self, _: &MetalStorage, _: &Layout, _: &MetalStorage, _: &Layout, ) -> Result<(MetalStorage, Shape)>

The forward pass, as run on a metal gpu device. Note that the storage can use arbitrary strides, offsets etc so the associated layout should be used to access it.

source

fn bwd( &self, _arg1: &Tensor, _arg2: &Tensor, _res: &Tensor, _grad_res: &Tensor, ) -> Result<(Option<Tensor>, Option<Tensor>)>

Implementors§