Named inference rule for `abs`. (#22151)
Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/22151
ghimport-source-id: 54c1726b578ac162af817f78df6f540b764e46e3
Test Plan:
- `python test/test_namedtensor.py` [namedtensor ci]
Imported from OSS
Differential Revision: D15970326
Pulled By: zou3519
fbshipit-source-id: 4ea25f0a73bbc24b604d3ded2027eeb4ce800de0
diff --git a/aten/src/ATen/function_wrapper.py b/aten/src/ATen/function_wrapper.py
index 39cd039..19dee13 100644
--- a/aten/src/ATen/function_wrapper.py
+++ b/aten/src/ATen/function_wrapper.py
@@ -603,6 +603,10 @@
def named_guard(option, tensors, tensorlists):
if not option.get('named_guard', True) or (len(tensors) + len(tensorlists) == 0):
return ''
+ # Override: named_guard = True for _th_ functions. This is because:
+ # There is always some at:: function that calls the _th_ function.
+ if option['name'].startswith('_th_'):
+ return ''
named_conditions = []
for tensor in tensors:
named_conditions.append('{}.is_named()'.format(tensor))
diff --git a/aten/src/ATen/native/UnaryOps.cpp b/aten/src/ATen/native/UnaryOps.cpp
index 86715a7..6627a5a 100644
--- a/aten/src/ATen/native/UnaryOps.cpp
+++ b/aten/src/ATen/native/UnaryOps.cpp
@@ -16,6 +16,9 @@
#include <ATen/Parallel.h>
#include <ATen/native/UnaryOps.h>
#include <ATen/native/TensorIterator.h>
+#ifdef NAMEDTENSOR_ENABLED
+#include <ATen/NamedTensorUtils.h>
+#endif
#include <algorithm>
#include <cmath>
@@ -137,6 +140,14 @@
return result;
}
+static void propagate_names(Tensor& result, const Tensor& src) {
+#ifdef NAMEDTENSOR_ENABLED
+ if (src.is_named()) {
+ at::internal_set_names_inplace(result, src.names());
+ }
+#endif
+}
+
// NB: If you use this macro, you may also need to add a CUDA forwarding
// stub in CUDAUnaryOps
@@ -154,6 +165,7 @@
assert_no_internal_overlap(result, #op); \
auto iter = TensorIterator::unary_op(result, self); \
op##_stub(iter->device_type(), *iter); \
+ propagate_names(result, self); \
return result; \
}
diff --git a/aten/src/ATen/native/native_functions.yaml b/aten/src/ATen/native/native_functions.yaml
index 5c6848f..8d4007f 100644
--- a/aten/src/ATen/native/native_functions.yaml
+++ b/aten/src/ATen/native/native_functions.yaml
@@ -92,14 +92,17 @@
- func: abs(Tensor self) -> Tensor
variants: function, method
+ named_guard: False
- func: abs_(Tensor(a!) self) -> Tensor(a!)
variants: function, method
+ named_guard: False
dispatch:
CPU: _abs__cpu
CUDA: _abs__cuda
- func: abs(Tensor self, *, Tensor(a!) out) -> Tensor(a!)
+ named_guard: False
dispatch:
CPU: _abs_out_cpu
CUDA: _abs_out_cuda
diff --git a/aten/src/THC/generic/THCTensorMathPointwise.cu b/aten/src/THC/generic/THCTensorMathPointwise.cu
index 2c510db..ce3f23f 100644
--- a/aten/src/THC/generic/THCTensorMathPointwise.cu
+++ b/aten/src/THC/generic/THCTensorMathPointwise.cu
@@ -3,6 +3,9 @@
#else
#include <ATen/MemoryOverlap.h>
+#ifdef NAMEDTENSOR_ENABLED
+#include <ATen/NamedTensorUtils.h>
+#endif
void THCTensor_(cbitand)(THCState* state, THCTensor *self_, THCTensor *src1, THCTensor *src2)
{
@@ -172,6 +175,15 @@
#if !defined(THC_REAL_IS_BOOL)
+static void propagate_names(THCTensor* result, THCTensor* src) {
+#ifdef NAMEDTENSOR_ENABLED
+ if (at::impl::internal_is_named(src)) {
+ const auto names = at::impl::internal_get_names(src);
+ at::impl::internal_set_names_inplace(result, names);
+ }
+#endif
+}
+
#define IMPLEMENT_CUDA_TENSOR_BASIC_FUNC_(NAME, CFUNC, REAL) \
struct Tensor_##NAME##_##REAL##_Op { \
__device__ __forceinline__ void operator()(scalar_t* out, scalar_t* in) const { \
@@ -199,6 +211,7 @@
} \
\
THCudaCheck(cudaGetLastError()); \
+ propagate_names(self_, src); \
}
#define IMPLEMENT_CUDA_TENSOR_BASIC_FUNC(NAME, CFUNC, REAL) \
diff --git a/test/test_namedtensor.py b/test/test_namedtensor.py
index 5149500..8dcf555 100644
--- a/test/test_namedtensor.py
+++ b/test/test_namedtensor.py
@@ -1,6 +1,7 @@
import unittest
from common_utils import TestCase, run_tests
from common_cuda import TEST_CUDA
+import itertools
import torch
import sys
@@ -69,6 +70,45 @@
def test_empty_cuda(self):
self._test_factory(torch.empty, 'cuda')
+ def test_unary_fns(self):
+ def _test(lambd, names=('N', 'D'), device='cpu'):
+ sizes = [2] * len(names)
+ tensor = torch.empty(sizes, names=names, device=device)
+ out = lambd(tensor)
+ self.assertEqual(out.names, tensor.names)
+
+ def method(name, *args, **kwargs):
+ return [lambda t: getattr(t, name)(*args, **kwargs)]
+
+ def out_function(name, *args, **kwargs):
+ out_fn = getattr(torch, name)
+
+ def fn(tensor):
+ result = tensor.new_empty([0])
+ out_fn(tensor, *args, out=result, **kwargs)
+ return result
+
+ return [fn]
+
+ def fn_method_and_inplace(name, *args, **kwargs):
+ return (
+ method(name, *args, **kwargs) +
+ method(name + '_', *args, **kwargs) +
+ out_function(name, *args, **kwargs)
+ )
+
+ def flatten(lst):
+ return [item for sublist in lst for item in sublist]
+
+ tests = [
+ fn_method_and_inplace('abs'),
+ ]
+ tests = flatten(tests)
+
+ for testcase, device in itertools.product(tests, torch.testing.get_all_device_types()):
+ _test(testcase, device=device)
+
+
def test_using_seen_interned_string_doesnt_bump_refcount(self):
def see_name():
seen_name = 'N'