pub trait CustomOp1 {
// Required methods
fn name(&self) -> &'static str;
fn cpu_fwd(
&self,
storage: &CpuStorage,
layout: &Layout,
) -> Result<(CpuStorage, Shape)>;
// Provided methods
fn cuda_fwd(
&self,
_storage: &CudaStorage,
_layout: &Layout,
) -> Result<(CudaStorage, Shape)> { ... }
fn metal_fwd(
&self,
_storage: &MetalStorage,
_layout: &Layout,
) -> Result<(MetalStorage, Shape)> { ... }
fn bwd(
&self,
_arg: &Tensor,
_res: &Tensor,
_grad_res: &Tensor,
) -> Result<Option<Tensor>> { ... }
}
Expand description
Unary ops that can be defined in user-land.
Required Methods§
fn name(&self) -> &'static str
sourcefn cpu_fwd(
&self,
storage: &CpuStorage,
layout: &Layout,
) -> Result<(CpuStorage, Shape)>
fn cpu_fwd( &self, storage: &CpuStorage, layout: &Layout, ) -> Result<(CpuStorage, Shape)>
The forward pass, as run on a cpu device. Note that the storage can use arbitrary strides, offsets etc so the associated layout should be used to access it.
Provided Methods§
sourcefn cuda_fwd(
&self,
_storage: &CudaStorage,
_layout: &Layout,
) -> Result<(CudaStorage, Shape)>
fn cuda_fwd( &self, _storage: &CudaStorage, _layout: &Layout, ) -> Result<(CudaStorage, Shape)>
The forward pass, as run on a gpu device. Note that the storage can use arbitrary strides, offsets etc so the associated layout should be used to access it.
sourcefn metal_fwd(
&self,
_storage: &MetalStorage,
_layout: &Layout,
) -> Result<(MetalStorage, Shape)>
fn metal_fwd( &self, _storage: &MetalStorage, _layout: &Layout, ) -> Result<(MetalStorage, Shape)>
The forward pass, as run on a metal gpu device. Note that the storage can use arbitrary strides, offsets etc so the associated layout should be used to access it.
sourcefn bwd(
&self,
_arg: &Tensor,
_res: &Tensor,
_grad_res: &Tensor,
) -> Result<Option<Tensor>>
fn bwd( &self, _arg: &Tensor, _res: &Tensor, _grad_res: &Tensor, ) -> Result<Option<Tensor>>
This function takes as argument the argument arg
used in the forward pass, the result
produced by the forward operation res
and the gradient of the result grad_res
.
The function should return the gradient of the argument.