Implement tensor.size(Dimname), tensor.stride(Dimname)

Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/22989

Test Plan: Imported from OSS

Differential Revision: D16364437

Pulled By: zou3519

fbshipit-source-id: 393a93fecac27b5d3b1a7f7692590d8fd5e95a5d
diff --git a/aten/src/ATen/core/Tensor.h b/aten/src/ATen/core/Tensor.h
index f223d10..1280888 100644
--- a/aten/src/ATen/core/Tensor.h
+++ b/aten/src/ATen/core/Tensor.h
@@ -516,6 +516,9 @@
   Tensor detach() const;
   Tensor & detach_();
   int64_t size(int64_t dim) const;
+  #ifdef BUILD_NAMEDTENSOR
+  int64_t size(Dimname dim) const;
+  #endif
   Tensor slice(int64_t dim=0, int64_t start=0, int64_t end=9223372036854775807, int64_t step=1) const;
   std::tuple<Tensor,Tensor> slogdet() const;
   Tensor smm(const Tensor & mat2) const;
@@ -529,6 +532,9 @@
   Tensor sspaddmm(const Tensor & mat1, const Tensor & mat2, Scalar beta=1, Scalar alpha=1) const;
   Tensor stft(int64_t n_fft, c10::optional<int64_t> hop_length=c10::nullopt, c10::optional<int64_t> win_length=c10::nullopt, const Tensor & window={}, bool normalized=false, bool onesided=true) const;
   int64_t stride(int64_t dim) const;
+  #ifdef BUILD_NAMEDTENSOR
+  int64_t stride(Dimname dim) const;
+  #endif
   Tensor sum(c10::optional<ScalarType> dtype=c10::nullopt) const;
   Tensor sum(IntArrayRef dim, bool keepdim=false, c10::optional<ScalarType> dtype=c10::nullopt) const;
   Tensor sum_to_size(IntArrayRef size) const;
diff --git a/aten/src/ATen/core/TensorMethods.h b/aten/src/ATen/core/TensorMethods.h
index a62c25e..e56fd14 100644
--- a/aten/src/ATen/core/TensorMethods.h
+++ b/aten/src/ATen/core/TensorMethods.h
@@ -715,6 +715,12 @@
     static auto table = globalATenDispatch().getOpTable("aten::size(Tensor self, int dim) -> int");
     return table->getOp<int64_t (const Tensor &, int64_t)>(tensorTypeIdToBackend(type_id()), is_variable())(*this, dim);
 }
+#ifdef BUILD_NAMEDTENSOR
+inline int64_t Tensor::size(Dimname dim) const {
+    static auto table = globalATenDispatch().getOpTable("aten::size(Tensor self, Dimname dim) -> int");
+    return table->getOp<int64_t (const Tensor &, Dimname)>(tensorTypeIdToBackend(type_id()), is_variable())(*this, dim);
+}
+#endif
 inline Tensor Tensor::slice(int64_t dim, int64_t start, int64_t end, int64_t step) const {
     static auto table = globalATenDispatch().getOpTable("aten::slice(Tensor(a) self, int dim=0, int start=0, int end=9223372036854775807, int step=1) -> Tensor(a)");
     return table->getOp<Tensor (const Tensor &, int64_t, int64_t, int64_t, int64_t)>(tensorTypeIdToBackend(type_id()), is_variable())(*this, dim, start, end, step);
@@ -767,6 +773,12 @@
     static auto table = globalATenDispatch().getOpTable("aten::stride(Tensor self, int dim) -> int");
     return table->getOp<int64_t (const Tensor &, int64_t)>(tensorTypeIdToBackend(type_id()), is_variable())(*this, dim);
 }
+#ifdef BUILD_NAMEDTENSOR
+inline int64_t Tensor::stride(Dimname dim) const {
+    static auto table = globalATenDispatch().getOpTable("aten::stride(Tensor self, Dimname dim) -> int");
+    return table->getOp<int64_t (const Tensor &, Dimname)>(tensorTypeIdToBackend(type_id()), is_variable())(*this, dim);
+}
+#endif
 inline Tensor Tensor::sum(c10::optional<ScalarType> dtype) const {
     static auto table = globalATenDispatch().getOpTable("aten::sum(Tensor self, *, ScalarType? dtype=None) -> Tensor");
     return table->getOp<Tensor (const Tensor &, c10::optional<ScalarType>)>(tensorTypeIdToBackend(type_id()), is_variable())(*this, dtype);
diff --git a/aten/src/ATen/native/TensorProperties.cpp b/aten/src/ATen/native/TensorProperties.cpp
index 33d2fe3..de1e8c6 100644
--- a/aten/src/ATen/native/TensorProperties.cpp
+++ b/aten/src/ATen/native/TensorProperties.cpp
@@ -2,6 +2,9 @@
 #include <ATen/NativeFunctions.h>
 #include <ATen/WrapDimUtils.h>
 #include <ATen/detail/CUDAHooksInterface.h>
+#ifdef BUILD_NAMEDTENSOR
+#include <ATen/NamedTensorUtils.h>
+#endif
 
 #include <ATen/Config.h>
 namespace at {
@@ -23,6 +26,18 @@
   return self.strides()[dim];
 }
 
+#ifdef BUILD_NAMEDTENSOR
+int64_t size(const Tensor& self, Dimname dim) {
+  size_t pos_dim = dimname_to_position(self, dim);
+  return self.sizes()[pos_dim];
+}
+
+int64_t stride(const Tensor& self, Dimname dim) {
+  size_t pos_dim = dimname_to_position(self, dim);
+  return self.strides()[pos_dim];
+}
+#endif
+
 bool cudnn_is_acceptable(const Tensor& self) {
   if (!globalContext().userEnabledCuDNN()) return false;
   if (!self.is_cuda()) return false;
diff --git a/aten/src/ATen/native/native_functions.yaml b/aten/src/ATen/native/native_functions.yaml
index 66afc84..82ce28c 100644
--- a/aten/src/ATen/native/native_functions.yaml
+++ b/aten/src/ATen/native/native_functions.yaml
@@ -1854,6 +1854,11 @@
   device_guard: False
   named_guard: False
 
+- func: size(Tensor self, Dimname dim) -> int
+  variants: function, method
+  device_guard: False
+  named_guard: False
+
 - func: slice(Tensor(a) self, int dim=0, int start=0, int end=9223372036854775807, int step=1) -> Tensor(a)
   variants: function, method
   device_guard: False
@@ -1965,6 +1970,12 @@
 - func: stride(Tensor self, int dim) -> int
   variants: function, method
   device_guard: False
+  named_guard: False
+
+- func: stride(Tensor self, Dimname dim) -> int
+  variants: function, method
+  device_guard: False
+  named_guard: False
 
 - func: sum(Tensor self, *, ScalarType? dtype=None) -> Tensor
   variants: function, method
diff --git a/test/test_namedtensor.py b/test/test_namedtensor.py
index c814059..564dd40 100644
--- a/test/test_namedtensor.py
+++ b/test/test_namedtensor.py
@@ -71,6 +71,28 @@
     def test_empty_cuda(self):
         self._test_factory(torch.empty, 'cuda')
 
+    def test_size(self):
+        t = torch.empty(2, 3, 5, names=('N', None, 'C'))
+        self.assertEqual(t.size('N'), 2)
+        self.assertEqual(t.size('C'), 5)
+        with self.assertRaisesRegex(RuntimeError, 'Please look up dimensions by name*'):
+            t.size(None)
+        with self.assertRaisesRegex(RuntimeError, 'Name \'channels\' not found in '):
+            t.size('channels')
+        with self.assertRaisesRegex(RuntimeError, 'Name \'N\' not found in '):
+            torch.empty(2, 3, 4).size('N')
+
+    def test_stride(self):
+        t = torch.empty(2, 3, 5, names=('N', None, 'C'))
+        self.assertEqual(t.stride('N'), 3 * 5)
+        self.assertEqual(t.stride('C'), 1)
+        with self.assertRaisesRegex(RuntimeError, 'Please look up dimensions by name'):
+            t.stride(None)
+        with self.assertRaisesRegex(RuntimeError, 'Name \'channels\' not found in '):
+            t.stride('channels')
+        with self.assertRaisesRegex(RuntimeError, 'Name \'N\' not found in '):
+            torch.empty(2, 3, 4).stride('N')
+
     def test_info_smoke(self):
         # Smoke test for info functions / methods / attributes on named tensors.
         tensor = torch.empty(1, 1, names=('N', 'D'))
@@ -97,10 +119,12 @@
         tensor.nelement()
         tensor.shape
         tensor.size()
+        tensor.size(1)
         tensor.storage()
         tensor.storage_offset()
         tensor.storage_type()
         tensor.stride()
+        tensor.stride(1)
         tensor.data
         tensor.data_ptr()
         tensor.ndim
diff --git a/tools/autograd/templates/python_variable_methods.cpp b/tools/autograd/templates/python_variable_methods.cpp
index 6fd6b70..fda0886 100644
--- a/tools/autograd/templates/python_variable_methods.cpp
+++ b/tools/autograd/templates/python_variable_methods.cpp
@@ -77,6 +77,9 @@
   static PythonArgParser parser({
     "size(int64_t dim)",
     "size()",
+#ifdef BUILD_NAMEDTENSOR
+    "size(Dimname dim)",
+#endif
   });
   auto& self_ = reinterpret_cast<THPVariable*>(self)->cdata;
   ParsedArgs<3> parsed_args;
@@ -92,6 +95,14 @@
     // torch.Size and tuple in python.
     return THPSize_New(self_);
   }
+#ifdef BUILD_NAMEDTENSOR
+  else if (r.idx == 2) {
+    if (jit::tracer::isTracing()) {
+      TORCH_INTERNAL_ASSERT("NYI: Named tensors w/ JIT");
+    }
+    return wrap(self_.size(r.dimname(0)));
+  }
+#endif
   Py_RETURN_NONE;
   END_HANDLE_TH_ERRORS
 }
@@ -102,6 +113,9 @@
   static PythonArgParser parser({
     "stride(int64_t dim)",
     "stride()",
+#ifdef BUILD_NAMEDTENSOR
+    "stride(Dimname dim)",
+#endif
   });
   auto& self_ = reinterpret_cast<THPVariable*>(self)->cdata;
   ParsedArgs<3> parsed_args;
@@ -115,6 +129,11 @@
     // torch.Size and tuple in python
     return THPUtils_packInt64Array(strides.size(), strides.data());
   }
+#ifdef BUILD_NAMEDTENSOR
+  else if (r.idx == 2) {
+    return wrap(self_.stride(r.dimname(0)));
+  }
+#endif
   Py_RETURN_NONE;
   END_HANDLE_TH_ERRORS
 }