More named inference rules for pointwise unary ops

Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/22308

Test Plan:
- `python test/test_namedtensor.py -v` [namedtensor ci]

gh-metadata: pytorch pytorch 22308 gh/zou3519/65/head

Imported from OSS

Differential Revision: D16053441

Pulled By: zou3519

fbshipit-source-id: 2e8d4cc11d7a711d2b789752a316a11fffc0996e
diff --git a/aten/src/ATen/NamedTensorUtils.cpp b/aten/src/ATen/NamedTensorUtils.cpp
index 57259a8..a0ab16b 100644
--- a/aten/src/ATen/NamedTensorUtils.cpp
+++ b/aten/src/ATen/NamedTensorUtils.cpp
@@ -155,6 +155,11 @@
   at::internal_set_names_inplace(result, src.names());
 }
 
+void propagate_names(TensorImpl* result, TensorImpl* src) {
+  const auto names = at::impl::internal_get_names(src);
+  at::impl::internal_set_names_inplace(result, names);
+}
+
 } // namespace namedinference
 } // namespace at
 #endif
diff --git a/aten/src/ATen/NamedTensorUtils.h b/aten/src/ATen/NamedTensorUtils.h
index b91c1e7..8014825 100644
--- a/aten/src/ATen/NamedTensorUtils.h
+++ b/aten/src/ATen/NamedTensorUtils.h
@@ -34,6 +34,7 @@
 
 optional<std::vector<Dimname>> erase_name(optional<DimnameList> self_names, int64_t dim);
 void propagate_names(Tensor& result, const Tensor& src);
+void propagate_names(TensorImpl* result, /*const */TensorImpl* src);
 
 } // namespace namedinference
 
diff --git a/aten/src/ATen/native/native_functions.yaml b/aten/src/ATen/native/native_functions.yaml
index 385e5dc..cc4ecaa 100644
--- a/aten/src/ATen/native/native_functions.yaml
+++ b/aten/src/ATen/native/native_functions.yaml
@@ -1664,6 +1664,7 @@
     MkldnnCPU: mkldnn_relu
 
 - func: relu_(Tensor(a!) self) -> Tensor(a!)
+  named_guard: False
   variants: function, method
   dispatch:
     CPU: relu_
@@ -1742,6 +1743,7 @@
 - func: celu_(Tensor(a!) self, Scalar alpha=1.0) -> Tensor(a!)
 
 - func: sigmoid(Tensor self) -> Tensor
+  named_guard: False
   variants: function, method
   dispatch:
     CPU: sigmoid
@@ -1749,6 +1751,7 @@
     MkldnnCPU: mkldnn_sigmoid
 
 - func: sigmoid_(Tensor(a!) self) -> Tensor(a!)
+  named_guard: False
   variants: function, method
   dispatch:
     CPU: _sigmoid__cpu
@@ -1756,6 +1759,7 @@
     MkldnnCPU: mkldnn_sigmoid_
 
 - func: sigmoid(Tensor self, *, Tensor(a!) out) -> Tensor(a!)
+  named_guard: False
   dispatch:
     CPU: _sigmoid_out_cpu
     CUDA: _sigmoid_out_cuda
@@ -1990,15 +1994,18 @@
     CUDA: _tan_out_cuda
 
 - func: tanh(Tensor self) -> Tensor
+  named_guard: False
   variants: function, method
 
 - func: tanh_(Tensor(a!) self) -> Tensor(a!)
+  named_guard: False
   variants: function, method
   dispatch:
     CPU: _tanh__cpu
     CUDA: _tanh__cuda
 
 - func: tanh(Tensor self, *, Tensor(a!) out) -> Tensor(a!)
+  named_guard: False
   dispatch:
     CPU: _tanh_out_cpu
     CUDA: _tanh_out_cuda
@@ -2312,6 +2319,7 @@
     SparseCUDA: pow_sparse_scalar
 
 - func: zero_(Tensor(a!) self) -> Tensor(a!)
+  named_guard: False
   variants: method, function
   dispatch:
     CPU: legacy::cpu::_th_zero_
@@ -3165,6 +3173,7 @@
     CUDA: legacy::cuda::_th_irshift_
 
 - func: lgamma_(Tensor(a!) self) -> Tensor(a!)
+  named_guard: False
   variants: method
   dispatch:
     CPU: legacy::cpu::_th_lgamma_
@@ -3189,18 +3198,21 @@
     CUDA: triu_cuda_
 
 - func: digamma_(Tensor(a!) self) -> Tensor(a!)
+  named_guard: False
   variants: method
   dispatch:
     CPU: legacy::cpu::_th_digamma_
     CUDA: legacy::cuda::_th_digamma_
 
 - func: polygamma_(Tensor(a!) self, int n) -> Tensor(a!)
+  named_guard: False
   variants: method
   dispatch:
     CPU: legacy::cpu::_th_polygamma_
     CUDA: legacy::cuda::_th_polygamma_
 
 - func: erfinv_(Tensor(a!) self) -> Tensor(a!)
+  named_guard: False
   variants: method
   dispatch:
     CPU: legacy::cpu::_th_erfinv_
@@ -3802,44 +3814,52 @@
     CUDA: legacy::cuda::_th_multinomial_alias_draw
 
 - func: lgamma(Tensor self, *, Tensor(a!) out) -> Tensor(a!)
+  named_guard: False
   dispatch:
     CPU: legacy::cpu::_th_lgamma_out
     CUDA: legacy::cuda::_th_lgamma_out
 
 - func: lgamma(Tensor self) -> Tensor
+  named_guard: False
   variants: method, function
   dispatch:
     CPU: legacy::cpu::_th_lgamma
     CUDA: legacy::cuda::_th_lgamma
 
 - func: digamma(Tensor self, *, Tensor(a!) out) -> Tensor(a!)
+  named_guard: False
   dispatch:
     CPU: legacy::cpu::_th_digamma_out
     CUDA: legacy::cuda::_th_digamma_out
 
 - func: digamma(Tensor self) -> Tensor
+  named_guard: False
   variants: method, function
   dispatch:
     CPU: legacy::cpu::_th_digamma
     CUDA: legacy::cuda::_th_digamma
 
 - func: polygamma(int n, Tensor self, *, Tensor(a!) out) -> Tensor(a!)
+  named_guard: False
   dispatch:
     CPU: legacy::cpu::_th_polygamma_out
     CUDA: legacy::cuda::_th_polygamma_out
 
 - func: polygamma(int n, Tensor self) -> Tensor
+  named_guard: False
   variants: method, function
   dispatch:
     CPU: legacy::cpu::_th_polygamma
     CUDA: legacy::cuda::_th_polygamma
 
 - func: erfinv(Tensor self, *, Tensor(a!) out) -> Tensor(a!)
+  named_guard: False
   dispatch:
     CPU: legacy::cpu::_th_erfinv_out
     CUDA: legacy::cuda::_th_erfinv_out
 
 - func: erfinv(Tensor self) -> Tensor
+  named_guard: False
   variants: method, function
   dispatch:
     CPU: legacy::cpu::_th_erfinv
diff --git a/aten/src/TH/generic/THTensor.cpp b/aten/src/TH/generic/THTensor.cpp
index 9ca8630..414b625 100644
--- a/aten/src/TH/generic/THTensor.cpp
+++ b/aten/src/TH/generic/THTensor.cpp
@@ -4,6 +4,9 @@
 
 #include <ATen/InferSize.h>
 #include <new>
+#ifdef NAMEDTENSOR_ENABLED
+#include <ATen/NamedTensorUtils.h>
+#endif
 
 /**** access methods ****/
 THStorage *THTensor_(storage)(const THTensor *self)
diff --git a/aten/src/TH/generic/THTensorMoreMath.cpp b/aten/src/TH/generic/THTensorMoreMath.cpp
index 2aaf392..636c51b 100644
--- a/aten/src/TH/generic/THTensorMoreMath.cpp
+++ b/aten/src/TH/generic/THTensorMoreMath.cpp
@@ -5,6 +5,9 @@
 #include <TH/generic/THTensorApply.hpp>
 #include <ATen/CPUGenerator.h>
 #include <ATen/Utils.h>
+#ifdef NAMEDTENSOR_ENABLED
+#include <ATen/NamedTensorUtils.h>
+#endif
 
 #define TENSOR_IMPLEMENT_LOGICAL(NAME,OP) \
   void THTensor_(NAME##Value)(THByteTensor *r_, THTensor* t, scalar_t value) \
@@ -1025,6 +1028,12 @@
   }
 }
 
+static void THTensor_(propagate_names_if_named_tensor_enabled)(THTensor* result, THTensor* src) {
+#ifdef NAMEDTENSOR_ENABLED
+  at::namedinference::propagate_names(result, src);
+#endif
+}
+
 #define LAB_IMPLEMENT_BASIC_FUNCTION_3_ARGS(NAME, CFUNC, THRESHOLD) \
   void THTensor_(NAME)(THTensor *r_, THTensor *t) \
   { \
@@ -1033,6 +1042,7 @@
     int r_Contig = THTensor_(isContiguous)(r_); \
     int tContig = THTensor_(isContiguous)(t); \
     TH_TENSOR_APPLY2_PARALLEL(r_Size, r_Contig, tContig, scalar_t, r_, scalar_t, t, *r__data = CFUNC(*t_data);, THRESHOLD); \
+    THTensor_(propagate_names_if_named_tensor_enabled)(r_, t); \
   }
 
 #define LAB_IMPLEMENT_BASIC_FUNCTION_2_ARGS(NAME, CFUNC) \
@@ -1050,6 +1060,7 @@
     } else { \
       TH_TENSOR_APPLY2_PARALLEL(r_Size, r_Contig, tContig, scalar_t, r_, scalar_t, t, *r__data = CFUNC(*t_data);, THRESHOLD); \
     } \
+    THTensor_(propagate_names_if_named_tensor_enabled)(r_, t); \
   }
 
 #define LAB_IMPLEMENT_VECTORIZED_FUNCTION_2_ARGS(NAME, CFUNC) \
@@ -1130,10 +1141,13 @@
 
 void THTensor_(polygamma)(THTensor *r_, int64_t n, THTensor *t) {
   switch (n) {
-    case 0: THTensor_(digamma)(r_, t); return;
-    case 1: THTensor_(trigamma)(r_, t); return;
+    case 0: THTensor_(digamma)(r_, t); break;
+    case 1: THTensor_(trigamma)(r_, t); break;
     default: THError("polygamma(n,x) is not implemented for n>=2");
   }
+#ifdef NAMEDTENSOR_ENABLED
+  at::namedinference::propagate_names(r_, t);
+#endif
 }
 
 void THTensor_(std)(THTensor *r_, THTensor *t, int dimension, int biased, int keepdim)
diff --git a/aten/src/THC/generic/THCTensorMathPointwise.cu b/aten/src/THC/generic/THCTensorMathPointwise.cu
index 88c67ee..e7a119d 100644
--- a/aten/src/THC/generic/THCTensorMathPointwise.cu
+++ b/aten/src/THC/generic/THCTensorMathPointwise.cu
@@ -175,10 +175,9 @@
 
 #if !defined(THC_REAL_IS_BOOL)
 
-static void propagate_names(THCTensor* result, THCTensor* src) {
+static void propagate_names_if_named_tensor_enabled(THCTensor* result, THCTensor* src) {
 #ifdef NAMEDTENSOR_ENABLED
-  const auto names = at::impl::internal_get_names(src);
-  at::impl::internal_set_names_inplace(result, names);
+  at::namedinference::propagate_names(result, src);
 #endif
 }
 
@@ -209,7 +208,7 @@
     }                                                                   \
                                                                         \
     THCudaCheck(cudaGetLastError());                                    \
-    propagate_names(self_, src);                                        \
+    propagate_names_if_named_tensor_enabled(self_, src);                \
   }
 
 #define IMPLEMENT_CUDA_TENSOR_BASIC_FUNC(NAME, CFUNC, REAL) \
@@ -322,6 +321,9 @@
   }
 
   THCudaCheck(cudaGetLastError());
+#ifdef NAMEDTENSOR_ENABLED
+  at::namedinference::propagate_names(self_, src);
+#endif
 }
 
 void THCTensor_(digamma)(THCState* state, THCTensor* self_, THCTensor* src) {
@@ -334,6 +336,9 @@
   }
 
   THCudaCheck(cudaGetLastError());
+#ifdef NAMEDTENSOR_ENABLED
+  at::namedinference::propagate_names(self_, src);
+#endif
 }
 
 void THCTensor_(polygamma)(THCState* state, THCTensor* self_, int64_t n, THCTensor* src) {
@@ -357,6 +362,9 @@
   }
 
   THCudaCheck(cudaGetLastError());
+#ifdef NAMEDTENSOR_ENABLED
+  at::namedinference::propagate_names(self_, src);
+#endif
 }
 
 #endif
diff --git a/test/test_namedtensor.py b/test/test_namedtensor.py
index ec2fd8e..e930580 100644
--- a/test/test_namedtensor.py
+++ b/test/test_namedtensor.py
@@ -115,25 +115,33 @@
             fn_method_and_inplace('clamp_max', 2),
             fn_method_and_inplace('cos'),
             fn_method_and_inplace('cosh'),
+            fn_method_and_inplace('digamma'),
             fn_method_and_inplace('erf'),
             fn_method_and_inplace('erfc'),
+            fn_method_and_inplace('erfinv'),
             fn_method_and_inplace('exp'),
             fn_method_and_inplace('expm1'),
             fn_method_and_inplace('floor'),
             fn_method_and_inplace('frac'),
+            fn_method_and_inplace('lgamma'),
             fn_method_and_inplace('log'),
             fn_method_and_inplace('log10'),
             fn_method_and_inplace('log1p'),
             fn_method_and_inplace('log2'),
             fn_method_and_inplace('neg'),
+            [TestCase('polygamma', lambda t: torch.polygamma(1, t))],
+            method('polygamma_', 1),
             fn_method_and_inplace('reciprocal'),
             fn_method_and_inplace('round'),
             fn_method_and_inplace('rsqrt'),
+            fn_method_and_inplace('sigmoid'),
             fn_method_and_inplace('sin'),
             fn_method_and_inplace('sinh'),
             fn_method_and_inplace('sqrt'),
             fn_method_and_inplace('tan'),
+            fn_method_and_inplace('tanh'),
             fn_method_and_inplace('trunc'),
+            method('zero_'),
             method('fill_', 1),
             method('fill_', torch.tensor(3.14)),
         ]