Add support for torch.complex in functorch (#96032)
Fixes #91175
Pull Request resolved: https://github.com/pytorch/pytorch/pull/96032
Approved by: https://github.com/Skylion007, https://github.com/kshitij12345, https://github.com/zou3519
diff --git a/aten/src/ATen/functorch/BatchRulesBinaryOps.cpp b/aten/src/ATen/functorch/BatchRulesBinaryOps.cpp
index 5a00f7d..706e035 100644
--- a/aten/src/ATen/functorch/BatchRulesBinaryOps.cpp
+++ b/aten/src/ATen/functorch/BatchRulesBinaryOps.cpp
@@ -384,6 +384,7 @@
BINARY_POINTWISE2(clamp_max, Tensor);
UNARY_POINTWISE(clamp_max);
POINTWISE_BOXED(clamp_max_);
+ BINARY_POINTWISE(complex);
VARIADIC_BDIMS_BOXED(_euclidean_dist);
// Implementation note: _binary_pointwise_helper performs a dtype promotion if args are scalars,
diff --git a/test/functorch/test_ops.py b/test/functorch/test_ops.py
index 38fc695..7fab7d1 100644
--- a/test/functorch/test_ops.py
+++ b/test/functorch/test_ops.py
@@ -1049,7 +1049,6 @@
xfail('cumprod'),
xfail('masked_fill'),
xfail('copysign'),
- xfail('complex'),
xfail('fill'),
skip('masked.mean'), # ???
xfail('masked_scatter'),
@@ -1066,7 +1065,6 @@
xfail('special.log_ndtr', ''),
xfail('fft.ihfft2'), # conj_physical fallback
xfail('fft.ihfftn'), # conj_physical fallback
- xfail('polar'), # complex fallback
xfail('nn.functional.max_unpool3d', 'grad'),
xfail('nn.functional.smooth_l1_loss', ''),
xfail('nn.functional.max_unpool2d', 'grad'),
@@ -1117,7 +1115,6 @@
@skipOps('TestOperators', 'test_vmapvjp_has_batch_rule', vmapvjp_fail.union({
skip('to'), # RuntimeError: required rank 4 tensor to use channels_last format
xfail('view_as_complex'),
- xfail('complex'),
xfail('copysign'),
xfail('cummax'),
xfail('cummin'),
diff --git a/test/functorch/test_vmap.py b/test/functorch/test_vmap.py
index 2ce55fb..ccee632 100644
--- a/test/functorch/test_vmap.py
+++ b/test/functorch/test_vmap.py
@@ -3601,7 +3601,6 @@
@skipOps('TestVmapOperatorsOpInfo', 'test_op_has_batch_rule', vmap_fail.union({
xfail('as_strided', 'partial_views'),
skip('to'), # RuntimeError: required rank 4 tensor to use channels_last format
- xfail('complex'),
xfail('copysign'),
xfail('fill'),
# Batch norm got a batched tensor as input while the running_mean or running_var,