Implement tensor.size(Dimname), tensor.stride(Dimname)
Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/22989
Test Plan: Imported from OSS
Differential Revision: D16364437
Pulled By: zou3519
fbshipit-source-id: 393a93fecac27b5d3b1a7f7692590d8fd5e95a5d
diff --git a/aten/src/ATen/core/Tensor.h b/aten/src/ATen/core/Tensor.h
index f223d10..1280888 100644
--- a/aten/src/ATen/core/Tensor.h
+++ b/aten/src/ATen/core/Tensor.h
@@ -516,6 +516,9 @@
Tensor detach() const;
Tensor & detach_();
int64_t size(int64_t dim) const;
+ #ifdef BUILD_NAMEDTENSOR
+ int64_t size(Dimname dim) const;
+ #endif
Tensor slice(int64_t dim=0, int64_t start=0, int64_t end=9223372036854775807, int64_t step=1) const;
std::tuple<Tensor,Tensor> slogdet() const;
Tensor smm(const Tensor & mat2) const;
@@ -529,6 +532,9 @@
Tensor sspaddmm(const Tensor & mat1, const Tensor & mat2, Scalar beta=1, Scalar alpha=1) const;
Tensor stft(int64_t n_fft, c10::optional<int64_t> hop_length=c10::nullopt, c10::optional<int64_t> win_length=c10::nullopt, const Tensor & window={}, bool normalized=false, bool onesided=true) const;
int64_t stride(int64_t dim) const;
+ #ifdef BUILD_NAMEDTENSOR
+ int64_t stride(Dimname dim) const;
+ #endif
Tensor sum(c10::optional<ScalarType> dtype=c10::nullopt) const;
Tensor sum(IntArrayRef dim, bool keepdim=false, c10::optional<ScalarType> dtype=c10::nullopt) const;
Tensor sum_to_size(IntArrayRef size) const;
diff --git a/aten/src/ATen/core/TensorMethods.h b/aten/src/ATen/core/TensorMethods.h
index a62c25e..e56fd14 100644
--- a/aten/src/ATen/core/TensorMethods.h
+++ b/aten/src/ATen/core/TensorMethods.h
@@ -715,6 +715,12 @@
static auto table = globalATenDispatch().getOpTable("aten::size(Tensor self, int dim) -> int");
return table->getOp<int64_t (const Tensor &, int64_t)>(tensorTypeIdToBackend(type_id()), is_variable())(*this, dim);
}
+#ifdef BUILD_NAMEDTENSOR
+inline int64_t Tensor::size(Dimname dim) const {
+ static auto table = globalATenDispatch().getOpTable("aten::size(Tensor self, Dimname dim) -> int");
+ return table->getOp<int64_t (const Tensor &, Dimname)>(tensorTypeIdToBackend(type_id()), is_variable())(*this, dim);
+}
+#endif
inline Tensor Tensor::slice(int64_t dim, int64_t start, int64_t end, int64_t step) const {
static auto table = globalATenDispatch().getOpTable("aten::slice(Tensor(a) self, int dim=0, int start=0, int end=9223372036854775807, int step=1) -> Tensor(a)");
return table->getOp<Tensor (const Tensor &, int64_t, int64_t, int64_t, int64_t)>(tensorTypeIdToBackend(type_id()), is_variable())(*this, dim, start, end, step);
@@ -767,6 +773,12 @@
static auto table = globalATenDispatch().getOpTable("aten::stride(Tensor self, int dim) -> int");
return table->getOp<int64_t (const Tensor &, int64_t)>(tensorTypeIdToBackend(type_id()), is_variable())(*this, dim);
}
+#ifdef BUILD_NAMEDTENSOR
+inline int64_t Tensor::stride(Dimname dim) const {
+ static auto table = globalATenDispatch().getOpTable("aten::stride(Tensor self, Dimname dim) -> int");
+ return table->getOp<int64_t (const Tensor &, Dimname)>(tensorTypeIdToBackend(type_id()), is_variable())(*this, dim);
+}
+#endif
inline Tensor Tensor::sum(c10::optional<ScalarType> dtype) const {
static auto table = globalATenDispatch().getOpTable("aten::sum(Tensor self, *, ScalarType? dtype=None) -> Tensor");
return table->getOp<Tensor (const Tensor &, c10::optional<ScalarType>)>(tensorTypeIdToBackend(type_id()), is_variable())(*this, dtype);
diff --git a/aten/src/ATen/native/TensorProperties.cpp b/aten/src/ATen/native/TensorProperties.cpp
index 33d2fe3..de1e8c6 100644
--- a/aten/src/ATen/native/TensorProperties.cpp
+++ b/aten/src/ATen/native/TensorProperties.cpp
@@ -2,6 +2,9 @@
#include <ATen/NativeFunctions.h>
#include <ATen/WrapDimUtils.h>
#include <ATen/detail/CUDAHooksInterface.h>
+#ifdef BUILD_NAMEDTENSOR
+#include <ATen/NamedTensorUtils.h>
+#endif
#include <ATen/Config.h>
namespace at {
@@ -23,6 +26,18 @@
return self.strides()[dim];
}
+#ifdef BUILD_NAMEDTENSOR
+int64_t size(const Tensor& self, Dimname dim) {
+ size_t pos_dim = dimname_to_position(self, dim);
+ return self.sizes()[pos_dim];
+}
+
+int64_t stride(const Tensor& self, Dimname dim) {
+ size_t pos_dim = dimname_to_position(self, dim);
+ return self.strides()[pos_dim];
+}
+#endif
+
bool cudnn_is_acceptable(const Tensor& self) {
if (!globalContext().userEnabledCuDNN()) return false;
if (!self.is_cuda()) return false;
diff --git a/aten/src/ATen/native/native_functions.yaml b/aten/src/ATen/native/native_functions.yaml
index 66afc84..82ce28c 100644
--- a/aten/src/ATen/native/native_functions.yaml
+++ b/aten/src/ATen/native/native_functions.yaml
@@ -1854,6 +1854,11 @@
device_guard: False
named_guard: False
+- func: size(Tensor self, Dimname dim) -> int
+ variants: function, method
+ device_guard: False
+ named_guard: False
+
- func: slice(Tensor(a) self, int dim=0, int start=0, int end=9223372036854775807, int step=1) -> Tensor(a)
variants: function, method
device_guard: False
@@ -1965,6 +1970,12 @@
- func: stride(Tensor self, int dim) -> int
variants: function, method
device_guard: False
+ named_guard: False
+
+- func: stride(Tensor self, Dimname dim) -> int
+ variants: function, method
+ device_guard: False
+ named_guard: False
- func: sum(Tensor self, *, ScalarType? dtype=None) -> Tensor
variants: function, method
diff --git a/test/test_namedtensor.py b/test/test_namedtensor.py
index c814059..564dd40 100644
--- a/test/test_namedtensor.py
+++ b/test/test_namedtensor.py
@@ -71,6 +71,28 @@
def test_empty_cuda(self):
self._test_factory(torch.empty, 'cuda')
+ def test_size(self):
+ t = torch.empty(2, 3, 5, names=('N', None, 'C'))
+ self.assertEqual(t.size('N'), 2)
+ self.assertEqual(t.size('C'), 5)
+ with self.assertRaisesRegex(RuntimeError, 'Please look up dimensions by name*'):
+ t.size(None)
+ with self.assertRaisesRegex(RuntimeError, 'Name \'channels\' not found in '):
+ t.size('channels')
+ with self.assertRaisesRegex(RuntimeError, 'Name \'N\' not found in '):
+ torch.empty(2, 3, 4).size('N')
+
+ def test_stride(self):
+ t = torch.empty(2, 3, 5, names=('N', None, 'C'))
+ self.assertEqual(t.stride('N'), 3 * 5)
+ self.assertEqual(t.stride('C'), 1)
+ with self.assertRaisesRegex(RuntimeError, 'Please look up dimensions by name'):
+ t.stride(None)
+ with self.assertRaisesRegex(RuntimeError, 'Name \'channels\' not found in '):
+ t.stride('channels')
+ with self.assertRaisesRegex(RuntimeError, 'Name \'N\' not found in '):
+ torch.empty(2, 3, 4).stride('N')
+
def test_info_smoke(self):
# Smoke test for info functions / methods / attributes on named tensors.
tensor = torch.empty(1, 1, names=('N', 'D'))
@@ -97,10 +119,12 @@
tensor.nelement()
tensor.shape
tensor.size()
+ tensor.size(1)
tensor.storage()
tensor.storage_offset()
tensor.storage_type()
tensor.stride()
+ tensor.stride(1)
tensor.data
tensor.data_ptr()
tensor.ndim
diff --git a/tools/autograd/templates/python_variable_methods.cpp b/tools/autograd/templates/python_variable_methods.cpp
index 6fd6b70..fda0886 100644
--- a/tools/autograd/templates/python_variable_methods.cpp
+++ b/tools/autograd/templates/python_variable_methods.cpp
@@ -77,6 +77,9 @@
static PythonArgParser parser({
"size(int64_t dim)",
"size()",
+#ifdef BUILD_NAMEDTENSOR
+ "size(Dimname dim)",
+#endif
});
auto& self_ = reinterpret_cast<THPVariable*>(self)->cdata;
ParsedArgs<3> parsed_args;
@@ -92,6 +95,14 @@
// torch.Size and tuple in python.
return THPSize_New(self_);
}
+#ifdef BUILD_NAMEDTENSOR
+ else if (r.idx == 2) {
+ if (jit::tracer::isTracing()) {
+ TORCH_INTERNAL_ASSERT("NYI: Named tensors w/ JIT");
+ }
+ return wrap(self_.size(r.dimname(0)));
+ }
+#endif
Py_RETURN_NONE;
END_HANDLE_TH_ERRORS
}
@@ -102,6 +113,9 @@
static PythonArgParser parser({
"stride(int64_t dim)",
"stride()",
+#ifdef BUILD_NAMEDTENSOR
+ "stride(Dimname dim)",
+#endif
});
auto& self_ = reinterpret_cast<THPVariable*>(self)->cdata;
ParsedArgs<3> parsed_args;
@@ -115,6 +129,11 @@
// torch.Size and tuple in python
return THPUtils_packInt64Array(strides.size(), strides.data());
}
+#ifdef BUILD_NAMEDTENSOR
+ else if (r.idx == 2) {
+ return wrap(self_.stride(r.dimname(0)));
+ }
+#endif
Py_RETURN_NONE;
END_HANDLE_TH_ERRORS
}