[chalf] update type promotion table (#76893)
Reference #74537
TODO:
* [x] Add tests
Pull Request resolved: https://github.com/pytorch/pytorch/pull/76893
Approved by: https://github.com/anjali411
diff --git a/c10/core/ScalarType.h b/c10/core/ScalarType.h
index eb4ffe3..0728e67 100644
--- a/c10/core/ScalarType.h
+++ b/c10/core/ScalarType.h
@@ -416,28 +416,28 @@
toString(b));
}
- // this matrix has to be consistent with AT_FORALL_SCALAR_TYPES_WITH_COMPLEX
- // so that's why we have to add undefined as we are not sure what is the
- // corrent values for the type promotions in complex type cases.
+ // this matrix has to be consistent with
+ // AT_FORALL_SCALAR_TYPES_WITH_COMPLEX_AND_QINTS undefined is used where we
+ // are not sure about the correct value for type promotion.
static constexpr ScalarType _promoteTypesLookup[static_cast<int>(
ScalarType::NumOptions)][static_cast<int>(ScalarType::NumOptions)] = {
/* u1 i1 i2 i4 i8 f2 f4 f8 c2 c4 c8 b1 q1 q2 q3 bf*/
- /* u1 */ {u1, i2, i2, i4, i8, f2, f4, f8, ud, c4, c8, u1, ud, ud, ud, bf},
- /* i1 */ {i2, i1, i2, i4, i8, f2, f4, f8, ud, c4, c8, i1, ud, ud, ud, bf},
- /* i2 */ {i2, i2, i2, i4, i8, f2, f4, f8, ud, c4, c8, i2, ud, ud, ud, bf},
- /* i4 */ {i4, i4, i4, i4, i8, f2, f4, f8, ud, c4, c8, i4, ud, ud, ud, bf},
- /* i8 */ {i8, i8, i8, i8, i8, f2, f4, f8, ud, c4, c8, i8, ud, ud, ud, bf},
- /* f2 */ {f2, f2, f2, f2, f2, f2, f4, f8, ud, c4, c8, f2, ud, ud, ud, f4},
- /* f4 */ {f4, f4, f4, f4, f4, f4, f4, f8, ud, c4, c8, f4, ud, ud, ud, f4},
- /* f8 */ {f8, f8, f8, f8, f8, f8, f8, f8, ud, c8, c8, f8, ud, ud, ud, f8},
- /* c2 */ {ud, ud, ud, ud, ud, ud, ud, ud, c2, c4, c8, ud, ud, ud, ud, ud},
+ /* u1 */ {u1, i2, i2, i4, i8, f2, f4, f8, c2, c4, c8, u1, ud, ud, ud, bf},
+ /* i1 */ {i2, i1, i2, i4, i8, f2, f4, f8, c2, c4, c8, i1, ud, ud, ud, bf},
+ /* i2 */ {i2, i2, i2, i4, i8, f2, f4, f8, c2, c4, c8, i2, ud, ud, ud, bf},
+ /* i4 */ {i4, i4, i4, i4, i8, f2, f4, f8, c2, c4, c8, i4, ud, ud, ud, bf},
+ /* i8 */ {i8, i8, i8, i8, i8, f2, f4, f8, c2, c4, c8, i8, ud, ud, ud, bf},
+ /* f2 */ {f2, f2, f2, f2, f2, f2, f4, f8, c2, c4, c8, f2, ud, ud, ud, f4},
+ /* f4 */ {f4, f4, f4, f4, f4, f4, f4, f8, c4, c4, c8, f4, ud, ud, ud, f4},
+ /* f8 */ {f8, f8, f8, f8, f8, f8, f8, f8, c8, c8, c8, f8, ud, ud, ud, f8},
+ /* c2 */ {c2, c2, c2, c2, c2, c2, c4, c8, c2, c4, c8, c2, ud, ud, ud, c4},
/* c4 */ {c4, c4, c4, c4, c4, c4, c4, c8, c4, c4, c8, c4, ud, ud, ud, c4},
/* c8 */ {c8, c8, c8, c8, c8, c8, c8, c8, c8, c8, c8, c8, ud, ud, ud, c8},
- /* b1 */ {u1, i1, i2, i4, i8, f2, f4, f8, ud, c4, c8, b1, ud, ud, ud, bf},
+ /* b1 */ {u1, i1, i2, i4, i8, f2, f4, f8, c2, c4, c8, b1, ud, ud, ud, bf},
/* q1 */ {ud, ud, ud, ud, ud, ud, ud, ud, ud, ud, ud, ud, ud, ud, ud, ud},
/* q2 */ {ud, ud, ud, ud, ud, ud, ud, ud, ud, ud, ud, ud, ud, ud, ud, ud},
/* q3 */ {ud, ud, ud, ud, ud, ud, ud, ud, ud, ud, ud, ud, ud, ud, ud, ud},
- /* bf */ {bf, bf, bf, bf, bf, f4, f4, f8, ud, c4, c8, bf, ud, ud, ud, bf},
+ /* bf */ {bf, bf, bf, bf, bf, f4, f4, f8, c4, c4, c8, bf, ud, ud, ud, bf},
};
return _promoteTypesLookup[static_cast<int>(a)][static_cast<int>(b)];
}
diff --git a/test/test_type_promotion.py b/test/test_type_promotion.py
index 26d12c1..8c82b43 100644
--- a/test/test_type_promotion.py
+++ b/test/test_type_promotion.py
@@ -11,7 +11,7 @@
from torch.testing._internal.common_device_type import (instantiate_device_type_tests, onlyNativeDeviceTypes,
dtypes, onlyCPU, expectedFailureMeta, skipMeta)
from torch.testing._internal.common_dtype import (
- all_types_and_complex_and, get_all_math_dtypes, floating_types
+ all_types_and_complex_and, get_all_math_dtypes, floating_types, get_all_dtypes
)
import numpy as np
@@ -189,6 +189,8 @@
if dtype in (torch.float16, torch.float32, torch.float64, torch.cfloat, torch.cdouble):
# Handles bfloat16 x float16 -> float32 promotion
expected_dtype = dtype if dtype != torch.half else torch.float32
+ elif dtype is torch.chalf:
+ expected_dtype = torch.cfloat
elif dtype in (torch.bool, torch.uint8,
torch.int8, torch.int16, torch.int32, torch.int64, torch.bfloat16):
expected_dtype = torch.bfloat16
@@ -199,6 +201,39 @@
self.assertEqual(torch.promote_types(torch.bfloat16, dtype), expected_dtype)
self.assertEqual((bf + t).dtype, expected_dtype)
+ @onlyNativeDeviceTypes
+ def test_complex_half(self, device):
+ # with scalar
+ chalf = torch.tensor(5.5, dtype=torch.chalf, device=device)
+ for scalar in (2.2, 5, 100000): # chalf + 100000 is inf
+ self.assertEqual((chalf * scalar).dtype, torch.chalf)
+ self.assertEqual(scalar * chalf, chalf * scalar)
+
+ for scalar in (complex(1, 1), complex(-2, 0), complex(0, -3)):
+ self.assertEqual((chalf * scalar).dtype, torch.chalf)
+ self.assertEqual(chalf * scalar, scalar * chalf)
+
+ # with tensor
+ dtypes = all_types_and_complex_and(torch.chalf, torch.half, torch.bfloat16, torch.bool)
+ for dtype in dtypes:
+ t = torch.tensor(1, dtype=dtype, device=device)
+ self.assertEqual(chalf * t, t * chalf)
+ if dtype in (torch.float16, torch.chalf):
+ expected_dtype = torch.chalf
+ elif dtype in (torch.float, torch.double, torch.bfloat16):
+ expected_dtype = torch.cdouble if dtype is torch.double else torch.cfloat
+ elif dtype in (torch.cfloat, torch.cdouble):
+ expected_dtype = dtype
+ elif dtype in (torch.bool, torch.uint8,
+ torch.int8, torch.int16, torch.int32, torch.int64):
+ expected_dtype = torch.chalf
+ else:
+ raise AssertionError(f'Missing dtype {dtype} not tested.')
+
+ self.assertEqual(torch.promote_types(dtype, torch.chalf), expected_dtype)
+ self.assertEqual(torch.promote_types(torch.chalf, dtype), expected_dtype)
+ self.assertEqual((chalf * t).dtype, expected_dtype)
+
@float_double_default_dtype
def test_alternate_result(self, device):
f = torch.tensor([1, 1, 1, 1], dtype=torch.float, device=device)
@@ -520,12 +555,16 @@
dict(name="ne", compare_op=lambda x, y: x != y, ),
]
for op in comparison_ops:
- for dt1 in get_all_math_dtypes(device):
- for dt2 in get_all_math_dtypes(device):
- if (dt1.is_complex or dt2.is_complex) and not (op["name"] == "eq" or op["name"] == "ne"):
- u = torch.tensor([1], dtype=dt1, device=device)
- v = torch.tensor([2], dtype=dt2, device=device)
- self.assertRaises(RuntimeError, lambda: torch.tensor([op["compare_op"](u, v)], dtype=torch.bool))
+ is_cuda = torch.device(device).type == 'cuda'
+ dtypes = get_all_dtypes(include_half=is_cuda,
+ include_bfloat16=False, include_bool=False,
+ include_complex32=True)
+
+ for dt1, dt2 in itertools.product(dtypes, dtypes):
+ if (dt1.is_complex or dt2.is_complex) and not (op["name"] == "eq" or op["name"] == "ne"):
+ u = torch.tensor([1], dtype=dt1, device=device)
+ v = torch.tensor([2], dtype=dt2, device=device)
+ self.assertRaises(RuntimeError, lambda: torch.tensor([op["compare_op"](u, v)], dtype=torch.bool))
@float_double_default_dtype
def test_lt_with_type_promotion(self, device):
@@ -562,7 +601,7 @@
@float_double_default_dtype
def test_promote_self(self, device):
- for dtype in all_types_and_complex_and(torch.half, torch.bfloat16, torch.bool):
+ for dtype in all_types_and_complex_and(torch.half, torch.bfloat16, torch.chalf, torch.bool):
self.assertEqual(torch.promote_types(dtype, dtype), dtype)
@expectedFailureMeta
diff --git a/torch/testing/_internal/common_methods_invocations.py b/torch/testing/_internal/common_methods_invocations.py
index 9dd7795..4fabb6a 100644
--- a/torch/testing/_internal/common_methods_invocations.py
+++ b/torch/testing/_internal/common_methods_invocations.py
@@ -18378,9 +18378,6 @@
# torch function issue:
# ValueError: Callable cat has no meta function!
DecorateInfo(unittest.expectedFailure, 'TestCommon', 'test_python_reference_meta_functions'),
- # eager torch.cat incorrectly throws an error for a chalf x double x half type promotion case
- DecorateInfo(unittest.expectedFailure, 'TestCommon', 'test_python_reference_consistency',
- dtypes=(torch.chalf,)),
)
),
PythonRefInfo(