Replace `_prims_common.check` with `torch._check*` (#103240)

This relands most of the changes from #102219 which were backed out by #103128. However, instead of removing `_prims_common.check`, it adds a warning and a comment mentioning that it will be removed in the future and `torch._check*` should be used instead. As mentioned in https://github.com/pytorch/pytorch/pull/103128#pullrequestreview-1466414415, `_prims_common.check` cannot yet be removed because of some internal usage

Part of #72948

Pull Request resolved: https://github.com/pytorch/pytorch/pull/103240
Approved by: https://github.com/albanD
diff --git a/test/test_prims.py b/test/test_prims.py
index da1f5a1..ad1e63f 100644
--- a/test/test_prims.py
+++ b/test/test_prims.py
@@ -1161,6 +1161,10 @@
     def test_mul_complex(self):
         prims.mul(torch.randn(2), 1 + 1j)
 
+    def test_check_deprecation_warning(self):
+        with self.assertWarnsRegex(DeprecationWarning, 'will be removed in the future'):
+            torch._prims_common.check(True, lambda: 'message')
+
 
 instantiate_device_type_tests(TestPrims, globals())
 
diff --git a/torch/__init__.py b/torch/__init__.py
index 148336c..b4c45f7 100644
--- a/torch/__init__.py
+++ b/torch/__init__.py
@@ -936,7 +936,7 @@
 # These error checking functions must be kept consistent with their C++
 # equivalents. Their C++ equivalents are mentioned where applicable.
 
-def _check_with(error_type, cond, message):
+def _check_with(error_type, cond: Union[builtins.bool, SymBool], message: Callable[[], str]):
     if not isinstance(cond, (builtins.bool, torch.SymBool)):
         raise TypeError(f'cond must be a bool, but got {type(cond)}')
 
diff --git a/torch/_decomp/decompositions.py b/torch/_decomp/decompositions.py
index 0c307c8..8715b1f 100644
--- a/torch/_decomp/decompositions.py
+++ b/torch/_decomp/decompositions.py
@@ -149,7 +149,7 @@
 
 @register_decomposition([aten.fill.Tensor])
 def fill_tensor(self, value: Tensor):
-    utils.check(
+    torch._check(
         value.dim() == 0,
         lambda: f"fill only supports 0-dimension value tensor but got tensor with {value.dim()} dimensions",
     )
@@ -785,14 +785,14 @@
     padding: List[int],
     stride: List[int],
 ) -> Tensor:
-    utils.check(len(kernel_size) == 2, lambda: "im2col(): only 2D kernel supported")
-    utils.check(len(dilation) == 2, lambda: "im2col(): only 2D dilation supported")
-    utils.check(len(padding) == 2, lambda: "im2col(): only 2D padding supported")
-    utils.check(len(stride) == 2, lambda: "im2col(): only 2D stride supported")
+    torch._check(len(kernel_size) == 2, lambda: "im2col(): only 2D kernel supported")
+    torch._check(len(dilation) == 2, lambda: "im2col(): only 2D dilation supported")
+    torch._check(len(padding) == 2, lambda: "im2col(): only 2D padding supported")
+    torch._check(len(stride) == 2, lambda: "im2col(): only 2D stride supported")
 
     def check_positive(param, param_name, strict=True):
         cond = all(p > 0 for p in param) if strict else all(p >= 0 for p in param)
-        utils.check(
+        torch._check(
             cond, lambda: "{param_name} should be greater {'than' zero, but got {param}"
         )
 
@@ -803,7 +803,7 @@
 
     shape = input.shape
     ndim = len(shape)
-    utils.check(
+    torch._check(
         ndim in (3, 4) and all(d != 0 for d in shape[-3:]),
         lambda: "Expected 3D or 4D (batch mode) tensor for input with possible 0 batch size "
         f"and non-zero dimensions, but got: {tuple(shape)}",
@@ -814,7 +814,7 @@
             shape[-2:], padding, dilation, kernel_size, stride
         )
     )
-    utils.check(
+    torch._check(
         all(c > 0 for c in output_size),
         lambda: f"Given an input with spacial size {tuple(shape[-2:])}, "
         f"kernel_size={kernel_size}, dilation={dilation}, "
@@ -869,15 +869,15 @@
     padding: List[int],
     stride: List[int],
 ) -> Tensor:
-    utils.check(len(output_size) == 2, lambda: "only 2D output_size supported")
-    utils.check(len(kernel_size) == 2, lambda: "only 2D kernel supported")
-    utils.check(len(dilation) == 2, lambda: "only 2D dilation supported")
-    utils.check(len(padding) == 2, lambda: "only 2D padding supported")
-    utils.check(len(stride) == 2, lambda: "only 2D stride supported")
+    torch._check(len(output_size) == 2, lambda: "only 2D output_size supported")
+    torch._check(len(kernel_size) == 2, lambda: "only 2D kernel supported")
+    torch._check(len(dilation) == 2, lambda: "only 2D dilation supported")
+    torch._check(len(padding) == 2, lambda: "only 2D padding supported")
+    torch._check(len(stride) == 2, lambda: "only 2D stride supported")
 
     def check_positive(param, param_name, strict=True):
         cond = all(p > 0 for p in param) if strict else all(p >= 0 for p in param)
-        utils.check(
+        torch._check(
             cond, lambda: "{param_name} should be greater than zero, but got {param}"
         )
 
@@ -889,13 +889,13 @@
 
     shape = input.shape
     ndim = len(shape)
-    utils.check(
+    torch._check(
         ndim in (2, 3) and all(d != 0 for d in shape[-2:]),
         lambda: "Expected 2D or 3D (batch mode) tensor for input with possible 0 batch size "
         f"and non-zero dimensions, but got: {tuple(shape)}",
     )
     prod_kernel_size = kernel_size[0] * kernel_size[1]
-    utils.check(
+    torch._check(
         shape[-2] % prod_kernel_size == 0,
         lambda: "Expected size of input's first non-batch dimension to be divisible by the "
         f"product of kernel_size, but got input.shape[-2] = {shape[-2]} and "
@@ -908,13 +908,13 @@
         )
     ]
     L = col[0] * col[1]
-    utils.check(
+    torch._check(
         shape[-1] == L,
         lambda: f"Given output_size={output_size}, kernel_size={kernel_size}, "
         f"dilation={dilation}, padding={padding}, stride={stride}, "
         f"expected input.size(-1) to be {L} but got {shape[-1]}.",
     )
-    utils.check(
+    torch._check(
         L > 0,
         lambda: f"Given output_size={output_size}, kernel_size={kernel_size}, "
         f"dilation={dilation}, padding={padding}, stride={stride}, "
@@ -961,7 +961,7 @@
 def native_dropout_backward(grad_output: Tensor, mask: Tensor, scale: float):
     # According to the CUDA kernel implementation we should have this test;
     # but it seems to fail tests!
-    # utils.check(mask.dtype == torch.bool, lambda: f"Mask should be Bool Scalar Type {mask.dtype}")
+    # torch._check(mask.dtype == torch.bool, lambda: f"Mask should be Bool Scalar Type {mask.dtype}")
 
     # Mimicking CUDA kernel's behavior for output stride: output follow input's memory format
     # This different from TensorIterator's behavior
@@ -1221,21 +1221,21 @@
     )
     utils.check_same_shape(input, grad_output, allow_cpu_scalar_tensors=False)
     utils.check_same_shape(mean, rstd, allow_cpu_scalar_tensors=False)
-    utils.check(
+    torch._check(
         input.numel() == N * C * HxW,
         lambda: f"Expect input to have { N * C * HxW} elements",
     )
-    utils.check(
+    torch._check(
         mean.shape == (N, group),
         lambda: f"Expect mean to have shape ({N}, {group}, but got {mean.shape}",
     )
-    utils.check(
+    torch._check(
         gamma is None or gamma.numel() == C,
         lambda: f"Expect gamma to have {C} elements but got {gamma.numel() if gamma is not None else -1}",
     )
 
     cpg, _rem = divmod(C, group)
-    utils.check(
+    torch._check(
         _rem == 0,
         lambda: f"Expect number of channels {C} to be evenly-divisible by number of groups {group}",
     )
@@ -1834,12 +1834,12 @@
     device = input.device
     shape = input.shape
     ndim = len(shape)
-    utils.check(
+    torch._check(
         ndim in (3, 4),
         lambda: f"adaptive_avg_pool2d(): Expected 3D or 4D tensor, but got {ndim}",
     )
     for d in input.shape[-2:]:
-        utils.check(
+        torch._check(
             d != 0,
             lambda: "adaptive_avg_pool2d(): Expected input to have non-zero size for "
             f"non-batch dimensions, but input has shape {tuple(shape)}.",
@@ -1966,13 +1966,13 @@
     alpha: NumberType = 1,
 ):
     dim = utils.canonicalize_dims(x.ndim, dim)
-    utils.check(
+    torch._check(
         index.ndim <= 1,
         lambda: f"Index should have dimension 1 or 0 (got {index.ndim})",
     )
     if alpha != 1:
         python_type = utils.dtype_to_type(x.dtype)
-        utils.check(
+        torch._check(
             python_type == bool
             or utils.is_weakly_lesser_type(type(alpha), python_type),
             lambda: f"alpha argument of type {type(alpha)} cannot be safely cast to type {python_type}!",
@@ -2005,7 +2005,7 @@
     x: TensorLike, dim: int, index: TensorLike, tensor: TensorLike, *, inplace: bool
 ):
     dim = utils.canonicalize_dims(x.ndim, dim)
-    utils.check(
+    torch._check(
         index.ndim <= 1,
         lambda: f"Index should have dimension 1 or 0 (got {index.ndim})",
     )
@@ -2060,19 +2060,19 @@
 def upsample_compute_output_size(input_size, output_size, scale_factors):
     spatial_dimensions = len(input_size) - 2
     if output_size is not None:
-        utils.check(
+        torch._check(
             scale_factors is None,
             lambda: "Must specify exactly one of output_size and scale_factors",
         )
-        utils.check(len(output_size) == spatial_dimensions, lambda: "")
+        torch._check(len(output_size) == spatial_dimensions, lambda: "")
         return output_size
     if scale_factors is not None:
         # NB: this isn't necessary lol
-        utils.check(
+        torch._check(
             output_size is None,
             lambda: "Must specify exactly one of output_size and scale_factors",
         )
-        utils.check(len(scale_factors) == spatial_dimensions, lambda: "")
+        torch._check(len(scale_factors) == spatial_dimensions, lambda: "")
         output_size = []
         for i, s in enumerate(scale_factors):
             if int(s) == s:
@@ -2080,7 +2080,7 @@
             else:
                 output_size.append(sym_int(input_size[i + 2] * s))
         return output_size
-    utils.check(
+    torch._check(
         False, lambda: "Must specify exactly one of output_size and scale_factors"
     )
 
@@ -2969,11 +2969,11 @@
     padding_mode: int = 0,
     align_corners: bool = False,
 ) -> Tensor:
-    utils.check(
+    torch._check(
         interpolation_mode in (0, 1, 2),
         lambda: f"Invalid interpolation mode {interpolation_mode}",
     )
-    utils.check(
+    torch._check(
         padding_mode in (0, 1, 2), lambda: f"Invalid padding mode {padding_mode}"
     )
 
@@ -3110,11 +3110,11 @@
 @out_wrapper()
 @pw_cast_for_opmath
 def mv(self, vec):
-    utils.check(
+    torch._check(
         self.dim() == 2 and vec.dim() == 1,
         lambda: f"matrix @ vector expected, got {self.dim()}, {vec.dim()}",
     )
-    utils.check(
+    torch._check(
         self.size(1) == vec.size(0),
         lambda: f"size mismatch, got input ({self.size(0)}x{self.size(1)}), vec ({vec.size(0)})",
     )
@@ -3134,11 +3134,11 @@
         elif other.is_conj():
             return torch.vdot(other.conj(), self)
 
-    utils.check(
+    torch._check(
         self.dim() == 1 and other.dim() == 1,
         lambda: f"1D tensors expected, but got {self.dim()}D and {other.dim()}D tensors",
     )
-    utils.check(
+    torch._check(
         self.dtype == other.dtype,
         lambda: f"dot : expected both vectors to have same dtype, but found {self.dtype} and {other.dtype}",
     )
@@ -3149,7 +3149,7 @@
             f"same number of elements, but got {self.numel()} and {other.numel()} elements respectively"
         )
 
-    utils.check(self.numel() == other.numel(), numel_error)
+    torch._check(self.numel() == other.numel(), numel_error)
 
     return (self * other).sum()
 
@@ -3296,7 +3296,7 @@
 
         return tensor1_expanded.bmm(tensor2_expanded).view(output_shape)
     else:
-        utils.check(False, lambda: "both arguments to matmul need to be at least 1D")
+        torch._check(False, lambda: "both arguments to matmul need to be at least 1D")
 
 
 @register_decomposition(aten.upsample_bicubic2d.default)
@@ -3373,7 +3373,7 @@
     align_corners: bool,
     scale_factors: Optional[Tuple[float, float]] = None,
 ) -> Tensor:
-    utils.check(
+    torch._check(
         bool(output_size) + bool(scale_factors) == 1,
         lambda: "Must specify exactly one of output_size and scale_factors.",
     )
diff --git a/torch/_inductor/cudagraph_trees.py b/torch/_inductor/cudagraph_trees.py
index e74fe63..fccdad3 100644
--- a/torch/_inductor/cudagraph_trees.py
+++ b/torch/_inductor/cudagraph_trees.py
@@ -72,7 +72,6 @@
     remove_unaligned_input_idxs,
     static_input,
 )
-from torch._prims_common import check
 from torch.multiprocessing.reductions import StorageWeakRef
 from torch.storage import UntypedStorage
 from torch.utils import _pytree as pytree
@@ -1071,7 +1070,7 @@
                 self.output_storage_alias.append(UnaliasedStorage)
                 continue
 
-            check(
+            torch._check(
                 o.is_cuda,
                 lambda: f"Expected all cuda outputs in cuda graph recording. Non cuda output from {self.stack_traces[i]}",
             ),
@@ -1447,7 +1446,7 @@
         for idx in self.cudagraph_managed_idxs:
             inputs[idx] = None
 
-        check(
+        torch._check(
             self._check_liveness(
                 self.expected_dead_indices_after_graph, self.path_weakrefs
             ),
@@ -1522,7 +1521,7 @@
 
             addr += block["size"]
 
-    check(
+    torch._check(
         len(unique_storages) == 0,
         lambda: f"These storage data ptrs are not allocated in pool {pool_id} but should be {unique_storages}",
     )
diff --git a/torch/_meta_registrations.py b/torch/_meta_registrations.py
index e3c4f1b..c346045 100644
--- a/torch/_meta_registrations.py
+++ b/torch/_meta_registrations.py
@@ -13,7 +13,6 @@
 from torch._ops import OpOverload
 from torch._prims import _elementwise_meta, ELEMENTWISE_PRIM_TYPE_PROMOTION_KIND
 from torch._prims_common import (
-    check,
     corresponding_complex_dtype,
     corresponding_real_dtype,
     elementwise_dtypes,
@@ -63,7 +62,7 @@
 
 def check_inplace_broadcast(self_shape, *args_shape):
     broadcasted_shape = tuple(_broadcast_shapes(self_shape, *args_shape))
-    check(
+    torch._check(
         broadcasted_shape == self_shape,
         lambda: f"output with shape {self_shape} doesn't match the broadcast shape {broadcasted_shape}",
     )
@@ -73,15 +72,14 @@
 @out_wrapper()
 def meta_take(self, index):
     # Type and device checks
-    check(
+    torch._check(
         index.dtype == torch.long,
         lambda: f"take(): Expected a long tensor for index, but got {index.dtype}",
     )
     # Index checks
-    check(
+    torch._check_index(
         not (self.numel() == 0 and index.numel() != 0),
         lambda: "take(): tried to take from an empty tensor",
-        IndexError,
     )
     return self.new_empty(index.shape)
 
@@ -91,11 +89,11 @@
 def linalg_cross(self, other, *, dim=-1):
     x_d = self.ndim
     y_d = other.ndim
-    check(
+    torch._check(
         x_d == y_d,
         lambda: "linalg.cross: inputs must have the same number of dimensions.",
     )
-    check(
+    torch._check(
         self.size(dim) == 3 and other.size(dim) == 3,
         lambda: (
             f"linalg.cross: inputs dimension {dim} must have length 3. "
@@ -334,7 +332,7 @@
     A: Tensor,
     name: str,
 ):
-    check(
+    torch._check(
         self.device == A.device,
         lambda: (
             f"Expected b and A to be on the same device, but found b on "
@@ -342,7 +340,7 @@
         ),
     )
 
-    check(
+    torch._check(
         self.dtype == A.dtype,
         lambda: (
             f"Expected b and A to have the same dtype, but found b of type "
@@ -350,7 +348,7 @@
         ),
     )
 
-    check(
+    torch._check(
         A.size(-1) == A.size(-2),
         lambda: (
             f"A must be batches of square matrices, "
@@ -358,7 +356,7 @@
         ),
     )
 
-    check(
+    torch._check(
         A.size(-1) == self.size(-2),
         lambda: (
             f"Incompatible matrix sizes for {name}: each A "
@@ -373,12 +371,12 @@
     t: Tensor, f_name: str, allow_low_precision_dtypes: bool = True
 ):
     dtype = t.dtype
-    check(
+    torch._check(
         t.is_floating_point() or t.is_complex(),
         lambda: f"{f_name}: Expected a floating point or complex tensor as input. Got {dtype}",
     )
     if not allow_low_precision_dtypes:
-        check(
+        torch._check(
             dtype in (torch.float, torch.double, torch.cfloat, torch.cdouble),
             lambda: f"{f_name}: Low precision dtypes not supported. Got {dtype}",
         )
@@ -386,7 +384,7 @@
 
 # From aten/src/ATen/native/LinearAlgebraUtils.h
 def checkIsMatrix(A: Tensor, f_name: str, arg_name: str = "A"):
-    check(
+    torch._check(
         A.dim() >= 2,
         lambda: f"{f_name}: The input tensor {arg_name} must have at least 2 dimensions.",
     )
@@ -400,7 +398,7 @@
 ):
     squareCheckInputs(A, f_name)
     checkIsMatrix(B, f_name)
-    check(
+    torch._check(
         A.size(-2) == B.size(-2) if left else A.size(-1) == B.size(-1),
         lambda: (
             f"{f_name}: Incompatible shapes of A and B for the equation "
@@ -413,7 +411,7 @@
 def checkSameDevice(
     fn_name: str, result: Tensor, input: Tensor, result_name: str = "result"
 ):
-    check(
+    torch._check(
         result.device == input.device,
         lambda: (
             f"{fn_name}: Expected {result_name} and input tensors to be on the same device, but got "
@@ -424,7 +422,7 @@
 
 def checkUplo(UPLO: str):
     UPLO_uppercase = UPLO.upper()
-    check(
+    torch._check(
         len(UPLO) == 1 and (UPLO_uppercase == "U" or UPLO_uppercase == "L"),
         lambda: f"Expected UPLO argument to be 'L' or 'U', but got {UPLO}",
     )
@@ -477,20 +475,20 @@
 )
 @out_wrapper()
 def linalg_householder_product(input: Tensor, tau: Tensor) -> Tensor:
-    check(
+    torch._check(
         input.ndim >= 2,
         lambda: "torch.linalg.householder_product: input must have at least 2 dimensions.",
     )
-    check(
+    torch._check(
         input.size(-2) >= input.size(-1),
         lambda: "torch.linalg.householder_product: input.shape[-2] must be greater than or equal to input.shape[-1]",
     )
-    check(
+    torch._check(
         input.size(-1) >= tau.size(-1),
         lambda: "torch.linalg.householder_product: input.shape[-1] must be greater than or equal to tau.shape[-1]",
     )
 
-    check(
+    torch._check(
         input.ndim - tau.ndim == 1,
         lambda: (
             f"torch.linalg.householder_product: Expected tau to have one dimension less than input, "
@@ -500,7 +498,7 @@
     if input.ndim > 2:
         expected_batch_tau_shape = input.shape[:-2]
         actual_batch_tau_shape = tau.shape[:-1]
-        check(
+        torch._check(
             actual_batch_tau_shape == expected_batch_tau_shape,
             lambda: (
                 f"torch.linalg.householder_product: Expected batch dimensions of tau to be "
@@ -508,7 +506,7 @@
             ),
         )
 
-    check(
+    torch._check(
         tau.dtype == input.dtype,
         lambda: (
             f"torch.linalg.householder_product: tau dtype {tau.dtype}"
@@ -567,7 +565,7 @@
     squareCheckInputs(LD, "torch.linalg.ldl_solve")
     checkFloatingOrComplex(LD, "torch.linalg.ldl_solve")
     linearSolveCheckInputs(B, LD, "torch.linalg.ldl_solve")
-    check(
+    torch._check(
         B.ndim >= 2,
         lambda: (
             f"torch.linalg.ldl_solve: Expected B to have at least 2 dimensions, "
@@ -575,18 +573,18 @@
         ),
     )
     expected_pivots_shape = LD.shape[:-1]
-    check(
+    torch._check(
         expected_pivots_shape == pivots.shape,
         lambda: (
             f"torch.linalg.ldl_solve: Expected LD.shape[:-1] and pivots.shape to be the same, "
             f"but got pivots with shape {pivots.shape} instead"
         ),
     )
-    check(
+    torch._check(
         utils.is_integer_dtype(pivots.dtype),
         lambda: f"torch.linalg.ldl_solve: Expected pivots to be integers. Got {pivots.dtype}",
     )
-    check(
+    torch._check(
         LD.dtype == B.dtype,
         lambda: f"torch.linalg.ldl_solve: LD dtype {LD.dtype} does not match b dtype {B.dtype}",
     )
@@ -602,7 +600,7 @@
 @register_meta([aten.linalg_lu.default, aten.linalg_lu.out])
 @out_wrapper("P", "L", "U")
 def linalg_lu_meta(A: Tensor, *, pivot: bool = True) -> Tuple[Tensor, Tensor, Tensor]:
-    check(
+    torch._check(
         A.ndim >= 2,
         lambda: f"linalg.lu: Expected tensor with 2 or more dimensions. Got size: {A.shape} instead",
     )
@@ -632,7 +630,7 @@
 def linalg_lu_factor_ex_meta(
     A: Tensor, *, pivot: bool = True, check_errors: bool = False
 ) -> Tuple[Tensor, Tensor, Tensor]:
-    check(
+    torch._check(
         A.ndim >= 2,
         lambda: f"torch.lu_factor: Expected tensor with 2 or more dimensions. Got size: {A.shape} instead",
     )
@@ -672,14 +670,14 @@
 ) -> Tensor:
     # dtype
     checkFloatingOrComplex(LU, "torch.linalg.lu_solve")
-    check(
+    torch._check(
         LU.dtype == B.dtype,
         lambda: (
             f"linalg.lu_solve: Expected LU and B to have the same dtype, "
             f"but found LU of type {LU.dtype} and B of type {B.dtype} instead"
         ),
     )
-    check(
+    torch._check(
         pivots.dtype == torch.int,
         lambda: "linalg.lu_solve: pivots should be a Tensor of scalar type torch.int32",
     )
@@ -687,13 +685,13 @@
     # matrix shapes
     squareCheckInputs(LU, "torch.linalg.lu_solve")
     checkInputsSolver(LU, B, left, "linalg.lu_solve")
-    check(
+    torch._check(
         LU.size(-1) == pivots.size(-1),
         lambda: "linalg.lu_solve: Number of pivots per batch should be same as the dimension of the matrix",
     )
 
     # batches
-    check(
+    torch._check(
         LU.shape[:-1] == pivots.shape,
         lambda: (
             f"linalg.lu_solve: Expected LU.shape[:-1] and pivots.shape to be the same, "
@@ -770,7 +768,7 @@
         compute_q = False
         reduced = True  # this is actually irrelevant in this mode
     else:
-        check(
+        torch._check(
             False,
             lambda: (
                 f"qr received unrecognized mode '{mode}' "
@@ -1043,11 +1041,11 @@
     output_h = input_h + pad_t + pad_b
     output_w = input_w + pad_l + pad_r
 
-    check(
+    torch._check(
         output_w == grad_output.shape[dim_w],
         lambda: f"gradOutput width unexpected. Expected: {output_w}, Got: {grad_output.shape[dim_w]}",
     )
-    check(
+    torch._check(
         output_h == grad_output.shape[dim_h],
         lambda: f"gradOutput height unexpected. Expected: {output_h}, Got: {grad_output.shape[dim_h]}",
     )
@@ -1057,7 +1055,7 @@
 @register_meta(aten.reflection_pad2d.default)
 def meta_pad2d(self, padding):
     valid_dims = self.size(1) != 0 and self.size(2) != 0
-    check(
+    torch._check(
         (self.ndim == 3 and valid_dims)
         or (self.ndim == 4 and valid_dims and self.size(3) != 0),
         lambda: f"3D or 4D (batch mode) tensor expected for input, but got: {self}",
@@ -1086,9 +1084,9 @@
     dim2 = batch1.size(1)
     dim3 = batch2.size(2)
     self = self.expand((dim1, dim2, dim3))
-    check(batch1.dim() == 3, lambda: "batch1 must be a 3D tensor")
-    check(batch2.dim() == 3, lambda: "batch2 must be a 3D tensor")
-    check(
+    torch._check(batch1.dim() == 3, lambda: "batch1 must be a 3D tensor")
+    torch._check(batch2.dim() == 3, lambda: "batch2 must be a 3D tensor")
+    torch._check(
         self.dtype == batch1.dtype == batch2.dtype,
         lambda: f"Input dtypes must be the same, got: input: {self.dtype}, batch1: {batch1.dtype}, batch2: {batch2.dtype}",
     )
@@ -1096,7 +1094,7 @@
     batch2_sizes = batch2.shape
     bs = batch1_sizes[0]
     contraction_size = batch1_sizes[2]
-    check(
+    torch._check(
         batch2_sizes[0] == bs and batch2_sizes[1] == contraction_size,
         lambda: (
             f"Expected size for first two dimensions of batch2 tensor to be: "
@@ -1140,7 +1138,7 @@
     per_row_fake_quant=False,
     symmetric_quant=False,
 ):
-    check(
+    torch._check(
         ch_axis < self.dim(),
         lambda: "Error in fused_moving_avg_obs_fake_quant_cpu: ch_axis must be < self.dim()",
     )
@@ -1149,7 +1147,7 @@
 
 
 def dot_check(self, other):
-    check(
+    torch._check(
         self.dim() == 1 and other.dim() == 1,
         lambda: f"1D tensors expected, but got {self.dim()}D and {other.dim()}D tensors",
     )
@@ -1163,11 +1161,11 @@
 
 @register_meta([aten.mm.default])
 def meta_mm(a, b):
-    check(a.dim() == 2, lambda: "a must be 2D")
-    check(b.dim() == 2, lambda: "b must be 2D")
+    torch._check(a.dim() == 2, lambda: "a must be 2D")
+    torch._check(b.dim() == 2, lambda: "b must be 2D")
     N, M1 = a.shape
     M2, P = b.shape
-    check(
+    torch._check(
         M1 == M2,
         lambda: f"a and b must have same reduction dim, but got [{N}, {M1}] X [{M2}, {P}].",
     )
@@ -1389,7 +1387,7 @@
 
 # from check_dim_size() in aten/src/ATen/TensorUtils.cpp.
 def check_dim_size(tensor, dim, dim_size, size):
-    check(
+    torch._check(
         tensor.dim() == dim and tensor.shape[dim_size] == size,
         lambda: f"Expected a tensor of dimension {dim} and tensor.size[{dim_size}] == {size}, "
         + f"but got : dimension {tensor.dim()} and tensor.size[{dim_size}] = {tensor.shape[dim_size]}",
@@ -1407,7 +1405,7 @@
     divisor_override=None,
 ):
     def unpack(name, val):
-        check(
+        torch._check(
             len(val) in [1, 2],
             lambda: f"avg_pool2d: {name} must either be a single int, or a tuple of two ints",
         )
@@ -1416,7 +1414,7 @@
         return H, W
 
     kH, kW = unpack("kernel_size", kernel_size)
-    check(
+    torch._check(
         len(stride) in [0, 1, 2],
         lambda: "avg_pool2d: stride must either be omitted, a single int, or a tuple of two ints",
     )
@@ -1429,7 +1427,7 @@
 
     padH, padW = unpack("padding", padding)
 
-    check(
+    torch._check(
         divisor_override is None or divisor_override != 0,
         lambda: "divisor must be not zero",
     )
@@ -1530,26 +1528,26 @@
     divisor_override,
 ):
     # From aten/src/ATen/native/AveragePool2d.cpp structured kernel meta func.
-    check(
+    torch._check(
         len(kernel_size) == 1 or len(kernel_size) == 2,
         lambda: "avg_pool2d: kernel_size must either be a single int, or a tuple of two ints",
     )
     kH = kernel_size[0]
     kW = kH if len(kernel_size) == 1 else kernel_size[1]
-    check(
+    torch._check(
         len(stride) == 0 or len(stride) == 1 or len(stride) == 2,
         lambda: "avg_pool2d: stride must either be omitted, a single int, or a tuple of two ints",
     )
     dH = kH if len(stride) == 0 else stride[0]
     dW = kW if len(stride) == 0 else dH if len(stride) == 1 else stride[1]
-    check(
+    torch._check(
         len(padding) == 1 or len(padding) == 2,
         lambda: "avg_pool2d: padding must either be a single int, or a tuple of two ints",
     )
     padH = padding[0]
     padW = padH if len(padding) == 1 else padding[1]
 
-    check(
+    torch._check(
         divisor_override is None or divisor_override != 0,
         lambda: "divisor must be not zero",
     )
@@ -1602,7 +1600,7 @@
     count_include_pad=True,
     divisor_override=None,
 ):
-    check(
+    torch._check(
         len(kernel_size) in (1, 3),
         lambda: "avg_pool3d: kernel_size must be a single int, or a tuple of three ints",
     )
@@ -1610,7 +1608,7 @@
     kH = kT if len(kernel_size) == 1 else kernel_size[1]
     kW = kT if len(kernel_size) == 1 else kernel_size[2]
 
-    check(
+    torch._check(
         not stride or len(stride) in (1, 3),
         lambda: "avg_pool3d: stride must be omitted, a single int, or a tuple of three ints",
     )
@@ -1618,7 +1616,7 @@
     dH = kH if not stride else (dT if len(stride) == 1 else stride[1])
     dW = kW if not stride else (dT if len(stride) == 1 else stride[2])
 
-    check(
+    torch._check(
         len(padding) in (1, 3),
         lambda: "avg_pool3d: padding must be a single int, or a tuple of three ints",
     )
@@ -1626,12 +1624,12 @@
     padH = padT if len(padding) == 1 else padding[1]
     padW = padT if len(padding) == 1 else padding[2]
 
-    check(
+    torch._check(
         input.ndim in (4, 5),
         lambda: "non-empty 4D or 5D (batch mode) tensor expected for input",
     )
 
-    check(
+    torch._check(
         not divisor_override or divisor_override != 0,
         lambda: "divisor must be not zero",
     )
@@ -1689,7 +1687,7 @@
     count_include_pad,
     divisor_override,
 ):
-    check(
+    torch._check(
         len(kernel_size) in (1, 3),
         lambda: "avg_pool3d: kernel_size must be a single int, or a tuple of three ints",
     )
@@ -1697,7 +1695,7 @@
     kH = kT if len(kernel_size) == 1 else kernel_size[1]
     kW = kT if len(kernel_size) == 1 else kernel_size[2]
 
-    check(
+    torch._check(
         not stride or len(stride) in (1, 3),
         lambda: "avg_pool3d: stride must be omitted, a single int, or a tuple of three ints",
     )
@@ -1705,7 +1703,7 @@
     dH = kH if not stride else (dT if len(stride) == 1 else stride[1])
     dW = kW if not stride else (dT if len(stride) == 1 else stride[2])
 
-    check(
+    torch._check(
         len(padding) in (1, 3),
         lambda: "avg_pool3d: padding must be a single int, or a tuple of three ints",
     )
@@ -1713,12 +1711,12 @@
     padH = padT if len(padding) == 1 else padding[1]
     padW = padT if len(padding) == 1 else padding[2]
 
-    check(
+    torch._check(
         input.ndim in (4, 5),
         lambda: "non-empty 4D or 5D (batch mode) tensor expected for input",
     )
 
-    check(
+    torch._check(
         not divisor_override or divisor_override != 0,
         lambda: "divisor must be not zero",
     )
@@ -1759,7 +1757,7 @@
 
 @register_meta(aten._adaptive_avg_pool2d.default)
 def meta_adaptive_avg_pool2d(self, output_size):
-    check(
+    torch._check(
         self.ndim == 3 or self.ndim == 4,
         lambda: f"Expected 3D or 4D tensor, but got {self.shape}",
     )
@@ -1777,7 +1775,7 @@
 
 @register_meta(aten._adaptive_avg_pool3d.default)
 def meta_adaptive_avg_pool3d(self, output_size):
-    check(
+    torch._check(
         self.ndim == 4 or self.ndim == 5,
         lambda: f"Expected 4D or 5D tensor, but got {self.shape}",
     )
@@ -1788,16 +1786,16 @@
 def meta__adaptive_avg_pool2d_backward(grad_out, self):
     ndim = grad_out.ndim
     for i in range(1, ndim):
-        check(
+        torch._check(
             grad_out.size(i) > 0,
             lambda: f"adaptive_avg_pool2d_backward(): Expected grad_output to have non-zero \
                       size for non-batch dimensions, {grad_out.shape} with dimension {i} being empty",
         )
-    check(
+    torch._check(
         ndim == 3 or ndim == 4,
         lambda: f"adaptive_avg_pool2d_backward(): Expected 3D or 4D tensor, but got {self.shape}",
     )
-    check(
+    torch._check(
         self.dtype == grad_out.dtype,
         lambda: f"expected dtype {self.dtype} for `grad_output` but got dtype {grad_out.dtype}",
     )
@@ -1852,30 +1850,28 @@
 
 @register_meta([aten.index.Tensor, aten._unsafe_index.Tensor])
 def meta_index_Tensor(self, indices):
-    check(indices, lambda: "at least one index must be provided")
+    torch._check(bool(indices), lambda: "at least one index must be provided")
     # aten::index is the internal advanced indexing implementation
     # checkIndexTensorTypes and expandTensors
     result: List[Optional[Tensor]] = []
     for i, index in enumerate(indices):
         if index is not None:
-            check(
+            torch._check(
                 index.dtype in [torch.long, torch.int, torch.int8, torch.bool],
                 lambda: "tensors used as indices must be long, int, byte or bool tensors",
             )
             if index.dtype in [torch.int8, torch.bool]:
                 nonzero = index.nonzero()
                 k = len(result)
-                check(
+                torch._check_index(
                     k + index.ndim <= self.ndim,
                     lambda: f"too many indices for tensor of dimension {self.ndim}",
-                    IndexError,
                 )
                 for j in range(index.ndim):
-                    check(
+                    torch._check_index(
                         index.shape[j] == self.shape[k + j],
                         lambda: f"The shape of the mask {index.shape} at index {i} "
                         f"does not match the shape of the indexed tensor {self.shape} at index {k + j}",
-                        IndexError,
                     )
                     result.append(nonzero.select(1, j))
             else:
@@ -1883,7 +1879,7 @@
         else:
             result.append(index)
     indices = result
-    check(
+    torch._check(
         len(indices) <= self.ndim,
         lambda: f"too many indices for tensor of dimension {self.ndim} (got {len(indices)})",
     )
@@ -1988,20 +1984,20 @@
     dim1 = batch1.size(1)
     dim2 = batch2.size(2)
     self = self.expand((dim1, dim2))
-    check(batch1.dim() == 3, lambda: "batch1 must be a 3D tensor")
-    check(batch2.dim() == 3, lambda: "batch2 must be a 3D tensor")
-    check(
+    torch._check(batch1.dim() == 3, lambda: "batch1 must be a 3D tensor")
+    torch._check(batch2.dim() == 3, lambda: "batch2 must be a 3D tensor")
+    torch._check(
         batch1.size(0) == batch2.size(0),
         lambda: f"batch1 and batch2 must have same number of batches, got {batch1.size(0)} and {batch2.size(0)}",
     )
-    check(
+    torch._check(
         batch1.size(2) == batch2.size(1),
         lambda: (
             f"Incompatible matrix sizes for bmm ({batch1.size(1)}x{batch1.size(2)} "
             f"and {batch2.size(1)}x{batch2.size(2)})"
         ),
     )
-    check(
+    torch._check(
         self.size(0) == dim1 and self.size(1) == dim2,
         lambda: "self tensor does not match matmul output shape",
     )
@@ -2015,7 +2011,7 @@
     ]
 )
 def meta__foreach_unaop_(self):
-    check(
+    torch._check(
         isinstance(self, List),
         lambda: f"Expect List[Tensor] but got {type(self)}",
     )
@@ -2029,7 +2025,7 @@
     ]
 )
 def meta__foreach_unaop(self):
-    check(
+    torch._check(
         isinstance(self, List),
         lambda: f"Expect List[Tensor] but got {type(self)}",
     )
@@ -2037,14 +2033,14 @@
 
 
 def _check_foreach_binop_tensor_lists(self, other):
-    check(
+    torch._check(
         isinstance(self, List) and isinstance(other, List),
         lambda: (
             "The first two arguments of must be List[Tensor], "
             f"but got {type(self)} and {type(other)}."
         ),
     )
-    check(
+    torch._check(
         len(self) > 0 and len(self) == len(other),
         lambda: (
             "self and other must be non-empty and match in length, "
@@ -2100,7 +2096,7 @@
     ]
 )
 def meta__foreach_binop__scalar(self, scalar=1):
-    check(
+    torch._check(
         isinstance(self, List),
         lambda: f"The first argument of must be List[Tensor], but got {type(self)}.",
     )
@@ -2115,7 +2111,7 @@
     ]
 )
 def meta__foreach_binop_scalar(self, scalar=1):
-    check(
+    torch._check(
         isinstance(self, List),
         lambda: f"The first argument of must be List[Tensor], but got {type(self)}.",
     )
@@ -2129,15 +2125,15 @@
     ]
 )
 def meta__foreach_addcop__scalar(self, tensor1, tensor2, scalar=1):
-    check(
+    torch._check(
         all(isinstance(l, List) for l in [self, tensor1, tensor2]),
         lambda: (
             "All arguments of _foreach_addc*_ must be List[Tensor], "
             f"but got {type(self)}, {type(tensor1)}, and {type(tensor2)}"
         ),
     )
-    check(len(self) > 0, lambda: "input tensor list must not be empty.")
-    check(
+    torch._check(len(self) > 0, lambda: "input tensor list must not be empty.")
+    torch._check(
         len(self) == len(tensor1) and len(self) == len(tensor2),
         lambda: "All input tensor lists must have the same length",
     )
@@ -2150,15 +2146,15 @@
     ]
 )
 def meta__foreach_addcop_scalar(self, tensor1, tensor2, scalar=1):
-    check(
+    torch._check(
         all(isinstance(l, List) for l in [self, tensor1, tensor2]),
         lambda: (
             "All arguments must be List[Tensor], "
             f"but got {type(self)}, {type(tensor1)}, and {type(tensor2)}"
         ),
     )
-    check(len(self) > 0, lambda: "input tensor list must not be empty.")
-    check(
+    torch._check(len(self) > 0, lambda: "input tensor list must not be empty.")
+    torch._check(
         len(self) == len(tensor1) and len(self) == len(tensor2),
         lambda: "All input tensor lists must have the same length",
     )
@@ -2168,7 +2164,7 @@
 
 @register_meta([aten._foreach_pow.ScalarAndTensor])
 def meta__foreach_pow_scalar_and_tensor(self, exponent):
-    check(
+    torch._check(
         isinstance(exponent, List),
         lambda: f"exponent must be a tensor list but got {type(exponent)}",
     )
@@ -2177,7 +2173,7 @@
 
 @register_meta([aten._foreach_addcdiv_.Tensor, aten._foreach_addcmul_.Tensor])
 def meta__foreach_addcop_tensor(self, tensor1, tensor2, scalars):
-    check(
+    torch._check(
         all(isinstance(l, List) for l in [self, tensor1, tensor2])
         and isinstance(scalars, torch.Tensor),
         lambda: (
@@ -2185,8 +2181,8 @@
             f"but got: {type(self)}, {type(tensor1)}, {type(tensor2)}, and {type(scalars)}"
         ),
     )
-    check(len(self) > 0, lambda: "input tensor list must not be empty.")
-    check(
+    torch._check(len(self) > 0, lambda: "input tensor list must not be empty.")
+    torch._check(
         len(self) == len(tensor1) and len(self) == len(tensor2),
         lambda: "All input tensor lists must have the same length",
     )
@@ -2212,7 +2208,7 @@
     found_inf=None,
 ):
     for l in [self, grads, exp_avgs, exp_avg_sqs, max_exp_avg_sqs, state_steps]:
-        check(
+        torch._check(
             isinstance(l, List),
             lambda: f"exponent must be a tensor list but got {type(l)}",
         )
@@ -2238,7 +2234,7 @@
     found_inf=None,
 ):
     for l in [self, grads, exp_avgs, exp_avg_sqs, max_exp_avg_sqs, state_steps]:
-        check(
+        torch._check(
             isinstance(l, List),
             lambda: f"exponent must be a tensor list but got {type(l)}",
         )
@@ -2258,17 +2254,17 @@
 @register_meta([aten._int_mm])
 @out_wrapper()
 def meta__int_mm(a, b):
-    check(a.dim() == 2, lambda: "a must be a 2D tensor")
-    check(b.dim() == 2, lambda: "b must be a 2D tensor")
-    check(
+    torch._check(a.dim() == 2, lambda: "a must be a 2D tensor")
+    torch._check(b.dim() == 2, lambda: "b must be a 2D tensor")
+    torch._check(
         a.dtype is torch.int8,
         lambda: f"expected self to be int8, got {a.dtype}",
     )
-    check(
+    torch._check(
         b.dtype is torch.int8,
         lambda: f"expected mat2 to be int8, got {b.dtype}",
     )
-    check(
+    torch._check(
         a.size(1) == b.size(0),
         lambda: (
             f"Incompatible matrix sizes for _int_mm ({a.size(0)}x{a.size(1)} "
@@ -2280,28 +2276,28 @@
 
 @register_meta(aten._cdist_forward.default)
 def meta_cdist_forward(x1, x2, p, compute_mode):
-    check(
+    torch._check(
         x1.dim() >= 2,
         lambda: f"cdist only supports at least 2D tensors, X1 got: {x1.dim()}D",
     )
-    check(
+    torch._check(
         x2.dim() >= 2,
         lambda: f"cdist only supports at least 2D tensors, X2 got: {x2.dim()}D",
     )
-    check(
+    torch._check(
         x1.size(-1) == x2.size(-1),
         lambda: f"X1 and X2 must have the same number of columns. X1: {x1.size(-1)} X2: {x2.size(-1)}",
     )
-    check(
+    torch._check(
         utils.is_float_dtype(x1.dtype),
         lambda: "cdist only supports floating-point dtypes, X1 got: {x1.dtype}",
     )
-    check(
+    torch._check(
         utils.is_float_dtype(x2.dtype),
         lambda: "cdist only supports floating-point dtypes, X2 got: {x2.dtype}",
     )
-    check(p >= 0, lambda: "cdist only supports non-negative p values")
-    check(
+    torch._check(p >= 0, lambda: "cdist only supports non-negative p values")
+    torch._check(
         compute_mode in (None, 1, 2),
         lambda: f"possible modes: None, 1, 2, but was: {compute_mode}",
     )
@@ -2326,22 +2322,22 @@
     include_last_offset=False,
     padding_idx=-1,
 ):
-    check(
+    torch._check(
         indices.dtype in (torch.long, torch.int),
         lambda: f"expected indices to be long or int, got {indices.dtype}",
     )
-    check(
+    torch._check(
         offsets.dtype in (torch.long, torch.int),
         lambda: f"expected offsets to be long or int, got {offsets.dtype}",
     )
-    check(
+    torch._check(
         utils.is_float_dtype(weight.dtype),
         lambda: f"expected weight to be floating point type, got {weight.dtype}",
     )
 
     num_bags = offsets.size(0)
     if include_last_offset:
-        check(
+        torch._check(
             num_bags >= 1,
             lambda: "include_last_offset: numBags should be at least 1",
         )
@@ -2351,19 +2347,19 @@
     MODE_SUM, MODE_MEAN, MODE_MAX = range(3)
 
     if per_sample_weights is not None:
-        check(
+        torch._check(
             mode == MODE_SUM,
             lambda: "embedding_bag: per_sample_weights only supported with mode='sum'",
         )
-        check(
+        torch._check(
             per_sample_weights.dtype == weight.dtype,
             lambda: f"expected weight ({weight.dtype}) and per_sample_weights ({per_sample_weights.dtype}) to have same dtype",
         )
-        check(
+        torch._check(
             per_sample_weights.ndim == 1,
             lambda: f"expected per_sample_weights to be 1D tensor, got {per_sample_weights.ndim}D",
         )
-        check(
+        torch._check(
             per_sample_weights.numel() == indices.numel(),
             lambda: (
                 f"expected per_sample_weights.numel() ({per_sample_weights.numel()} "
@@ -2408,7 +2404,7 @@
         numBags = offsets.shape[0]
         if mode == MODE_MAX:
             if include_last_offset:
-                check(
+                torch._check(
                     numBags >= 1,
                     lambda: "include_last_offset: numBags should be at least 1",
                 )
@@ -2477,7 +2473,7 @@
 
 @register_meta(aten.repeat.default)
 def meta_repeat(self, repeats):
-    check(
+    torch._check(
         len(repeats) >= self.dim(),
         lambda: "Number of dimensions of repeat dims can not be smaller than number of dimensions of tensor",
     )
@@ -2534,17 +2530,17 @@
 
 
 def shift_dtype_check(fn_name, self, val):
-    check(
+    torch._check(
         utils.is_integer_dtype(self.dtype),
         lambda: f"{fn_name}: Expected input tensor to have an integral dtype. Got {self.dtype}",
     )
     if isinstance(val, torch.Tensor):
-        check(
+        torch._check(
             utils.is_integer_dtype(val.dtype),
             lambda: f"{fn_name}: Expected shift value to have an integral dtype. Got {val.dtype}",
         )
     else:
-        check(
+        torch._check(
             isinstance(val, IntLike),
             lambda: f"{fn_name}: Expected shift value to be an int. Got {val}",
         )
@@ -2620,8 +2616,8 @@
 
 
 def common_meta_baddbmm_bmm(batch1, batch2, is_bmm, self_baddbmm=None):
-    check(batch1.dim() == 3, lambda: "batch1 must be a 3D tensor")
-    check(batch2.dim() == 3, lambda: "batch2 must be a 3D tensor")
+    torch._check(batch1.dim() == 3, lambda: "batch1 must be a 3D tensor")
+    torch._check(batch2.dim() == 3, lambda: "batch2 must be a 3D tensor")
 
     batch1_sizes = batch1.size()
     batch2_sizes = batch2.size()
@@ -2632,7 +2628,7 @@
     res_cols = batch2_sizes[2]
     output_size = (bs, res_rows, res_cols)
 
-    check(
+    torch._check(
         batch2_sizes[0] == bs and batch2_sizes[1] == contraction_size,
         lambda: f"Expected size for first two dimensions of batch2 tensor to be: [{bs}"
         f", {contraction_size}] but got: [{batch2_sizes[0]}, {batch2_sizes[1]}].",
@@ -2643,8 +2639,8 @@
     output = batch2.new_empty(output_size)
 
     if not is_bmm and self_baddbmm is not None:
-        check(self_baddbmm.dim() == 3, lambda: "self must be a 3D tensor")
-        check(
+        torch._check(self_baddbmm.dim() == 3, lambda: "self must be a 3D tensor")
+        torch._check(
             self_baddbmm.size() == output_size,
             lambda: f"Expected an input tensor shape with shape {output_size} but got shape: {self_baddbmm.size()}",
         )
@@ -2689,9 +2685,9 @@
 
 
 def pooling_output_shape(inputSize, kernelSize, pad, stride, dilation, ceil_mode):
-    check(stride != 0, lambda: "stride should not be zero")
-    check(pad >= 0, lambda: f"pad must be non-negative, but got pad: {pad}")
-    check(
+    torch._check(stride != 0, lambda: "stride should not be zero")
+    torch._check(pad >= 0, lambda: f"pad must be non-negative, but got pad: {pad}")
+    torch._check(
         pad <= kernelSize // 2,
         lambda: f"pad should be at most half of kernel size, but got pad={pad} and kernel_size={kernelSize}",
     )
@@ -2720,15 +2716,15 @@
     ndim = input.dim()
     nOutputPlane = nInputPlane
 
-    check(
+    torch._check(
         kW > 0 and kH > 0,
         lambda: "kernel size should be greater than zero, but got kH: {kH}, kW: {kW}",
     )
-    check(
+    torch._check(
         dW > 0 and dH > 0,
         lambda: "stride should be greater than zero, but got dH: {dH}, dW: {dW}",
     )
-    check(
+    torch._check(
         dilationH > 0 and dilationW > 0,
         lambda: "dilation should be greater than zero, but got dilationH: {dilationH}, dilationW: {dilationW}",
     )
@@ -2736,25 +2732,25 @@
     valid_dims = input.size(1) != 0 and input.size(2) != 0
 
     if memory_format == torch.channels_last:
-        check(
+        torch._check(
             ndim == 4 and valid_dims and input.size(3) != 0,
             lambda: "Expected 4D (batch mode) tensor expected for input with channels_last layout"
             " with optional 0 dim batch size for input, but got: {input.size()}",
         )
     else:
-        check(
+        torch._check(
             (ndim == 3 and input.size(0) != 0 and valid_dims)
             or (ndim == 4 and valid_dims and input.size(3) != 0),
             lambda: f"Expected 3D or 4D (batch mode) tensor with optional 0 dim batch size for input, but got: {input.size()}",
         )
 
-    check(
+    torch._check(
         kW // 2 >= padW and kH // 2 >= padH,
         lambda: "pad should be smaller than or equal to half of kernel size, but got "
         f"padW = {padW}, padH = {padH}, kW = {kW}, kH = {kH}",
     )
 
-    check(
+    torch._check(
         outputWidth >= 1 and outputHeight >= 1,
         lambda: f"Given input size: ({nInputPlane}x{inputHeight}x{inputWidth}). "
         f"Calculated output size: ({nOutputPlane}x{outputHeight}x{outputWidth}). "
@@ -2788,21 +2784,21 @@
 ):
     ndim = input.ndim
 
-    check(
+    torch._check(
         kT > 0 and kW > 0 and kH > 0,
         lambda: (
             f"kernel size should be greater than zero, but got "
             f"kT: {kT}, kH: {kH}, kW: {kW}"
         ),
     )
-    check(
+    torch._check(
         dT > 0 and dW > 0 and dH > 0,
         lambda: (
             f"stride should be greater than zero, but got "
             f"dT: {dT}, dH: {dH}, dW: {dW}"
         ),
     )
-    check(
+    torch._check(
         dilationT > 0 and dilationW > 0 and dilationH > 0,
         lambda: (
             f"dilation should be greater than zero, but got "
@@ -2810,7 +2806,7 @@
         ),
     )
 
-    check(
+    torch._check(
         ndim in (4, 5),
         lambda: f"{fn_name}: Expected 4D or 5D tensor for input, but got: {input.shape}",
     )
@@ -2819,7 +2815,7 @@
         if ndim == 5 and i == 0:
             # size of batch-dim can be 0.
             continue
-        check(
+        torch._check(
             input.size(i) > 0,
             lambda: (
                 f"{fn_name}: Expected input's non-batch dimensions to have positive length,"
@@ -2829,7 +2825,7 @@
         )
 
     if check_input_size:  # AveragePool3d
-        check(
+        torch._check(
             itime >= kT and iheight >= kH and iwidth >= kW,
             lambda: (
                 f"input image (T: {itime} H: {iheight} W: {iwidth}) smaller than "
@@ -2837,7 +2833,7 @@
             ),
         )
 
-    check(
+    torch._check(
         kT / 2 >= pT and kW / 2 >= pW and kH / 2 >= pH,
         lambda: (
             f"pad should be smaller than or equal to half of kernel size, but got "
@@ -2845,7 +2841,7 @@
         ),
     )
 
-    check(
+    torch._check(
         otime >= 1 and owidth >= 1 and oheight >= 1,
         lambda: (
             f"Given input size: ({nslices}x{itime}x{iheight}x{iwidth}). "
@@ -2914,7 +2910,7 @@
 ):
     # Reference: aten/src/ATen/native/DilatedMaxPool2d.cpp
     def unpack(name, val):
-        check(
+        torch._check(
             len(val) in [1, 2],
             lambda: f"max_pool2d: {name} must either be a single int, or a tuple of two ints",
         )
@@ -2924,7 +2920,7 @@
 
     kH, kW = unpack("kernel_size", kernel_size)
 
-    check(
+    torch._check(
         len(stride) in [0, 1, 2],
         lambda: "max_pool2d: stride must either be omitted, a single int, or a tuple of two ints",
     )
@@ -2941,17 +2937,17 @@
 
     memory_format = utils.suggest_memory_format(input)
     if memory_format == torch.channels_last:
-        check(
+        torch._check(
             input.dim() == 4,
             lambda: "non-empty 4D (batch mode) tensor expected for input with channels_last layout",
         )
     elif memory_format == torch.contiguous_format:
-        check(
+        torch._check(
             input.dim() in [3, 4],
             lambda: "non-empty 3D or 4D (batch mode) tensor expected for input",
         )
     else:
-        check(
+        torch._check(
             False,
             lambda: "Unsupport memory format. Supports only ChannelsLast, Contiguous",
         )
@@ -2999,7 +2995,7 @@
         self, kernel_size, stride, padding, dilation, ceil_mode
     )
 
-    check(
+    torch._check(
         self.dtype == grad_output.dtype,
         lambda: f"Expected dtype {self.dtype} for `gradOutput` but got dtype {grad_output.dtype}",
     )
@@ -3093,7 +3089,7 @@
     memory_format=None,
 ):
     if layout == torch.sparse_coo:
-        check(
+        torch._check(
             memory_format is None,
             lambda: "memory format option is only supported by strided tensors",
         )
@@ -3131,20 +3127,18 @@
 @register_meta(aten.select.int)
 def meta_select(self, dim, index):
     ndim = self.dim()
-    check(
+    torch._check_index(
         ndim != 0,
         lambda: "select() cannot be applied to a 0-dim tensor.",
-        IndexError,
     )
 
     dim = dim if dim >= 0 else dim + ndim
     size = self.size(dim)
 
-    check(
+    torch._check_index(
         not (-index > size or index >= size),
         lambda: f"select(): index {index} out of range for tensor of size "
         f"{self.size()} at dimension {dim}",
-        IndexError,
     )
 
     index = index if index >= 0 else index + size
@@ -3190,13 +3184,13 @@
 def gather_shape_check(self, dim, index):
     self_dims = max(self.dim(), 1)
     index_dims = max(index.dim(), 1)
-    check(
+    torch._check(
         self_dims == index_dims,
         lambda: "Index tensor must have the same number of dimensions as input tensor",
     )
     for i in range(self_dims):
         if i != dim:
-            check(
+            torch._check(
                 ensure_nonempty_size(index, i) <= ensure_nonempty_size(self, i),
                 lambda: f"Size does not match at dimension {i} expected index {index.shape}"
                 + f" to be smaller than self {self.shape} apart from dimension {dim}",
@@ -3208,7 +3202,7 @@
     wrapped_dim = maybe_wrap_dim(dim, self.dim())
     is_index_empty = index.numel() == 0
     if not is_index_empty:
-        check(
+        torch._check(
             index.dtype == torch.long,
             lambda: f"gather(): Expected dtype int64 for index, but got {index.dtype}",
         )
@@ -3229,7 +3223,7 @@
             return "REDUCE_MAXIMUM"
         elif reduce_ == "amin":
             return "REDUCE_MINIMUM"
-        check(
+        torch._check(
             False,
             lambda: "reduce argument must be either sum, prod, mean, amax or amin.",
         )
@@ -3239,20 +3233,20 @@
             return "REDUCE_ADD"
         elif reduce_ == "multiply":
             return "REDUCE_MULTIPLY"
-        check(False, lambda: "reduce argument must be either add or multiply.")
+        torch._check(False, lambda: "reduce argument must be either add or multiply.")
         return
 
 
 # From aten/src/ATen/native/ScatterGatherChecks.h
 def scatter_gather_dtype_check(method_name, self, index, src_opt=None):
     if index.numel() != 0:
-        check(
+        torch._check(
             index.dtype == torch.long,
             lambda: f"{method_name}(): Expected dtype int64 for index",
         )
 
     if src_opt is not None:
-        check(
+        torch._check(
             self.dtype == src_opt.dtype,
             lambda: f"{method_name}(): Expected self.dtype to be equal to src.dtype",
         )
@@ -3266,7 +3260,7 @@
 def scatter_shape_check(self, dim, index, src_opt=None):
     if index.numel() == 0:
         return
-    check(
+    torch._check(
         ensure_nonempty_dim(self.dim()) == ensure_nonempty_dim(index.dim()),
         lambda: "Index tensor must have the same number of dimensions as self tensor",
     )
@@ -3292,17 +3286,17 @@
                 break
 
     if src_opt is not None:
-        check(
+        torch._check(
             ensure_nonempty_dim(self.dim()) == ensure_nonempty_dim(index.dim()),
             lambda: "Index tensor must have the same number of dimensions as self tensor",
         )
-        check(
+        torch._check(
             not is_wrong_shape,
             lambda: f"Expected index {index.shape} to be smaller than self {self.shape}"
             + f" apart from dimension {dim} and to be smaller than src {src_opt.shape}",
         )
     else:
-        check(
+        torch._check(
             not is_wrong_shape,
             lambda: f"Expected index {index.shape} to be smaller than self {self.shape}"
             + f" apart from dimension {dim}",
@@ -3588,7 +3582,7 @@
 @register_meta([aten.multinomial.default, aten.multinomial.out])
 @out_wrapper()
 def meta_multinomial(input, num_samples, replacement=False, *, generator=None):
-    check(
+    torch._check(
         0 < input.dim() <= 2,
         lambda: f"The probabilty distributions dimensions must be 1 or 2, but got {input.dim()}",
     )
@@ -3607,17 +3601,17 @@
 
 
 def upsample_common_check(input_size, output_size, num_spatial_dims):
-    check(
+    torch._check(
         len(output_size) == num_spatial_dims,
         lambda: f"It is expected output_size equals to {num_spatial_dims}, but got size {len(output_size)}",
     )
     expected_input_dims = num_spatial_dims + 2  # N, C, ...
-    check(
+    torch._check(
         len(input_size) == expected_input_dims,
         lambda: f"It is expected input_size equals to {expected_input_dims}, but got size {len(input_size)}",
     )
 
-    check(
+    torch._check(
         all(s > 0 for s in input_size[2:]) and all(s > 0 for s in output_size),
         lambda: f"Input and output sizes should be greater than 0, but got "
         f"input size {input_size} and output size {output_size}",
@@ -3629,7 +3623,7 @@
 
 @register_meta(aten.upsample_nearest1d.default)
 def upsample_nearest1d(input, output_size, scales=None):
-    check(
+    torch._check(
         input.numel() != 0 or multiply_integers(input.size()[1:]),
         lambda: f"Non-empty 3D data tensor expected but got a tensor with sizes {input.size()}",
     )
@@ -3643,7 +3637,7 @@
 
 @register_meta(aten.upsample_nearest2d.default)
 def upsample_nearest2d(input, output_size, scales_h=None, scales_w=None):
-    check(
+    torch._check(
         input.numel() != 0 or multiply_integers(input.size()[1:]),
         lambda: f"Non-empty 4D data tensor expected but got a tensor with sizes {input.size()}",
     )
@@ -3676,12 +3670,12 @@
     full_output_size = upsample_common_check(
         input_size, output_size, num_spatial_dims=2
     )
-    check(
+    torch._check(
         grad_output.ndim == 4,
         lambda: f"Expected grad_output to be a tensor of dimension 4 but got: dimension {grad_output.ndim}",
     )
     for i in range(4):
-        check(
+        torch._check(
             grad_output.size(i) == full_output_size[i],
             lambda: (
                 f"Expected grad_output to have the same shape as output;"
@@ -3697,7 +3691,7 @@
 
 @register_meta(aten.upsample_nearest3d.default)
 def upsample_nearest3d(input, output_size, scales_d=None, scales_h=None, scales_w=None):
-    check(
+    torch._check(
         input.numel() != 0 or multiply_integers(input.size()[1:]),
         lambda: f"Non-empty 5D data tensor expected but got a tensor with sizes {input.size()}",
     )
@@ -3739,29 +3733,29 @@
 def rnn_cell_checkSizes(
     input_gates, hidden_gates, input_bias, hidden_bias, factor, prev_hidden
 ):
-    check(input_gates.ndim == 2, lambda: f"{input_gates.ndim} != 2")
-    check(
+    torch._check(input_gates.ndim == 2, lambda: f"{input_gates.ndim} != 2")
+    torch._check(
         input_gates.shape == hidden_gates.shape,
         lambda: f"{input_gates.shape} != {hidden_gates.shape}",
     )
     gates_size = input_gates.size(1)
     if input_bias is not None:
-        check(input_bias.ndim == 1, lambda: f"{input_bias.ndim} != 1")
-        check(
+        torch._check(input_bias.ndim == 1, lambda: f"{input_bias.ndim} != 1")
+        torch._check(
             input_bias.numel() == gates_size,
             lambda: f"{input_bias.numel()} != {gates_size}",
         )
-        check(
+        torch._check(
             input_bias.shape == hidden_bias.shape,
             lambda: f"{input_bias.shape} != {hidden_bias.shape}",
         )
-    check(prev_hidden.ndim == 2, lambda: f"{prev_hidden.ndim} != 2")
+    torch._check(prev_hidden.ndim == 2, lambda: f"{prev_hidden.ndim} != 2")
     expected_prev_hidden_numel = input_gates.size(0) * gates_size // factor
-    check(
+    torch._check(
         prev_hidden.numel() == expected_prev_hidden_numel,
         lambda: f"{prev_hidden.numel()} != {input_gates.size(0)} * {gates_size} // {factor} (aka {expected_prev_hidden_numel})",
     )
-    check(
+    torch._check(
         all(
             x.device == input_gates.device
             for x in [hidden_gates, input_bias, hidden_bias, prev_hidden]
@@ -3879,16 +3873,14 @@
 
 def zero_numel_check_dims(self, dim, fn_name):
     if self.ndim == 0:
-        check(
+        torch._check_index(
             dim == 0 or dim == -1,
             lambda: f"{fn_name}: Expected reduction dim -1 or 0 for scalar but got {dim}",
-            IndexError,
         )
     else:
-        check(
+        torch._check_index(
             self.size(dim) != 0,
             lambda: f"{fn_name}: Expected reduction dim {dim} to have non-zero size.",
-            IndexError,
         )
 
 
@@ -3898,7 +3890,7 @@
         dim = maybe_wrap_dim(dim, self.dim())
         zero_numel_check_dims(self, dim, name)
     else:
-        check(
+        torch._check(
             self.numel() != 0,
             lambda: f"{name}: Expected reduction dim to be specified for input.numel() == 0.",
         )
@@ -3923,12 +3915,12 @@
 def topk_meta(self, k, dim=-1, largest=True, sorted=True):
     # From aten/src/ATen/native/Sorting.cpp
     dim = maybe_wrap_dim(dim, self.dim(), wrap_scalar=True)
-    check(
+    torch._check(
         k >= 0 and k <= (self.size(dim) if self.dim() > 0 else 1),
         lambda: "selected index k out of range",
     )
     sliceSize = 1 if self.dim() == 0 else self.size(dim)
-    check(k >= 0 and k <= sliceSize, lambda: "k not in range for dimension")
+    torch._check(k >= 0 and k <= sliceSize, lambda: "k not in range for dimension")
 
     topKSize = list(self.shape)
     if len(topKSize) > 0:
@@ -3942,16 +3934,16 @@
 # From aten/src/ATen/native/cuda/RNN.cu
 def checkLSTMBackwardSizes(grad_hy, grad_cy, cx, cy, workspace):
     defined_grad = grad_hy if grad_hy is not None else grad_cy
-    check(defined_grad.dim() == 2, lambda: "")
+    torch._check(defined_grad.dim() == 2, lambda: "")
     exp_size = defined_grad.size()
     if grad_hy is not None:
-        check(grad_hy.size() == exp_size, lambda: "")
+        torch._check(grad_hy.size() == exp_size, lambda: "")
     if grad_cy is not None:
-        check(grad_cy.size() == exp_size, lambda: "")
-    check(cx.size() == exp_size, lambda: "")
-    check(cy.size() == exp_size, lambda: "")
-    check(workspace.dim() == 2, lambda: "")
-    check(workspace.numel() == exp_size[0] * exp_size[1] * 4, lambda: "")
+        torch._check(grad_cy.size() == exp_size, lambda: "")
+    torch._check(cx.size() == exp_size, lambda: "")
+    torch._check(cy.size() == exp_size, lambda: "")
+    torch._check(workspace.dim() == 2, lambda: "")
+    torch._check(workspace.numel() == exp_size[0] * exp_size[1] * 4, lambda: "")
 
 
 # From aten/src/ATen/native/cuda/RNN.cu
@@ -4048,7 +4040,7 @@
     full_output_size = upsample_common_check(
         input.size(), output_size, num_spatial_dims=2
     )
-    check(
+    torch._check(
         input.numel() != 0 or all(size > 0 for size in input.size()[1:]),
         lambda: f"Non-empty 4D data tensor expected but got a tensor with sizes {input.size()}",
     )
@@ -4060,13 +4052,17 @@
 # From aten/src/ATen/native/cuda/AmpKernels.cu
 @register_meta(aten._amp_foreach_non_finite_check_and_unscale_.default)
 def _amp_foreach_non_finite_check_and_unscale_(self, found_inf, inv_scale):
-    check(found_inf.numel() == 1, lambda: "found_inf must be a 1-element tensor.")
-    check(inv_scale.numel() == 1, lambda: "inv_scale must be a 1-element tensor.")
-    check(
+    torch._check(
+        found_inf.numel() == 1, lambda: "found_inf must be a 1-element tensor."
+    )
+    torch._check(
+        inv_scale.numel() == 1, lambda: "inv_scale must be a 1-element tensor."
+    )
+    torch._check(
         found_inf.dtype.is_floating_point,
         lambda: "found_inf must be a float tensor.",
     )
-    check(
+    torch._check(
         inv_scale.dtype.is_floating_point,
         lambda: "inv_scale must be a float tensor.",
     )
diff --git a/torch/_prims/__init__.py b/torch/_prims/__init__.py
index 29cb75c..65a4080 100644
--- a/torch/_prims/__init__.py
+++ b/torch/_prims/__init__.py
@@ -16,7 +16,6 @@
 from torch._prims.nvfuser_prims import register_nvprims
 from torch._prims.rng_prims import register_rng_prims
 from torch._prims_common import (
-    check,
     Dim,
     DimsSequenceType,
     DimsType,
@@ -422,7 +421,7 @@
 
 
 def _complex_only_elementwise_meta(*args, **kwargs):
-    utils.check(
+    torch._check(
         utils.is_complex_dtype(args[0].dtype), lambda: "Only complex dtype is supported"
     )
     return _elementwise_meta(*args, **kwargs)
@@ -581,7 +580,7 @@
 
 
 def _cbrt_aten(a: torch.Tensor) -> Tensor:
-    utils.check(
+    torch._check(
         not a.is_complex(),
         lambda: "cbrt: Complex inputs not supported. Consider calling torch.pow(a, 1.0/3.0)",
     )
@@ -1293,10 +1292,9 @@
 
     # Verifies end is strictly greater than start
     # (Collapse requires a non-empty interval)
-    utils.check(
+    torch._check_value(
         end >= start,
         lambda: f"Attempting to collapse but end, {end}, is less than start, {start}!",
-        ValueError,
     )
 
 
@@ -1823,7 +1821,7 @@
     utils.validate_strides(stride)
 
     required_size = utils.compute_required_storage_length(size, stride, storage_offset)
-    utils.check(
+    torch._check(
         input.numel() >= required_size,
         lambda: (
             f"as_strided_scatter: sizes {size}, strides {stride}, storage offset {storage_offset} "
@@ -1832,7 +1830,7 @@
             f"for storage of size {input.numel() * input.element_size()}"
         ),
     )
-    utils.check(
+    torch._check(
         utils.is_same_shape(src.shape, size),
         lambda: f"expected src to have a size equal to the slice of self. src size = {src.shape}, slice size = {size}",
     )
@@ -2432,11 +2430,11 @@
     device: torch.device,
     requires_grad: bool,
 ) -> TensorLikeType:
-    utils.check(
+    torch._check(
         utils.is_integer_dtype(dtype),
         lambda: "prims.iota only supports integer dtypes",
     )
-    utils.check(step != 0, lambda: "step must be nonzero")
+    torch._check(step != 0, lambda: "step must be nonzero")
     return torch.empty(
         length,
         dtype=dtype,
@@ -2532,7 +2530,7 @@
 ) -> TensorLikeType:
     p_strides = utils.make_contiguous_strides_for([shape[l] for l in physical_layout])
     dim = len(shape)
-    utils.check(
+    torch._check(
         len(physical_layout) == dim,
         lambda: (
             "Number of dimensions in the tensor input does not match the "
@@ -2543,7 +2541,7 @@
     strides = [0] * len(shape)
     seen_dims = set()
     for p, l in enumerate(physical_layout):
-        utils.check(
+        torch._check(
             0 <= l < dim,
             lambda: (
                 f"Dimension out of range (expected to be between 0 and {dim - 1}, but got "
@@ -2551,7 +2549,7 @@
                 "not currently supported; file an issue if you want it."
             ),
         )
-        utils.check(l not in seen_dims, lambda: "Duplicate dim not allowed")
+        torch._check(l not in seen_dims, lambda: "Duplicate dim not allowed")
         strides[l] = p_strides[p]
         seen_dims.add(l)
     return TensorMeta(
@@ -2779,12 +2777,12 @@
     device: torch.device,
     requires_grad: bool,
 ) -> TensorLikeType:
-    utils.check(
+    torch._check(
         std >= 0.0,
         lambda: f"expected non-negative standard deviation, but got std={std}",
     )
 
-    utils.check(
+    torch._check(
         utils.is_float_dtype(dtype) or utils.is_complex_dtype(dtype),
         lambda: f"expected a floating-point or complex dtype, but got dtype={dtype}",
     )
diff --git a/torch/_prims_common/__init__.py b/torch/_prims_common/__init__.py
index b0c8664..0807197 100644
--- a/torch/_prims_common/__init__.py
+++ b/torch/_prims_common/__init__.py
@@ -7,6 +7,7 @@
 import operator
 import sympy
 import weakref
+import warnings
 import torch
 from torch import sym_float, sym_int, sym_max
 
@@ -268,7 +269,7 @@
 
 
 def validate_memory_format(memory_format: torch.memory_format):
-    check(
+    torch._check(
         memory_format in _memory_formats,
         lambda: f"Received unknown memory format {memory_format}!",
     )
@@ -286,7 +287,7 @@
     if memory_format == torch.channels_last_3d:
         return is_channels_last_contiguous_3d(a)
 
-    check(
+    torch._check(
         False,
         lambda: f"is_contiguous received unsupported memory format {memory_format}",
     )
@@ -795,13 +796,13 @@
     newsize = 1
     for i, d in enumerate(shape):
         if d == -1:
-            check(dim is None, lambda: "only one dimension can be inferred")
+            torch._check(dim is None, lambda: "only one dimension can be inferred")
             dim = i
         elif d >= 0:
             newsize *= d
         else:
-            check(False, lambda: f"invalid shape dimension {d}")
-    check(
+            torch._check(False, lambda: f"invalid shape dimension {d}")
+    torch._check(
         numel == newsize or (dim is not None and newsize > 0 and numel % newsize == 0),
         lambda: f"shape '{list(shape)}' is invalid for input of size {numel}",
     )
@@ -809,7 +810,7 @@
         # Convert to list to produce a compatible error message with core
         # PyTorch, which prints sequences in square brackets.
         shape = list(shape)
-        check(
+        torch._check(
             newsize != 0,
             lambda: (f"cannot reshape tensor of 0 elements into shape {shape} because the "
                      f"unspecified dimension size -1 can be any value and is ambiguous"),
@@ -954,18 +955,18 @@
     Checks whether the input is floating point or complex.
     If allow_low_precision_dtypes is True, it allows having float16, bfloat16, and complex32
     """
-    check(
+    torch._check(
         is_float_dtype(dtype) or is_complex_dtype(dtype),
         lambda: f"{fn_name}: Expected a floating point or complex tensor as input. Got {dtype}",
     )
-    check(
+    torch._check(
         allow_low_precision_dtypes or not is_low_precision_dtype(dtype),
         lambda: f"{fn_name}: Half precision dtypes not supported. Got {dtype}",
     )
 
 
 def check_is_matrix(A: TensorLikeType, f_name: str, arg_name: str = "A"):
-    check(
+    torch._check(
         len(A.shape) >= 2,
         lambda: f"{f_name}: The input tensor {arg_name} must have at least 2 dimensions.",
     )
@@ -1060,11 +1061,11 @@
 
 
 def check_pin_memory(pin_memory: bool):
-    check(not pin_memory, lambda: "PrimTorch does not support pinned memory", NotImplementedError)
+    torch._check_not_implemented(not pin_memory, lambda: "PrimTorch does not support pinned memory")
 
 
 def check_layout(layout: torch.layout):
-    check(layout == torch.strided, lambda: f"PrimTorch doesn't support layout={layout}", NotImplementedError)
+    torch._check_not_implemented(layout == torch.strided, lambda: f"PrimTorch doesn't support layout={layout}")
 
 
 # TODO: maybe unify with can_cast_to?
@@ -1485,7 +1486,7 @@
 
 
 def make_channels_last_1d_strides_for(shape: ShapeType) -> Tuple[int, ...]:
-    check(
+    torch._check(
         len(shape) == 3,
         lambda: "Only tensors of rank 3 can use the channels_last_1d memory format",
     )
@@ -1503,7 +1504,7 @@
 
 def make_channels_last_2d_strides_for(shape: ShapeType) -> Tuple[int, ...]:
     # TODO: maybe inform the user of channels_last_3d if rank of the tensor is 5?
-    check(
+    torch._check(
         len(shape) == 4,
         lambda: "Only tensors of rank 4 can use the channels_last memory format",
     )
@@ -1520,7 +1521,7 @@
 
 
 def make_channels_last_3d_strides_for(shape: ShapeType) -> Tuple[int, ...]:
-    check(
+    torch._check(
         len(shape) == 5,
         lambda: "Only tensors of rank 5 can use the channels_last_3d memory format",
     )
@@ -1654,6 +1655,9 @@
         raise ValueError(msg)
 
 
+# NOTE: This function should ideally be removed, but some Meta internal models
+# packaged with `torch.package` are using it, so it will have to be removed
+# at some point in the future when those models no longer use this function.
 def check(
     b: bool, s: Callable[[], str], exc_type: Type[Exception] = RuntimeError
 ) -> None:
@@ -1662,9 +1666,14 @@
     Error message is a callable producing a string (to avoid wasting time
     string formatting in non-error case, and also to make it easier for torchdynamo
     to trace.)
+
+    .. note:: This function is planned for removal in the future. Please use
+        `torch._check*` functions instead.
     """
-    if not b:
-        raise exc_type(s())
+    warnings.warn(DeprecationWarning((
+        "'torch._prims_common.check' will be removed in the future. Please use "
+        "'torch._check*' functions instead")))
+    torch._check_with(exc_type, b, s)
 
 
 # This combines is_channels_last_strides_2d and is_channels_last_strides_3d in
diff --git a/torch/_prims_common/wrappers.py b/torch/_prims_common/wrappers.py
index 5f10a2a..041fb76 100644
--- a/torch/_prims_common/wrappers.py
+++ b/torch/_prims_common/wrappers.py
@@ -176,13 +176,13 @@
 
     # Checks safe cast
     if exact_dtype:
-        utils.check(
+        torch._check(
             copy_from.dtype == copy_to.dtype,
             lambda: f"Expected out tensor to have dtype {copy_from.dtype} "
             f"but got {copy_to.dtype} instead",
         )
     else:
-        utils.check(
+        torch._check(
             utils.can_safe_cast_to(cast_from=copy_from.dtype, cast_to=copy_to.dtype),
             lambda: f"Attempting to cast from {copy_from.dtype} to out tensor with dtype {copy_to.dtype}, "
             "but this can't be cast because it is not safe!",
@@ -255,10 +255,9 @@
                     _safe_copy_out(copy_from=result, copy_to=out, exact_dtype=exact_dtype)  # type: ignore[arg-type]
                 else:
                     assert isinstance(out, Tuple)  # type: ignore[arg-type]
-                    utils.check(
+                    torch._check_type(
                         len(out) == len(result),
                         lambda: f"expected tuple of {len(result)} elements but got {len(out)}",
-                        TypeError,
                     )
                     for r, o in zip(result, out):
                         # These two operations are done in-place
diff --git a/torch/_refs/__init__.py b/torch/_refs/__init__.py
index f1778eb..a1e66cb 100644
--- a/torch/_refs/__init__.py
+++ b/torch/_refs/__init__.py
@@ -16,7 +16,6 @@
 import torch._prims_common as utils
 from torch import sym_float, sym_int
 from torch._prims_common import (
-    check,
     DeviceLikeType,
     Dim,
     DimsSequenceType,
@@ -626,7 +625,7 @@
 # imag does not use _make_elementwise_unary_reference because it does not support out
 def imag(a: TensorLikeType) -> TensorLikeType:
     assert isinstance(a, TensorLike)
-    utils.check(
+    torch._check(
         utils.is_complex_dtype(a.dtype), lambda: "imag only supports complex tensors."
     )
     return prims.imag(a)
@@ -654,7 +653,7 @@
 
 @_make_elementwise_unary_reference(ELEMENTWISE_TYPE_PROMOTION_KIND.ALWAYS_BOOL)
 def isposinf(a: TensorLikeType) -> TensorLikeType:
-    utils.check(
+    torch._check(
         not utils.is_complex_dtype(a.dtype),
         lambda: f"Complex dtype is not supported for isposinf, got dtype {a.dtype}",
     )
@@ -665,7 +664,7 @@
 
 @_make_elementwise_unary_reference(ELEMENTWISE_TYPE_PROMOTION_KIND.ALWAYS_BOOL)
 def isneginf(a: TensorLikeType) -> TensorLikeType:
-    utils.check(
+    torch._check(
         not utils.is_complex_dtype(a.dtype),
         lambda: f"Complex dtype is not supported for isneginf, got dtype {a.dtype}",
     )
@@ -788,7 +787,7 @@
 
 
 def _neg_meta(a: TensorLikeType):
-    check(
+    torch._check(
         a.dtype is not torch.bool,
         lambda: (
             "Negation, the `-` operator, on a bool tensor is not supported. "
@@ -935,23 +934,20 @@
             a: Union[Tensor, NumberType],
             b: Union[Tensor, NumberType],
         ) -> Tensor:
-            check(
+            torch._check_value(
                 supports_lhs_python_scalar or not isinstance(a, Number),
                 lambda: f"{name}: Received a lhs Python scalar to an elementwise binary "
                 "operation that does not accept lhs scalars!",
-                ValueError,
             )
-            check(
+            torch._check_value(
                 supports_rhs_python_scalar or not isinstance(b, Number),
                 lambda: f"{name}: Received a rhs Python scalar to an elementwise binary "
                 "operation that does not accept rhs scalars!",
-                ValueError,
             )
-            check(
+            torch._check_value(
                 supports_two_python_scalars
                 or not (isinstance(a, Number) and isinstance(b, Number)),
                 lambda: f"{name}: Receive two Number inputs to an elementwise binary operation!",
-                ValueError,
             )
             a, b = _maybe_broadcast(a, b)
             return prim(a, b)
@@ -1230,7 +1226,7 @@
     elif utils.is_integer_dtype(dtype):
         return _floor_divide_integer(a, b)
     else:
-        check(False, lambda: f"{dtype} not supported for floor_divide")
+        torch._check(False, lambda: f"{dtype} not supported for floor_divide")
 
 
 def _floor_divide_integer(a: Tensor, b: Tensor) -> Tensor:
@@ -1374,20 +1370,19 @@
     rtol: float,
     atol: float,
 ) -> None:
-    check(
+    torch._check_value(
         a.dtype == b.dtype,
         lambda: "{0}: Attempting to compare tensors of different dtypes {1} and {2}!".format(
             name, a.dtype, b.dtype
         ),
-        ValueError,
     )
-    check(
+    torch._check(
         rtol >= 0,
         lambda: "{0}: rtol must be greater than or equal to zero, but got {1}!".format(
             name, rtol
         ),
     )
-    check(
+    torch._check(
         atol >= 0,
         lambda: "{0}: atol must be greater than or equal to zero, but got {1}!".format(
             name, atol
@@ -1678,7 +1673,7 @@
     type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT,
 )
 def xlogy(a: Union[TensorLikeType, NumberType], b: Union[TensorLikeType, NumberType]):
-    utils.check(
+    torch._check(
         isinstance(a, TensorLike) or isinstance(b, TensorLike),
         lambda: 'Expected either argument a or b to be a Tensor"',
     )
@@ -1736,12 +1731,11 @@
     if value is not None:
         dtype = self.dtype  # no scalars allowed, see add
         python_type = utils.dtype_to_type(dtype)
-        check(
+        torch._check_value(
             utils.is_weakly_lesser_type(type(value), python_type),
             lambda: "value argument of type {0} cannot be safely cast to type {1}!".format(
                 type(value), python_type
             ),
-            exc_type=ValueError,
         )
 
     return self + value * tensor1 / tensor2
@@ -1766,12 +1760,11 @@
     if value is not None:
         dtype = self.dtype  # no scalars allowed, see add
         python_type = utils.dtype_to_type(dtype)
-        check(
+        torch._check_value(
             utils.is_weakly_lesser_type(type(value), python_type),
             lambda: "value argument of type {0} cannot be safely cast to type {1}!".format(
                 type(value), python_type
             ),
-            exc_type=ValueError,
         )
 
     return self + value * tensor1 * tensor2
@@ -1851,7 +1844,7 @@
         raise NotImplementedError
 
     utils.check_same_device(pred, a, b, allow_cpu_scalar_tensors=True)
-    check(
+    torch._check(
         pred.dtype is torch.bool,
         lambda: f"expected predicate to be bool, got {pred.dtype}",
     )
@@ -2229,7 +2222,7 @@
     *shape,
 ) -> Tensor:
     shape = utils.extract_shape_from_varargs(shape, validate=False)
-    utils.check(
+    torch._check(
         utils.is_expandable_to(shape, a.shape),
         lambda: f'sum_to_size: size "{shape}" is not expandable to size "{a.shape}"',
     )
@@ -2402,7 +2395,7 @@
     if dtype is None:
         dtype = a.dtype
     # can't use out wrapper because of this argument
-    check(
+    torch._check(
         out is None or out.dtype == dtype,
         lambda: f"Expected out tensor to have dtype {dtype}, but got {out.dtype} instead",
     )
@@ -2415,7 +2408,7 @@
         out=None,
         output_dtype_kind=REDUCTION_OUTPUT_TYPE_KIND.KEEP_PROMOTED_TYPE,
     )
-    check(
+    torch._check(
         utils.is_float_dtype(dtype) or utils.is_complex_dtype(dtype),
         lambda: (
             f"mean(): could not infer output dtype. "
@@ -2491,22 +2484,22 @@
     beta: NumberType = 1,
     alpha: NumberType = 1,
 ) -> TensorLikeType:
-    check(
+    torch._check(
         vec1.ndim == 1,
         lambda: f"addr: Expected 1-D argument vec1, but got {vec1.ndim}-D",
     )
-    check(
+    torch._check(
         vec2.ndim == 1,
         lambda: f"addr: Expected 1-D argument vec2, but got {vec2.ndim}-D",
     )
     self = self.expand(vec1.shape[0], vec2.shape[0])
     if utils.is_boolean_dtype(self.dtype):
         # Integers are accepted for booleans
-        check(
+        torch._check(
             is_weakly_lesser_type(type(beta), int),
             lambda: f"expected bool/int beta but got {type(beta)}",
         )
-        check(
+        torch._check(
             is_weakly_lesser_type(type(alpha), int),
             lambda: f"expected bool/int alpha but got {type(beta)}",
         )
@@ -2518,11 +2511,11 @@
                 torch.outer(vec1, vec2) if alpha else torch.full_like(self, False),
             )
     else:
-        check(
+        torch._check(
             is_weakly_lesser_type(type(beta), dtype_to_type(self.dtype)),
             lambda: f"cannot safely convert {type(beta)} to {self.dtype}",
         )
-        check(
+        torch._check(
             is_weakly_lesser_type(type(alpha), dtype_to_type(self.dtype)),
             lambda: f"cannot safely convert {type(alpha)} to {self.dtype}",
         )
@@ -2712,7 +2705,7 @@
 def constant_pad_nd(
     input: TensorLikeType, pad: List[int], value: NumberType = 0
 ) -> TensorLikeType:
-    check(
+    torch._check(
         len(pad) % 2 == 0,
         lambda: f"Length of pad must be even but instead it equals {len(pad)}",
     )
@@ -2723,7 +2716,7 @@
     l_pad = len(pad) // 2
     l_diff = l_inp - l_pad
 
-    check(
+    torch._check(
         l_inp >= l_pad,
         lambda: "Length of pad should be no more than twice the number of "
         f"dimensions of the input. Pad length is {len(pad)} while the input has "
@@ -2748,7 +2741,7 @@
     for i in range(l_pad):
         pad_idx = len(pad) - ((i + 1) * 2)
         new_dim = input_sizes[l_diff + i] + pad[pad_idx] + pad[pad_idx + 1]
-        check(
+        torch._check(
             new_dim > 0,
             lambda: f"The input size {input_sizes[l_diff + i]}, plus negative padding "
             f"{pad[pad_idx]} and {pad[pad_idx + 1]} resulted in a negative output size, "
@@ -2787,7 +2780,7 @@
 def contiguous(
     a: Tensor, *, memory_format: torch.memory_format = torch.contiguous_format
 ) -> Tensor:
-    check(
+    torch._check(
         memory_format != torch.preserve_format,
         lambda: "preserve memory format is unsupported by the contiguous operator",
     )
@@ -2800,7 +2793,7 @@
 
 @out_wrapper()
 def dstack(tensors: TensorSequenceType) -> TensorLikeType:
-    check(len(tensors) > 0, lambda: "dstack expects a non-empty TensorList")
+    torch._check(len(tensors) > 0, lambda: "dstack expects a non-empty TensorList")
     aligned_tensors = atleast_3d(*tensors)
     return cat(aligned_tensors, 2)
 
@@ -2813,7 +2806,7 @@
     if len(shape) == 1 and isinstance(shape[0], Sequence):
         shape = tuple(shape[0])
 
-    check(
+    torch._check(
         len(shape) >= len(a.shape),
         lambda: "expand: the requested shape has too few dimensions!",
     )
@@ -2823,7 +2816,7 @@
     for idx, x in enumerate(a.shape):
         offset_idx = idx + offset
         requested_length = shape[offset_idx]
-        check(
+        torch._check(
             requested_length == x or x == 1 or requested_length == -1,
             lambda: f"expand: attempting to expand a dimension of length {x}!",
         )
@@ -2917,13 +2910,13 @@
     # Supports Tensor overload that was added for XLA:
     # https://github.com/pytorch/pytorch/issues/31558
     if isinstance(start, TensorLike):
-        check(
+        torch._check(
             start.dim() == 0 and utils.is_integer_dtype(start.dtype),
             lambda: "start must be an 0-dim integral Tensor.",
         )
         start = start.item()  # type: ignore[assignment]
-    check(a.dim() > 0, lambda: "narrow() cannot be applied to a 0-dim tensor.")
-    check(length >= 0, lambda: "narrow(): length must be non-negative.")
+    torch._check(a.dim() > 0, lambda: "narrow() cannot be applied to a 0-dim tensor.")
+    torch._check(length >= 0, lambda: "narrow(): length must be non-negative.")
     dim = utils.canonicalize_dim(a.ndim, dim)
     dim_length = a.size(dim)
     # Start being the end is usually invalid since it's out of bounds. So it's
@@ -2934,7 +2927,7 @@
         # Note: a dimension isn't being canonicalized here, this reuses
         # canonicalize_dim because the semantics are similar.
         start = utils.canonicalize_dim(dim_length, start)  # type: ignore[arg-type]
-    check(
+    torch._check(
         start <= dim_length - length,  # type: ignore[arg-type]
         lambda: f"start ({start}) + length ({length}) exceeds dimension size ({dim_length}).",
     )
@@ -2993,11 +2986,11 @@
     num_groups: int,
     eps: float,
 ) -> Tuple[Tensor, Tensor, Tensor]:
-    utils.check(
+    torch._check(
         input.ndim >= 2,
         lambda: f"Expected at least 2 dimensions for input tensor but received {input.ndim}",
     )
-    utils.check(
+    torch._check(
         num_channels % num_groups == 0,
         lambda: "Expected number of channels in input to be divisible by num_groups, "
         + f"but got input of shape {input.shape} and num_groups = {num_groups}",
@@ -3044,7 +3037,7 @@
     eps: float,
 ) -> Tuple[Tensor, Tensor, Tensor]:
     normalized_ndim = len(normalized_shape)
-    utils.check(
+    torch._check(
         normalized_ndim >= 1,
         lambda: "Expected normalized_shape to be at least 1-dimensional, i.e., "
         + "containing at least one element, but got normalized_shape = "
@@ -3053,7 +3046,7 @@
     # torch.Size([1, 2, 3]) == [1, 2, 3] evaluates to False
     # while torch.Size([1, 2, 3]) == (1, 2, 3) is True
     # therefore we use tuple(normalized_shape)
-    utils.check(
+    torch._check(
         weight is None or weight.shape == tuple(normalized_shape),
         lambda: "Expected weight to be of same shape as normalized_shape, but got "
         + "weight of shape "
@@ -3061,7 +3054,7 @@
         + " and normalized_shape = "
         + str(normalized_shape),
     )
-    utils.check(
+    torch._check(
         bias is None or bias.shape == tuple(normalized_shape),
         lambda: "Expected bias to be of same shape as normalized_shape, but got "
         + "bias of shape "
@@ -3069,7 +3062,7 @@
         + " and normalized_shape = "
         + str(normalized_shape),
     )
-    utils.check(
+    torch._check(
         input.ndim >= normalized_ndim
         and input.shape[(input.ndim - normalized_ndim) :] == tuple(normalized_shape),
         lambda: "Given normalized_shape="
@@ -3123,12 +3116,12 @@
     max_size = 1 if a_ndim == 0 else a_shape[dim]
     last_stride = 1 if a_ndim == 0 else a_stride[dim]
 
-    utils.check(
+    torch._check(
         size <= max_size,
         lambda: f"Maximum size for tensor at dimension {dim} is {max_size} but size is {size}",
     )
 
-    utils.check(
+    torch._check(
         step > 0,
         lambda: f"Step is {step} but must be > 0",
     )
@@ -3146,7 +3139,7 @@
 @register_decomposition(aten.repeat)
 def repeat(a: Tensor, *repeat_shape) -> Tensor:
     repeat_shape = utils.extract_shape_from_varargs(repeat_shape, validate=False)
-    utils.check(
+    torch._check(
         len(repeat_shape) >= len(a.shape),
         lambda: "repeat: Number of dimensions of repeat dims can not be smaller than number of dimensions of tensor",
     )
@@ -3452,7 +3445,7 @@
 # CompositeImplicitAutograd - don't register decomp
 @out_wrapper()
 def hstack(tensors: TensorSequenceType) -> TensorLikeType:
-    check(len(tensors) > 0, lambda: "hstack expects a non-empty TensorList")
+    torch._check(len(tensors) > 0, lambda: "hstack expects a non-empty TensorList")
     aligned_tensors = atleast_1d(*tensors)
     if aligned_tensors[0].ndim == 1:
         return cat(aligned_tensors, 0)
@@ -3462,7 +3455,7 @@
 # CompositeImplicitAutograd - don't register decomp
 @out_wrapper()
 def vstack(tensors: TensorSequenceType) -> TensorLikeType:
-    check(len(tensors) > 0, lambda: "vstack expects a non-empty TensorList")
+    torch._check(len(tensors) > 0, lambda: "vstack expects a non-empty TensorList")
     aligned_tensors = atleast_2d(*tensors)
     return cat(aligned_tensors, 0)
 
@@ -3470,17 +3463,16 @@
 # CompositeImplicitAutograd - don't register decomp
 def unflatten(a: TensorLikeType, dim: int, sizes: ShapeType) -> TensorLikeType:
     dim = utils.canonicalize_dim(a.ndim, dim)
-    utils.check(len(sizes) != 0, lambda: "unflatten: sizes must be non-empty")
+    torch._check(len(sizes) != 0, lambda: "unflatten: sizes must be non-empty")
     return a.view(tuple(a.shape[:dim]) + tuple(sizes) + tuple(a.shape[dim + 1 :]))
 
 
 @register_decomposition(aten.unbind)
 def unbind(t: TensorLikeType, dim: int = 0) -> TensorSequenceType:
     dim = utils.canonicalize_dim(t.ndim, dim)
-    check(
+    torch._check_index(
         len(t.shape) > 0,
         lambda: "Dimension specified as 0 but tensor has no dimensions",
-        IndexError,
     )
     if t.shape[dim] == 0:
         return tuple()
@@ -3499,7 +3491,7 @@
 
 def index_copy_(x: TensorLike, dim: int, index: TensorLike, tensor: TensorLike):
     dim = utils.canonicalize_dims(x.ndim, dim)
-    utils.check(
+    torch._check(
         index.ndim <= 1,
         lambda: f"Index should have dimension 1 or 0 (got {index.ndim})",
     )
@@ -3532,12 +3524,12 @@
     *,
     inplace: bool,
 ):
-    utils.check(
+    torch._check(
         index.ndim <= 1,
         lambda: f"Index should have dimension 1 or 0 (got {index.ndim})",
     )
     if isinstance(value, TensorLike):
-        utils.check(
+        torch._check(
             value.ndim == 0,
             lambda: "Only supports 0-dimensional value tensor. "  # type: ignore[union-attr]
             f"Got a tensor with {value.ndim} dimensions.",
@@ -3589,7 +3581,7 @@
 @out_wrapper()
 def index_select(x: TensorLike, dim: int, index: TensorLike):
     dim = utils.canonicalize_dims(x.ndim, dim)
-    utils.check(
+    torch._check(
         index.ndim <= 1,
         lambda: f"Index should have dimension 1 or 0 (got {index.ndim})",
     )
@@ -3713,7 +3705,7 @@
 def hsplit(
     a: TensorLikeType, indices_or_sections: DimsType
 ) -> Tuple[TensorLikeType, ...]:
-    check(
+    torch._check(
         a.ndim >= 1,
         lambda: (
             "torch.hsplit requires a tensor with at least 1 dimension, but got a tensor with "
@@ -3724,7 +3716,7 @@
     dim = 0 if a.ndim == 1 else 1
     if isinstance(indices_or_sections, IntLike):
         split_size = indices_or_sections
-        check(
+        torch._check(
             (split_size != 0 and a.shape[dim] % split_size == 0),
             lambda: (
                 "torch.hsplit attempted to split along dimension "
@@ -3738,14 +3730,13 @@
         )
         return tensor_split(a, split_size, dim)
 
-    check(
+    torch._check_type(
         isinstance(indices_or_sections, (list, tuple)),
         lambda: (
             "hsplit(): received an invalid combination of arguments. "
             "Expected indices_or_sections to be of type int, list of ints or tuple of ints "
             f"but got type {type(indices_or_sections)}"
         ),
-        exc_type=TypeError,
     )
 
     split_sizes = indices_or_sections
@@ -3756,7 +3747,7 @@
 def vsplit(
     a: TensorLikeType, indices_or_sections: DimsType
 ) -> Tuple[TensorLikeType, ...]:
-    check(
+    torch._check(
         a.ndim >= 2,
         lambda: (
             "torch.vsplit requires a tensor with at least 2 dimension, but got a tensor with "
@@ -3766,7 +3757,7 @@
     )
     if isinstance(indices_or_sections, IntLike):
         split_size = indices_or_sections
-        check(
+        torch._check(
             (split_size != 0 and a.shape[0] % split_size == 0),
             lambda: (
                 f"torch.vsplit attempted to split along dimension 0"
@@ -3779,14 +3770,13 @@
         )
         return tensor_split(a, split_size, 0)
 
-    check(
+    torch._check_type(
         isinstance(indices_or_sections, (list, tuple)),
         lambda: (
             "vsplit(): received an invalid combination of arguments. "
             "Expected indices_or_sections to be of type int, list of ints or tuple of ints "
             f"but got type {type(indices_or_sections)}"
         ),
-        exc_type=TypeError,
     )
 
     split_sizes = indices_or_sections
@@ -3800,7 +3790,7 @@
     offset: int = 0,
 ) -> TensorLikeType:
     ndim = self.dim()
-    utils.check(
+    torch._check(
         ndim in (1, 2), lambda: f"diag(): Supports 1D or 2D tensors. Got {ndim}D"
     )
     if ndim == 1:
@@ -3820,7 +3810,7 @@
 ) -> TensorLikeType:
     out = utils.clone_preserve_strides(input)
     diag = out.diagonal(offset, dim1, dim2)
-    check(
+    torch._check(
         diag.shape == src.shape,
         lambda: "expected src to have a size equal to the diagonal of the input."
         f"Got {src.shape} for a diagonal of shape {diag.shape}",
@@ -3843,7 +3833,7 @@
     dim1 = utils.canonicalize_dim(idx=dim1, rank=num_dims)
     dim2 = utils.canonicalize_dim(idx=dim2, rank=num_dims)
 
-    check(
+    torch._check(
         dim1 != dim2, lambda: f"diagonal dimensions cannot be identical {dim1}, {dim2}"
     )
 
@@ -3896,7 +3886,7 @@
     dim1 = utils.canonicalize_dim(rank=rank, idx=dim1)
     dim2 = utils.canonicalize_dim(rank=rank, idx=dim2)
 
-    check(
+    torch._check(
         dim1 != dim2, lambda: f"diagonal dimensions cannot be identical {dim1}, {dim2}"
     )
 
@@ -3967,7 +3957,7 @@
 # CompositeImplicitAutograd - don't register decomp
 def T(a: TensorLikeType) -> TensorLikeType:
     # n != 2 && n != 0 is deprecated in regular PyTorch.
-    check(
+    torch._check(
         a.ndim in (0, 2),
         lambda: (
             "The use of `x.T` on tensors of dimension other than 0 or 2 "
@@ -4102,7 +4092,7 @@
     pin_memory: bool = False,
     memory_format: torch.memory_format = torch.contiguous_format,
 ) -> TensorLikeType:
-    check(
+    torch._check(
         memory_format != torch.preserve_format,
         lambda: "torch.empty: the Preserve memory format is not supported",
     )
@@ -4114,7 +4104,7 @@
     elif memory_format == torch.channels_last_3d:
         strides = utils.make_channels_last_3d_strides_for(shape)
     else:  # memory_format == torch.channels_last
-        check(
+        torch._check(
             memory_format == torch.channels_last,
             lambda: f"torch.empty: received an unknown memory format {memory_format}!",
         )
@@ -4398,8 +4388,8 @@
     if end is None:
         end = start
         start = 0
-    utils.check(step != 0, lambda: "step must be nonzero")
-    utils.check(
+    torch._check(step != 0, lambda: "step must be nonzero")
+    torch._check(
         (step > 0 and end >= start) or (step < 0 and end <= start),
         lambda: "upper bound and lower bound inconsistent with step sign",
     )
@@ -4407,11 +4397,11 @@
     def is_finite(x):
         return not isinstance(x, FloatWithoutSymFloat) or math.isfinite(x)
 
-    utils.check(
+    torch._check(
         is_finite(start) and is_finite(end),
         lambda: f"unsupported range: {start} -> {end}",
     )
-    utils.check(
+    torch._check(
         is_finite(step),
         lambda: f"step must be finite but got {step}",
     )
@@ -4514,7 +4504,7 @@
         if dtype is None:
             dtype = default_complex_dtype
         else:
-            check(
+            torch._check(
                 utils.is_complex_dtype(dtype),
                 lambda: f"linspace(): inferred dtype {default_complex_dtype} can't be safely cast to passed dtype {dtype}",
             )
@@ -4523,13 +4513,12 @@
     assert isinstance(dtype, torch.dtype)
 
     # steps does not participate in the computation of the dtype
-    check(
+    torch._check_type(
         isinstance(steps, IntLike),
         lambda: "steps must be int, not float",
-        exc_type=TypeError,
     )
     assert isinstance(steps, IntLike)  # for mypy
-    check(steps >= 0, lambda: "number of steps must be non-negative")
+    torch._check(steps >= 0, lambda: "number of steps must be non-negative")
 
     factory_kwargs = {
         "layout": layout,
@@ -4631,19 +4620,19 @@
         assert len(tensors) == 1
         tensors = tuple(tensors[0])
 
-    check(
+    torch._check(
         py_all(isinstance(a, TensorLike) for a in tensors),
         lambda: "meshgrid expects its inputs to be tensors",
     )
 
-    check(len(tensors) > 0, lambda: "meshgrid expects a non-empty TensorList")
+    torch._check(len(tensors) > 0, lambda: "meshgrid expects a non-empty TensorList")
 
     for i in range(len(tensors) - 1):
-        check(
+        torch._check(
             tensors[i].dtype == tensors[i + 1].dtype,  # type: ignore[union-attr]
             lambda: "meshgrid expects all tensors to have the same dtype",
         )
-        check(
+        torch._check(
             tensors[i].device == tensors[i + 1].device,  # type: ignore[union-attr]
             lambda: "meshgrid expects all tensors to have the same device",
         )
@@ -4654,7 +4643,7 @@
         if swap_first_and_second_tensors:
             tensors = (tensors[1], tensors[0], *tensors[2:])
     else:
-        check(
+        torch._check(
             indexing == "ij",
             lambda: (
                 'torch.meshgrid: indexing must be one of "xy" or "ij", '
@@ -4665,7 +4654,7 @@
     result_shape: List[int] = []
     for t in tensors:
         assert isinstance(t, TensorLike)  # mypy
-        check(
+        torch._check(
             t.ndim == 0 or t.ndim == 1,
             lambda: f"torch.meshgrid: Expected 0D or 1D tensor in the tensor list but got: {t}",
         )
@@ -4701,7 +4690,7 @@
 
     # Converts to list to produce a compatible error message with core PyTorch,
     # which prints sequences in square brackets.
-    utils.check(
+    torch._check(
         len(source) == len(destination),  # type: ignore[arg-type]
         lambda: (
             "movedim: Invalid source or destination dims: source "  # type: ignore[arg-type]
@@ -4718,11 +4707,11 @@
     dss = set(ds)
 
     # See above on why this converts to list in error messages.
-    utils.check(
+    torch._check(
         len(ss) == len(sss),
         lambda: f"movedim: repeated dim in `source` ({list(source)})",  # type: ignore[arg-type]
     )
-    utils.check(
+    torch._check(
         len(ds) == len(dss),
         lambda: f"movedim: repeated dim in `destination` ({list(destination)})",  # type: ignore[arg-type]
     )
@@ -4795,8 +4784,8 @@
     if m is None:
         m = n
 
-    check(n >= 0, lambda: f"n must be greater or equal to 0, got {n}")
-    check(m >= 0, lambda: f"m must be greater or equal to 0, got {m}")
+    torch._check(n >= 0, lambda: f"n must be greater or equal to 0, got {n}")
+    torch._check(m >= 0, lambda: f"m must be greater or equal to 0, got {m}")
 
     range_n = torch.arange(n, dtype=torch.int64, device=device, requires_grad=False)
     range_m = torch.arange(m, dtype=torch.int64, device=device, requires_grad=False)
@@ -4994,13 +4983,13 @@
         # NOTE: Could not use value = item(value) as it resulted in
         # RuntimeError: Cannot cast FakeTensor(cpu) to number
         value_ndim = value.ndim
-        check(
+        torch._check(
             value_ndim == 0,
             lambda: f"only supports a 0-dimensional value tensor, but got tensor with {value_ndim} dimension",
         )
         # `masked_fill` allows cpu scalar to be moved to cuda and xpu but not otherwise.
         is_cpu_scalar = a.device.type in ["cuda", "xpu"] and value.device.type == "cpu"
-        check(
+        torch._check(
             is_cpu_scalar or value.device == a.device,
             lambda: "Expected `value` to be on same device as `a`",
         )
@@ -5011,7 +5000,7 @@
         # We allow casting `value` to lower type for other case
         # Eg. float -> int.
         # Ref: https://github.com/pytorch/pytorch/issues/79195
-        check(
+        torch._check(
             utils.is_weakly_lesser_type(value_type, python_type),
             lambda: f"could not convert to type {python_type} without overflow",
         )
@@ -5101,7 +5090,7 @@
 
 @register_decomposition(aten.trace)
 def trace(self: TensorLikeType) -> TensorLikeType:
-    utils.check(
+    torch._check(
         self.ndim == 2, lambda: "expected a matrix, but got tensor with dim {self.ndim}"
     )
     return torch.sum(torch.diag(self, 0))
@@ -5125,7 +5114,7 @@
 @register_decomposition(aten.triu)
 @out_wrapper()
 def triu(a: TensorLikeType, diagonal: int = 0) -> TensorLikeType:
-    utils.check(
+    torch._check(
         a.ndim >= 2, lambda: "triu: input tensor must have at least 2 dimensions"
     )
     h, w = a.shape[-2:]
@@ -5142,7 +5131,7 @@
 @register_decomposition(aten.tril)
 @out_wrapper()
 def tril(a: TensorLikeType, diagonal: int = 0) -> TensorLikeType:
-    utils.check(
+    torch._check(
         a.ndim >= 2, lambda: "tril: input tensor must have at least 2 dimensions"
     )
     h, w = a.shape[-2:]
@@ -5187,9 +5176,9 @@
     layout: torch.layout,
     pin_memory: bool,
 ):
-    check(row >= 0, lambda: f"row must be non-negative, got {row}")
-    check(col >= 0, lambda: f"col must be non-negative, got {col}")
-    check(
+    torch._check(row >= 0, lambda: f"row must be non-negative, got {row}")
+    torch._check(col >= 0, lambda: f"col must be non-negative, got {col}")
+    torch._check(
         dtype in (torch.int32, torch.int64),
         lambda: f"\"{name}\" not implemented for '{dtype}'",
     )
@@ -5306,7 +5295,7 @@
     out_int32: bool = False,
     right: bool = False,
 ):
-    utils.check(
+    torch._check(
         boundaries.dim() == 1,
         lambda: f"boundaries tensor must be 1 dimension but got dim({boundaries.dim()})",
     )
@@ -5364,14 +5353,14 @@
 )
 def cauchy(self, median=0, sigma=1, generator=None):
     assert generator is None
-    utils.check(
+    torch._check(
         not utils.is_complex_dtype(self.dtype)
         and not utils.is_integer_dtype(self.dtype)
         and not utils.is_boolean_dtype(self.dtype),
         lambda: f"Cauchy distribution is a continuous probability distribution. \
         dtype must be a floating point but you specified {self.dtype}",
     )
-    utils.check(
+    torch._check(
         sigma > 0.0,
         lambda: f"cauchy_ expects sigma > 0.0, but found sigma={sigma}",
     )
@@ -5386,14 +5375,14 @@
 )
 def exponential(self, rate=1, generator=None):
     assert generator is None
-    utils.check(
+    torch._check(
         not utils.is_complex_dtype(self.dtype)
         and not utils.is_integer_dtype(self.dtype)
         and not utils.is_boolean_dtype(self.dtype),
         lambda: f"Exponential distribution is a continuous probability distribution. \
         dtype must be a floating point but you specified {self.dtype}",
     )
-    utils.check(
+    torch._check(
         rate > 0.0,
         lambda: f"exponential_ expects lambda > 0.0, but found lambda={rate}",
     )
@@ -5409,12 +5398,12 @@
 def geometric(self, p, generator=None):
     assert generator is None
     # TODO: fix inductor rand_like for integer, bool dtypes
-    utils.check(
+    torch._check(
         not utils.is_complex_dtype(self.dtype)
         and not utils.is_boolean_dtype(self.dtype),
         lambda: f"geometric not implemented for {self.dtype}",
     )
-    utils.check(
+    torch._check(
         0 < p and p < 1,
         lambda: f"geometric_ expects p to be in (0, 1), but got p={p}",
     )
@@ -5429,13 +5418,13 @@
 )
 def log_normal(self, mean=1, std=2, generator=None):
     assert generator is None
-    utils.check(
+    torch._check(
         not utils.is_complex_dtype(self.dtype)
         and not utils.is_integer_dtype(self.dtype)
         and not utils.is_boolean_dtype(self.dtype),
         lambda: f"log_normal not implemented for {self.dtype}",
     )
-    utils.check(
+    torch._check(
         0 < std,
         lambda: f"log_normal_ expects std > 0.0, but found std={std}",
     )
@@ -5451,7 +5440,7 @@
 )
 def normal(self, mean=0, std=1, generator=None):
     assert generator is None
-    utils.check(std >= 0, lambda: f"normal expects std >= 0.0, but found std {std}")
+    torch._check(std >= 0, lambda: f"normal expects std >= 0.0, but found std {std}")
     normal_samples = prims.normal(
         self.shape,
         mean=0.0,
@@ -5465,7 +5454,7 @@
 
 @_make_elementwise_unary_reference(ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT)
 def rad2deg(self: TensorLikeType):
-    utils.check(
+    torch._check(
         not utils.is_complex_dtype(self.dtype),
         lambda: "rad2deg is not supported for complex tensors.",
     )
@@ -5475,7 +5464,7 @@
 
 @_make_elementwise_unary_reference(ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT)
 def deg2rad(self: TensorLikeType):
-    utils.check(
+    torch._check(
         not utils.is_complex_dtype(self.dtype),
         lambda: "deg2rad is not supported for complex tensors.",
     )
diff --git a/torch/_refs/_conversions.py b/torch/_refs/_conversions.py
index d0607d8..fa1ca24 100644
--- a/torch/_refs/_conversions.py
+++ b/torch/_refs/_conversions.py
@@ -4,7 +4,7 @@
 # Utilities should come BEFORE this import
 from torch._decomp import register_decomposition
 
-from torch._prims_common import check, TensorLikeType
+from torch._prims_common import TensorLikeType
 from torch._prims_common.wrappers import out_wrapper
 from torch._refs import _broadcast_shapes
 
@@ -79,14 +79,14 @@
 @out_wrapper(exact_dtype=True)
 def complex(real: TensorLikeType, imag: TensorLikeType) -> TensorLikeType:
     allowed_dtypes = (torch.float32, torch.float64, torch.float16)
-    check(
+    torch._check(
         real.dtype in allowed_dtypes and imag.dtype in allowed_dtypes,
         lambda: (
             f"Expected both inputs to be Half, Float or Double tensors but got "
             f"{real.dtype} and {imag.dtype}"
         ),
     )
-    check(
+    torch._check(
         real.dtype == imag.dtype,
         lambda: (
             f"Expected object of scalar type {real.dtype} but got "
diff --git a/torch/_refs/fft.py b/torch/_refs/fft.py
index 54a98c2..4b22d86 100644
--- a/torch/_refs/fft.py
+++ b/torch/_refs/fft.py
@@ -6,7 +6,7 @@
 import torch._prims as prims
 import torch._prims_common as utils
 from torch._decomp import register_decomposition
-from torch._prims_common import check, DimsType, ShapeType, TensorLikeType
+from torch._prims_common import DimsType, ShapeType, TensorLikeType
 from torch._prims_common.wrappers import _maybe_convert_to_dtype, out_wrapper
 
 __all__ = [
@@ -43,7 +43,7 @@
     x: TensorLikeType, norm: NormType, signal_numel: int, forward: bool
 ) -> TensorLikeType:
     """Apply normalization to the un-normalized FFT result"""
-    check(norm in _NORM_VALUES, lambda: f"Invalid normalization mode: {norm}")
+    torch._check(norm in _NORM_VALUES, lambda: f"Invalid normalization mode: {norm}")
 
     if norm == "ortho":
         return x * (1 / math.sqrt(signal_numel))
@@ -116,7 +116,9 @@
     input = _maybe_promote_tensor_fft(input, require_complex=True)
     dims = (utils.canonicalize_dim(input.ndim, dim, wrap_scalar=False),)
     last_dim_size = n if n is not None else 2 * (input.shape[dim] - 1)
-    check(last_dim_size >= 1, lambda: f"Invalid number of data points ({n}) specified")
+    torch._check(
+        last_dim_size >= 1, lambda: f"Invalid number of data points ({n}) specified"
+    )
 
     if n is not None:
         input = _resize_fft_input(input, dims=dims, sizes=(last_dim_size // 2 + 1,))
@@ -138,7 +140,7 @@
     onesided: bool,
 ) -> TensorLikeType:
     """Common code for performing any real to complex FFT (rfft or ihfft)"""
-    check(
+    torch._check(
         not input.dtype.is_complex,
         lambda: f"{func_name} expects a floating point input tensor, but got {input.dtype}",
     )
@@ -162,7 +164,7 @@
     forward: bool,
 ) -> TensorLikeType:
     """Common code for performing any complex to complex FFT (fft or ifft)"""
-    check(
+    torch._check(
         input.dtype.is_complex,
         lambda: f"{func_name} expects a complex input tensor, but got {input.dtype}",
     )
@@ -265,20 +267,20 @@
         ret_dims = utils.canonicalize_dims(input_dim, dim, wrap_scalar=False)
 
         # Check dims are unique
-        check(len(set(dim)) == len(dim), lambda: "FFT dims must be unique")
+        torch._check(len(set(dim)) == len(dim), lambda: "FFT dims must be unique")
 
     if shape is not None:
         if not isinstance(shape, Sequence):
             shape = (shape,)
 
         # Has shape, might have dim
-        check(
+        torch._check(
             dim is None or len(dim) == len(shape),
             lambda: "When given, dim and shape arguments must have the same length",
         )
         transform_ndim = len(shape)
 
-        check(
+        torch._check(
             transform_ndim <= input_dim,
             lambda: f"Got shape with {transform_ndim} values but input tensor "
             f"only has {input_dim} dimensions.",
@@ -301,7 +303,7 @@
         ret_shape = tuple(input_sizes[d] for d in ret_dims)
 
     for n in ret_shape:
-        check(n > 0, lambda: f"Invalid number of data points ({n}) specified")
+        torch._check(n > 0, lambda: f"Invalid number of data points ({n}) specified")
 
     return _ShapeAndDims(shape=ret_shape, dims=ret_dims)
 
@@ -323,7 +325,7 @@
     forward: bool,
 ) -> TensorLikeType:
     """Common code for n-dimensional complex to complex FFTs (fftn or ifftn)"""
-    check(
+    torch._check(
         input.dtype.is_complex,
         lambda: f"{function_name} expects a complex input tensor, "
         f"but got {input.dtype}",
@@ -367,7 +369,7 @@
     dim: Optional[DimsType] = None,
     norm: NormType = None,
 ) -> TensorLikeType:
-    check(
+    torch._check(
         not input.dtype.is_complex,
         lambda: f"rfftn expects a real-valued input tensor, but got {input.dtype}",
     )
@@ -386,12 +388,12 @@
     dim: Optional[DimsType] = None,
     norm: NormType = None,
 ) -> TensorLikeType:
-    check(
+    torch._check(
         not input.dtype.is_complex,
         lambda: f"ihfftn expects a real-valued input tensor, but got {input.dtype}",
     )
     shape, dim = _canonicalize_fft_shape_and_dim_args(input, s, dim)
-    check(len(shape) > 0, lambda: "ihfftn must transform at least one axis")
+    torch._check(len(shape) > 0, lambda: "ihfftn must transform at least one axis")
     input = _maybe_promote_tensor_fft(input, require_complex=False)
     input = _resize_fft_input(input, dim, shape)
 
@@ -421,14 +423,14 @@
     """Canonicalize shape and dim arguments for n-dimensional c2r transforms,
     as well as calculating the last_dim_size which is shape[dim[-1]] for the output"""
     (shape, dim) = _canonicalize_fft_shape_and_dim_args(input, s, dim)
-    check(len(shape) > 0, lambda: f"{fname} must transform at least one axis")
+    torch._check(len(shape) > 0, lambda: f"{fname} must transform at least one axis")
 
     if s is None or s[-1] == -1:
         last_dim_size = 2 * (input.shape[dim[-1]] - 1)
     else:
         last_dim_size = shape[-1]
 
-    check(
+    torch._check(
         last_dim_size >= 1,
         lambda: f"Invalid number of data points ({last_dim_size}) specified",
     )
diff --git a/torch/_refs/linalg/__init__.py b/torch/_refs/linalg/__init__.py
index f22926c..97fb1d9 100644
--- a/torch/_refs/linalg/__init__.py
+++ b/torch/_refs/linalg/__init__.py
@@ -11,7 +11,6 @@
 import torch._refs.linalg as linalg
 from torch import Tensor
 from torch._prims_common import (
-    check,
     check_fp_or_complex,
     check_is_matrix,
     Dim,
@@ -29,11 +28,11 @@
     Checks related to the dtype kwarg in `linalg.*norm` functions
     """
     if dtype is not None:
-        check(
+        torch._check(
             utils.is_float_dtype(dtype) or utils.is_complex_dtype(dtype),
             lambda: f"{fn_name}: dtype should be floating point or complex. Got {dtype}",
         )
-        check(
+        torch._check(
             utils.is_complex_dtype(dtype) == utils.is_complex_dtype(x_dtype),
             lambda: "{fn_name}: dtype should be {d} for {d} inputs. Got {dtype}".format(
                 fn_name=fn_name,
@@ -41,7 +40,7 @@
                 dtype=dtype,
             ),
         )
-        check(
+        torch._check(
             utils.get_higher_dtype(dtype, x_dtype) == dtype,
             lambda: f"{fn_name}: the dtype of the input ({x_dtype}) should be convertible "
             "without narrowing to the specified dtype ({dtype})",
@@ -79,7 +78,7 @@
         dim = [dim]  # type: ignore[assignment]
 
     if x.numel() == 0 and (ord < 0.0 or ord == float("inf")):
-        check(
+        torch._check(
             dim is not None and len(dim) != 0,
             lambda: f"linalg.vector_norm cannot compute the {ord} norm on an empty tensor "
             "because the operation does not have an identity",
@@ -87,7 +86,7 @@
         shape = x.shape
         assert dim is not None  # mypy does not seem to be able to see through check?
         for d in dim:
-            check(
+            torch._check(
                 shape[d] != 0,
                 lambda: f"linalg.vector_norm cannot compute the {ord} norm on the "
                 f"dimension {d} because this dimension is empty and the "
@@ -147,8 +146,10 @@
     dim = utils.canonicalize_dims(A.ndim, dim)
     if isinstance(dim, Dim):
         dim = (dim,)  # type: ignore[assignment]
-    check(len(dim) == 2, lambda: "linalg.matrix_norm: dim must be a 2-tuple. Got {dim}")
-    check(
+    torch._check(
+        len(dim) == 2, lambda: "linalg.matrix_norm: dim must be a 2-tuple. Got {dim}"
+    )
+    torch._check(
         dim[0] != dim[1],
         lambda: "linalg.matrix_norm: dims must be different. Got ({dim[0]}, {dim[1]})",
     )
@@ -157,7 +158,7 @@
 
     if isinstance(ord, str):
         # ord
-        check(
+        torch._check(
             ord in ("fro", "nuc"),
             lambda: "linalg.matrix_norm: Order {ord} not supported.",
         )
@@ -180,7 +181,7 @@
     else:
         # ord
         abs_ord = abs(ord)
-        check(
+        torch._check(
             abs_ord in (2, 1, float("inf")),
             lambda: "linalg.matrix_norm: Order {ord} not supported.",
         )
@@ -224,12 +225,12 @@
     if dim is not None:
         if isinstance(dim, Dim):
             dim = (dim,)  # type: ignore[assignment]
-        check(
+        torch._check(
             len(dim) in (1, 2),
             lambda: "linalg.norm: If dim is specified, it must be of length 1 or 2. Got {dim}",
         )
     elif ord is not None:
-        check(
+        torch._check(
             A.ndim in (1, 2),
             lambda: "linalg.norm: If dim is not specified but ord is, the input must be 1D or 2D. Got {A.ndim}D",
         )
diff --git a/torch/_refs/nn/functional/__init__.py b/torch/_refs/nn/functional/__init__.py
index be82d0a..eaa6618 100644
--- a/torch/_refs/nn/functional/__init__.py
+++ b/torch/_refs/nn/functional/__init__.py
@@ -8,7 +8,6 @@
 import torch._refs as refs
 from torch._decomp import register_decomposition
 from torch._prims_common import (
-    check,
     ELEMENTWISE_TYPE_PROMOTION_KIND,
     NumberType,
     ShapeType,
@@ -98,7 +97,7 @@
     if not training:
         return self
 
-    utils.check(
+    torch._check(
         p <= 1 and p >= 0,
         lambda: f"dropout probability has to be between 0 and 1, but got, {p}",
     )
@@ -134,7 +133,7 @@
     @wraps(fn)
     def _fn(a, *args, inplace=False, **kwargs):
         if inplace:
-            check(
+            torch._check(
                 "out" not in kwargs,
                 lambda: "Cannot set inplace=True and pass out= at the same time",
             )
@@ -193,7 +192,7 @@
     if not training:
         return a
 
-    utils.check(
+    torch._check(
         p <= 1 and p >= 0,
         lambda: f"dropout probability has to be between 0 and 1, but got, {p}",
     )
@@ -232,15 +231,15 @@
 
     # nb. This should be factored out into a can_cast aux function
     python_type = utils.dtype_to_type(a.dtype)
-    check(
+    torch._check(
         utils.is_weakly_lesser_type(type(input_scale), python_type),
         lambda: f"input_scale argument of type {type(input_scale)} cannot be safely cast to type {python_type}!",
     )
-    check(
+    torch._check(
         utils.is_weakly_lesser_type(type(scale), python_type),
         lambda: f"scale argument of type {type(scale)} cannot be safely cast to type {python_type}!",
     )
-    check(
+    torch._check(
         utils.is_weakly_lesser_type(type(alpha), python_type),
         lambda: f"alpha argument of type {type(alpha)} cannot be safely cast to type {python_type}!",
     )
@@ -276,14 +275,14 @@
     """
     Reference implementation of :func:`torch.nn.functional.group_norm`.
     """
-    utils.check(
+    torch._check(
         input.ndim >= 2,
         lambda: f"Expected at least 2 dimensions for input tensor but received {input.ndim}",
     )
 
     batch_size = input.shape[0]
     num_channels = input.shape[1]
-    utils.check(
+    torch._check(
         num_channels % num_groups == 0,
         lambda: "Expected number of channels in input to be divisible by num_groups, "
         + f"but got input of shape {input.shape} and num_groups = {num_groups}",
@@ -394,7 +393,7 @@
     # deprecated.  For PrimTorch, it's fine to drop support for deprecated
     # behavior because it requires explicit opt in.  This error is to inform
     # users how to update their calls.
-    check(dim is not None, lambda: "implicit dim not supported, use dim=X")
+    torch._check(dim is not None, lambda: "implicit dim not supported, use dim=X")
     return torch.softmax(a=a, dim=dim, dtype=dtype)  # type: ignore[call-overload]
 
 
@@ -409,7 +408,7 @@
     # deprecated.  For PrimTorch, it's fine to drop support for deprecated
     # behavior because it requires explicit opt in.  This error is to inform
     # users how to update their calls.
-    check(dim is not None, lambda: "implicit dim not supported, use dim=X")
+    torch._check(dim is not None, lambda: "implicit dim not supported, use dim=X")
     return torch.softmax(a=-a, dim=dim, dtype=dtype)  # type: ignore[call-overload]
 
 
@@ -469,7 +468,7 @@
     # softshrink(x) = x - lambd if x > lambd
     #               = x + lambd if x < -lambd
     #               = 0 otherwise
-    check(
+    torch._check(
         lambd >= 0,
         lambda: f"lambda must be greater or equal to 0, but found to be {lambd}",
     )
@@ -596,7 +595,7 @@
     # deprecated.  For PrimTorch, it's fine to drop support for deprecated
     # behavior because it requires explicit opt in.  This error is to inform
     # users how to update their calls.
-    check(dim is not None, lambda: "implicit dim not supported, use dim=X")
+    torch._check(dim is not None, lambda: "implicit dim not supported, use dim=X")
     return torch.log_softmax(a=a, dim=dim, dtype=dtype)  # type: ignore[call-overload]
 
 
@@ -668,12 +667,12 @@
     reduction: str,
     ignore_index: int,
 ) -> TensorLikeType:
-    utils.check(
+    torch._check(
         input.ndim > 0 and input.ndim <= 3,
         lambda: f"Expected input dimension to be either [1, 2, 3] but received {input.ndim}.",
     )
 
-    utils.check(
+    torch._check(
         (input.ndim == 1) or (input.shape[0] == target.shape[0]),
         lambda: f"Expected input batch size {input.shape[0]} to match target batch size {target.shape[0]}.",
     )
@@ -693,7 +692,7 @@
         (flat_target >= 0), (flat_target < num_classes)
     )
     class_check = torch.all(torch.logical_or(ignore_classes_mask, valid_classes_mask))
-    utils.check(
+    torch._check(
         isinstance(target, FakeTensor) or bool(class_check.item()),
         lambda: "A target class is out-of-bounds and not the ignore index.",
     )
@@ -758,7 +757,7 @@
     """
     Reference implementation of torch.nn.functional.nll_loss
     """
-    utils.check(
+    torch._check(
         input.ndim > 0,
         lambda: f"Expected input tensor to have 1 or more dimensions (got {input.ndim})",
     )
@@ -796,9 +795,13 @@
     # For ndim > 3, we reshape the input and target to 3-D case.
     # Input (N batch-size, C classes, k-dimensions)
     # Target (N batch-size, k-dimensions)
-    utils.check(
+    torch._check(
         input.ndim > 0 and target.ndim > 0 and target.shape[1:] == input.shape[2:],
-        lambda: f"Expected target shape {out_size} but got {target.shape}",
+        lambda: (
+            "Expected input and target to both have ndim > 0 and "
+            "target.shape[1:] == input.shape[2:], but got "
+            f"target.shape {target.shape} and input.shape {input.shape}"
+        ),
     )
 
     batch_size = input.shape[0]
@@ -837,7 +840,7 @@
     if type(reduction) is int:
         reduction = _reduction_int_to_str(reduction)
     _check_reduction_value(reduction)  # type: ignore[arg-type]
-    check(
+    torch._check(
         delta > 0,
         lambda: "huber_loss does not support non-positive values for delta.",
     )
@@ -938,7 +941,7 @@
     a_dim = anchor.ndim
     p_dim = positive.ndim
     n_dim = negative.ndim
-    check(
+    torch._check(
         a_dim == p_dim and p_dim == n_dim,
         lambda: (
             f"The anchor, positive, and negative tensors are expected to have "
@@ -1075,25 +1078,25 @@
     """
     Reference implementation of torch.nn.functional.prelu
     """
-    check(
+    torch._check(
         isinstance(a, TensorLike),
         lambda: f"prelu: Expected `a` to be tensor, but got: {type(a)}",
     )
-    check(
+    torch._check(
         isinstance(weight, TensorLike),
         lambda: f"prelu: Expected `weight` to be tensor, but got: {type(weight)}",
     )
 
     if weight.numel() != 1:
-        check(a.ndim > 0, lambda: "Not allow zero-dim input tensor.")
+        torch._check(a.ndim > 0, lambda: "Not allow zero-dim input tensor.")
         channel_size = a.shape[1] if a.ndim >= 2 else 1
-        check(
+        torch._check(
             weight.numel() == channel_size,
             lambda: f"Mismatch of parameter numbers and input channel size. Found parameter numbers ="
             f" {weight.numel()} and channel size = {channel_size}.",
         )
 
-    check(
+    torch._check(
         weight.ndim == 0 or weight.ndim == 1,
         lambda: f"prelu: Expected `weight` to be a scalar or 1D tensor, but got: "
         f"ndim = {weight.ndim}",
@@ -1132,7 +1135,7 @@
 )
 def glu(a: TensorLikeType, dim: int = -1) -> TensorLikeType:
     dim = utils.canonicalize_dims(a.ndim, dim)
-    check(
+    torch._check(
         a.shape[dim] % 2 == 0,
         lambda: f"Halving dimension must be even, but dimension {dim} is size {a.shape[dim]}",
     )
@@ -1160,8 +1163,8 @@
     type_promotion_kind=utils.ELEMENTWISE_TYPE_PROMOTION_KIND.DEFAULT,
 )
 def pdist(a: TensorLikeType, p: float = 2) -> TensorLikeType:
-    check(a.ndim == 2, lambda: f"pdist only supports 2D tensors, got: {a.ndim}D")
-    check(p >= 0, lambda: "pdist only supports non-negative p values")
+    torch._check(a.ndim == 2, lambda: f"pdist only supports 2D tensors, got: {a.ndim}D")
+    torch._check(p >= 0, lambda: "pdist only supports non-negative p values")
     # For p == 2 we can use an efficient implementation, but other values of p
     # require creating a much bigger tensor for an intermediate step
     if p == 2:
diff --git a/torch/_refs/special/__init__.py b/torch/_refs/special/__init__.py
index 4369265..048de83 100644
--- a/torch/_refs/special/__init__.py
+++ b/torch/_refs/special/__init__.py
@@ -148,7 +148,7 @@
     type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT,
 )
 def xlog1py(a: Union[TensorLikeType, NumberType], b: Union[TensorLikeType, NumberType]):
-    utils.check(
+    torch._check(
         isinstance(a, TensorLike) or isinstance(b, TensorLike),
         lambda: 'Expected either argument a or b to be a Tensor"',
     )
diff --git a/torch/_subclasses/fake_tensor.py b/torch/_subclasses/fake_tensor.py
index 3ba091a..fe1dd93 100644
--- a/torch/_subclasses/fake_tensor.py
+++ b/torch/_subclasses/fake_tensor.py
@@ -15,7 +15,6 @@
 from torch._guards import Source
 from torch._ops import OpOverload
 from torch._prims_common import (
-    check,
     elementwise_dtypes,
     ELEMENTWISE_TYPE_PROMOTION_KIND,
     is_boolean_dtype,
@@ -1495,7 +1494,7 @@
                 ) = FakeTensor._find_common_device(func, args, kwargs)
 
             if isinstance(e, FakeTensor):
-                check(
+                torch._check(
                     e.device == common_device,
                     lambda: f"FakeTensor is wrapped to wrong device, found {e.device}, expected {common_device}",
                 )