[chalf] update type promotion table (#76893)

Reference #74537

TODO:
* [x] Add tests
Pull Request resolved: https://github.com/pytorch/pytorch/pull/76893
Approved by: https://github.com/anjali411
diff --git a/c10/core/ScalarType.h b/c10/core/ScalarType.h
index eb4ffe3..0728e67 100644
--- a/c10/core/ScalarType.h
+++ b/c10/core/ScalarType.h
@@ -416,28 +416,28 @@
         toString(b));
   }
 
-  // this matrix has to be consistent with AT_FORALL_SCALAR_TYPES_WITH_COMPLEX
-  // so that's why we have to add undefined as we are not sure what is the
-  // corrent values for the type promotions in complex type cases.
+  // this matrix has to be consistent with
+  // AT_FORALL_SCALAR_TYPES_WITH_COMPLEX_AND_QINTS undefined is used where we
+  // are not sure about the correct value for type promotion.
   static constexpr ScalarType _promoteTypesLookup[static_cast<int>(
       ScalarType::NumOptions)][static_cast<int>(ScalarType::NumOptions)] = {
       /*        u1  i1  i2  i4  i8  f2  f4  f8  c2  c4  c8  b1  q1  q2  q3  bf*/
-      /* u1 */ {u1, i2, i2, i4, i8, f2, f4, f8, ud, c4, c8, u1, ud, ud, ud, bf},
-      /* i1 */ {i2, i1, i2, i4, i8, f2, f4, f8, ud, c4, c8, i1, ud, ud, ud, bf},
-      /* i2 */ {i2, i2, i2, i4, i8, f2, f4, f8, ud, c4, c8, i2, ud, ud, ud, bf},
-      /* i4 */ {i4, i4, i4, i4, i8, f2, f4, f8, ud, c4, c8, i4, ud, ud, ud, bf},
-      /* i8 */ {i8, i8, i8, i8, i8, f2, f4, f8, ud, c4, c8, i8, ud, ud, ud, bf},
-      /* f2 */ {f2, f2, f2, f2, f2, f2, f4, f8, ud, c4, c8, f2, ud, ud, ud, f4},
-      /* f4 */ {f4, f4, f4, f4, f4, f4, f4, f8, ud, c4, c8, f4, ud, ud, ud, f4},
-      /* f8 */ {f8, f8, f8, f8, f8, f8, f8, f8, ud, c8, c8, f8, ud, ud, ud, f8},
-      /* c2 */ {ud, ud, ud, ud, ud, ud, ud, ud, c2, c4, c8, ud, ud, ud, ud, ud},
+      /* u1 */ {u1, i2, i2, i4, i8, f2, f4, f8, c2, c4, c8, u1, ud, ud, ud, bf},
+      /* i1 */ {i2, i1, i2, i4, i8, f2, f4, f8, c2, c4, c8, i1, ud, ud, ud, bf},
+      /* i2 */ {i2, i2, i2, i4, i8, f2, f4, f8, c2, c4, c8, i2, ud, ud, ud, bf},
+      /* i4 */ {i4, i4, i4, i4, i8, f2, f4, f8, c2, c4, c8, i4, ud, ud, ud, bf},
+      /* i8 */ {i8, i8, i8, i8, i8, f2, f4, f8, c2, c4, c8, i8, ud, ud, ud, bf},
+      /* f2 */ {f2, f2, f2, f2, f2, f2, f4, f8, c2, c4, c8, f2, ud, ud, ud, f4},
+      /* f4 */ {f4, f4, f4, f4, f4, f4, f4, f8, c4, c4, c8, f4, ud, ud, ud, f4},
+      /* f8 */ {f8, f8, f8, f8, f8, f8, f8, f8, c8, c8, c8, f8, ud, ud, ud, f8},
+      /* c2 */ {c2, c2, c2, c2, c2, c2, c4, c8, c2, c4, c8, c2, ud, ud, ud, c4},
       /* c4 */ {c4, c4, c4, c4, c4, c4, c4, c8, c4, c4, c8, c4, ud, ud, ud, c4},
       /* c8 */ {c8, c8, c8, c8, c8, c8, c8, c8, c8, c8, c8, c8, ud, ud, ud, c8},
-      /* b1 */ {u1, i1, i2, i4, i8, f2, f4, f8, ud, c4, c8, b1, ud, ud, ud, bf},
+      /* b1 */ {u1, i1, i2, i4, i8, f2, f4, f8, c2, c4, c8, b1, ud, ud, ud, bf},
       /* q1 */ {ud, ud, ud, ud, ud, ud, ud, ud, ud, ud, ud, ud, ud, ud, ud, ud},
       /* q2 */ {ud, ud, ud, ud, ud, ud, ud, ud, ud, ud, ud, ud, ud, ud, ud, ud},
       /* q3 */ {ud, ud, ud, ud, ud, ud, ud, ud, ud, ud, ud, ud, ud, ud, ud, ud},
-      /* bf */ {bf, bf, bf, bf, bf, f4, f4, f8, ud, c4, c8, bf, ud, ud, ud, bf},
+      /* bf */ {bf, bf, bf, bf, bf, f4, f4, f8, c4, c4, c8, bf, ud, ud, ud, bf},
   };
   return _promoteTypesLookup[static_cast<int>(a)][static_cast<int>(b)];
 }
diff --git a/test/test_type_promotion.py b/test/test_type_promotion.py
index 26d12c1..8c82b43 100644
--- a/test/test_type_promotion.py
+++ b/test/test_type_promotion.py
@@ -11,7 +11,7 @@
 from torch.testing._internal.common_device_type import (instantiate_device_type_tests, onlyNativeDeviceTypes,
                                                         dtypes, onlyCPU, expectedFailureMeta, skipMeta)
 from torch.testing._internal.common_dtype import (
-    all_types_and_complex_and, get_all_math_dtypes, floating_types
+    all_types_and_complex_and, get_all_math_dtypes, floating_types, get_all_dtypes
 )
 
 import numpy as np
@@ -189,6 +189,8 @@
             if dtype in (torch.float16, torch.float32, torch.float64, torch.cfloat, torch.cdouble):
                 # Handles bfloat16 x float16 -> float32 promotion
                 expected_dtype = dtype if dtype != torch.half else torch.float32
+            elif dtype is torch.chalf:
+                expected_dtype = torch.cfloat
             elif dtype in (torch.bool, torch.uint8,
                            torch.int8, torch.int16, torch.int32, torch.int64, torch.bfloat16):
                 expected_dtype = torch.bfloat16
@@ -199,6 +201,39 @@
             self.assertEqual(torch.promote_types(torch.bfloat16, dtype), expected_dtype)
             self.assertEqual((bf + t).dtype, expected_dtype)
 
+    @onlyNativeDeviceTypes
+    def test_complex_half(self, device):
+        # with scalar
+        chalf = torch.tensor(5.5, dtype=torch.chalf, device=device)
+        for scalar in (2.2, 5, 100000):   # chalf + 100000 is inf
+            self.assertEqual((chalf * scalar).dtype, torch.chalf)
+            self.assertEqual(scalar * chalf, chalf * scalar)
+
+        for scalar in (complex(1, 1), complex(-2, 0), complex(0, -3)):
+            self.assertEqual((chalf * scalar).dtype, torch.chalf)
+            self.assertEqual(chalf * scalar, scalar * chalf)
+
+        # with tensor
+        dtypes = all_types_and_complex_and(torch.chalf, torch.half, torch.bfloat16, torch.bool)
+        for dtype in dtypes:
+            t = torch.tensor(1, dtype=dtype, device=device)
+            self.assertEqual(chalf * t, t * chalf)
+            if dtype in (torch.float16, torch.chalf):
+                expected_dtype = torch.chalf
+            elif dtype in (torch.float, torch.double, torch.bfloat16):
+                expected_dtype = torch.cdouble if dtype is torch.double else torch.cfloat
+            elif dtype in (torch.cfloat, torch.cdouble):
+                expected_dtype = dtype
+            elif dtype in (torch.bool, torch.uint8,
+                           torch.int8, torch.int16, torch.int32, torch.int64):
+                expected_dtype = torch.chalf
+            else:
+                raise AssertionError(f'Missing dtype {dtype} not tested.')
+
+            self.assertEqual(torch.promote_types(dtype, torch.chalf), expected_dtype)
+            self.assertEqual(torch.promote_types(torch.chalf, dtype), expected_dtype)
+            self.assertEqual((chalf * t).dtype, expected_dtype)
+
     @float_double_default_dtype
     def test_alternate_result(self, device):
         f = torch.tensor([1, 1, 1, 1], dtype=torch.float, device=device)
@@ -520,12 +555,16 @@
             dict(name="ne", compare_op=lambda x, y: x != y, ),
         ]
         for op in comparison_ops:
-            for dt1 in get_all_math_dtypes(device):
-                for dt2 in get_all_math_dtypes(device):
-                    if (dt1.is_complex or dt2.is_complex) and not (op["name"] == "eq" or op["name"] == "ne"):
-                        u = torch.tensor([1], dtype=dt1, device=device)
-                        v = torch.tensor([2], dtype=dt2, device=device)
-                        self.assertRaises(RuntimeError, lambda: torch.tensor([op["compare_op"](u, v)], dtype=torch.bool))
+            is_cuda = torch.device(device).type == 'cuda'
+            dtypes = get_all_dtypes(include_half=is_cuda,
+                                    include_bfloat16=False, include_bool=False,
+                                    include_complex32=True)
+
+            for dt1, dt2 in itertools.product(dtypes, dtypes):
+                if (dt1.is_complex or dt2.is_complex) and not (op["name"] == "eq" or op["name"] == "ne"):
+                    u = torch.tensor([1], dtype=dt1, device=device)
+                    v = torch.tensor([2], dtype=dt2, device=device)
+                    self.assertRaises(RuntimeError, lambda: torch.tensor([op["compare_op"](u, v)], dtype=torch.bool))
 
     @float_double_default_dtype
     def test_lt_with_type_promotion(self, device):
@@ -562,7 +601,7 @@
 
     @float_double_default_dtype
     def test_promote_self(self, device):
-        for dtype in all_types_and_complex_and(torch.half, torch.bfloat16, torch.bool):
+        for dtype in all_types_and_complex_and(torch.half, torch.bfloat16, torch.chalf, torch.bool):
             self.assertEqual(torch.promote_types(dtype, dtype), dtype)
 
     @expectedFailureMeta
diff --git a/torch/testing/_internal/common_methods_invocations.py b/torch/testing/_internal/common_methods_invocations.py
index 9dd7795..4fabb6a 100644
--- a/torch/testing/_internal/common_methods_invocations.py
+++ b/torch/testing/_internal/common_methods_invocations.py
@@ -18378,9 +18378,6 @@
             # torch function issue:
             # ValueError: Callable cat has no meta function!
             DecorateInfo(unittest.expectedFailure, 'TestCommon', 'test_python_reference_meta_functions'),
-            # eager torch.cat incorrectly throws an error for a chalf x double x half type promotion case
-            DecorateInfo(unittest.expectedFailure, 'TestCommon', 'test_python_reference_consistency',
-                         dtypes=(torch.chalf,)),
         )
     ),
     PythonRefInfo(