[inductor] Fix angle decomposition return type (#115700)
The current decomposition always returns float32 when the input isn't complex.
Instead, we should do proper type promotion.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/115700
Approved by: https://github.com/lezcano
ghstack dependencies: #115677, #115699
diff --git a/torch/_inductor/decomposition.py b/torch/_inductor/decomposition.py
index a83edb1..3f9d5e0 100644
--- a/torch/_inductor/decomposition.py
+++ b/torch/_inductor/decomposition.py
@@ -20,7 +20,11 @@
)
from torch._decomp.decompositions_for_rng import extra_random_decomps
from torch._higher_order_ops.out_dtype import out_dtype
-from torch._prims_common import type_to_dtype
+from torch._prims_common import (
+ elementwise_dtypes,
+ ELEMENTWISE_TYPE_PROMOTION_KIND,
+ type_to_dtype,
+)
from . import config, inductor_prims
@@ -260,14 +264,18 @@
return torch.where(
torch.isnan(x.real), float("nan"), torch.atan2(x.imag, x.real)
)
- else:
- # when x is real number
- # if x >= 0, return 0
- # if x < 0, return pi
- # if x is nan, return nan
- ret = torch.where(x < 0, math.pi, 0.0)
- nan = torch.where(torch.isnan(x), float("nan"), 0.0)
- return ret + nan
+
+ # when x is real number
+ # if x >= 0, return 0
+ # if x < 0, return pi
+ # if x is nan, return nan
+ _, dtype = elementwise_dtypes(
+ x,
+ type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT,
+ )
+ pi = torch.scalar_tensor(math.pi, dtype=dtype, device=x.device)
+ ret = torch.where(x < 0, pi, 0.0)
+ return torch.where(torch.isnan(x), float("nan"), ret)
@register_decomposition([aten.add])