Add complex support for `torch.{acosh, asinh, atanh}` (#50387)
Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/50387
Test Plan: Imported from OSS
Reviewed By: heitorschueroff
Differential Revision: D25947496
Pulled By: anjali411
fbshipit-source-id: c70886a73378501421ff94cdc0dc737f1738bf6f
diff --git a/aten/src/ATen/native/cpu/UnaryOpsKernel.cpp b/aten/src/ATen/native/cpu/UnaryOpsKernel.cpp
index 32ebaf7..6ed4b3a 100644
--- a/aten/src/ATen/native/cpu/UnaryOpsKernel.cpp
+++ b/aten/src/ATen/native/cpu/UnaryOpsKernel.cpp
@@ -336,7 +336,7 @@
}
static void acosh_kernel(TensorIterator& iter) {
- AT_DISPATCH_FLOATING_TYPES(iter.dtype(), "acosh_cpu", [&]() {
+ AT_DISPATCH_FLOATING_AND_COMPLEX_TYPES(iter.dtype(), "acosh_cpu", [&]() {
cpu_kernel(
iter,
[=](scalar_t a) -> scalar_t { return std::acosh(a); });
@@ -344,7 +344,7 @@
}
static void asinh_kernel(TensorIterator& iter) {
- AT_DISPATCH_FLOATING_TYPES(iter.dtype(), "asinh_cpu", [&]() {
+ AT_DISPATCH_FLOATING_AND_COMPLEX_TYPES(iter.dtype(), "asinh_cpu", [&]() {
cpu_kernel(
iter,
[=](scalar_t a) -> scalar_t { return std::asinh(a); });
@@ -352,7 +352,7 @@
}
static void atanh_kernel(TensorIterator& iter) {
- AT_DISPATCH_FLOATING_TYPES(iter.dtype(), "atanh_cpu", [&]() {
+ AT_DISPATCH_FLOATING_AND_COMPLEX_TYPES(iter.dtype(), "atanh_cpu", [&]() {
cpu_kernel(
iter,
[=](scalar_t a) -> scalar_t { return std::atanh(a); });
diff --git a/aten/src/ATen/native/cuda/UnaryGeometricKernels.cu b/aten/src/ATen/native/cuda/UnaryGeometricKernels.cu
index 8678552..bac3a05 100644
--- a/aten/src/ATen/native/cuda/UnaryGeometricKernels.cu
+++ b/aten/src/ATen/native/cuda/UnaryGeometricKernels.cu
@@ -75,7 +75,7 @@
}
void acosh_kernel_cuda(TensorIterator& iter) {
- AT_DISPATCH_FLOATING_TYPES_AND2(ScalarType::Half, ScalarType::BFloat16, iter.common_dtype(), "acosh_cuda", [&]() {
+ AT_DISPATCH_FLOATING_AND_COMPLEX_TYPES_AND2(ScalarType::Half, ScalarType::BFloat16, iter.common_dtype(), "acosh_cuda", [&]() {
gpu_kernel(iter, []GPU_LAMBDA(scalar_t a) -> scalar_t {
return ::acosh(a);
});
@@ -83,7 +83,7 @@
}
void asinh_kernel_cuda(TensorIterator& iter) {
- AT_DISPATCH_FLOATING_TYPES_AND2(ScalarType::Half, ScalarType::BFloat16, iter.common_dtype(), "asinh_cuda", [&]() {
+ AT_DISPATCH_FLOATING_AND_COMPLEX_TYPES_AND2(ScalarType::Half, ScalarType::BFloat16, iter.common_dtype(), "asinh_cuda", [&]() {
gpu_kernel(iter, []GPU_LAMBDA(scalar_t a) -> scalar_t {
return ::asinh(a);
});
@@ -91,7 +91,7 @@
}
void atanh_kernel_cuda(TensorIterator& iter) {
- AT_DISPATCH_FLOATING_TYPES_AND2(ScalarType::Half, ScalarType::BFloat16, iter.common_dtype(), "atanh_cuda", [&]() {
+ AT_DISPATCH_FLOATING_AND_COMPLEX_TYPES_AND2(ScalarType::Half, ScalarType::BFloat16, iter.common_dtype(), "atanh_cuda", [&]() {
gpu_kernel(iter, []GPU_LAMBDA(scalar_t a) -> scalar_t {
return ::atanh(a);
});
diff --git a/tools/autograd/gen_variable_type.py b/tools/autograd/gen_variable_type.py
index 50de4b5..3e48c05 100644
--- a/tools/autograd/gen_variable_type.py
+++ b/tools/autograd/gen_variable_type.py
@@ -72,15 +72,15 @@
GRADIENT_IMPLEMENTED_FOR_COMPLEX = {
't', 'view', 'reshape', 'reshape_as', 'view_as', 'roll', 'clone',
'repeat', 'expand', 'flip', 'fliplr', 'flipud', 'rot90', 'transpose',
- 'permute', 'squeeze', 'unsqueeze', 'resize', 'resize_as', 'tril', 'triu',
- 'chunk', 'split', 'split_with_sizes', 'repeat', 'expand', 'zero_', 'eq_',
- 'ne_', 'add', '__radd__', 'sum', '_conj', 'sin', 'cos', 'mul', 'sinc', 'sinh',
- 'cosh', '__rmul__', 'sgn', 'asin', 'acos', 'sub', 'div', 'cat', 'view_as_complex',
+ 'permute', 'squeeze', 'unsqueeze', 'resize', 'resize_as', 'tril',
+ 'triu', 'chunk', 'zero_', 'eq_', 'ne_', 'add', '__radd__', 'sum',
+ '_conj', 'sin', 'cos', 'mul', 'sinc', 'sinh', 'cosh', '__rmul__',
+ 'sgn', 'asin', 'acos', 'sub', 'div', 'cat', 'view_as_complex',
'neg', 'complex', 'select', '_s_where', 'as_strided', 'slice', 'constant_pad_nd',
'unbind', 'split', 'split_with_sizes', 'unsafe_split', 'split_with_sizes_backward',
'dot', 'vdot', 'cholesky', 'triangular_solve', 'mm', '_unsafe_view', 'mv', 'ger',
'bmm', 'diagonal', 'alias', 'atan', 'log', 'log10', 'log1p', 'log2', 'reciprocal',
- 'tan', 'pow', 'rsqrt', 'tanh', 'tanh_backward', 'asinh', 'acosh', 'take', 'fill_',
+ 'tan', 'pow', 'rsqrt', 'tanh', 'tanh_backward', 'asinh', 'acosh', 'atanh', 'take', 'fill_',
'exp', 'nonzero', 'mean', 'inverse', 'solve', 'linalg_cholesky', 'addcmul', 'addcdiv',
'matrix_exp', 'linalg_eigh', 'cholesky_solve', 'linalg_qr', '_svd_helper', '_fft_c2c', '_fft_r2c',
'linalg_solve', 'sqrt', 'stack', 'gather', 'index_select', 'index_add_', 'linalg_inv',
diff --git a/torch/testing/_internal/common_methods_invocations.py b/torch/testing/_internal/common_methods_invocations.py
index 5ddce86..0e7163c 100644
--- a/torch/testing/_internal/common_methods_invocations.py
+++ b/torch/testing/_internal/common_methods_invocations.py
@@ -859,9 +859,9 @@
UnaryUfuncInfo('acosh',
ref=np.arccosh,
domain=(1, float('inf')),
- dtypes=all_types_and(torch.bool),
- dtypesIfCPU=all_types_and(torch.bool),
- dtypesIfCUDA=all_types_and(torch.bool, torch.half, torch.bfloat16),
+ dtypes=all_types_and_complex_and(torch.bool),
+ dtypesIfCPU=all_types_and_complex_and(torch.bool),
+ dtypesIfCUDA=all_types_and_complex_and(torch.bool, torch.half, torch.bfloat16),
promotes_integers_to_float=True,
decorators=(precisionOverride({torch.bfloat16: 5e-2}),),
test_inplace_grad=False,
@@ -869,6 +869,16 @@
# RuntimeError: "rsqrt_cuda" not implemented for 'BFloat16'
SkipInfo('TestCommon', 'test_variant_consistency_jit',
device_type='cuda', dtypes=[torch.bfloat16]),
+ SkipInfo('TestUnaryUfuncs', 'test_reference_numerics',
+ device_type='cpu', dtypes=[torch.cfloat, torch.cdouble]),
+ SkipInfo('TestUnaryUfuncs', 'test_reference_numerics',
+ device_type='cuda', dtypes=[torch.cfloat, torch.cdouble],
+ active_if=IS_WINDOWS),
+ # Reference: https://github.com/pytorch/pytorch/issues/50692
+ SkipInfo('TestGradients', 'test_fn_grad',
+ device_type='cuda', dtypes=[torch.cdouble], active_if=IS_WINDOWS),
+ SkipInfo('TestGradients', 'test_method_grad',
+ device_type='cuda', dtypes=[torch.cdouble], active_if=IS_WINDOWS),
)),
OpInfo('addmm',
dtypes=floating_types(),
@@ -903,9 +913,9 @@
# NOTE: derivative for inplace asinh is not implemented
UnaryUfuncInfo('asinh',
ref=np.arcsinh,
- dtypes=all_types_and(torch.bool),
- dtypesIfCPU=all_types_and(torch.bool),
- dtypesIfCUDA=all_types_and(torch.bool, torch.half, torch.bfloat16),
+ dtypes=all_types_and_complex_and(torch.bool),
+ dtypesIfCPU=all_types_and_complex_and(torch.bool),
+ dtypesIfCUDA=all_types_and_complex_and(torch.bool, torch.half, torch.bfloat16),
promotes_integers_to_float=True,
decorators=(precisionOverride({torch.bfloat16: 5e-2}),),
test_inplace_grad=False,
@@ -913,6 +923,11 @@
# RuntimeError: "rsqrt_cuda" not implemented for 'BFloat16'
SkipInfo('TestCommon', 'test_variant_consistency_jit',
device_type='cuda', dtypes=[torch.bfloat16]),
+ SkipInfo('TestUnaryUfuncs', 'test_reference_numerics',
+ device_type='cpu', dtypes=[torch.cfloat, torch.cdouble]),
+ SkipInfo('TestUnaryUfuncs', 'test_reference_numerics',
+ device_type='cuda', dtypes=[torch.cfloat, torch.cdouble],
+ active_if=IS_WINDOWS),
)),
UnaryUfuncInfo('atan',
ref=np.arctan,
@@ -933,12 +948,19 @@
UnaryUfuncInfo('atanh',
ref=np.arctanh,
domain=(-1, 1),
- dtypes=all_types_and(torch.bool),
- dtypesIfCPU=all_types_and(torch.bool),
- dtypesIfCUDA=all_types_and(torch.bool, torch.half, torch.bfloat16),
+ dtypes=all_types_and_complex_and(torch.bool),
+ dtypesIfCPU=all_types_and_complex_and(torch.bool),
+ dtypesIfCUDA=all_types_and_complex_and(torch.bool, torch.half, torch.bfloat16),
promotes_integers_to_float=True,
decorators=(precisionOverride({torch.bfloat16: 1e-2}),),
- test_inplace_grad=False),
+ test_inplace_grad=False,
+ skips=(
+ SkipInfo('TestUnaryUfuncs', 'test_reference_numerics',
+ device_type='cpu', dtypes=[torch.cfloat, torch.cdouble]),
+ SkipInfo('TestUnaryUfuncs', 'test_reference_numerics',
+ device_type='cuda', dtypes=[torch.cfloat, torch.cdouble],
+ active_if=IS_WINDOWS),
+ )),
OpInfo('broadcast_to',
dtypes=all_types_and_complex_and(torch.bool, torch.float16, torch.bfloat16),
supports_tensor_out=False,