Implement dim_arange operator (#8266)
* Implement arange_like operator
* add ONNX symbolic
* lint
* change name
* Comment the hack
diff --git a/aten/src/ATen/native/TensorFactories.cpp b/aten/src/ATen/native/TensorFactories.cpp
index ceed9a3..aba2312 100644
--- a/aten/src/ATen/native/TensorFactories.cpp
+++ b/aten/src/ATen/native/TensorFactories.cpp
@@ -36,6 +36,10 @@
return at::_arange_out(result, end);
}
+Tensor _dim_arange(const Tensor& like, int64_t dim) {
+ return like.type().toScalarType(at::kLong)._arange(like.size(dim));
+}
+
Tensor empty(const Type& dtype, IntList size) {
return dtype.tensor(size);
}
diff --git a/aten/src/ATen/native/native_functions.yaml b/aten/src/ATen/native/native_functions.yaml
index 6601cbb..0888c77 100644
--- a/aten/src/ATen/native/native_functions.yaml
+++ b/aten/src/ATen/native/native_functions.yaml
@@ -110,6 +110,14 @@
- func: arange_out(Tensor result, Scalar end) -> Tensor
variants: function
+# This function is a temporary hack to allow tracing of arange like constructs with dynamic
+# bounds on arange. Normal arange is not traceable because it does not take any tensor inputs;
+# if the range you need is based on another tensor, calling this function directly will
+# preserve tracing. Get rid of this when arange can directly take tensors for bounds
+# (so that it can be traced directly).
+- func: _dim_arange(Tensor like, int64_t dim) -> Tensor
+ variants: function
+
# `argmin` and `argmax` are exposed in C++ but not in Python, where we only
# expose `_argmin` and `_argmax` (which call the first versions). In Python, we
# then define our own `argmax` and `argmin` that handle passing `dim=None`,
diff --git a/torch/onnx/symbolic.py b/torch/onnx/symbolic.py
index e7cd77b..ec8dd60 100644
--- a/torch/onnx/symbolic.py
+++ b/torch/onnx/symbolic.py
@@ -998,3 +998,7 @@
return prev_output, h_outs, c_outs
return symbolic
+
+
+def _dim_arange(g, like, dim):
+ return g.op('ATen', like, dim_i=dim, operator_s='_dim_arange')