Fix handling of device (#78615)

Removes an unnecessary auxiliary function (we had already implemented
it), uses DeviceLikeType to denote str or dtype, and adds `is_cpu` and
`is_cuda` helper functions
Pull Request resolved: https://github.com/pytorch/pytorch/pull/78615
Approved by: https://github.com/mruberry
diff --git a/torch/_prims/__init__.py b/torch/_prims/__init__.py
index 6609ff3..137116e 100644
--- a/torch/_prims/__init__.py
+++ b/torch/_prims/__init__.py
@@ -1956,7 +1956,7 @@
 
 
 _clone_doc = """
-    Creates a copy of a tensors.
+    Creates a copy of a tensor.
 """
 
 clone = _make_prim(
@@ -2022,7 +2022,7 @@
     assert isinstance(a, TensorLike)
     assert isinstance(device, (str, torch.device))
 
-    return TensorMeta(a, device=utils.wrap_device(device))
+    return TensorMeta(a, device=utils.canonicalize_device(device))
 
 
 def _device_put_aten(a: Tensor, device: Union[str, torch.device]) -> Tensor:
diff --git a/torch/_prims/utils.py b/torch/_prims/utils.py
index a097fc2..83f972b 100644
--- a/torch/_prims/utils.py
+++ b/torch/_prims/utils.py
@@ -477,7 +477,7 @@
             raise RuntimeError(msg)
 
 
-def canonicalize_device(device: Union[str, torch.device]) -> torch.device:
+def canonicalize_device(device: DeviceLikeType) -> torch.device:
     if isinstance(device, torch.device):
         return device
 
@@ -1099,20 +1099,6 @@
     return computation_dtype, result_dtype
 
 
-def wrap_device(d: Union[str, torch.device]) -> torch.device:
-    """
-    Wraps strings into torch.device objects.
-
-    Given torch.device objects are returned unmodified.
-    """
-
-    assert isinstance(d, (str, torch.device))
-    if isinstance(d, str):
-        return torch.device(d)
-
-    return d
-
-
 def make_contiguous_strides_for(shape: ShapeType) -> Tuple[int, ...]:
     validate_shape(shape)
     if not shape:
diff --git a/torch/_refs/__init__.py b/torch/_refs/__init__.py
index 90facc7..f3415db 100644
--- a/torch/_refs/__init__.py
+++ b/torch/_refs/__init__.py
@@ -2556,7 +2556,7 @@
 
     # If indices_or_sections is a tensor, it must be a CPU Long tensor
     if isinstance(indices_or_sections, TensorLike):
-        if indices_or_sections.device != torch.device("cpu"):
+        if not indices_or_sections.device.type == "cpu":
             msg = "tensor_split: if indices_or_sections is a tensor it must be on the CPU, but received one on {0}".format(
                 indices_or_sections.device
             )