Add an aten_op to contrib.
Summary:
This operator allows the use of Torch's underlying TH libraries (TH, THC, THNN, and THCUNN)
through the ATen tensor library. Use of the operator is described in the README.
The operator itself is generated from ATen's Declarations.yaml file which describes its public API.
Closes https://github.com/caffe2/caffe2/pull/1235
Reviewed By: dzhulgakov
Differential Revision: D5876944
Pulled By: zdevito
fbshipit-source-id: b558e8563a5e82a0e6278705a4a359bd7df4e70a
diff --git a/.gitmodules b/.gitmodules
index e7d357f..aa0a896 100644
--- a/.gitmodules
+++ b/.gitmodules
@@ -46,3 +46,6 @@
[submodule "third_party/NNPACK_deps/psimd"]
path = third_party/NNPACK_deps/psimd
url = https://github.com/Maratyszcza/psimd.git
+[submodule "third_party/aten"]
+ path = third_party/aten
+ url = https://github.com/zdevito/aten
diff --git a/CMakeLists.txt b/CMakeLists.txt
index 06c3eb3..9b6f2da 100644
--- a/CMakeLists.txt
+++ b/CMakeLists.txt
@@ -37,6 +37,7 @@
option(BUILD_PYTHON "Build Python binaries" ON)
option(BUILD_SHARED_LIBS "Build libcaffe2.so" ON)
option(BUILD_TEST "Build C++ test binaries (need gtest and gbenchmark)" ON)
+option(USE_ATEN "Use ATen" OFF)
option(USE_CUDA "Use Cuda" ON)
option(USE_FFMPEG "Use ffmpeg" OFF)
option(USE_GFLAGS "Use GFLAGS" ON)
diff --git a/caffe2/contrib/CMakeLists.txt b/caffe2/contrib/CMakeLists.txt
index a6e028c..bcaccba 100644
--- a/caffe2/contrib/CMakeLists.txt
+++ b/caffe2/contrib/CMakeLists.txt
@@ -1,3 +1,4 @@
+add_subdirectory(aten)
add_subdirectory(gloo)
add_subdirectory(nccl)
add_subdirectory(nnpack)
diff --git a/caffe2/contrib/aten/CMakeLists.txt b/caffe2/contrib/aten/CMakeLists.txt
new file mode 100644
index 0000000..d400587
--- /dev/null
+++ b/caffe2/contrib/aten/CMakeLists.txt
@@ -0,0 +1,26 @@
+if(USE_ATEN)
+ if(NOT USE_CUDA)
+ set(NO_CUDA ON)
+ endif()
+ set(TORCH_CUDA_ARCH_LIST "3.5 5.2 6.0 6.1+PTX")
+ set(TORCH_NVCC_FLAGS "-Xfatbin -compress-all")
+ set(CMAKE_POSITION_INDEPENDENT_CODE ON)
+ set(TH_LINK_STYLE STATIC)
+ set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -fvisibility=hidden")
+ set(CMAKE_C_FLAGS "${CMAKE_C_FLAGS} -fvisibility=hidden")
+ add_subdirectory(${PROJECT_SOURCE_DIR}/third_party/aten aten EXCLUDE_FROM_ALL)
+
+ add_custom_command(OUTPUT aten_op.h
+ COMMAND
+ python ${CMAKE_CURRENT_SOURCE_DIR}/gen_op.py ${PROJECT_SOURCE_DIR}
+ DEPENDS
+ ${CMAKE_CURRENT_BINARY_DIR}/aten/src/ATen/ATen/Declarations.yaml
+ ${CMAKE_CURRENT_SOURCE_DIR}/gen_op.py
+ ${CMAKE_CURRENT_SOURCE_DIR}/aten_op_template.h)
+
+ add_custom_target(aten_build
+ DEPENDS aten_op.h)
+
+ set(Caffe2_CPU_SRCS ${Caffe2_CPU_SRCS} "${CMAKE_CURRENT_SOURCE_DIR}/aten_op.cc" PARENT_SCOPE)
+ set(Caffe2_GPU_SRCS ${Caffe2_GPU_SRCS} "${CMAKE_CURRENT_SOURCE_DIR}/aten_op_cuda.cc" PARENT_SCOPE)
+endif()
diff --git a/caffe2/contrib/aten/README.md b/caffe2/contrib/aten/README.md
new file mode 100644
index 0000000..d3046aa
--- /dev/null
+++ b/caffe2/contrib/aten/README.md
@@ -0,0 +1,80 @@
+# An ATen operator for Caffe2
+
+[ATen](https://github.com/zdevito/aten) is a simple tensor library thats exposes the Tensor operations in Torch
+and PyTorch directly in C++11. This library provides a generated wrapper around the ATen API
+that makes these functions available in Caffe2 as an operator. It also makes it accessible using the
+ToffeeIR.
+
+
+### Example Usage in Caffe2
+
+First identify a function in ATen you want to call in [Functions.h](https://github.com/zdevito/ATen/blob/master/doc/Functions.h),
+[Tensor.h](https://github.com/zdevito/ATen/blob/master/doc/Tensor.h), or [Type.h](https://github.com/zdevito/ATen/blob/master/doc/Type.h).
+
+We will call the `pow` operator:
+
+```
+static inline Tensor pow(const Tensor & self, Scalar exponent);
+```
+
+Now create a Caffe2 operator to call this op. The name of the operator is always `"ATen"`,
+and there is always a string attribute `operator` that defines which ATen function to call:
+
+
+```
+import numpy as np
+from caffe2.python import core, workspace
+
+
+# create the Caffe2 Op:
+op = core.CreateOperator(
+ "ATen",
+ ["MyInput"],
+ ["MyOutput"],
+ operator="pow", exponent=2.0)
+
+```
+
+Each `Tensor` input becomes an Caffe2 input Blob, and each output becomes a Caffe2 output blob.
+Non-tensor inputs such as `Scalar exponent` become Caffe2 `arg` attributes.
+In the case of `Scalar` the attributes can be either an integers or floating point numbers.
+
+The op can now be run like any other Caffe2 operator:
+
+```
+workspace.FeedBlob("MyInput",np.random.randn(2,3).astype(np.float32))
+workspace.RunOperatorOnce(op)
+print(workspace.FetchBlob("MyOutput")
+```
+
+For methods, the first input is always the `this` Tensor in C++.
+To call methods of ATen's `Type` objects, you provide an additional string attribute
+that determines the type:
+
+```
+# create a 2x4 tensor filled with floating point ones
+op = core.CreateOperator(
+ "ATen",
+ [],
+ ["MyOutput"],
+ operator="ones", type="Float", size={2,4})
+```
+
+Generally ATen operators are polymorphic across input types, and work on both the CPU and CUDA.
+
+### Example Usage via PyTorch Symbolic
+
+The ATen operator can also be used to define `symbolic` definitions for PyTorch when an operator is being exported
+to ONNX. In this case, the definition of the operator looks the same but is defined using PyTorch's ONNX API:
+
+```
+class Add(torch.autograd.Function):
+
+ @staticmethod
+ def symbolic(g, a, b):
+ return g.op("ATen", a, b, operator_s = "add")
+
+ @staticmethod
+ def forward(ctx, a, b):
+ return a + b
+```
diff --git a/caffe2/contrib/aten/aten_op.cc b/caffe2/contrib/aten/aten_op.cc
new file mode 100644
index 0000000..47d06b7
--- /dev/null
+++ b/caffe2/contrib/aten/aten_op.cc
@@ -0,0 +1,21 @@
+#include "aten_op.h"
+
+namespace caffe2 {
+
+REGISTER_CPU_OPERATOR(ATen, ATenOp<CPUContext>);
+template<>
+at::Backend ATenOp<CPUContext>::backend() const {
+ return at::kCPU;
+}
+
+OPERATOR_SCHEMA(ATen);
+CAFFE_KNOWN_TYPE(at::Half);
+
+namespace math {
+template<>
+void Set<at::Half,CPUContext>(TIndex N, at::Half h, at::Half* v, CPUContext * c) {
+ Set(0, h.x, (uint16_t*) v, c);
+}
+}
+
+}
diff --git a/caffe2/contrib/aten/aten_op_cuda.cc b/caffe2/contrib/aten/aten_op_cuda.cc
new file mode 100644
index 0000000..4d9fcfc
--- /dev/null
+++ b/caffe2/contrib/aten/aten_op_cuda.cc
@@ -0,0 +1,19 @@
+#include "aten_op.h"
+#include "caffe2/core/context_gpu.h"
+
+namespace caffe2 {
+
+REGISTER_CUDA_OPERATOR(ATen, ATenOp<CUDAContext>);
+template<>
+at::Backend ATenOp<CUDAContext>::backend() const {
+ return at::kCUDA;
+}
+
+namespace math {
+template<>
+void Set<at::Half,CUDAContext>(TIndex N, at::Half h, at::Half* v, CUDAContext * c) {
+ Set(0, h.x, (uint16_t*) v, c);
+}
+}
+
+}
diff --git a/caffe2/contrib/aten/aten_op_template.h b/caffe2/contrib/aten/aten_op_template.h
new file mode 100644
index 0000000..3030d70
--- /dev/null
+++ b/caffe2/contrib/aten/aten_op_template.h
@@ -0,0 +1,186 @@
+#pragma once
+#include <unordered_map>
+#include <string>
+#include <ATen/ATen.h>
+#include <caffe2/core/context.h>
+#include <caffe2/core/operator.h>
+#include <google/protobuf/text_format.h>
+#include <iostream>
+
+// a map from descriptor strings (see [DESCRIPTORS])
+// to the key in the switch statement that implements them
+static std::unordered_map<std::string, int> op_to_key = {
+ ${mappings}
+};
+
+namespace caffe2 {
+
+using at::Half; // for AT_FORALL_SCALAR_TYPES
+
+std::function<void(void*)> deleterFor(at::Tensor t) {
+ // return a closure that holds a handle to t until it is called
+ // to keep the aten memory alive
+ return [t](void * ptr) mutable {
+ t.reset();
+ };
+}
+
+template <class Context>
+class ATenOp : public Operator<Context> {
+ public:
+ ATenOp(const OperatorDef& operator_def, Workspace* ws)
+ : Operator<Context>(operator_def, ws) {
+ VLOG(2) << "ATen OpDef: " << ProtoDebugString(operator_def) << "\n";
+ switch(findImplementation(operator_def)) {
+ ${implementations}
+ default:
+ CAFFE_THROW("Unexpected key value for aten operator");
+ }
+ }
+ USE_OPERATOR_CONTEXT_FUNCTIONS;
+
+ bool RunOnDevice() override {
+ return run_op();
+ }
+private:
+ // actual operator implementation is initialized in ctor.
+ std::function<bool()> run_op;
+ at::Backend backend() const;
+
+ TypeMeta typeMetaFor(const at::Tensor & t) {
+ return typeMetaFor(t.type().scalarType());
+ }
+ TypeMeta typeMetaFor(at::ScalarType st) {
+ #define DEFINE_CASE(ctype,aten_name,_) \
+ case at::k##aten_name: \
+ return TypeMeta::Make<ctype>();
+ switch(st) {
+ AT_FORALL_SCALAR_TYPES(DEFINE_CASE)
+ default:
+ CAFFE_THROW("Unknown ATen Type");
+ }
+ #undef DEFINE_CASE
+ }
+
+ at::Type & typeFor(const Tensor<Context> & ten) {
+ return at::getType(backend(), atScalarTypeFor(ten.meta()));
+ }
+ at::Tensor tensorWrapping(Tensor<Context> & ten) {
+ return typeFor(ten).tensorFromBlob(ten.raw_mutable_data(), ten.dims());
+ }
+ at::ScalarType atScalarTypeFor(const TypeMeta & meta) {
+ #define DEFINE_IF(ctype,aten_name,_) \
+ if(meta.Match<ctype>()) { \
+ return at::k##aten_name; \
+ }
+ AT_FORALL_SCALAR_TYPES(DEFINE_IF)
+ #undef DEFINE_IF
+ CAFFE_THROW("Unknown type meta"); // TODO: improve error message...
+ }
+ void assignTo(Tensor<Context> * dst, const at::Tensor & src_) {
+ at::Tensor src = src_.contiguous();
+ auto at_sizes = src.sizes();
+ std::vector<int64_t> dims(at_sizes.begin(),at_sizes.end());
+ dst->Resize(dims);
+ dst->ShareExternalPointer(src.data_ptr(), typeMetaFor(src), 0, deleterFor(src));
+ }
+
+ // the AT_FORALL_SCALAR_TYPES macro just gives a 'i' or 'd' argument
+ // for each type to specify if it is stored as a integer or a double.
+ // We need this workaround here to extract the value in the scalar losslessly
+ // because in some cases like 'sum' Torch promotes float to double
+ // and will complain if we downcast it with toFloat, causing it
+ // to lose precision
+ double extract_d(const at::Scalar & s) {
+ return s.toDouble();
+ }
+ int64_t extract_i(const at::Scalar & s) {
+ return s.toLong();
+ }
+
+ void assignTo(Tensor<Context> * dst, at::Type & inferred_type, at::Scalar scalar) {
+ switch(inferred_type.scalarType()) {
+ #define DEFINE_CASE(ctype,aten_name,native) \
+ case at::k##aten_name: { \
+ auto value = extract_##native(scalar); \
+ assignToValue<ctype>(dst, at::convert<ctype,decltype(value)>(value)); \
+ } break;
+ AT_FORALL_SCALAR_TYPES(DEFINE_CASE)
+ #undef DEFINE_CASE
+ default:
+ CAFFE_THROW("Unknown ATen Type");
+ }
+ }
+ template<typename T>
+ void assignToValue(Tensor<Context> * dst, T v) {
+ dst->Resize(std::vector<TIndex>());
+ math::Set(1, v, dst->template mutable_data<T>(), &context_);
+ }
+ int findImplementation(const OperatorDef& operator_def) {
+ CAFFE_ENFORCE(HasArgument("operator"));
+ std::string op = OperatorBase::GetSingleArgument<std::string>("operator", "");
+ // construct descriptor string ([DESCRIPTORS]) given the attributes
+ // and inputs of this operator_def, and look up the implementation key
+ // for this variant
+ std::stringstream descriptor;
+ descriptor << op << "-" << InputSize();
+ std::vector<std::string> attrs;
+ for(size_t i = 0; i < operator_def.arg_size(); i++) {
+ auto & attr = operator_def.arg(i);
+ if(attr.name() == "operator" || attr.name() == "type" )
+ continue;
+ attrs.push_back(attr.name());
+ }
+ std::sort(attrs.begin(), attrs.end());
+ for(auto & a : attrs)
+ descriptor << "-" << a;
+ std::string descriptor_s = descriptor.str();
+ if(op_to_key.count(descriptor_s) == 0) {
+ std::stringstream ss;
+ ss << "Attempting to run unknown ATen operator configuration: "
+ << descriptor_s;
+ CAFFE_THROW(ss.str());
+ }
+ return op_to_key.at(descriptor_s);
+ }
+ at::Scalar readScalarAttribute(const std::string & name) {
+ if(OperatorBase::HasSingleArgumentOfType<int64_t>(name)) {
+ return OperatorBase::GetSingleArgument<int64_t>(name, 0);
+ } else {
+ CAFFE_ENFORCE(OperatorBase::HasSingleArgumentOfType<float>(name));
+ return OperatorBase::GetSingleArgument<float>(name, 0);
+ }
+ }
+ template<typename T>
+ T readAttribute(const std::string & name) {
+ CAFFE_ENFORCE(OperatorBase::HasSingleArgumentOfType<int64_t>(name));
+ return OperatorBase::GetSingleArgument<T>(name, 0);
+ }
+ std::vector<int64_t> readIntList(const std::string & name) {
+ CAFFE_ENFORCE(OperatorBase::HasArgument(name));
+ return OperatorBase::GetRepeatedArgument<int64_t>(name, {});
+ }
+ at::ScalarType stringToScalarType(const std::string & name) {
+ #define DEFINE_IF(type,aten) \
+ if(#type == name) \
+ return at::k##aten;
+ DEFINE_IF(float16, Half)
+ DEFINE_IF(float, Float)
+ DEFINE_IF(double, Double)
+ DEFINE_IF(uint8, Byte)
+ DEFINE_IF(int8, Char)
+ DEFINE_IF(int16, Short)
+ DEFINE_IF(int32, Int)
+ DEFINE_IF(int64, Long)
+ CAFFE_THROW("unsupported type annotation: ", name);
+ }
+ at::Type & stringToType(const std::string & name) {
+ return at::getType(backend(), stringToScalarType(name));
+ }
+ at::Type * readTypeAttribute(const std::string & name) {
+ CAFFE_ENFORCE(OperatorBase::HasSingleArgumentOfType<std::string>(name));
+ return &stringToType(OperatorBase::GetSingleArgument<std::string>(name, ""));
+ }
+};
+
+}
diff --git a/caffe2/contrib/aten/aten_test.py b/caffe2/contrib/aten/aten_test.py
new file mode 100644
index 0000000..5646cdd
--- /dev/null
+++ b/caffe2/contrib/aten/aten_test.py
@@ -0,0 +1,83 @@
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+from __future__ import unicode_literals
+
+from caffe2.python import core
+from hypothesis import given
+
+import caffe2.python.hypothesis_test_util as hu
+import hypothesis.strategies as st
+import numpy as np
+
+
+class TestATen(hu.HypothesisTestCase):
+
+ @given(inputs=hu.tensors(n=2), **hu.gcs)
+ def test_add(self, inputs, gc, dc):
+ op = core.CreateOperator(
+ "ATen",
+ ["X", "Y"],
+ ["Z"],
+ operator="add")
+
+ def ref(X, Y):
+ return [X + Y]
+ self.assertReferenceChecks(gc, op, inputs, ref)
+
+ @given(inputs=hu.tensors(n=1), **hu.gcs)
+ def test_pow(self, inputs, gc, dc):
+ op = core.CreateOperator(
+ "ATen",
+ ["S"],
+ ["Z"],
+ operator="pow", exponent=2.0)
+
+ def ref(X):
+ return [np.square(X)]
+
+ self.assertReferenceChecks(gc, op, inputs, ref)
+
+ @given(x=st.integers(min_value=2, max_value=8), **hu.gcs)
+ def test_sort(self, x, gc, dc):
+ inputs = [np.random.permutation(x)]
+ op = core.CreateOperator(
+ "ATen",
+ ["S"],
+ ["Z", "I"],
+ operator="sort")
+
+ def ref(X):
+ return [np.sort(X), np.argsort(X)]
+ self.assertReferenceChecks(gc, op, inputs, ref)
+
+ @given(inputs=hu.tensors(n=1), **hu.gcs)
+ def test_sum(self, inputs, gc, dc):
+ op = core.CreateOperator(
+ "ATen",
+ ["S"],
+ ["Z"],
+ operator="sum")
+
+ def ref(X):
+ return [np.sum(X)]
+
+ self.assertReferenceChecks(gc, op, inputs, ref)
+
+ @given(**hu.gcs)
+ def test_ones(self, gc, dc):
+ op = core.CreateOperator(
+ "ATen",
+ [],
+ ["Z"],
+ operator="ones", type="float", size={2, 4})
+
+ def ref():
+ return [np.ones([2, 4])]
+
+ self.assertReferenceChecks(gc, op, [], ref)
+
+
+if __name__ == "__main__":
+ import unittest
+ unittest.main()
diff --git a/caffe2/contrib/aten/gen_op.py b/caffe2/contrib/aten/gen_op.py
new file mode 100644
index 0000000..ba7a8a8
--- /dev/null
+++ b/caffe2/contrib/aten/gen_op.py
@@ -0,0 +1,201 @@
+import sys
+import yaml
+project_root = sys.argv[1]
+sys.path.append(project_root + "/third_party/aten/src/ATen")
+from code_template import CodeTemplate as CT
+
+try:
+ # use faster C loader if available
+ from yaml import CLoader as Loader
+except ImportError:
+ from yaml import Loader
+
+OP_TEMPLATE = CT.from_file(project_root+'/caffe2/contrib/aten/aten_op_template.h')
+
+
+def write(filename, s):
+ with open(filename, "w") as f:
+ f.write(s)
+
+
+def read(filename):
+ with open(filename, "r") as f:
+ return f.read()
+
+
+decls = yaml.load(read('aten/src/ATen/ATen/Declarations.yaml'), Loader=Loader)
+
+top_env = {
+ 'mappings': [],
+ 'implementations': [],
+}
+
+
+def is_tensor_type(t):
+ return "Tensor" in t
+
+
+def value_is_tensor_type(v):
+ return is_tensor_type(v['dynamic_type'])
+
+# for each aten type, how do we handle a return value of that type?
+RETURN_MAP = {
+ 'Tensor': 'assignTo(Output(${offset}),${output});',
+ 'Scalar': 'assignTo(Output(${offset}),*inferred_type, ${output});',
+ 'bool': 'assignToValue<int64_t>(Output(${offset}),${output});',
+ 'int64_t': 'assignToValue<int64_t>(Output(${offset}),${output});',
+}
+
+# for each non-Tensor aten argument, how to we read it from caffe2's
+# attribute list. Most of these call runtime functions defined in the
+# template class.
+ARGUMENT_MAP = {
+ 'Scalar': 'at::Scalar ${arg} = readScalarAttribute("${arg}");',
+ 'bool': 'bool ${arg} = readAttribute<int64_t>("${arg}");',
+ 'int': 'int ${arg} = readAttribute<int64_t>("${arg}");',
+ 'int64_t': 'int64_t ${arg} = readAttribute<int64_t>("${arg}");',
+ 'IntList': 'auto ${arg} = readIntList("${arg}");',
+}
+
+
+# filter the list of declarations removing things we cannot support
+def supports(o):
+
+ # skip all in-place operators for now since aten cannot Resize
+ # caffe2 memory inside an operator
+ if o['inplace']:
+ return False
+
+ # _out variants also work in-place on arguments taken as destinations
+ # we also cannot handle these because aten cannot resize caffe2 Tensors
+ if "_out" in o['name']:
+ return False
+
+ # skip return types we cannot handle
+ for ret in o['returns']:
+ if not value_is_tensor_type(ret) and ret['type'] not in RETURN_MAP:
+ print("Skipping {} Because of Ret: {} ({})".format(o['name'], ret['type'], ret['dynamic_type']))
+ return False
+
+ # skip arguments we cannot handle
+ for arg in o['arguments']:
+ if not value_is_tensor_type(arg) and arg['type'] not in ARGUMENT_MAP:
+ print("Skipping {} Because of Arg: {} ({}) ".format(o['name'], arg['type'], arg['dynamic_type']))
+ return False
+ return True
+
+
+filtered = [o for o in decls if supports(o)]
+
+# template for each potential operator.
+# each operator has an integer 'key' associated with it, and
+# a lambda that defines the operator
+# non-tensor attributes are created in ${initialization}
+# and then saved as arguments to the lambda
+# Inputs/Outputs are read inside the lambda
+OPTION_TEMPLATE = CT("""\
+case ${key}: { // ${name}
+ ${initialization}
+ run_op = [=] {
+ ${statements}
+ auto the_result = ${invocation};
+ ${assignments}
+ return true;
+ };
+} break;
+""")
+
+
+def get_output(o, i):
+ if len(o['returns']) == 1:
+ return 'the_result'
+ else:
+ return 'std::get<{}>(the_result)'.format(i)
+
+
+def attribute_names(o):
+ return sorted([a['name'] for a in o['arguments'] if not value_is_tensor_type(a)])
+
+
+def self_as_first_argument(arguments):
+ return ([a for a in arguments if a['name'] == 'self'] +
+ [a for a in arguments if a['name'] != 'self'])
+
+seen = set()
+key = 0
+for o in filtered:
+ # [DESCRIPTORS]
+ # each option is associated with a descriptor string that is used
+ # to figure out which version of an op is being used:
+ # The format is:
+ # opname-num_inputs-attribute_1-attribute2
+ # Example:
+ # lerp-2-weight
+ # the operator lerp takes 2 arguments and has the attribute weight
+ attr_names = attribute_names(o)
+ num_inputs = len(o['arguments']) - len(attr_names)
+ descriptor = '-'.join([o['name'], str(num_inputs)] + attr_names)
+
+ if descriptor in seen:
+ continue
+ seen.add(descriptor)
+
+ # map from descriptor string to the integer key in the switch statements
+ # that initializes the operators
+ top_env['mappings'].append('{{ "{}", {} }},'.format(descriptor, key))
+ env = {
+ 'name': o['name'],
+ 'statements': [],
+ 'arguments': [],
+ 'assignments': [],
+ 'initialization': [],
+ 'key': str(key),
+ }
+ defined_inferred_type = False
+
+ if 'Tensor' in o['method_of']:
+ # make sure 'self' is the first argument. currently Declarations.yaml
+ # does not always do this. Instead it keeps the argument list the same order
+ # as the Type method.
+ o['arguments'] = self_as_first_argument(o['arguments'])
+ elif 'namespace' not in o['method_of']:
+ # methods on type like 'ones' or 'zeros' always take a
+ # string attribute that is translated into the at::Type object
+ # e.g. "Float" is at::kFloat
+ assert('Type' in o['method_of'])
+ defined_inferred_type = True
+ env['initialization'].append('auto inferred_type = readTypeAttribute("type");')
+
+ i = 0
+ for arg in o['arguments']:
+ env['arguments'].append(arg['name'])
+ if value_is_tensor_type(arg):
+ # load tensor inputs from Caffe2
+ env['statements'].append("auto {}_ = Input({});".format(arg['name'], i))
+ i += 1
+ env['statements'].append(CT(
+ "auto ${name} = tensorWrapping(${name}_);").substitute(arg))
+ if arg['dynamic_type'] == 'Tensor' and not defined_inferred_type:
+ # first tensor input is used to define the output type.
+ defined_inferred_type = True
+ env['statements'].append('auto inferred_type = &({}.type());'.format(arg['name']))
+ else:
+ init = CT(ARGUMENT_MAP[arg['type']]).substitute(env, arg=arg['name'])
+ env['initialization'].append(init)
+
+ for i, r in enumerate(o['returns']):
+ t = RETURN_MAP[r['type'] if not value_is_tensor_type(r) else 'Tensor']
+ assignment = CT(t).substitute(env, offset=i, output=get_output(o, i))
+ env['assignments'].append(assignment)
+
+ if 'Tensor' in o['method_of']:
+ env['invocation'] = "self.{}({})".format(o['name'], ', '.join(env['arguments'][1:]))
+ elif 'namespace' in o['method_of']:
+ env['invocation'] = CT("at::${name}(${arguments})").substitute(env)
+ else:
+ assert('Type' in o['method_of'])
+ env['invocation'] = CT('inferred_type->${name}(${arguments})').substitute(env)
+
+ top_env['implementations'].append(OPTION_TEMPLATE.substitute(env))
+ key += 1
+write("aten_op.h", OP_TEMPLATE.substitute(top_env))
diff --git a/cmake/Dependencies.cmake b/cmake/Dependencies.cmake
index 05300a9..e5cefb4 100644
--- a/cmake/Dependencies.cmake
+++ b/cmake/Dependencies.cmake
@@ -407,3 +407,11 @@
set(USE_METAL OFF)
endif()
endif()
+
+if (USE_ATEN)
+ list(APPEND Caffe2_EXTERNAL_DEPENDENCIES aten_build)
+ list(APPEND Caffe2_DEPENDENCY_LIBS ATen)
+ caffe2_include_directories(${PROJECT_BINARY_DIR}/caffe2/contrib/aten/aten/src/ATen)
+ caffe2_include_directories(${PROJECT_SOURCE_DIR}/third_party/aten/src)
+ caffe2_include_directories(${PROJECT_BINARY_DIR}/caffe2/contrib/aten)
+endif()
diff --git a/cmake/Summary.cmake b/cmake/Summary.cmake
index b7423a6..0a528dc 100644
--- a/cmake/Summary.cmake
+++ b/cmake/Summary.cmake
@@ -35,6 +35,7 @@
message(STATUS " BUILD_SHARED_LIBS : ${BUILD_SHARED_LIBS}")
message(STATUS " BUILD_TEST : ${BUILD_TEST}")
+ message(STATUS " USE_ATEN : ${USE_ATEN}")
message(STATUS " USE_CUDA : ${USE_CUDA}")
if(${USE_CUDA})
message(STATUS " CUDA version : ${CUDA_VERSION}")
diff --git a/third_party/aten b/third_party/aten
new file mode 120000
index 0000000..5ab2da8
--- /dev/null
+++ b/third_party/aten
@@ -0,0 +1 @@
+Subproject commit dd3333d8697590d1eb01b6b5ef1e7bc5aaaa7967