ATen is a simple tensor library thats exposes the Tensor operations in Torch and PyTorch directly in C++17. This library provides a generated wrapper around the ATen API that makes these functions available in Caffe2 as an operator. It also makes it accessible using the ToffeeIR.
First identify a function in ATen you want to call in Functions.h, Tensor.h, or Type.h.
We will call the pow
operator:
static inline Tensor pow(const Tensor & self, Scalar exponent);
Now create a Caffe2 operator to call this op. The name of the operator is always "ATen"
, and there is always a string attribute operator
that defines which ATen function to call:
import numpy as np from caffe2.python import core, workspace # create the Caffe2 Op: op = core.CreateOperator( "ATen", ["MyInput"], ["MyOutput"], operator="pow", exponent=2.0)
Each Tensor
input becomes an Caffe2 input Blob, and each output becomes a Caffe2 output blob. Non-tensor inputs such as Scalar exponent
become Caffe2 arg
attributes. In the case of Scalar
the attributes can be either an integers or floating point numbers.
The op can now be run like any other Caffe2 operator:
workspace.FeedBlob("MyInput",np.random.randn(2,3).astype(np.float32)) workspace.RunOperatorOnce(op) print(workspace.FetchBlob("MyOutput")
For methods, the first input is always the this
Tensor in C++. To call methods of ATen's Type
objects, you provide an additional string attribute that determines the type:
# create a 2x4 tensor filled with floating point ones op = core.CreateOperator( "ATen", [], ["MyOutput"], operator="ones", type="Float", size={2,4})
Generally ATen operators are polymorphic across input types, and work on both the CPU and CUDA.
The ATen operator can also be used to define symbolic
definitions for PyTorch when an operator is being exported to ONNX. In this case, the definition of the operator looks the same but is defined using PyTorch's ONNX API:
class Add(torch.autograd.Function): @staticmethod def symbolic(g, a, b): return g.at("add", a, b) @staticmethod def forward(ctx, a, b): return a + b