[primTorch] Adds any, all, equal, item references (#78072)

This PR adds the item, equal, any, and all references.

While doing this I found the following issues:
- https://github.com/pytorch/pytorch/issues/78070
- https://github.com/pytorch/pytorch/issues/78071

And I fixed a bug where the `convert_element_type` prim could not convert tensors requiring grad to datatypes that don't require grad.

Creating the item reference required adding item as a prim, but per @ngimel's suggestion I removed the prims for any and all and implemented them as references, so this is net negative one prim.

Reference OpInfos are added for any and all, but item and equal don't even have regular OpInfos.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/78072
Approved by: https://github.com/ngimel
diff --git a/test/test_ops.py b/test/test_ops.py
index c8de562..8578293 100644
--- a/test/test_ops.py
+++ b/test/test_ops.py
@@ -452,8 +452,8 @@
                 torch_distance = torch_distance + _distance(a, b)
 
             # TODO: consider adding some tolerance to this comparison
-            msg = f"Reference result was farther ({ref_distance}) from the precise \
-                    computation than the torch result was ({torch_distance})!"
+            msg = f"Reference result was farther ({ref_distance}) from the precise " \
+                  "computation than the torch result was ({torch_distance})!"
             self.assertTrue(ref_distance <= torch_distance, msg=msg)
 
         # Reports numerical accuracy discrepancies
diff --git a/torch/_prims/__init__.py b/torch/_prims/__init__.py
index 0b47d48..6c53b82 100644
--- a/torch/_prims/__init__.py
+++ b/torch/_prims/__init__.py
@@ -134,6 +134,7 @@
     "clone",
     "convert_element_type",
     "device_put",
+    "item",
     "to_dtype",
     #
     # Inplace prims
@@ -144,10 +145,8 @@
     #
     # Reduction prims
     #
-    "all",
     "amax",
     "amin",
-    "any",
     "prod",
     "sum",
     "var",
@@ -1709,11 +1708,16 @@
 
 
 def _convert_element_type_aten(a: Tensor, dtype: torch.dtype) -> Tensor:
-    # TODO: update meta objects so this can be acquired directly
-    try:
-        requires_grad = a.requires_grad
-    except Exception as e:
+
+    # Propagates requires grad when possible
+    if not utils.is_grad_dtype(dtype):
         requires_grad = False
+    else:
+        # TODO: update meta objects so this can be acquired directly
+        try:
+            requires_grad = a.requires_grad
+        except Exception as e:
+            requires_grad = False
 
     result = torch.empty_like(
         a, device=a.device, dtype=dtype, requires_grad=requires_grad
@@ -1766,6 +1770,28 @@
     doc=_device_put_doc,
 )
 
+# NOTE: need to model meta scalars
+# See https://github.com/pytorch/pytorch/issues/78070
+def _item_meta(a: TensorLikeType) -> TensorMeta:
+    number_type = utils.dtype_to_type(a.dtype)
+    return TensorMeta(number_type(-1))
+
+
+_item_doc = """
+    Converts a tensor with one element to a Python number.
+"""
+
+# TODO: create a new return type for scalars?
+# FIXME: currently returns integers for boolean tensors
+# https://github.com/pytorch/pytorch/issues/78071
+item = _make_prim(
+    schema="item(Tensor a) -> Scalar",
+    meta=_item_meta,
+    impl_aten=torch.Tensor.item,
+    return_type=RETURN_TYPE.NEW,
+    doc=_item_doc,
+)
+
 # TODO: FIXME: strides are incorrect
 def _to_dtype_meta(a: TensorLikeType, dtype: torch.dtype) -> TensorLikeType:
     strides = utils.make_contiguous_strides_for(a.shape)
@@ -1875,10 +1901,6 @@
     )
 
 
-def _bool_return_reduction_meta(inp, dims):
-    return _reduction_meta(inp, dims, output_dtype=torch.bool)
-
-
 def _var_reduction_meta(inp, dims, *, correction):
     if utils.is_complex_dtype(inp.dtype):
         output_dtype = utils.corresponding_real_dtype(inp.dtype)
@@ -1930,17 +1952,6 @@
     )
 
 
-def _make_bool_reduction_prim(name: str, impl_aten, doc):
-    """Creates a reduction prim that reduces to bool."""
-    return _make_prim(
-        schema=f"{name}(Tensor inp, int[]? dims, *, ScalarType? output_dtype=None) -> Tensor",
-        meta=_bool_return_reduction_meta,
-        impl_aten=impl_aten,
-        return_type=RETURN_TYPE.NEW,
-        doc=doc,
-    )
-
-
 sum = _make_reduction_prim(
     name="sum",
     impl_aten=torch.sum,
@@ -1971,18 +1982,6 @@
     doc=_amin_doc,
 )
 
-all = _make_bool_reduction_prim(
-    name="all",
-    impl_aten=torch.all,
-    doc="",
-)
-
-any = _make_bool_reduction_prim(
-    name="any",
-    impl_aten=torch.any,
-    doc="",
-)
-
 # TODO: layout, pin_memory, memory_format
 # TODO: model requires_grad on TensorMeta
 def _empty_meta(
diff --git a/torch/_prims/utils.py b/torch/_prims/utils.py
index f870610..3e9bb6b 100644
--- a/torch/_prims/utils.py
+++ b/torch/_prims/utils.py
@@ -609,6 +609,13 @@
     return dtype in _complex_dtypes
 
 
+def is_grad_dtype(dtype: torch.dtype) -> bool:
+    """
+    Checks if the dtype can require a gradient.
+    """
+    return is_float_dtype(dtype) or is_complex_dtype(dtype)
+
+
 _complex_to_real_dtype_map = {
     torch.complex128: torch.float64,
     torch.complex64: torch.float32,
diff --git a/torch/_refs/__init__.py b/torch/_refs/__init__.py
index d360fde..cb2a43f 100644
--- a/torch/_refs/__init__.py
+++ b/torch/_refs/__init__.py
@@ -136,22 +136,25 @@
     #
     # Conditional references
     #
-    "where",  # TODO: add opinfo
+    "where",
     #
     # Data conversion and movement references
     #
     "clone",
-    "copy_to",  # TODO: add opinfo
+    "copy_to",  # TODO: add OpInfo (or implement .to)
+    "item",  # TODO: add OpInfo
     #
     # Reduction ops
     #
-    "sum",
+    "all",
     "amax",
     "amin",
-    "var",
+    "any",
     "mean",
     "std_mean",
     "std_var",
+    "sum",
+    "var",
     #
     # View & Shape Ops
     #
@@ -188,7 +191,11 @@
     #
     # Randomness References
     #
-    "uniform",
+    "uniform",  # TODO: add OpInfo -- and testing for randomness?
+    #
+    # Test-related functions
+    #
+    "equal",  # TODO: add OpInfo
 ]
 
 Tensor = torch.Tensor
@@ -1006,6 +1013,17 @@
     return prims.copy_to(a, b)
 
 
+def item(a: TensorLikeType) -> NumberType:
+    if a.numel() != 1:
+        msg = f"Can't convert a tensor with {a.numel()} elements to a number!"
+        raise ValueError(msg)
+
+    # NOTE: explicit conversion is necessary for bool!
+    # See https://github.com/pytorch/pytorch/issues/78071
+    number_type = utils.dtype_to_type(a.dtype)
+    return number_type(prims.item(a))
+
+
 #
 # Reduction references
 #
@@ -1046,7 +1064,7 @@
         dims = (dims,)  # type: ignore[assignment]
     dims = utils.reduction_dims(a.shape, dims)
     if not has_identity:
-        valid_shape = a.ndim == 0 or all(a.shape[i] for i in dims)
+        valid_shape = a.ndim == 0 or py_all(a.shape[i] for i in dims)
         if not valid_shape:
             raise RuntimeError(
                 "reducing over zero-size dimension for reduction operation without identity"
@@ -1076,6 +1094,48 @@
     return result
 
 
+# Saves Python all
+py_all = all
+
+
+@out_wrapper
+def all(
+    a: TensorLikeType,
+    dim: Optional[DimsType] = None,
+    keepdim: bool = False,
+) -> TensorLikeType:
+    # Computes nelem
+    if isinstance(dim, int):
+        dim = (dim,)  # type: ignore[assignment]
+    dims = utils.reduction_dims(a.shape, dim)  # type: ignore[arg-type]
+    nelem = 1 if a.ndim == 0 else reduce(operator.mul, (a.shape[i] for i in dims), 1)
+
+    a_ = _maybe_convert_to_dtype(a, torch.bool)
+    result = eq(sum(a_, dim=dim, keepdim=keepdim), nelem)  # type: ignore[arg-type]
+
+    # Preserves uint8 -- probably a legacy mask thing
+    if a.dtype is torch.uint8:
+        return prims.convert_element_type(result, torch.uint8)
+
+    return result
+
+
+@out_wrapper
+def any(
+    a: TensorLikeType,
+    dim: Optional[DimsType] = None,
+    keepdim: bool = False,
+) -> TensorLikeType:
+    a_ = _maybe_convert_to_dtype(a, torch.bool)
+    result = ne(sum(a_, dim=dim, keepdim=keepdim), False)  # type: ignore[arg-type]
+
+    # Preserves uint8 -- probably a legacy mask thing
+    if a.dtype is torch.uint8:
+        return prims.convert_element_type(result, torch.uint8)
+
+    return result
+
+
 # TODO: register decomp after stride logic is fixed
 def sum(
     a: TensorLikeType,
@@ -1789,3 +1849,23 @@
     device = utils.canonicalize_device(device)
 
     return prims.uniform(shape, low=low, high=high, dtype=dtype, device=device)
+
+
+# TODO: add OpInfo for torch.equal and refs.equal
+def equal(a: TensorLikeType, b: TensorLikeType) -> bool:
+    utils.check_same_device(a, b, allow_cpu_scalar_tensors=False)
+    utils.check_same_dtype(a, b)
+
+    # Shape check
+    if a.ndim != b.ndim:
+        return False
+
+    for x, y in zip(a.shape, b.shape):
+        if x != y:
+            return False
+
+    # Short-circuits if there are no elements to validate
+    if a.numel() == 0:
+        return True
+
+    return item(all(eq(a, b)))  # type: ignore[return-value]
diff --git a/torch/testing/_internal/common_methods_invocations.py b/torch/testing/_internal/common_methods_invocations.py
index 758784b..cc910f4 100644
--- a/torch/testing/_internal/common_methods_invocations.py
+++ b/torch/testing/_internal/common_methods_invocations.py
@@ -19269,27 +19269,25 @@
     # Reduction Reference OpInfos
     #
     ReductionPythonRefInfo(
-        "_refs.sum",
-        torch_opinfo_name="sum",
-        supports_out=True,
-    ),
-    ReductionPythonRefInfo(
-        "_refs.mean",
-        torch_opinfo_name="mean",
-        supports_out=True,
-    ),
-    ReductionPythonRefInfo(
-        "_refs.amin",
-        torch_opinfo_name="amin",
+        "_refs.all",
+        torch_opinfo_name="all",
     ),
     ReductionPythonRefInfo(
         "_refs.amax",
         torch_opinfo_name="amax",
     ),
     ReductionPythonRefInfo(
-        "_refs.var",
-        torch_opinfo_name="var",
-        supports_out=True
+        "_refs.amin",
+        torch_opinfo_name="amin",
+    ),
+    ReductionPythonRefInfo(
+        "_refs.any",
+        torch_opinfo_name="any",
+    ),
+    ReductionPythonRefInfo(
+        "_refs.mean",
+        torch_opinfo_name="mean",
+        supports_out=True,
     ),
     ReductionPythonRefInfo(
         "_refs.std",
@@ -19302,6 +19300,16 @@
         torch_opinfo_name="std_mean",
         validate_view_consistency=False
     ),
+    ReductionPythonRefInfo(
+        "_refs.sum",
+        torch_opinfo_name="sum",
+        supports_out=True,
+    ),
+    ReductionPythonRefInfo(
+        "_refs.var",
+        torch_opinfo_name="var",
+        supports_out=True
+    ),
     PythonRefInfo(
         "_refs.var_mean",
         torch_opinfo_name="var_mean",