Support wrap_dim specifications from cwrap.
diff --git a/aten/src/ATen/function_wrapper.py b/aten/src/ATen/function_wrapper.py
index b581d7c..895a505 100644
--- a/aten/src/ATen/function_wrapper.py
+++ b/aten/src/ATen/function_wrapper.py
@@ -82,6 +82,22 @@
return static_cast<Type*>(this)->${method_prefix}${api_name}(${sparse_actuals});
}""")
+WRAP_DIM_GEN = CodeTemplate("""\
+auto ${target}${dim}_ = std::max<${type}>(${target_dim_expr}, 0);
+if (${target}${dim}_ <= 0) {
+ std::ostringstream oss;
+ oss << "dimension specified as " << ${dim} << " but tensor has no dimensions";
+ throw std::runtime_error(oss.str());
+}
+if (${dim} < -(${target}${dim}_) || ${dim} >= (${target}${dim}_)) {
+ std::ostringstream oss;
+ oss << "dimension out of range (expected to be in range of [" << -(${target}${dim}_)
+ << ", " << (${target}${dim}_)-1 << "], but got " << ${dim} << ")",
+ throw std::runtime_error(oss.str());
+}
+if (${dim} < 0) ${dim} += ${target}${dim}_;
+""")
+
class NYIError(Exception):
"""Indicates we don't support this declaration yet"""
@@ -571,6 +587,18 @@
if arg['type'] == 'THSize*':
scalar_check_is_from_size = True
scalar_check = '{}.size() == 0'.format(arg['name'])
+
+ wrap_dim_arg = arg.get('wrap_dim', None)
+ if wrap_dim_arg is not None:
+ # wrap_dim specification can have (add) expressions, e.g. self+1
+ wrap_dim_params = wrap_dim_arg.split("+")
+ wrap_dim_target = wrap_dim_params[0]
+ wrap_dim_params[0] = "{}.dim()".format(wrap_dim_target)
+ wrap_dim_expr = "+".join(wrap_dim_params)
+ body.append(WRAP_DIM_GEN.substitute(
+ target=wrap_dim_target, dim=arg['name'], target_dim_expr=wrap_dim_expr,
+ type=DYNAMIC_TYPE[arg['type']]))
+
# only generated checked casts the first time we see it
if not arg['name'] in seen_names and requires_checked_cast(arg):
seen_names.add(arg['name'])
diff --git a/aten/src/ATen/templates/TypeDerived.cpp b/aten/src/ATen/templates/TypeDerived.cpp
index 9a48ef7..c4ba1e4 100644
--- a/aten/src/ATen/templates/TypeDerived.cpp
+++ b/aten/src/ATen/templates/TypeDerived.cpp
@@ -9,6 +9,7 @@
#include "ATen/Utils.h"
#include "ATen/THLongStorageView.h"
#include <iostream>
+#include <sstream>
namespace at {
diff --git a/aten/src/ATen/test/CMakeLists.txt b/aten/src/ATen/test/CMakeLists.txt
index c6f1aa6..ed4ee90 100644
--- a/aten/src/ATen/test/CMakeLists.txt
+++ b/aten/src/ATen/test/CMakeLists.txt
@@ -9,3 +9,6 @@
add_executable(broadcast_test broadcast_test.cpp)
target_link_libraries(broadcast_test ATen)
+
+add_executable(wrapdim_test wrapdim_test.cpp)
+target_link_libraries(wrapdim_test ATen)
diff --git a/aten/src/ATen/test/wrapdim_test.cpp b/aten/src/ATen/test/wrapdim_test.cpp
new file mode 100644
index 0000000..221fa0d
--- /dev/null
+++ b/aten/src/ATen/test/wrapdim_test.cpp
@@ -0,0 +1,42 @@
+#include "ATen/ATen.h"
+
+using namespace at;
+
+int main() {
+ Type & T = CPU(kFloat);
+
+ // test simple case
+ {
+ auto a = T.randn({2, 3, 4, 5});
+ assert(a.prod(-4).equal(a.prod(0)));
+ assert(a.prod(3).equal(a.prod(-1)));
+ }
+
+ // test case with expression specification
+ {
+ auto a = T.randn({2, 3, 4, 5});
+ assert(a.unsqueeze(-5).equal(a.unsqueeze(0)));
+ assert(a.unsqueeze(4).equal(a.unsqueeze(-1)));
+ }
+
+ // test case with empty tensor
+ {
+ auto a = T.randn(0);
+ try {
+ a.prod(0);
+ assert(false);
+ } catch (std::runtime_error &e) {}
+ }
+
+ // test case with scalar vs 1-dim, 1-size
+ {
+ auto a = T.randn(1);
+ assert(a.prod(0).equal(a.prod(-1)));
+ a.get()->maybeScalar(true);
+ assert(a.get()->isScalar());
+ try {
+ a.prod(0);
+ assert(false);
+ } catch (std::runtime_error &e) {}
+ }
+}
diff --git a/aten/tools/run_tests.sh b/aten/tools/run_tests.sh
index dec88cb..ac54839 100755
--- a/aten/tools/run_tests.sh
+++ b/aten/tools/run_tests.sh
@@ -6,4 +6,5 @@
$BUILD_ROOT/src/ATen/test/atest
$BUILD_ROOT/src/ATen/test/scalar_test
$BUILD_ROOT/src/ATen/test/broadcast_test
+$BUILD_ROOT/src/ATen/test/wrapdim_test
valgrind --suppressions=`dirname $0`/valgrind.sup --error-exitcode=1 $BUILD_ROOT/src/ATen/test/basic -n