Replace `_prims_common.check` with `torch._check*` (#103240)
This relands most of the changes from #102219 which were backed out by #103128. However, instead of removing `_prims_common.check`, it adds a warning and a comment mentioning that it will be removed in the future and `torch._check*` should be used instead. As mentioned in https://github.com/pytorch/pytorch/pull/103128#pullrequestreview-1466414415, `_prims_common.check` cannot yet be removed because of some internal usage
Part of #72948
Pull Request resolved: https://github.com/pytorch/pytorch/pull/103240
Approved by: https://github.com/albanD
diff --git a/test/test_prims.py b/test/test_prims.py
index da1f5a1..ad1e63f 100644
--- a/test/test_prims.py
+++ b/test/test_prims.py
@@ -1161,6 +1161,10 @@
def test_mul_complex(self):
prims.mul(torch.randn(2), 1 + 1j)
+ def test_check_deprecation_warning(self):
+ with self.assertWarnsRegex(DeprecationWarning, 'will be removed in the future'):
+ torch._prims_common.check(True, lambda: 'message')
+
instantiate_device_type_tests(TestPrims, globals())
diff --git a/torch/__init__.py b/torch/__init__.py
index 148336c..b4c45f7 100644
--- a/torch/__init__.py
+++ b/torch/__init__.py
@@ -936,7 +936,7 @@
# These error checking functions must be kept consistent with their C++
# equivalents. Their C++ equivalents are mentioned where applicable.
-def _check_with(error_type, cond, message):
+def _check_with(error_type, cond: Union[builtins.bool, SymBool], message: Callable[[], str]):
if not isinstance(cond, (builtins.bool, torch.SymBool)):
raise TypeError(f'cond must be a bool, but got {type(cond)}')
diff --git a/torch/_decomp/decompositions.py b/torch/_decomp/decompositions.py
index 0c307c8..8715b1f 100644
--- a/torch/_decomp/decompositions.py
+++ b/torch/_decomp/decompositions.py
@@ -149,7 +149,7 @@
@register_decomposition([aten.fill.Tensor])
def fill_tensor(self, value: Tensor):
- utils.check(
+ torch._check(
value.dim() == 0,
lambda: f"fill only supports 0-dimension value tensor but got tensor with {value.dim()} dimensions",
)
@@ -785,14 +785,14 @@
padding: List[int],
stride: List[int],
) -> Tensor:
- utils.check(len(kernel_size) == 2, lambda: "im2col(): only 2D kernel supported")
- utils.check(len(dilation) == 2, lambda: "im2col(): only 2D dilation supported")
- utils.check(len(padding) == 2, lambda: "im2col(): only 2D padding supported")
- utils.check(len(stride) == 2, lambda: "im2col(): only 2D stride supported")
+ torch._check(len(kernel_size) == 2, lambda: "im2col(): only 2D kernel supported")
+ torch._check(len(dilation) == 2, lambda: "im2col(): only 2D dilation supported")
+ torch._check(len(padding) == 2, lambda: "im2col(): only 2D padding supported")
+ torch._check(len(stride) == 2, lambda: "im2col(): only 2D stride supported")
def check_positive(param, param_name, strict=True):
cond = all(p > 0 for p in param) if strict else all(p >= 0 for p in param)
- utils.check(
+ torch._check(
cond, lambda: "{param_name} should be greater {'than' zero, but got {param}"
)
@@ -803,7 +803,7 @@
shape = input.shape
ndim = len(shape)
- utils.check(
+ torch._check(
ndim in (3, 4) and all(d != 0 for d in shape[-3:]),
lambda: "Expected 3D or 4D (batch mode) tensor for input with possible 0 batch size "
f"and non-zero dimensions, but got: {tuple(shape)}",
@@ -814,7 +814,7 @@
shape[-2:], padding, dilation, kernel_size, stride
)
)
- utils.check(
+ torch._check(
all(c > 0 for c in output_size),
lambda: f"Given an input with spacial size {tuple(shape[-2:])}, "
f"kernel_size={kernel_size}, dilation={dilation}, "
@@ -869,15 +869,15 @@
padding: List[int],
stride: List[int],
) -> Tensor:
- utils.check(len(output_size) == 2, lambda: "only 2D output_size supported")
- utils.check(len(kernel_size) == 2, lambda: "only 2D kernel supported")
- utils.check(len(dilation) == 2, lambda: "only 2D dilation supported")
- utils.check(len(padding) == 2, lambda: "only 2D padding supported")
- utils.check(len(stride) == 2, lambda: "only 2D stride supported")
+ torch._check(len(output_size) == 2, lambda: "only 2D output_size supported")
+ torch._check(len(kernel_size) == 2, lambda: "only 2D kernel supported")
+ torch._check(len(dilation) == 2, lambda: "only 2D dilation supported")
+ torch._check(len(padding) == 2, lambda: "only 2D padding supported")
+ torch._check(len(stride) == 2, lambda: "only 2D stride supported")
def check_positive(param, param_name, strict=True):
cond = all(p > 0 for p in param) if strict else all(p >= 0 for p in param)
- utils.check(
+ torch._check(
cond, lambda: "{param_name} should be greater than zero, but got {param}"
)
@@ -889,13 +889,13 @@
shape = input.shape
ndim = len(shape)
- utils.check(
+ torch._check(
ndim in (2, 3) and all(d != 0 for d in shape[-2:]),
lambda: "Expected 2D or 3D (batch mode) tensor for input with possible 0 batch size "
f"and non-zero dimensions, but got: {tuple(shape)}",
)
prod_kernel_size = kernel_size[0] * kernel_size[1]
- utils.check(
+ torch._check(
shape[-2] % prod_kernel_size == 0,
lambda: "Expected size of input's first non-batch dimension to be divisible by the "
f"product of kernel_size, but got input.shape[-2] = {shape[-2]} and "
@@ -908,13 +908,13 @@
)
]
L = col[0] * col[1]
- utils.check(
+ torch._check(
shape[-1] == L,
lambda: f"Given output_size={output_size}, kernel_size={kernel_size}, "
f"dilation={dilation}, padding={padding}, stride={stride}, "
f"expected input.size(-1) to be {L} but got {shape[-1]}.",
)
- utils.check(
+ torch._check(
L > 0,
lambda: f"Given output_size={output_size}, kernel_size={kernel_size}, "
f"dilation={dilation}, padding={padding}, stride={stride}, "
@@ -961,7 +961,7 @@
def native_dropout_backward(grad_output: Tensor, mask: Tensor, scale: float):
# According to the CUDA kernel implementation we should have this test;
# but it seems to fail tests!
- # utils.check(mask.dtype == torch.bool, lambda: f"Mask should be Bool Scalar Type {mask.dtype}")
+ # torch._check(mask.dtype == torch.bool, lambda: f"Mask should be Bool Scalar Type {mask.dtype}")
# Mimicking CUDA kernel's behavior for output stride: output follow input's memory format
# This different from TensorIterator's behavior
@@ -1221,21 +1221,21 @@
)
utils.check_same_shape(input, grad_output, allow_cpu_scalar_tensors=False)
utils.check_same_shape(mean, rstd, allow_cpu_scalar_tensors=False)
- utils.check(
+ torch._check(
input.numel() == N * C * HxW,
lambda: f"Expect input to have { N * C * HxW} elements",
)
- utils.check(
+ torch._check(
mean.shape == (N, group),
lambda: f"Expect mean to have shape ({N}, {group}, but got {mean.shape}",
)
- utils.check(
+ torch._check(
gamma is None or gamma.numel() == C,
lambda: f"Expect gamma to have {C} elements but got {gamma.numel() if gamma is not None else -1}",
)
cpg, _rem = divmod(C, group)
- utils.check(
+ torch._check(
_rem == 0,
lambda: f"Expect number of channels {C} to be evenly-divisible by number of groups {group}",
)
@@ -1834,12 +1834,12 @@
device = input.device
shape = input.shape
ndim = len(shape)
- utils.check(
+ torch._check(
ndim in (3, 4),
lambda: f"adaptive_avg_pool2d(): Expected 3D or 4D tensor, but got {ndim}",
)
for d in input.shape[-2:]:
- utils.check(
+ torch._check(
d != 0,
lambda: "adaptive_avg_pool2d(): Expected input to have non-zero size for "
f"non-batch dimensions, but input has shape {tuple(shape)}.",
@@ -1966,13 +1966,13 @@
alpha: NumberType = 1,
):
dim = utils.canonicalize_dims(x.ndim, dim)
- utils.check(
+ torch._check(
index.ndim <= 1,
lambda: f"Index should have dimension 1 or 0 (got {index.ndim})",
)
if alpha != 1:
python_type = utils.dtype_to_type(x.dtype)
- utils.check(
+ torch._check(
python_type == bool
or utils.is_weakly_lesser_type(type(alpha), python_type),
lambda: f"alpha argument of type {type(alpha)} cannot be safely cast to type {python_type}!",
@@ -2005,7 +2005,7 @@
x: TensorLike, dim: int, index: TensorLike, tensor: TensorLike, *, inplace: bool
):
dim = utils.canonicalize_dims(x.ndim, dim)
- utils.check(
+ torch._check(
index.ndim <= 1,
lambda: f"Index should have dimension 1 or 0 (got {index.ndim})",
)
@@ -2060,19 +2060,19 @@
def upsample_compute_output_size(input_size, output_size, scale_factors):
spatial_dimensions = len(input_size) - 2
if output_size is not None:
- utils.check(
+ torch._check(
scale_factors is None,
lambda: "Must specify exactly one of output_size and scale_factors",
)
- utils.check(len(output_size) == spatial_dimensions, lambda: "")
+ torch._check(len(output_size) == spatial_dimensions, lambda: "")
return output_size
if scale_factors is not None:
# NB: this isn't necessary lol
- utils.check(
+ torch._check(
output_size is None,
lambda: "Must specify exactly one of output_size and scale_factors",
)
- utils.check(len(scale_factors) == spatial_dimensions, lambda: "")
+ torch._check(len(scale_factors) == spatial_dimensions, lambda: "")
output_size = []
for i, s in enumerate(scale_factors):
if int(s) == s:
@@ -2080,7 +2080,7 @@
else:
output_size.append(sym_int(input_size[i + 2] * s))
return output_size
- utils.check(
+ torch._check(
False, lambda: "Must specify exactly one of output_size and scale_factors"
)
@@ -2969,11 +2969,11 @@
padding_mode: int = 0,
align_corners: bool = False,
) -> Tensor:
- utils.check(
+ torch._check(
interpolation_mode in (0, 1, 2),
lambda: f"Invalid interpolation mode {interpolation_mode}",
)
- utils.check(
+ torch._check(
padding_mode in (0, 1, 2), lambda: f"Invalid padding mode {padding_mode}"
)
@@ -3110,11 +3110,11 @@
@out_wrapper()
@pw_cast_for_opmath
def mv(self, vec):
- utils.check(
+ torch._check(
self.dim() == 2 and vec.dim() == 1,
lambda: f"matrix @ vector expected, got {self.dim()}, {vec.dim()}",
)
- utils.check(
+ torch._check(
self.size(1) == vec.size(0),
lambda: f"size mismatch, got input ({self.size(0)}x{self.size(1)}), vec ({vec.size(0)})",
)
@@ -3134,11 +3134,11 @@
elif other.is_conj():
return torch.vdot(other.conj(), self)
- utils.check(
+ torch._check(
self.dim() == 1 and other.dim() == 1,
lambda: f"1D tensors expected, but got {self.dim()}D and {other.dim()}D tensors",
)
- utils.check(
+ torch._check(
self.dtype == other.dtype,
lambda: f"dot : expected both vectors to have same dtype, but found {self.dtype} and {other.dtype}",
)
@@ -3149,7 +3149,7 @@
f"same number of elements, but got {self.numel()} and {other.numel()} elements respectively"
)
- utils.check(self.numel() == other.numel(), numel_error)
+ torch._check(self.numel() == other.numel(), numel_error)
return (self * other).sum()
@@ -3296,7 +3296,7 @@
return tensor1_expanded.bmm(tensor2_expanded).view(output_shape)
else:
- utils.check(False, lambda: "both arguments to matmul need to be at least 1D")
+ torch._check(False, lambda: "both arguments to matmul need to be at least 1D")
@register_decomposition(aten.upsample_bicubic2d.default)
@@ -3373,7 +3373,7 @@
align_corners: bool,
scale_factors: Optional[Tuple[float, float]] = None,
) -> Tensor:
- utils.check(
+ torch._check(
bool(output_size) + bool(scale_factors) == 1,
lambda: "Must specify exactly one of output_size and scale_factors.",
)
diff --git a/torch/_inductor/cudagraph_trees.py b/torch/_inductor/cudagraph_trees.py
index e74fe63..fccdad3 100644
--- a/torch/_inductor/cudagraph_trees.py
+++ b/torch/_inductor/cudagraph_trees.py
@@ -72,7 +72,6 @@
remove_unaligned_input_idxs,
static_input,
)
-from torch._prims_common import check
from torch.multiprocessing.reductions import StorageWeakRef
from torch.storage import UntypedStorage
from torch.utils import _pytree as pytree
@@ -1071,7 +1070,7 @@
self.output_storage_alias.append(UnaliasedStorage)
continue
- check(
+ torch._check(
o.is_cuda,
lambda: f"Expected all cuda outputs in cuda graph recording. Non cuda output from {self.stack_traces[i]}",
),
@@ -1447,7 +1446,7 @@
for idx in self.cudagraph_managed_idxs:
inputs[idx] = None
- check(
+ torch._check(
self._check_liveness(
self.expected_dead_indices_after_graph, self.path_weakrefs
),
@@ -1522,7 +1521,7 @@
addr += block["size"]
- check(
+ torch._check(
len(unique_storages) == 0,
lambda: f"These storage data ptrs are not allocated in pool {pool_id} but should be {unique_storages}",
)
diff --git a/torch/_meta_registrations.py b/torch/_meta_registrations.py
index e3c4f1b..c346045 100644
--- a/torch/_meta_registrations.py
+++ b/torch/_meta_registrations.py
@@ -13,7 +13,6 @@
from torch._ops import OpOverload
from torch._prims import _elementwise_meta, ELEMENTWISE_PRIM_TYPE_PROMOTION_KIND
from torch._prims_common import (
- check,
corresponding_complex_dtype,
corresponding_real_dtype,
elementwise_dtypes,
@@ -63,7 +62,7 @@
def check_inplace_broadcast(self_shape, *args_shape):
broadcasted_shape = tuple(_broadcast_shapes(self_shape, *args_shape))
- check(
+ torch._check(
broadcasted_shape == self_shape,
lambda: f"output with shape {self_shape} doesn't match the broadcast shape {broadcasted_shape}",
)
@@ -73,15 +72,14 @@
@out_wrapper()
def meta_take(self, index):
# Type and device checks
- check(
+ torch._check(
index.dtype == torch.long,
lambda: f"take(): Expected a long tensor for index, but got {index.dtype}",
)
# Index checks
- check(
+ torch._check_index(
not (self.numel() == 0 and index.numel() != 0),
lambda: "take(): tried to take from an empty tensor",
- IndexError,
)
return self.new_empty(index.shape)
@@ -91,11 +89,11 @@
def linalg_cross(self, other, *, dim=-1):
x_d = self.ndim
y_d = other.ndim
- check(
+ torch._check(
x_d == y_d,
lambda: "linalg.cross: inputs must have the same number of dimensions.",
)
- check(
+ torch._check(
self.size(dim) == 3 and other.size(dim) == 3,
lambda: (
f"linalg.cross: inputs dimension {dim} must have length 3. "
@@ -334,7 +332,7 @@
A: Tensor,
name: str,
):
- check(
+ torch._check(
self.device == A.device,
lambda: (
f"Expected b and A to be on the same device, but found b on "
@@ -342,7 +340,7 @@
),
)
- check(
+ torch._check(
self.dtype == A.dtype,
lambda: (
f"Expected b and A to have the same dtype, but found b of type "
@@ -350,7 +348,7 @@
),
)
- check(
+ torch._check(
A.size(-1) == A.size(-2),
lambda: (
f"A must be batches of square matrices, "
@@ -358,7 +356,7 @@
),
)
- check(
+ torch._check(
A.size(-1) == self.size(-2),
lambda: (
f"Incompatible matrix sizes for {name}: each A "
@@ -373,12 +371,12 @@
t: Tensor, f_name: str, allow_low_precision_dtypes: bool = True
):
dtype = t.dtype
- check(
+ torch._check(
t.is_floating_point() or t.is_complex(),
lambda: f"{f_name}: Expected a floating point or complex tensor as input. Got {dtype}",
)
if not allow_low_precision_dtypes:
- check(
+ torch._check(
dtype in (torch.float, torch.double, torch.cfloat, torch.cdouble),
lambda: f"{f_name}: Low precision dtypes not supported. Got {dtype}",
)
@@ -386,7 +384,7 @@
# From aten/src/ATen/native/LinearAlgebraUtils.h
def checkIsMatrix(A: Tensor, f_name: str, arg_name: str = "A"):
- check(
+ torch._check(
A.dim() >= 2,
lambda: f"{f_name}: The input tensor {arg_name} must have at least 2 dimensions.",
)
@@ -400,7 +398,7 @@
):
squareCheckInputs(A, f_name)
checkIsMatrix(B, f_name)
- check(
+ torch._check(
A.size(-2) == B.size(-2) if left else A.size(-1) == B.size(-1),
lambda: (
f"{f_name}: Incompatible shapes of A and B for the equation "
@@ -413,7 +411,7 @@
def checkSameDevice(
fn_name: str, result: Tensor, input: Tensor, result_name: str = "result"
):
- check(
+ torch._check(
result.device == input.device,
lambda: (
f"{fn_name}: Expected {result_name} and input tensors to be on the same device, but got "
@@ -424,7 +422,7 @@
def checkUplo(UPLO: str):
UPLO_uppercase = UPLO.upper()
- check(
+ torch._check(
len(UPLO) == 1 and (UPLO_uppercase == "U" or UPLO_uppercase == "L"),
lambda: f"Expected UPLO argument to be 'L' or 'U', but got {UPLO}",
)
@@ -477,20 +475,20 @@
)
@out_wrapper()
def linalg_householder_product(input: Tensor, tau: Tensor) -> Tensor:
- check(
+ torch._check(
input.ndim >= 2,
lambda: "torch.linalg.householder_product: input must have at least 2 dimensions.",
)
- check(
+ torch._check(
input.size(-2) >= input.size(-1),
lambda: "torch.linalg.householder_product: input.shape[-2] must be greater than or equal to input.shape[-1]",
)
- check(
+ torch._check(
input.size(-1) >= tau.size(-1),
lambda: "torch.linalg.householder_product: input.shape[-1] must be greater than or equal to tau.shape[-1]",
)
- check(
+ torch._check(
input.ndim - tau.ndim == 1,
lambda: (
f"torch.linalg.householder_product: Expected tau to have one dimension less than input, "
@@ -500,7 +498,7 @@
if input.ndim > 2:
expected_batch_tau_shape = input.shape[:-2]
actual_batch_tau_shape = tau.shape[:-1]
- check(
+ torch._check(
actual_batch_tau_shape == expected_batch_tau_shape,
lambda: (
f"torch.linalg.householder_product: Expected batch dimensions of tau to be "
@@ -508,7 +506,7 @@
),
)
- check(
+ torch._check(
tau.dtype == input.dtype,
lambda: (
f"torch.linalg.householder_product: tau dtype {tau.dtype}"
@@ -567,7 +565,7 @@
squareCheckInputs(LD, "torch.linalg.ldl_solve")
checkFloatingOrComplex(LD, "torch.linalg.ldl_solve")
linearSolveCheckInputs(B, LD, "torch.linalg.ldl_solve")
- check(
+ torch._check(
B.ndim >= 2,
lambda: (
f"torch.linalg.ldl_solve: Expected B to have at least 2 dimensions, "
@@ -575,18 +573,18 @@
),
)
expected_pivots_shape = LD.shape[:-1]
- check(
+ torch._check(
expected_pivots_shape == pivots.shape,
lambda: (
f"torch.linalg.ldl_solve: Expected LD.shape[:-1] and pivots.shape to be the same, "
f"but got pivots with shape {pivots.shape} instead"
),
)
- check(
+ torch._check(
utils.is_integer_dtype(pivots.dtype),
lambda: f"torch.linalg.ldl_solve: Expected pivots to be integers. Got {pivots.dtype}",
)
- check(
+ torch._check(
LD.dtype == B.dtype,
lambda: f"torch.linalg.ldl_solve: LD dtype {LD.dtype} does not match b dtype {B.dtype}",
)
@@ -602,7 +600,7 @@
@register_meta([aten.linalg_lu.default, aten.linalg_lu.out])
@out_wrapper("P", "L", "U")
def linalg_lu_meta(A: Tensor, *, pivot: bool = True) -> Tuple[Tensor, Tensor, Tensor]:
- check(
+ torch._check(
A.ndim >= 2,
lambda: f"linalg.lu: Expected tensor with 2 or more dimensions. Got size: {A.shape} instead",
)
@@ -632,7 +630,7 @@
def linalg_lu_factor_ex_meta(
A: Tensor, *, pivot: bool = True, check_errors: bool = False
) -> Tuple[Tensor, Tensor, Tensor]:
- check(
+ torch._check(
A.ndim >= 2,
lambda: f"torch.lu_factor: Expected tensor with 2 or more dimensions. Got size: {A.shape} instead",
)
@@ -672,14 +670,14 @@
) -> Tensor:
# dtype
checkFloatingOrComplex(LU, "torch.linalg.lu_solve")
- check(
+ torch._check(
LU.dtype == B.dtype,
lambda: (
f"linalg.lu_solve: Expected LU and B to have the same dtype, "
f"but found LU of type {LU.dtype} and B of type {B.dtype} instead"
),
)
- check(
+ torch._check(
pivots.dtype == torch.int,
lambda: "linalg.lu_solve: pivots should be a Tensor of scalar type torch.int32",
)
@@ -687,13 +685,13 @@
# matrix shapes
squareCheckInputs(LU, "torch.linalg.lu_solve")
checkInputsSolver(LU, B, left, "linalg.lu_solve")
- check(
+ torch._check(
LU.size(-1) == pivots.size(-1),
lambda: "linalg.lu_solve: Number of pivots per batch should be same as the dimension of the matrix",
)
# batches
- check(
+ torch._check(
LU.shape[:-1] == pivots.shape,
lambda: (
f"linalg.lu_solve: Expected LU.shape[:-1] and pivots.shape to be the same, "
@@ -770,7 +768,7 @@
compute_q = False
reduced = True # this is actually irrelevant in this mode
else:
- check(
+ torch._check(
False,
lambda: (
f"qr received unrecognized mode '{mode}' "
@@ -1043,11 +1041,11 @@
output_h = input_h + pad_t + pad_b
output_w = input_w + pad_l + pad_r
- check(
+ torch._check(
output_w == grad_output.shape[dim_w],
lambda: f"gradOutput width unexpected. Expected: {output_w}, Got: {grad_output.shape[dim_w]}",
)
- check(
+ torch._check(
output_h == grad_output.shape[dim_h],
lambda: f"gradOutput height unexpected. Expected: {output_h}, Got: {grad_output.shape[dim_h]}",
)
@@ -1057,7 +1055,7 @@
@register_meta(aten.reflection_pad2d.default)
def meta_pad2d(self, padding):
valid_dims = self.size(1) != 0 and self.size(2) != 0
- check(
+ torch._check(
(self.ndim == 3 and valid_dims)
or (self.ndim == 4 and valid_dims and self.size(3) != 0),
lambda: f"3D or 4D (batch mode) tensor expected for input, but got: {self}",
@@ -1086,9 +1084,9 @@
dim2 = batch1.size(1)
dim3 = batch2.size(2)
self = self.expand((dim1, dim2, dim3))
- check(batch1.dim() == 3, lambda: "batch1 must be a 3D tensor")
- check(batch2.dim() == 3, lambda: "batch2 must be a 3D tensor")
- check(
+ torch._check(batch1.dim() == 3, lambda: "batch1 must be a 3D tensor")
+ torch._check(batch2.dim() == 3, lambda: "batch2 must be a 3D tensor")
+ torch._check(
self.dtype == batch1.dtype == batch2.dtype,
lambda: f"Input dtypes must be the same, got: input: {self.dtype}, batch1: {batch1.dtype}, batch2: {batch2.dtype}",
)
@@ -1096,7 +1094,7 @@
batch2_sizes = batch2.shape
bs = batch1_sizes[0]
contraction_size = batch1_sizes[2]
- check(
+ torch._check(
batch2_sizes[0] == bs and batch2_sizes[1] == contraction_size,
lambda: (
f"Expected size for first two dimensions of batch2 tensor to be: "
@@ -1140,7 +1138,7 @@
per_row_fake_quant=False,
symmetric_quant=False,
):
- check(
+ torch._check(
ch_axis < self.dim(),
lambda: "Error in fused_moving_avg_obs_fake_quant_cpu: ch_axis must be < self.dim()",
)
@@ -1149,7 +1147,7 @@
def dot_check(self, other):
- check(
+ torch._check(
self.dim() == 1 and other.dim() == 1,
lambda: f"1D tensors expected, but got {self.dim()}D and {other.dim()}D tensors",
)
@@ -1163,11 +1161,11 @@
@register_meta([aten.mm.default])
def meta_mm(a, b):
- check(a.dim() == 2, lambda: "a must be 2D")
- check(b.dim() == 2, lambda: "b must be 2D")
+ torch._check(a.dim() == 2, lambda: "a must be 2D")
+ torch._check(b.dim() == 2, lambda: "b must be 2D")
N, M1 = a.shape
M2, P = b.shape
- check(
+ torch._check(
M1 == M2,
lambda: f"a and b must have same reduction dim, but got [{N}, {M1}] X [{M2}, {P}].",
)
@@ -1389,7 +1387,7 @@
# from check_dim_size() in aten/src/ATen/TensorUtils.cpp.
def check_dim_size(tensor, dim, dim_size, size):
- check(
+ torch._check(
tensor.dim() == dim and tensor.shape[dim_size] == size,
lambda: f"Expected a tensor of dimension {dim} and tensor.size[{dim_size}] == {size}, "
+ f"but got : dimension {tensor.dim()} and tensor.size[{dim_size}] = {tensor.shape[dim_size]}",
@@ -1407,7 +1405,7 @@
divisor_override=None,
):
def unpack(name, val):
- check(
+ torch._check(
len(val) in [1, 2],
lambda: f"avg_pool2d: {name} must either be a single int, or a tuple of two ints",
)
@@ -1416,7 +1414,7 @@
return H, W
kH, kW = unpack("kernel_size", kernel_size)
- check(
+ torch._check(
len(stride) in [0, 1, 2],
lambda: "avg_pool2d: stride must either be omitted, a single int, or a tuple of two ints",
)
@@ -1429,7 +1427,7 @@
padH, padW = unpack("padding", padding)
- check(
+ torch._check(
divisor_override is None or divisor_override != 0,
lambda: "divisor must be not zero",
)
@@ -1530,26 +1528,26 @@
divisor_override,
):
# From aten/src/ATen/native/AveragePool2d.cpp structured kernel meta func.
- check(
+ torch._check(
len(kernel_size) == 1 or len(kernel_size) == 2,
lambda: "avg_pool2d: kernel_size must either be a single int, or a tuple of two ints",
)
kH = kernel_size[0]
kW = kH if len(kernel_size) == 1 else kernel_size[1]
- check(
+ torch._check(
len(stride) == 0 or len(stride) == 1 or len(stride) == 2,
lambda: "avg_pool2d: stride must either be omitted, a single int, or a tuple of two ints",
)
dH = kH if len(stride) == 0 else stride[0]
dW = kW if len(stride) == 0 else dH if len(stride) == 1 else stride[1]
- check(
+ torch._check(
len(padding) == 1 or len(padding) == 2,
lambda: "avg_pool2d: padding must either be a single int, or a tuple of two ints",
)
padH = padding[0]
padW = padH if len(padding) == 1 else padding[1]
- check(
+ torch._check(
divisor_override is None or divisor_override != 0,
lambda: "divisor must be not zero",
)
@@ -1602,7 +1600,7 @@
count_include_pad=True,
divisor_override=None,
):
- check(
+ torch._check(
len(kernel_size) in (1, 3),
lambda: "avg_pool3d: kernel_size must be a single int, or a tuple of three ints",
)
@@ -1610,7 +1608,7 @@
kH = kT if len(kernel_size) == 1 else kernel_size[1]
kW = kT if len(kernel_size) == 1 else kernel_size[2]
- check(
+ torch._check(
not stride or len(stride) in (1, 3),
lambda: "avg_pool3d: stride must be omitted, a single int, or a tuple of three ints",
)
@@ -1618,7 +1616,7 @@
dH = kH if not stride else (dT if len(stride) == 1 else stride[1])
dW = kW if not stride else (dT if len(stride) == 1 else stride[2])
- check(
+ torch._check(
len(padding) in (1, 3),
lambda: "avg_pool3d: padding must be a single int, or a tuple of three ints",
)
@@ -1626,12 +1624,12 @@
padH = padT if len(padding) == 1 else padding[1]
padW = padT if len(padding) == 1 else padding[2]
- check(
+ torch._check(
input.ndim in (4, 5),
lambda: "non-empty 4D or 5D (batch mode) tensor expected for input",
)
- check(
+ torch._check(
not divisor_override or divisor_override != 0,
lambda: "divisor must be not zero",
)
@@ -1689,7 +1687,7 @@
count_include_pad,
divisor_override,
):
- check(
+ torch._check(
len(kernel_size) in (1, 3),
lambda: "avg_pool3d: kernel_size must be a single int, or a tuple of three ints",
)
@@ -1697,7 +1695,7 @@
kH = kT if len(kernel_size) == 1 else kernel_size[1]
kW = kT if len(kernel_size) == 1 else kernel_size[2]
- check(
+ torch._check(
not stride or len(stride) in (1, 3),
lambda: "avg_pool3d: stride must be omitted, a single int, or a tuple of three ints",
)
@@ -1705,7 +1703,7 @@
dH = kH if not stride else (dT if len(stride) == 1 else stride[1])
dW = kW if not stride else (dT if len(stride) == 1 else stride[2])
- check(
+ torch._check(
len(padding) in (1, 3),
lambda: "avg_pool3d: padding must be a single int, or a tuple of three ints",
)
@@ -1713,12 +1711,12 @@
padH = padT if len(padding) == 1 else padding[1]
padW = padT if len(padding) == 1 else padding[2]
- check(
+ torch._check(
input.ndim in (4, 5),
lambda: "non-empty 4D or 5D (batch mode) tensor expected for input",
)
- check(
+ torch._check(
not divisor_override or divisor_override != 0,
lambda: "divisor must be not zero",
)
@@ -1759,7 +1757,7 @@
@register_meta(aten._adaptive_avg_pool2d.default)
def meta_adaptive_avg_pool2d(self, output_size):
- check(
+ torch._check(
self.ndim == 3 or self.ndim == 4,
lambda: f"Expected 3D or 4D tensor, but got {self.shape}",
)
@@ -1777,7 +1775,7 @@
@register_meta(aten._adaptive_avg_pool3d.default)
def meta_adaptive_avg_pool3d(self, output_size):
- check(
+ torch._check(
self.ndim == 4 or self.ndim == 5,
lambda: f"Expected 4D or 5D tensor, but got {self.shape}",
)
@@ -1788,16 +1786,16 @@
def meta__adaptive_avg_pool2d_backward(grad_out, self):
ndim = grad_out.ndim
for i in range(1, ndim):
- check(
+ torch._check(
grad_out.size(i) > 0,
lambda: f"adaptive_avg_pool2d_backward(): Expected grad_output to have non-zero \
size for non-batch dimensions, {grad_out.shape} with dimension {i} being empty",
)
- check(
+ torch._check(
ndim == 3 or ndim == 4,
lambda: f"adaptive_avg_pool2d_backward(): Expected 3D or 4D tensor, but got {self.shape}",
)
- check(
+ torch._check(
self.dtype == grad_out.dtype,
lambda: f"expected dtype {self.dtype} for `grad_output` but got dtype {grad_out.dtype}",
)
@@ -1852,30 +1850,28 @@
@register_meta([aten.index.Tensor, aten._unsafe_index.Tensor])
def meta_index_Tensor(self, indices):
- check(indices, lambda: "at least one index must be provided")
+ torch._check(bool(indices), lambda: "at least one index must be provided")
# aten::index is the internal advanced indexing implementation
# checkIndexTensorTypes and expandTensors
result: List[Optional[Tensor]] = []
for i, index in enumerate(indices):
if index is not None:
- check(
+ torch._check(
index.dtype in [torch.long, torch.int, torch.int8, torch.bool],
lambda: "tensors used as indices must be long, int, byte or bool tensors",
)
if index.dtype in [torch.int8, torch.bool]:
nonzero = index.nonzero()
k = len(result)
- check(
+ torch._check_index(
k + index.ndim <= self.ndim,
lambda: f"too many indices for tensor of dimension {self.ndim}",
- IndexError,
)
for j in range(index.ndim):
- check(
+ torch._check_index(
index.shape[j] == self.shape[k + j],
lambda: f"The shape of the mask {index.shape} at index {i} "
f"does not match the shape of the indexed tensor {self.shape} at index {k + j}",
- IndexError,
)
result.append(nonzero.select(1, j))
else:
@@ -1883,7 +1879,7 @@
else:
result.append(index)
indices = result
- check(
+ torch._check(
len(indices) <= self.ndim,
lambda: f"too many indices for tensor of dimension {self.ndim} (got {len(indices)})",
)
@@ -1988,20 +1984,20 @@
dim1 = batch1.size(1)
dim2 = batch2.size(2)
self = self.expand((dim1, dim2))
- check(batch1.dim() == 3, lambda: "batch1 must be a 3D tensor")
- check(batch2.dim() == 3, lambda: "batch2 must be a 3D tensor")
- check(
+ torch._check(batch1.dim() == 3, lambda: "batch1 must be a 3D tensor")
+ torch._check(batch2.dim() == 3, lambda: "batch2 must be a 3D tensor")
+ torch._check(
batch1.size(0) == batch2.size(0),
lambda: f"batch1 and batch2 must have same number of batches, got {batch1.size(0)} and {batch2.size(0)}",
)
- check(
+ torch._check(
batch1.size(2) == batch2.size(1),
lambda: (
f"Incompatible matrix sizes for bmm ({batch1.size(1)}x{batch1.size(2)} "
f"and {batch2.size(1)}x{batch2.size(2)})"
),
)
- check(
+ torch._check(
self.size(0) == dim1 and self.size(1) == dim2,
lambda: "self tensor does not match matmul output shape",
)
@@ -2015,7 +2011,7 @@
]
)
def meta__foreach_unaop_(self):
- check(
+ torch._check(
isinstance(self, List),
lambda: f"Expect List[Tensor] but got {type(self)}",
)
@@ -2029,7 +2025,7 @@
]
)
def meta__foreach_unaop(self):
- check(
+ torch._check(
isinstance(self, List),
lambda: f"Expect List[Tensor] but got {type(self)}",
)
@@ -2037,14 +2033,14 @@
def _check_foreach_binop_tensor_lists(self, other):
- check(
+ torch._check(
isinstance(self, List) and isinstance(other, List),
lambda: (
"The first two arguments of must be List[Tensor], "
f"but got {type(self)} and {type(other)}."
),
)
- check(
+ torch._check(
len(self) > 0 and len(self) == len(other),
lambda: (
"self and other must be non-empty and match in length, "
@@ -2100,7 +2096,7 @@
]
)
def meta__foreach_binop__scalar(self, scalar=1):
- check(
+ torch._check(
isinstance(self, List),
lambda: f"The first argument of must be List[Tensor], but got {type(self)}.",
)
@@ -2115,7 +2111,7 @@
]
)
def meta__foreach_binop_scalar(self, scalar=1):
- check(
+ torch._check(
isinstance(self, List),
lambda: f"The first argument of must be List[Tensor], but got {type(self)}.",
)
@@ -2129,15 +2125,15 @@
]
)
def meta__foreach_addcop__scalar(self, tensor1, tensor2, scalar=1):
- check(
+ torch._check(
all(isinstance(l, List) for l in [self, tensor1, tensor2]),
lambda: (
"All arguments of _foreach_addc*_ must be List[Tensor], "
f"but got {type(self)}, {type(tensor1)}, and {type(tensor2)}"
),
)
- check(len(self) > 0, lambda: "input tensor list must not be empty.")
- check(
+ torch._check(len(self) > 0, lambda: "input tensor list must not be empty.")
+ torch._check(
len(self) == len(tensor1) and len(self) == len(tensor2),
lambda: "All input tensor lists must have the same length",
)
@@ -2150,15 +2146,15 @@
]
)
def meta__foreach_addcop_scalar(self, tensor1, tensor2, scalar=1):
- check(
+ torch._check(
all(isinstance(l, List) for l in [self, tensor1, tensor2]),
lambda: (
"All arguments must be List[Tensor], "
f"but got {type(self)}, {type(tensor1)}, and {type(tensor2)}"
),
)
- check(len(self) > 0, lambda: "input tensor list must not be empty.")
- check(
+ torch._check(len(self) > 0, lambda: "input tensor list must not be empty.")
+ torch._check(
len(self) == len(tensor1) and len(self) == len(tensor2),
lambda: "All input tensor lists must have the same length",
)
@@ -2168,7 +2164,7 @@
@register_meta([aten._foreach_pow.ScalarAndTensor])
def meta__foreach_pow_scalar_and_tensor(self, exponent):
- check(
+ torch._check(
isinstance(exponent, List),
lambda: f"exponent must be a tensor list but got {type(exponent)}",
)
@@ -2177,7 +2173,7 @@
@register_meta([aten._foreach_addcdiv_.Tensor, aten._foreach_addcmul_.Tensor])
def meta__foreach_addcop_tensor(self, tensor1, tensor2, scalars):
- check(
+ torch._check(
all(isinstance(l, List) for l in [self, tensor1, tensor2])
and isinstance(scalars, torch.Tensor),
lambda: (
@@ -2185,8 +2181,8 @@
f"but got: {type(self)}, {type(tensor1)}, {type(tensor2)}, and {type(scalars)}"
),
)
- check(len(self) > 0, lambda: "input tensor list must not be empty.")
- check(
+ torch._check(len(self) > 0, lambda: "input tensor list must not be empty.")
+ torch._check(
len(self) == len(tensor1) and len(self) == len(tensor2),
lambda: "All input tensor lists must have the same length",
)
@@ -2212,7 +2208,7 @@
found_inf=None,
):
for l in [self, grads, exp_avgs, exp_avg_sqs, max_exp_avg_sqs, state_steps]:
- check(
+ torch._check(
isinstance(l, List),
lambda: f"exponent must be a tensor list but got {type(l)}",
)
@@ -2238,7 +2234,7 @@
found_inf=None,
):
for l in [self, grads, exp_avgs, exp_avg_sqs, max_exp_avg_sqs, state_steps]:
- check(
+ torch._check(
isinstance(l, List),
lambda: f"exponent must be a tensor list but got {type(l)}",
)
@@ -2258,17 +2254,17 @@
@register_meta([aten._int_mm])
@out_wrapper()
def meta__int_mm(a, b):
- check(a.dim() == 2, lambda: "a must be a 2D tensor")
- check(b.dim() == 2, lambda: "b must be a 2D tensor")
- check(
+ torch._check(a.dim() == 2, lambda: "a must be a 2D tensor")
+ torch._check(b.dim() == 2, lambda: "b must be a 2D tensor")
+ torch._check(
a.dtype is torch.int8,
lambda: f"expected self to be int8, got {a.dtype}",
)
- check(
+ torch._check(
b.dtype is torch.int8,
lambda: f"expected mat2 to be int8, got {b.dtype}",
)
- check(
+ torch._check(
a.size(1) == b.size(0),
lambda: (
f"Incompatible matrix sizes for _int_mm ({a.size(0)}x{a.size(1)} "
@@ -2280,28 +2276,28 @@
@register_meta(aten._cdist_forward.default)
def meta_cdist_forward(x1, x2, p, compute_mode):
- check(
+ torch._check(
x1.dim() >= 2,
lambda: f"cdist only supports at least 2D tensors, X1 got: {x1.dim()}D",
)
- check(
+ torch._check(
x2.dim() >= 2,
lambda: f"cdist only supports at least 2D tensors, X2 got: {x2.dim()}D",
)
- check(
+ torch._check(
x1.size(-1) == x2.size(-1),
lambda: f"X1 and X2 must have the same number of columns. X1: {x1.size(-1)} X2: {x2.size(-1)}",
)
- check(
+ torch._check(
utils.is_float_dtype(x1.dtype),
lambda: "cdist only supports floating-point dtypes, X1 got: {x1.dtype}",
)
- check(
+ torch._check(
utils.is_float_dtype(x2.dtype),
lambda: "cdist only supports floating-point dtypes, X2 got: {x2.dtype}",
)
- check(p >= 0, lambda: "cdist only supports non-negative p values")
- check(
+ torch._check(p >= 0, lambda: "cdist only supports non-negative p values")
+ torch._check(
compute_mode in (None, 1, 2),
lambda: f"possible modes: None, 1, 2, but was: {compute_mode}",
)
@@ -2326,22 +2322,22 @@
include_last_offset=False,
padding_idx=-1,
):
- check(
+ torch._check(
indices.dtype in (torch.long, torch.int),
lambda: f"expected indices to be long or int, got {indices.dtype}",
)
- check(
+ torch._check(
offsets.dtype in (torch.long, torch.int),
lambda: f"expected offsets to be long or int, got {offsets.dtype}",
)
- check(
+ torch._check(
utils.is_float_dtype(weight.dtype),
lambda: f"expected weight to be floating point type, got {weight.dtype}",
)
num_bags = offsets.size(0)
if include_last_offset:
- check(
+ torch._check(
num_bags >= 1,
lambda: "include_last_offset: numBags should be at least 1",
)
@@ -2351,19 +2347,19 @@
MODE_SUM, MODE_MEAN, MODE_MAX = range(3)
if per_sample_weights is not None:
- check(
+ torch._check(
mode == MODE_SUM,
lambda: "embedding_bag: per_sample_weights only supported with mode='sum'",
)
- check(
+ torch._check(
per_sample_weights.dtype == weight.dtype,
lambda: f"expected weight ({weight.dtype}) and per_sample_weights ({per_sample_weights.dtype}) to have same dtype",
)
- check(
+ torch._check(
per_sample_weights.ndim == 1,
lambda: f"expected per_sample_weights to be 1D tensor, got {per_sample_weights.ndim}D",
)
- check(
+ torch._check(
per_sample_weights.numel() == indices.numel(),
lambda: (
f"expected per_sample_weights.numel() ({per_sample_weights.numel()} "
@@ -2408,7 +2404,7 @@
numBags = offsets.shape[0]
if mode == MODE_MAX:
if include_last_offset:
- check(
+ torch._check(
numBags >= 1,
lambda: "include_last_offset: numBags should be at least 1",
)
@@ -2477,7 +2473,7 @@
@register_meta(aten.repeat.default)
def meta_repeat(self, repeats):
- check(
+ torch._check(
len(repeats) >= self.dim(),
lambda: "Number of dimensions of repeat dims can not be smaller than number of dimensions of tensor",
)
@@ -2534,17 +2530,17 @@
def shift_dtype_check(fn_name, self, val):
- check(
+ torch._check(
utils.is_integer_dtype(self.dtype),
lambda: f"{fn_name}: Expected input tensor to have an integral dtype. Got {self.dtype}",
)
if isinstance(val, torch.Tensor):
- check(
+ torch._check(
utils.is_integer_dtype(val.dtype),
lambda: f"{fn_name}: Expected shift value to have an integral dtype. Got {val.dtype}",
)
else:
- check(
+ torch._check(
isinstance(val, IntLike),
lambda: f"{fn_name}: Expected shift value to be an int. Got {val}",
)
@@ -2620,8 +2616,8 @@
def common_meta_baddbmm_bmm(batch1, batch2, is_bmm, self_baddbmm=None):
- check(batch1.dim() == 3, lambda: "batch1 must be a 3D tensor")
- check(batch2.dim() == 3, lambda: "batch2 must be a 3D tensor")
+ torch._check(batch1.dim() == 3, lambda: "batch1 must be a 3D tensor")
+ torch._check(batch2.dim() == 3, lambda: "batch2 must be a 3D tensor")
batch1_sizes = batch1.size()
batch2_sizes = batch2.size()
@@ -2632,7 +2628,7 @@
res_cols = batch2_sizes[2]
output_size = (bs, res_rows, res_cols)
- check(
+ torch._check(
batch2_sizes[0] == bs and batch2_sizes[1] == contraction_size,
lambda: f"Expected size for first two dimensions of batch2 tensor to be: [{bs}"
f", {contraction_size}] but got: [{batch2_sizes[0]}, {batch2_sizes[1]}].",
@@ -2643,8 +2639,8 @@
output = batch2.new_empty(output_size)
if not is_bmm and self_baddbmm is not None:
- check(self_baddbmm.dim() == 3, lambda: "self must be a 3D tensor")
- check(
+ torch._check(self_baddbmm.dim() == 3, lambda: "self must be a 3D tensor")
+ torch._check(
self_baddbmm.size() == output_size,
lambda: f"Expected an input tensor shape with shape {output_size} but got shape: {self_baddbmm.size()}",
)
@@ -2689,9 +2685,9 @@
def pooling_output_shape(inputSize, kernelSize, pad, stride, dilation, ceil_mode):
- check(stride != 0, lambda: "stride should not be zero")
- check(pad >= 0, lambda: f"pad must be non-negative, but got pad: {pad}")
- check(
+ torch._check(stride != 0, lambda: "stride should not be zero")
+ torch._check(pad >= 0, lambda: f"pad must be non-negative, but got pad: {pad}")
+ torch._check(
pad <= kernelSize // 2,
lambda: f"pad should be at most half of kernel size, but got pad={pad} and kernel_size={kernelSize}",
)
@@ -2720,15 +2716,15 @@
ndim = input.dim()
nOutputPlane = nInputPlane
- check(
+ torch._check(
kW > 0 and kH > 0,
lambda: "kernel size should be greater than zero, but got kH: {kH}, kW: {kW}",
)
- check(
+ torch._check(
dW > 0 and dH > 0,
lambda: "stride should be greater than zero, but got dH: {dH}, dW: {dW}",
)
- check(
+ torch._check(
dilationH > 0 and dilationW > 0,
lambda: "dilation should be greater than zero, but got dilationH: {dilationH}, dilationW: {dilationW}",
)
@@ -2736,25 +2732,25 @@
valid_dims = input.size(1) != 0 and input.size(2) != 0
if memory_format == torch.channels_last:
- check(
+ torch._check(
ndim == 4 and valid_dims and input.size(3) != 0,
lambda: "Expected 4D (batch mode) tensor expected for input with channels_last layout"
" with optional 0 dim batch size for input, but got: {input.size()}",
)
else:
- check(
+ torch._check(
(ndim == 3 and input.size(0) != 0 and valid_dims)
or (ndim == 4 and valid_dims and input.size(3) != 0),
lambda: f"Expected 3D or 4D (batch mode) tensor with optional 0 dim batch size for input, but got: {input.size()}",
)
- check(
+ torch._check(
kW // 2 >= padW and kH // 2 >= padH,
lambda: "pad should be smaller than or equal to half of kernel size, but got "
f"padW = {padW}, padH = {padH}, kW = {kW}, kH = {kH}",
)
- check(
+ torch._check(
outputWidth >= 1 and outputHeight >= 1,
lambda: f"Given input size: ({nInputPlane}x{inputHeight}x{inputWidth}). "
f"Calculated output size: ({nOutputPlane}x{outputHeight}x{outputWidth}). "
@@ -2788,21 +2784,21 @@
):
ndim = input.ndim
- check(
+ torch._check(
kT > 0 and kW > 0 and kH > 0,
lambda: (
f"kernel size should be greater than zero, but got "
f"kT: {kT}, kH: {kH}, kW: {kW}"
),
)
- check(
+ torch._check(
dT > 0 and dW > 0 and dH > 0,
lambda: (
f"stride should be greater than zero, but got "
f"dT: {dT}, dH: {dH}, dW: {dW}"
),
)
- check(
+ torch._check(
dilationT > 0 and dilationW > 0 and dilationH > 0,
lambda: (
f"dilation should be greater than zero, but got "
@@ -2810,7 +2806,7 @@
),
)
- check(
+ torch._check(
ndim in (4, 5),
lambda: f"{fn_name}: Expected 4D or 5D tensor for input, but got: {input.shape}",
)
@@ -2819,7 +2815,7 @@
if ndim == 5 and i == 0:
# size of batch-dim can be 0.
continue
- check(
+ torch._check(
input.size(i) > 0,
lambda: (
f"{fn_name}: Expected input's non-batch dimensions to have positive length,"
@@ -2829,7 +2825,7 @@
)
if check_input_size: # AveragePool3d
- check(
+ torch._check(
itime >= kT and iheight >= kH and iwidth >= kW,
lambda: (
f"input image (T: {itime} H: {iheight} W: {iwidth}) smaller than "
@@ -2837,7 +2833,7 @@
),
)
- check(
+ torch._check(
kT / 2 >= pT and kW / 2 >= pW and kH / 2 >= pH,
lambda: (
f"pad should be smaller than or equal to half of kernel size, but got "
@@ -2845,7 +2841,7 @@
),
)
- check(
+ torch._check(
otime >= 1 and owidth >= 1 and oheight >= 1,
lambda: (
f"Given input size: ({nslices}x{itime}x{iheight}x{iwidth}). "
@@ -2914,7 +2910,7 @@
):
# Reference: aten/src/ATen/native/DilatedMaxPool2d.cpp
def unpack(name, val):
- check(
+ torch._check(
len(val) in [1, 2],
lambda: f"max_pool2d: {name} must either be a single int, or a tuple of two ints",
)
@@ -2924,7 +2920,7 @@
kH, kW = unpack("kernel_size", kernel_size)
- check(
+ torch._check(
len(stride) in [0, 1, 2],
lambda: "max_pool2d: stride must either be omitted, a single int, or a tuple of two ints",
)
@@ -2941,17 +2937,17 @@
memory_format = utils.suggest_memory_format(input)
if memory_format == torch.channels_last:
- check(
+ torch._check(
input.dim() == 4,
lambda: "non-empty 4D (batch mode) tensor expected for input with channels_last layout",
)
elif memory_format == torch.contiguous_format:
- check(
+ torch._check(
input.dim() in [3, 4],
lambda: "non-empty 3D or 4D (batch mode) tensor expected for input",
)
else:
- check(
+ torch._check(
False,
lambda: "Unsupport memory format. Supports only ChannelsLast, Contiguous",
)
@@ -2999,7 +2995,7 @@
self, kernel_size, stride, padding, dilation, ceil_mode
)
- check(
+ torch._check(
self.dtype == grad_output.dtype,
lambda: f"Expected dtype {self.dtype} for `gradOutput` but got dtype {grad_output.dtype}",
)
@@ -3093,7 +3089,7 @@
memory_format=None,
):
if layout == torch.sparse_coo:
- check(
+ torch._check(
memory_format is None,
lambda: "memory format option is only supported by strided tensors",
)
@@ -3131,20 +3127,18 @@
@register_meta(aten.select.int)
def meta_select(self, dim, index):
ndim = self.dim()
- check(
+ torch._check_index(
ndim != 0,
lambda: "select() cannot be applied to a 0-dim tensor.",
- IndexError,
)
dim = dim if dim >= 0 else dim + ndim
size = self.size(dim)
- check(
+ torch._check_index(
not (-index > size or index >= size),
lambda: f"select(): index {index} out of range for tensor of size "
f"{self.size()} at dimension {dim}",
- IndexError,
)
index = index if index >= 0 else index + size
@@ -3190,13 +3184,13 @@
def gather_shape_check(self, dim, index):
self_dims = max(self.dim(), 1)
index_dims = max(index.dim(), 1)
- check(
+ torch._check(
self_dims == index_dims,
lambda: "Index tensor must have the same number of dimensions as input tensor",
)
for i in range(self_dims):
if i != dim:
- check(
+ torch._check(
ensure_nonempty_size(index, i) <= ensure_nonempty_size(self, i),
lambda: f"Size does not match at dimension {i} expected index {index.shape}"
+ f" to be smaller than self {self.shape} apart from dimension {dim}",
@@ -3208,7 +3202,7 @@
wrapped_dim = maybe_wrap_dim(dim, self.dim())
is_index_empty = index.numel() == 0
if not is_index_empty:
- check(
+ torch._check(
index.dtype == torch.long,
lambda: f"gather(): Expected dtype int64 for index, but got {index.dtype}",
)
@@ -3229,7 +3223,7 @@
return "REDUCE_MAXIMUM"
elif reduce_ == "amin":
return "REDUCE_MINIMUM"
- check(
+ torch._check(
False,
lambda: "reduce argument must be either sum, prod, mean, amax or amin.",
)
@@ -3239,20 +3233,20 @@
return "REDUCE_ADD"
elif reduce_ == "multiply":
return "REDUCE_MULTIPLY"
- check(False, lambda: "reduce argument must be either add or multiply.")
+ torch._check(False, lambda: "reduce argument must be either add or multiply.")
return
# From aten/src/ATen/native/ScatterGatherChecks.h
def scatter_gather_dtype_check(method_name, self, index, src_opt=None):
if index.numel() != 0:
- check(
+ torch._check(
index.dtype == torch.long,
lambda: f"{method_name}(): Expected dtype int64 for index",
)
if src_opt is not None:
- check(
+ torch._check(
self.dtype == src_opt.dtype,
lambda: f"{method_name}(): Expected self.dtype to be equal to src.dtype",
)
@@ -3266,7 +3260,7 @@
def scatter_shape_check(self, dim, index, src_opt=None):
if index.numel() == 0:
return
- check(
+ torch._check(
ensure_nonempty_dim(self.dim()) == ensure_nonempty_dim(index.dim()),
lambda: "Index tensor must have the same number of dimensions as self tensor",
)
@@ -3292,17 +3286,17 @@
break
if src_opt is not None:
- check(
+ torch._check(
ensure_nonempty_dim(self.dim()) == ensure_nonempty_dim(index.dim()),
lambda: "Index tensor must have the same number of dimensions as self tensor",
)
- check(
+ torch._check(
not is_wrong_shape,
lambda: f"Expected index {index.shape} to be smaller than self {self.shape}"
+ f" apart from dimension {dim} and to be smaller than src {src_opt.shape}",
)
else:
- check(
+ torch._check(
not is_wrong_shape,
lambda: f"Expected index {index.shape} to be smaller than self {self.shape}"
+ f" apart from dimension {dim}",
@@ -3588,7 +3582,7 @@
@register_meta([aten.multinomial.default, aten.multinomial.out])
@out_wrapper()
def meta_multinomial(input, num_samples, replacement=False, *, generator=None):
- check(
+ torch._check(
0 < input.dim() <= 2,
lambda: f"The probabilty distributions dimensions must be 1 or 2, but got {input.dim()}",
)
@@ -3607,17 +3601,17 @@
def upsample_common_check(input_size, output_size, num_spatial_dims):
- check(
+ torch._check(
len(output_size) == num_spatial_dims,
lambda: f"It is expected output_size equals to {num_spatial_dims}, but got size {len(output_size)}",
)
expected_input_dims = num_spatial_dims + 2 # N, C, ...
- check(
+ torch._check(
len(input_size) == expected_input_dims,
lambda: f"It is expected input_size equals to {expected_input_dims}, but got size {len(input_size)}",
)
- check(
+ torch._check(
all(s > 0 for s in input_size[2:]) and all(s > 0 for s in output_size),
lambda: f"Input and output sizes should be greater than 0, but got "
f"input size {input_size} and output size {output_size}",
@@ -3629,7 +3623,7 @@
@register_meta(aten.upsample_nearest1d.default)
def upsample_nearest1d(input, output_size, scales=None):
- check(
+ torch._check(
input.numel() != 0 or multiply_integers(input.size()[1:]),
lambda: f"Non-empty 3D data tensor expected but got a tensor with sizes {input.size()}",
)
@@ -3643,7 +3637,7 @@
@register_meta(aten.upsample_nearest2d.default)
def upsample_nearest2d(input, output_size, scales_h=None, scales_w=None):
- check(
+ torch._check(
input.numel() != 0 or multiply_integers(input.size()[1:]),
lambda: f"Non-empty 4D data tensor expected but got a tensor with sizes {input.size()}",
)
@@ -3676,12 +3670,12 @@
full_output_size = upsample_common_check(
input_size, output_size, num_spatial_dims=2
)
- check(
+ torch._check(
grad_output.ndim == 4,
lambda: f"Expected grad_output to be a tensor of dimension 4 but got: dimension {grad_output.ndim}",
)
for i in range(4):
- check(
+ torch._check(
grad_output.size(i) == full_output_size[i],
lambda: (
f"Expected grad_output to have the same shape as output;"
@@ -3697,7 +3691,7 @@
@register_meta(aten.upsample_nearest3d.default)
def upsample_nearest3d(input, output_size, scales_d=None, scales_h=None, scales_w=None):
- check(
+ torch._check(
input.numel() != 0 or multiply_integers(input.size()[1:]),
lambda: f"Non-empty 5D data tensor expected but got a tensor with sizes {input.size()}",
)
@@ -3739,29 +3733,29 @@
def rnn_cell_checkSizes(
input_gates, hidden_gates, input_bias, hidden_bias, factor, prev_hidden
):
- check(input_gates.ndim == 2, lambda: f"{input_gates.ndim} != 2")
- check(
+ torch._check(input_gates.ndim == 2, lambda: f"{input_gates.ndim} != 2")
+ torch._check(
input_gates.shape == hidden_gates.shape,
lambda: f"{input_gates.shape} != {hidden_gates.shape}",
)
gates_size = input_gates.size(1)
if input_bias is not None:
- check(input_bias.ndim == 1, lambda: f"{input_bias.ndim} != 1")
- check(
+ torch._check(input_bias.ndim == 1, lambda: f"{input_bias.ndim} != 1")
+ torch._check(
input_bias.numel() == gates_size,
lambda: f"{input_bias.numel()} != {gates_size}",
)
- check(
+ torch._check(
input_bias.shape == hidden_bias.shape,
lambda: f"{input_bias.shape} != {hidden_bias.shape}",
)
- check(prev_hidden.ndim == 2, lambda: f"{prev_hidden.ndim} != 2")
+ torch._check(prev_hidden.ndim == 2, lambda: f"{prev_hidden.ndim} != 2")
expected_prev_hidden_numel = input_gates.size(0) * gates_size // factor
- check(
+ torch._check(
prev_hidden.numel() == expected_prev_hidden_numel,
lambda: f"{prev_hidden.numel()} != {input_gates.size(0)} * {gates_size} // {factor} (aka {expected_prev_hidden_numel})",
)
- check(
+ torch._check(
all(
x.device == input_gates.device
for x in [hidden_gates, input_bias, hidden_bias, prev_hidden]
@@ -3879,16 +3873,14 @@
def zero_numel_check_dims(self, dim, fn_name):
if self.ndim == 0:
- check(
+ torch._check_index(
dim == 0 or dim == -1,
lambda: f"{fn_name}: Expected reduction dim -1 or 0 for scalar but got {dim}",
- IndexError,
)
else:
- check(
+ torch._check_index(
self.size(dim) != 0,
lambda: f"{fn_name}: Expected reduction dim {dim} to have non-zero size.",
- IndexError,
)
@@ -3898,7 +3890,7 @@
dim = maybe_wrap_dim(dim, self.dim())
zero_numel_check_dims(self, dim, name)
else:
- check(
+ torch._check(
self.numel() != 0,
lambda: f"{name}: Expected reduction dim to be specified for input.numel() == 0.",
)
@@ -3923,12 +3915,12 @@
def topk_meta(self, k, dim=-1, largest=True, sorted=True):
# From aten/src/ATen/native/Sorting.cpp
dim = maybe_wrap_dim(dim, self.dim(), wrap_scalar=True)
- check(
+ torch._check(
k >= 0 and k <= (self.size(dim) if self.dim() > 0 else 1),
lambda: "selected index k out of range",
)
sliceSize = 1 if self.dim() == 0 else self.size(dim)
- check(k >= 0 and k <= sliceSize, lambda: "k not in range for dimension")
+ torch._check(k >= 0 and k <= sliceSize, lambda: "k not in range for dimension")
topKSize = list(self.shape)
if len(topKSize) > 0:
@@ -3942,16 +3934,16 @@
# From aten/src/ATen/native/cuda/RNN.cu
def checkLSTMBackwardSizes(grad_hy, grad_cy, cx, cy, workspace):
defined_grad = grad_hy if grad_hy is not None else grad_cy
- check(defined_grad.dim() == 2, lambda: "")
+ torch._check(defined_grad.dim() == 2, lambda: "")
exp_size = defined_grad.size()
if grad_hy is not None:
- check(grad_hy.size() == exp_size, lambda: "")
+ torch._check(grad_hy.size() == exp_size, lambda: "")
if grad_cy is not None:
- check(grad_cy.size() == exp_size, lambda: "")
- check(cx.size() == exp_size, lambda: "")
- check(cy.size() == exp_size, lambda: "")
- check(workspace.dim() == 2, lambda: "")
- check(workspace.numel() == exp_size[0] * exp_size[1] * 4, lambda: "")
+ torch._check(grad_cy.size() == exp_size, lambda: "")
+ torch._check(cx.size() == exp_size, lambda: "")
+ torch._check(cy.size() == exp_size, lambda: "")
+ torch._check(workspace.dim() == 2, lambda: "")
+ torch._check(workspace.numel() == exp_size[0] * exp_size[1] * 4, lambda: "")
# From aten/src/ATen/native/cuda/RNN.cu
@@ -4048,7 +4040,7 @@
full_output_size = upsample_common_check(
input.size(), output_size, num_spatial_dims=2
)
- check(
+ torch._check(
input.numel() != 0 or all(size > 0 for size in input.size()[1:]),
lambda: f"Non-empty 4D data tensor expected but got a tensor with sizes {input.size()}",
)
@@ -4060,13 +4052,17 @@
# From aten/src/ATen/native/cuda/AmpKernels.cu
@register_meta(aten._amp_foreach_non_finite_check_and_unscale_.default)
def _amp_foreach_non_finite_check_and_unscale_(self, found_inf, inv_scale):
- check(found_inf.numel() == 1, lambda: "found_inf must be a 1-element tensor.")
- check(inv_scale.numel() == 1, lambda: "inv_scale must be a 1-element tensor.")
- check(
+ torch._check(
+ found_inf.numel() == 1, lambda: "found_inf must be a 1-element tensor."
+ )
+ torch._check(
+ inv_scale.numel() == 1, lambda: "inv_scale must be a 1-element tensor."
+ )
+ torch._check(
found_inf.dtype.is_floating_point,
lambda: "found_inf must be a float tensor.",
)
- check(
+ torch._check(
inv_scale.dtype.is_floating_point,
lambda: "inv_scale must be a float tensor.",
)
diff --git a/torch/_prims/__init__.py b/torch/_prims/__init__.py
index 29cb75c..65a4080 100644
--- a/torch/_prims/__init__.py
+++ b/torch/_prims/__init__.py
@@ -16,7 +16,6 @@
from torch._prims.nvfuser_prims import register_nvprims
from torch._prims.rng_prims import register_rng_prims
from torch._prims_common import (
- check,
Dim,
DimsSequenceType,
DimsType,
@@ -422,7 +421,7 @@
def _complex_only_elementwise_meta(*args, **kwargs):
- utils.check(
+ torch._check(
utils.is_complex_dtype(args[0].dtype), lambda: "Only complex dtype is supported"
)
return _elementwise_meta(*args, **kwargs)
@@ -581,7 +580,7 @@
def _cbrt_aten(a: torch.Tensor) -> Tensor:
- utils.check(
+ torch._check(
not a.is_complex(),
lambda: "cbrt: Complex inputs not supported. Consider calling torch.pow(a, 1.0/3.0)",
)
@@ -1293,10 +1292,9 @@
# Verifies end is strictly greater than start
# (Collapse requires a non-empty interval)
- utils.check(
+ torch._check_value(
end >= start,
lambda: f"Attempting to collapse but end, {end}, is less than start, {start}!",
- ValueError,
)
@@ -1823,7 +1821,7 @@
utils.validate_strides(stride)
required_size = utils.compute_required_storage_length(size, stride, storage_offset)
- utils.check(
+ torch._check(
input.numel() >= required_size,
lambda: (
f"as_strided_scatter: sizes {size}, strides {stride}, storage offset {storage_offset} "
@@ -1832,7 +1830,7 @@
f"for storage of size {input.numel() * input.element_size()}"
),
)
- utils.check(
+ torch._check(
utils.is_same_shape(src.shape, size),
lambda: f"expected src to have a size equal to the slice of self. src size = {src.shape}, slice size = {size}",
)
@@ -2432,11 +2430,11 @@
device: torch.device,
requires_grad: bool,
) -> TensorLikeType:
- utils.check(
+ torch._check(
utils.is_integer_dtype(dtype),
lambda: "prims.iota only supports integer dtypes",
)
- utils.check(step != 0, lambda: "step must be nonzero")
+ torch._check(step != 0, lambda: "step must be nonzero")
return torch.empty(
length,
dtype=dtype,
@@ -2532,7 +2530,7 @@
) -> TensorLikeType:
p_strides = utils.make_contiguous_strides_for([shape[l] for l in physical_layout])
dim = len(shape)
- utils.check(
+ torch._check(
len(physical_layout) == dim,
lambda: (
"Number of dimensions in the tensor input does not match the "
@@ -2543,7 +2541,7 @@
strides = [0] * len(shape)
seen_dims = set()
for p, l in enumerate(physical_layout):
- utils.check(
+ torch._check(
0 <= l < dim,
lambda: (
f"Dimension out of range (expected to be between 0 and {dim - 1}, but got "
@@ -2551,7 +2549,7 @@
"not currently supported; file an issue if you want it."
),
)
- utils.check(l not in seen_dims, lambda: "Duplicate dim not allowed")
+ torch._check(l not in seen_dims, lambda: "Duplicate dim not allowed")
strides[l] = p_strides[p]
seen_dims.add(l)
return TensorMeta(
@@ -2779,12 +2777,12 @@
device: torch.device,
requires_grad: bool,
) -> TensorLikeType:
- utils.check(
+ torch._check(
std >= 0.0,
lambda: f"expected non-negative standard deviation, but got std={std}",
)
- utils.check(
+ torch._check(
utils.is_float_dtype(dtype) or utils.is_complex_dtype(dtype),
lambda: f"expected a floating-point or complex dtype, but got dtype={dtype}",
)
diff --git a/torch/_prims_common/__init__.py b/torch/_prims_common/__init__.py
index b0c8664..0807197 100644
--- a/torch/_prims_common/__init__.py
+++ b/torch/_prims_common/__init__.py
@@ -7,6 +7,7 @@
import operator
import sympy
import weakref
+import warnings
import torch
from torch import sym_float, sym_int, sym_max
@@ -268,7 +269,7 @@
def validate_memory_format(memory_format: torch.memory_format):
- check(
+ torch._check(
memory_format in _memory_formats,
lambda: f"Received unknown memory format {memory_format}!",
)
@@ -286,7 +287,7 @@
if memory_format == torch.channels_last_3d:
return is_channels_last_contiguous_3d(a)
- check(
+ torch._check(
False,
lambda: f"is_contiguous received unsupported memory format {memory_format}",
)
@@ -795,13 +796,13 @@
newsize = 1
for i, d in enumerate(shape):
if d == -1:
- check(dim is None, lambda: "only one dimension can be inferred")
+ torch._check(dim is None, lambda: "only one dimension can be inferred")
dim = i
elif d >= 0:
newsize *= d
else:
- check(False, lambda: f"invalid shape dimension {d}")
- check(
+ torch._check(False, lambda: f"invalid shape dimension {d}")
+ torch._check(
numel == newsize or (dim is not None and newsize > 0 and numel % newsize == 0),
lambda: f"shape '{list(shape)}' is invalid for input of size {numel}",
)
@@ -809,7 +810,7 @@
# Convert to list to produce a compatible error message with core
# PyTorch, which prints sequences in square brackets.
shape = list(shape)
- check(
+ torch._check(
newsize != 0,
lambda: (f"cannot reshape tensor of 0 elements into shape {shape} because the "
f"unspecified dimension size -1 can be any value and is ambiguous"),
@@ -954,18 +955,18 @@
Checks whether the input is floating point or complex.
If allow_low_precision_dtypes is True, it allows having float16, bfloat16, and complex32
"""
- check(
+ torch._check(
is_float_dtype(dtype) or is_complex_dtype(dtype),
lambda: f"{fn_name}: Expected a floating point or complex tensor as input. Got {dtype}",
)
- check(
+ torch._check(
allow_low_precision_dtypes or not is_low_precision_dtype(dtype),
lambda: f"{fn_name}: Half precision dtypes not supported. Got {dtype}",
)
def check_is_matrix(A: TensorLikeType, f_name: str, arg_name: str = "A"):
- check(
+ torch._check(
len(A.shape) >= 2,
lambda: f"{f_name}: The input tensor {arg_name} must have at least 2 dimensions.",
)
@@ -1060,11 +1061,11 @@
def check_pin_memory(pin_memory: bool):
- check(not pin_memory, lambda: "PrimTorch does not support pinned memory", NotImplementedError)
+ torch._check_not_implemented(not pin_memory, lambda: "PrimTorch does not support pinned memory")
def check_layout(layout: torch.layout):
- check(layout == torch.strided, lambda: f"PrimTorch doesn't support layout={layout}", NotImplementedError)
+ torch._check_not_implemented(layout == torch.strided, lambda: f"PrimTorch doesn't support layout={layout}")
# TODO: maybe unify with can_cast_to?
@@ -1485,7 +1486,7 @@
def make_channels_last_1d_strides_for(shape: ShapeType) -> Tuple[int, ...]:
- check(
+ torch._check(
len(shape) == 3,
lambda: "Only tensors of rank 3 can use the channels_last_1d memory format",
)
@@ -1503,7 +1504,7 @@
def make_channels_last_2d_strides_for(shape: ShapeType) -> Tuple[int, ...]:
# TODO: maybe inform the user of channels_last_3d if rank of the tensor is 5?
- check(
+ torch._check(
len(shape) == 4,
lambda: "Only tensors of rank 4 can use the channels_last memory format",
)
@@ -1520,7 +1521,7 @@
def make_channels_last_3d_strides_for(shape: ShapeType) -> Tuple[int, ...]:
- check(
+ torch._check(
len(shape) == 5,
lambda: "Only tensors of rank 5 can use the channels_last_3d memory format",
)
@@ -1654,6 +1655,9 @@
raise ValueError(msg)
+# NOTE: This function should ideally be removed, but some Meta internal models
+# packaged with `torch.package` are using it, so it will have to be removed
+# at some point in the future when those models no longer use this function.
def check(
b: bool, s: Callable[[], str], exc_type: Type[Exception] = RuntimeError
) -> None:
@@ -1662,9 +1666,14 @@
Error message is a callable producing a string (to avoid wasting time
string formatting in non-error case, and also to make it easier for torchdynamo
to trace.)
+
+ .. note:: This function is planned for removal in the future. Please use
+ `torch._check*` functions instead.
"""
- if not b:
- raise exc_type(s())
+ warnings.warn(DeprecationWarning((
+ "'torch._prims_common.check' will be removed in the future. Please use "
+ "'torch._check*' functions instead")))
+ torch._check_with(exc_type, b, s)
# This combines is_channels_last_strides_2d and is_channels_last_strides_3d in
diff --git a/torch/_prims_common/wrappers.py b/torch/_prims_common/wrappers.py
index 5f10a2a..041fb76 100644
--- a/torch/_prims_common/wrappers.py
+++ b/torch/_prims_common/wrappers.py
@@ -176,13 +176,13 @@
# Checks safe cast
if exact_dtype:
- utils.check(
+ torch._check(
copy_from.dtype == copy_to.dtype,
lambda: f"Expected out tensor to have dtype {copy_from.dtype} "
f"but got {copy_to.dtype} instead",
)
else:
- utils.check(
+ torch._check(
utils.can_safe_cast_to(cast_from=copy_from.dtype, cast_to=copy_to.dtype),
lambda: f"Attempting to cast from {copy_from.dtype} to out tensor with dtype {copy_to.dtype}, "
"but this can't be cast because it is not safe!",
@@ -255,10 +255,9 @@
_safe_copy_out(copy_from=result, copy_to=out, exact_dtype=exact_dtype) # type: ignore[arg-type]
else:
assert isinstance(out, Tuple) # type: ignore[arg-type]
- utils.check(
+ torch._check_type(
len(out) == len(result),
lambda: f"expected tuple of {len(result)} elements but got {len(out)}",
- TypeError,
)
for r, o in zip(result, out):
# These two operations are done in-place
diff --git a/torch/_refs/__init__.py b/torch/_refs/__init__.py
index f1778eb..a1e66cb 100644
--- a/torch/_refs/__init__.py
+++ b/torch/_refs/__init__.py
@@ -16,7 +16,6 @@
import torch._prims_common as utils
from torch import sym_float, sym_int
from torch._prims_common import (
- check,
DeviceLikeType,
Dim,
DimsSequenceType,
@@ -626,7 +625,7 @@
# imag does not use _make_elementwise_unary_reference because it does not support out
def imag(a: TensorLikeType) -> TensorLikeType:
assert isinstance(a, TensorLike)
- utils.check(
+ torch._check(
utils.is_complex_dtype(a.dtype), lambda: "imag only supports complex tensors."
)
return prims.imag(a)
@@ -654,7 +653,7 @@
@_make_elementwise_unary_reference(ELEMENTWISE_TYPE_PROMOTION_KIND.ALWAYS_BOOL)
def isposinf(a: TensorLikeType) -> TensorLikeType:
- utils.check(
+ torch._check(
not utils.is_complex_dtype(a.dtype),
lambda: f"Complex dtype is not supported for isposinf, got dtype {a.dtype}",
)
@@ -665,7 +664,7 @@
@_make_elementwise_unary_reference(ELEMENTWISE_TYPE_PROMOTION_KIND.ALWAYS_BOOL)
def isneginf(a: TensorLikeType) -> TensorLikeType:
- utils.check(
+ torch._check(
not utils.is_complex_dtype(a.dtype),
lambda: f"Complex dtype is not supported for isneginf, got dtype {a.dtype}",
)
@@ -788,7 +787,7 @@
def _neg_meta(a: TensorLikeType):
- check(
+ torch._check(
a.dtype is not torch.bool,
lambda: (
"Negation, the `-` operator, on a bool tensor is not supported. "
@@ -935,23 +934,20 @@
a: Union[Tensor, NumberType],
b: Union[Tensor, NumberType],
) -> Tensor:
- check(
+ torch._check_value(
supports_lhs_python_scalar or not isinstance(a, Number),
lambda: f"{name}: Received a lhs Python scalar to an elementwise binary "
"operation that does not accept lhs scalars!",
- ValueError,
)
- check(
+ torch._check_value(
supports_rhs_python_scalar or not isinstance(b, Number),
lambda: f"{name}: Received a rhs Python scalar to an elementwise binary "
"operation that does not accept rhs scalars!",
- ValueError,
)
- check(
+ torch._check_value(
supports_two_python_scalars
or not (isinstance(a, Number) and isinstance(b, Number)),
lambda: f"{name}: Receive two Number inputs to an elementwise binary operation!",
- ValueError,
)
a, b = _maybe_broadcast(a, b)
return prim(a, b)
@@ -1230,7 +1226,7 @@
elif utils.is_integer_dtype(dtype):
return _floor_divide_integer(a, b)
else:
- check(False, lambda: f"{dtype} not supported for floor_divide")
+ torch._check(False, lambda: f"{dtype} not supported for floor_divide")
def _floor_divide_integer(a: Tensor, b: Tensor) -> Tensor:
@@ -1374,20 +1370,19 @@
rtol: float,
atol: float,
) -> None:
- check(
+ torch._check_value(
a.dtype == b.dtype,
lambda: "{0}: Attempting to compare tensors of different dtypes {1} and {2}!".format(
name, a.dtype, b.dtype
),
- ValueError,
)
- check(
+ torch._check(
rtol >= 0,
lambda: "{0}: rtol must be greater than or equal to zero, but got {1}!".format(
name, rtol
),
)
- check(
+ torch._check(
atol >= 0,
lambda: "{0}: atol must be greater than or equal to zero, but got {1}!".format(
name, atol
@@ -1678,7 +1673,7 @@
type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT,
)
def xlogy(a: Union[TensorLikeType, NumberType], b: Union[TensorLikeType, NumberType]):
- utils.check(
+ torch._check(
isinstance(a, TensorLike) or isinstance(b, TensorLike),
lambda: 'Expected either argument a or b to be a Tensor"',
)
@@ -1736,12 +1731,11 @@
if value is not None:
dtype = self.dtype # no scalars allowed, see add
python_type = utils.dtype_to_type(dtype)
- check(
+ torch._check_value(
utils.is_weakly_lesser_type(type(value), python_type),
lambda: "value argument of type {0} cannot be safely cast to type {1}!".format(
type(value), python_type
),
- exc_type=ValueError,
)
return self + value * tensor1 / tensor2
@@ -1766,12 +1760,11 @@
if value is not None:
dtype = self.dtype # no scalars allowed, see add
python_type = utils.dtype_to_type(dtype)
- check(
+ torch._check_value(
utils.is_weakly_lesser_type(type(value), python_type),
lambda: "value argument of type {0} cannot be safely cast to type {1}!".format(
type(value), python_type
),
- exc_type=ValueError,
)
return self + value * tensor1 * tensor2
@@ -1851,7 +1844,7 @@
raise NotImplementedError
utils.check_same_device(pred, a, b, allow_cpu_scalar_tensors=True)
- check(
+ torch._check(
pred.dtype is torch.bool,
lambda: f"expected predicate to be bool, got {pred.dtype}",
)
@@ -2229,7 +2222,7 @@
*shape,
) -> Tensor:
shape = utils.extract_shape_from_varargs(shape, validate=False)
- utils.check(
+ torch._check(
utils.is_expandable_to(shape, a.shape),
lambda: f'sum_to_size: size "{shape}" is not expandable to size "{a.shape}"',
)
@@ -2402,7 +2395,7 @@
if dtype is None:
dtype = a.dtype
# can't use out wrapper because of this argument
- check(
+ torch._check(
out is None or out.dtype == dtype,
lambda: f"Expected out tensor to have dtype {dtype}, but got {out.dtype} instead",
)
@@ -2415,7 +2408,7 @@
out=None,
output_dtype_kind=REDUCTION_OUTPUT_TYPE_KIND.KEEP_PROMOTED_TYPE,
)
- check(
+ torch._check(
utils.is_float_dtype(dtype) or utils.is_complex_dtype(dtype),
lambda: (
f"mean(): could not infer output dtype. "
@@ -2491,22 +2484,22 @@
beta: NumberType = 1,
alpha: NumberType = 1,
) -> TensorLikeType:
- check(
+ torch._check(
vec1.ndim == 1,
lambda: f"addr: Expected 1-D argument vec1, but got {vec1.ndim}-D",
)
- check(
+ torch._check(
vec2.ndim == 1,
lambda: f"addr: Expected 1-D argument vec2, but got {vec2.ndim}-D",
)
self = self.expand(vec1.shape[0], vec2.shape[0])
if utils.is_boolean_dtype(self.dtype):
# Integers are accepted for booleans
- check(
+ torch._check(
is_weakly_lesser_type(type(beta), int),
lambda: f"expected bool/int beta but got {type(beta)}",
)
- check(
+ torch._check(
is_weakly_lesser_type(type(alpha), int),
lambda: f"expected bool/int alpha but got {type(beta)}",
)
@@ -2518,11 +2511,11 @@
torch.outer(vec1, vec2) if alpha else torch.full_like(self, False),
)
else:
- check(
+ torch._check(
is_weakly_lesser_type(type(beta), dtype_to_type(self.dtype)),
lambda: f"cannot safely convert {type(beta)} to {self.dtype}",
)
- check(
+ torch._check(
is_weakly_lesser_type(type(alpha), dtype_to_type(self.dtype)),
lambda: f"cannot safely convert {type(alpha)} to {self.dtype}",
)
@@ -2712,7 +2705,7 @@
def constant_pad_nd(
input: TensorLikeType, pad: List[int], value: NumberType = 0
) -> TensorLikeType:
- check(
+ torch._check(
len(pad) % 2 == 0,
lambda: f"Length of pad must be even but instead it equals {len(pad)}",
)
@@ -2723,7 +2716,7 @@
l_pad = len(pad) // 2
l_diff = l_inp - l_pad
- check(
+ torch._check(
l_inp >= l_pad,
lambda: "Length of pad should be no more than twice the number of "
f"dimensions of the input. Pad length is {len(pad)} while the input has "
@@ -2748,7 +2741,7 @@
for i in range(l_pad):
pad_idx = len(pad) - ((i + 1) * 2)
new_dim = input_sizes[l_diff + i] + pad[pad_idx] + pad[pad_idx + 1]
- check(
+ torch._check(
new_dim > 0,
lambda: f"The input size {input_sizes[l_diff + i]}, plus negative padding "
f"{pad[pad_idx]} and {pad[pad_idx + 1]} resulted in a negative output size, "
@@ -2787,7 +2780,7 @@
def contiguous(
a: Tensor, *, memory_format: torch.memory_format = torch.contiguous_format
) -> Tensor:
- check(
+ torch._check(
memory_format != torch.preserve_format,
lambda: "preserve memory format is unsupported by the contiguous operator",
)
@@ -2800,7 +2793,7 @@
@out_wrapper()
def dstack(tensors: TensorSequenceType) -> TensorLikeType:
- check(len(tensors) > 0, lambda: "dstack expects a non-empty TensorList")
+ torch._check(len(tensors) > 0, lambda: "dstack expects a non-empty TensorList")
aligned_tensors = atleast_3d(*tensors)
return cat(aligned_tensors, 2)
@@ -2813,7 +2806,7 @@
if len(shape) == 1 and isinstance(shape[0], Sequence):
shape = tuple(shape[0])
- check(
+ torch._check(
len(shape) >= len(a.shape),
lambda: "expand: the requested shape has too few dimensions!",
)
@@ -2823,7 +2816,7 @@
for idx, x in enumerate(a.shape):
offset_idx = idx + offset
requested_length = shape[offset_idx]
- check(
+ torch._check(
requested_length == x or x == 1 or requested_length == -1,
lambda: f"expand: attempting to expand a dimension of length {x}!",
)
@@ -2917,13 +2910,13 @@
# Supports Tensor overload that was added for XLA:
# https://github.com/pytorch/pytorch/issues/31558
if isinstance(start, TensorLike):
- check(
+ torch._check(
start.dim() == 0 and utils.is_integer_dtype(start.dtype),
lambda: "start must be an 0-dim integral Tensor.",
)
start = start.item() # type: ignore[assignment]
- check(a.dim() > 0, lambda: "narrow() cannot be applied to a 0-dim tensor.")
- check(length >= 0, lambda: "narrow(): length must be non-negative.")
+ torch._check(a.dim() > 0, lambda: "narrow() cannot be applied to a 0-dim tensor.")
+ torch._check(length >= 0, lambda: "narrow(): length must be non-negative.")
dim = utils.canonicalize_dim(a.ndim, dim)
dim_length = a.size(dim)
# Start being the end is usually invalid since it's out of bounds. So it's
@@ -2934,7 +2927,7 @@
# Note: a dimension isn't being canonicalized here, this reuses
# canonicalize_dim because the semantics are similar.
start = utils.canonicalize_dim(dim_length, start) # type: ignore[arg-type]
- check(
+ torch._check(
start <= dim_length - length, # type: ignore[arg-type]
lambda: f"start ({start}) + length ({length}) exceeds dimension size ({dim_length}).",
)
@@ -2993,11 +2986,11 @@
num_groups: int,
eps: float,
) -> Tuple[Tensor, Tensor, Tensor]:
- utils.check(
+ torch._check(
input.ndim >= 2,
lambda: f"Expected at least 2 dimensions for input tensor but received {input.ndim}",
)
- utils.check(
+ torch._check(
num_channels % num_groups == 0,
lambda: "Expected number of channels in input to be divisible by num_groups, "
+ f"but got input of shape {input.shape} and num_groups = {num_groups}",
@@ -3044,7 +3037,7 @@
eps: float,
) -> Tuple[Tensor, Tensor, Tensor]:
normalized_ndim = len(normalized_shape)
- utils.check(
+ torch._check(
normalized_ndim >= 1,
lambda: "Expected normalized_shape to be at least 1-dimensional, i.e., "
+ "containing at least one element, but got normalized_shape = "
@@ -3053,7 +3046,7 @@
# torch.Size([1, 2, 3]) == [1, 2, 3] evaluates to False
# while torch.Size([1, 2, 3]) == (1, 2, 3) is True
# therefore we use tuple(normalized_shape)
- utils.check(
+ torch._check(
weight is None or weight.shape == tuple(normalized_shape),
lambda: "Expected weight to be of same shape as normalized_shape, but got "
+ "weight of shape "
@@ -3061,7 +3054,7 @@
+ " and normalized_shape = "
+ str(normalized_shape),
)
- utils.check(
+ torch._check(
bias is None or bias.shape == tuple(normalized_shape),
lambda: "Expected bias to be of same shape as normalized_shape, but got "
+ "bias of shape "
@@ -3069,7 +3062,7 @@
+ " and normalized_shape = "
+ str(normalized_shape),
)
- utils.check(
+ torch._check(
input.ndim >= normalized_ndim
and input.shape[(input.ndim - normalized_ndim) :] == tuple(normalized_shape),
lambda: "Given normalized_shape="
@@ -3123,12 +3116,12 @@
max_size = 1 if a_ndim == 0 else a_shape[dim]
last_stride = 1 if a_ndim == 0 else a_stride[dim]
- utils.check(
+ torch._check(
size <= max_size,
lambda: f"Maximum size for tensor at dimension {dim} is {max_size} but size is {size}",
)
- utils.check(
+ torch._check(
step > 0,
lambda: f"Step is {step} but must be > 0",
)
@@ -3146,7 +3139,7 @@
@register_decomposition(aten.repeat)
def repeat(a: Tensor, *repeat_shape) -> Tensor:
repeat_shape = utils.extract_shape_from_varargs(repeat_shape, validate=False)
- utils.check(
+ torch._check(
len(repeat_shape) >= len(a.shape),
lambda: "repeat: Number of dimensions of repeat dims can not be smaller than number of dimensions of tensor",
)
@@ -3452,7 +3445,7 @@
# CompositeImplicitAutograd - don't register decomp
@out_wrapper()
def hstack(tensors: TensorSequenceType) -> TensorLikeType:
- check(len(tensors) > 0, lambda: "hstack expects a non-empty TensorList")
+ torch._check(len(tensors) > 0, lambda: "hstack expects a non-empty TensorList")
aligned_tensors = atleast_1d(*tensors)
if aligned_tensors[0].ndim == 1:
return cat(aligned_tensors, 0)
@@ -3462,7 +3455,7 @@
# CompositeImplicitAutograd - don't register decomp
@out_wrapper()
def vstack(tensors: TensorSequenceType) -> TensorLikeType:
- check(len(tensors) > 0, lambda: "vstack expects a non-empty TensorList")
+ torch._check(len(tensors) > 0, lambda: "vstack expects a non-empty TensorList")
aligned_tensors = atleast_2d(*tensors)
return cat(aligned_tensors, 0)
@@ -3470,17 +3463,16 @@
# CompositeImplicitAutograd - don't register decomp
def unflatten(a: TensorLikeType, dim: int, sizes: ShapeType) -> TensorLikeType:
dim = utils.canonicalize_dim(a.ndim, dim)
- utils.check(len(sizes) != 0, lambda: "unflatten: sizes must be non-empty")
+ torch._check(len(sizes) != 0, lambda: "unflatten: sizes must be non-empty")
return a.view(tuple(a.shape[:dim]) + tuple(sizes) + tuple(a.shape[dim + 1 :]))
@register_decomposition(aten.unbind)
def unbind(t: TensorLikeType, dim: int = 0) -> TensorSequenceType:
dim = utils.canonicalize_dim(t.ndim, dim)
- check(
+ torch._check_index(
len(t.shape) > 0,
lambda: "Dimension specified as 0 but tensor has no dimensions",
- IndexError,
)
if t.shape[dim] == 0:
return tuple()
@@ -3499,7 +3491,7 @@
def index_copy_(x: TensorLike, dim: int, index: TensorLike, tensor: TensorLike):
dim = utils.canonicalize_dims(x.ndim, dim)
- utils.check(
+ torch._check(
index.ndim <= 1,
lambda: f"Index should have dimension 1 or 0 (got {index.ndim})",
)
@@ -3532,12 +3524,12 @@
*,
inplace: bool,
):
- utils.check(
+ torch._check(
index.ndim <= 1,
lambda: f"Index should have dimension 1 or 0 (got {index.ndim})",
)
if isinstance(value, TensorLike):
- utils.check(
+ torch._check(
value.ndim == 0,
lambda: "Only supports 0-dimensional value tensor. " # type: ignore[union-attr]
f"Got a tensor with {value.ndim} dimensions.",
@@ -3589,7 +3581,7 @@
@out_wrapper()
def index_select(x: TensorLike, dim: int, index: TensorLike):
dim = utils.canonicalize_dims(x.ndim, dim)
- utils.check(
+ torch._check(
index.ndim <= 1,
lambda: f"Index should have dimension 1 or 0 (got {index.ndim})",
)
@@ -3713,7 +3705,7 @@
def hsplit(
a: TensorLikeType, indices_or_sections: DimsType
) -> Tuple[TensorLikeType, ...]:
- check(
+ torch._check(
a.ndim >= 1,
lambda: (
"torch.hsplit requires a tensor with at least 1 dimension, but got a tensor with "
@@ -3724,7 +3716,7 @@
dim = 0 if a.ndim == 1 else 1
if isinstance(indices_or_sections, IntLike):
split_size = indices_or_sections
- check(
+ torch._check(
(split_size != 0 and a.shape[dim] % split_size == 0),
lambda: (
"torch.hsplit attempted to split along dimension "
@@ -3738,14 +3730,13 @@
)
return tensor_split(a, split_size, dim)
- check(
+ torch._check_type(
isinstance(indices_or_sections, (list, tuple)),
lambda: (
"hsplit(): received an invalid combination of arguments. "
"Expected indices_or_sections to be of type int, list of ints or tuple of ints "
f"but got type {type(indices_or_sections)}"
),
- exc_type=TypeError,
)
split_sizes = indices_or_sections
@@ -3756,7 +3747,7 @@
def vsplit(
a: TensorLikeType, indices_or_sections: DimsType
) -> Tuple[TensorLikeType, ...]:
- check(
+ torch._check(
a.ndim >= 2,
lambda: (
"torch.vsplit requires a tensor with at least 2 dimension, but got a tensor with "
@@ -3766,7 +3757,7 @@
)
if isinstance(indices_or_sections, IntLike):
split_size = indices_or_sections
- check(
+ torch._check(
(split_size != 0 and a.shape[0] % split_size == 0),
lambda: (
f"torch.vsplit attempted to split along dimension 0"
@@ -3779,14 +3770,13 @@
)
return tensor_split(a, split_size, 0)
- check(
+ torch._check_type(
isinstance(indices_or_sections, (list, tuple)),
lambda: (
"vsplit(): received an invalid combination of arguments. "
"Expected indices_or_sections to be of type int, list of ints or tuple of ints "
f"but got type {type(indices_or_sections)}"
),
- exc_type=TypeError,
)
split_sizes = indices_or_sections
@@ -3800,7 +3790,7 @@
offset: int = 0,
) -> TensorLikeType:
ndim = self.dim()
- utils.check(
+ torch._check(
ndim in (1, 2), lambda: f"diag(): Supports 1D or 2D tensors. Got {ndim}D"
)
if ndim == 1:
@@ -3820,7 +3810,7 @@
) -> TensorLikeType:
out = utils.clone_preserve_strides(input)
diag = out.diagonal(offset, dim1, dim2)
- check(
+ torch._check(
diag.shape == src.shape,
lambda: "expected src to have a size equal to the diagonal of the input."
f"Got {src.shape} for a diagonal of shape {diag.shape}",
@@ -3843,7 +3833,7 @@
dim1 = utils.canonicalize_dim(idx=dim1, rank=num_dims)
dim2 = utils.canonicalize_dim(idx=dim2, rank=num_dims)
- check(
+ torch._check(
dim1 != dim2, lambda: f"diagonal dimensions cannot be identical {dim1}, {dim2}"
)
@@ -3896,7 +3886,7 @@
dim1 = utils.canonicalize_dim(rank=rank, idx=dim1)
dim2 = utils.canonicalize_dim(rank=rank, idx=dim2)
- check(
+ torch._check(
dim1 != dim2, lambda: f"diagonal dimensions cannot be identical {dim1}, {dim2}"
)
@@ -3967,7 +3957,7 @@
# CompositeImplicitAutograd - don't register decomp
def T(a: TensorLikeType) -> TensorLikeType:
# n != 2 && n != 0 is deprecated in regular PyTorch.
- check(
+ torch._check(
a.ndim in (0, 2),
lambda: (
"The use of `x.T` on tensors of dimension other than 0 or 2 "
@@ -4102,7 +4092,7 @@
pin_memory: bool = False,
memory_format: torch.memory_format = torch.contiguous_format,
) -> TensorLikeType:
- check(
+ torch._check(
memory_format != torch.preserve_format,
lambda: "torch.empty: the Preserve memory format is not supported",
)
@@ -4114,7 +4104,7 @@
elif memory_format == torch.channels_last_3d:
strides = utils.make_channels_last_3d_strides_for(shape)
else: # memory_format == torch.channels_last
- check(
+ torch._check(
memory_format == torch.channels_last,
lambda: f"torch.empty: received an unknown memory format {memory_format}!",
)
@@ -4398,8 +4388,8 @@
if end is None:
end = start
start = 0
- utils.check(step != 0, lambda: "step must be nonzero")
- utils.check(
+ torch._check(step != 0, lambda: "step must be nonzero")
+ torch._check(
(step > 0 and end >= start) or (step < 0 and end <= start),
lambda: "upper bound and lower bound inconsistent with step sign",
)
@@ -4407,11 +4397,11 @@
def is_finite(x):
return not isinstance(x, FloatWithoutSymFloat) or math.isfinite(x)
- utils.check(
+ torch._check(
is_finite(start) and is_finite(end),
lambda: f"unsupported range: {start} -> {end}",
)
- utils.check(
+ torch._check(
is_finite(step),
lambda: f"step must be finite but got {step}",
)
@@ -4514,7 +4504,7 @@
if dtype is None:
dtype = default_complex_dtype
else:
- check(
+ torch._check(
utils.is_complex_dtype(dtype),
lambda: f"linspace(): inferred dtype {default_complex_dtype} can't be safely cast to passed dtype {dtype}",
)
@@ -4523,13 +4513,12 @@
assert isinstance(dtype, torch.dtype)
# steps does not participate in the computation of the dtype
- check(
+ torch._check_type(
isinstance(steps, IntLike),
lambda: "steps must be int, not float",
- exc_type=TypeError,
)
assert isinstance(steps, IntLike) # for mypy
- check(steps >= 0, lambda: "number of steps must be non-negative")
+ torch._check(steps >= 0, lambda: "number of steps must be non-negative")
factory_kwargs = {
"layout": layout,
@@ -4631,19 +4620,19 @@
assert len(tensors) == 1
tensors = tuple(tensors[0])
- check(
+ torch._check(
py_all(isinstance(a, TensorLike) for a in tensors),
lambda: "meshgrid expects its inputs to be tensors",
)
- check(len(tensors) > 0, lambda: "meshgrid expects a non-empty TensorList")
+ torch._check(len(tensors) > 0, lambda: "meshgrid expects a non-empty TensorList")
for i in range(len(tensors) - 1):
- check(
+ torch._check(
tensors[i].dtype == tensors[i + 1].dtype, # type: ignore[union-attr]
lambda: "meshgrid expects all tensors to have the same dtype",
)
- check(
+ torch._check(
tensors[i].device == tensors[i + 1].device, # type: ignore[union-attr]
lambda: "meshgrid expects all tensors to have the same device",
)
@@ -4654,7 +4643,7 @@
if swap_first_and_second_tensors:
tensors = (tensors[1], tensors[0], *tensors[2:])
else:
- check(
+ torch._check(
indexing == "ij",
lambda: (
'torch.meshgrid: indexing must be one of "xy" or "ij", '
@@ -4665,7 +4654,7 @@
result_shape: List[int] = []
for t in tensors:
assert isinstance(t, TensorLike) # mypy
- check(
+ torch._check(
t.ndim == 0 or t.ndim == 1,
lambda: f"torch.meshgrid: Expected 0D or 1D tensor in the tensor list but got: {t}",
)
@@ -4701,7 +4690,7 @@
# Converts to list to produce a compatible error message with core PyTorch,
# which prints sequences in square brackets.
- utils.check(
+ torch._check(
len(source) == len(destination), # type: ignore[arg-type]
lambda: (
"movedim: Invalid source or destination dims: source " # type: ignore[arg-type]
@@ -4718,11 +4707,11 @@
dss = set(ds)
# See above on why this converts to list in error messages.
- utils.check(
+ torch._check(
len(ss) == len(sss),
lambda: f"movedim: repeated dim in `source` ({list(source)})", # type: ignore[arg-type]
)
- utils.check(
+ torch._check(
len(ds) == len(dss),
lambda: f"movedim: repeated dim in `destination` ({list(destination)})", # type: ignore[arg-type]
)
@@ -4795,8 +4784,8 @@
if m is None:
m = n
- check(n >= 0, lambda: f"n must be greater or equal to 0, got {n}")
- check(m >= 0, lambda: f"m must be greater or equal to 0, got {m}")
+ torch._check(n >= 0, lambda: f"n must be greater or equal to 0, got {n}")
+ torch._check(m >= 0, lambda: f"m must be greater or equal to 0, got {m}")
range_n = torch.arange(n, dtype=torch.int64, device=device, requires_grad=False)
range_m = torch.arange(m, dtype=torch.int64, device=device, requires_grad=False)
@@ -4994,13 +4983,13 @@
# NOTE: Could not use value = item(value) as it resulted in
# RuntimeError: Cannot cast FakeTensor(cpu) to number
value_ndim = value.ndim
- check(
+ torch._check(
value_ndim == 0,
lambda: f"only supports a 0-dimensional value tensor, but got tensor with {value_ndim} dimension",
)
# `masked_fill` allows cpu scalar to be moved to cuda and xpu but not otherwise.
is_cpu_scalar = a.device.type in ["cuda", "xpu"] and value.device.type == "cpu"
- check(
+ torch._check(
is_cpu_scalar or value.device == a.device,
lambda: "Expected `value` to be on same device as `a`",
)
@@ -5011,7 +5000,7 @@
# We allow casting `value` to lower type for other case
# Eg. float -> int.
# Ref: https://github.com/pytorch/pytorch/issues/79195
- check(
+ torch._check(
utils.is_weakly_lesser_type(value_type, python_type),
lambda: f"could not convert to type {python_type} without overflow",
)
@@ -5101,7 +5090,7 @@
@register_decomposition(aten.trace)
def trace(self: TensorLikeType) -> TensorLikeType:
- utils.check(
+ torch._check(
self.ndim == 2, lambda: "expected a matrix, but got tensor with dim {self.ndim}"
)
return torch.sum(torch.diag(self, 0))
@@ -5125,7 +5114,7 @@
@register_decomposition(aten.triu)
@out_wrapper()
def triu(a: TensorLikeType, diagonal: int = 0) -> TensorLikeType:
- utils.check(
+ torch._check(
a.ndim >= 2, lambda: "triu: input tensor must have at least 2 dimensions"
)
h, w = a.shape[-2:]
@@ -5142,7 +5131,7 @@
@register_decomposition(aten.tril)
@out_wrapper()
def tril(a: TensorLikeType, diagonal: int = 0) -> TensorLikeType:
- utils.check(
+ torch._check(
a.ndim >= 2, lambda: "tril: input tensor must have at least 2 dimensions"
)
h, w = a.shape[-2:]
@@ -5187,9 +5176,9 @@
layout: torch.layout,
pin_memory: bool,
):
- check(row >= 0, lambda: f"row must be non-negative, got {row}")
- check(col >= 0, lambda: f"col must be non-negative, got {col}")
- check(
+ torch._check(row >= 0, lambda: f"row must be non-negative, got {row}")
+ torch._check(col >= 0, lambda: f"col must be non-negative, got {col}")
+ torch._check(
dtype in (torch.int32, torch.int64),
lambda: f"\"{name}\" not implemented for '{dtype}'",
)
@@ -5306,7 +5295,7 @@
out_int32: bool = False,
right: bool = False,
):
- utils.check(
+ torch._check(
boundaries.dim() == 1,
lambda: f"boundaries tensor must be 1 dimension but got dim({boundaries.dim()})",
)
@@ -5364,14 +5353,14 @@
)
def cauchy(self, median=0, sigma=1, generator=None):
assert generator is None
- utils.check(
+ torch._check(
not utils.is_complex_dtype(self.dtype)
and not utils.is_integer_dtype(self.dtype)
and not utils.is_boolean_dtype(self.dtype),
lambda: f"Cauchy distribution is a continuous probability distribution. \
dtype must be a floating point but you specified {self.dtype}",
)
- utils.check(
+ torch._check(
sigma > 0.0,
lambda: f"cauchy_ expects sigma > 0.0, but found sigma={sigma}",
)
@@ -5386,14 +5375,14 @@
)
def exponential(self, rate=1, generator=None):
assert generator is None
- utils.check(
+ torch._check(
not utils.is_complex_dtype(self.dtype)
and not utils.is_integer_dtype(self.dtype)
and not utils.is_boolean_dtype(self.dtype),
lambda: f"Exponential distribution is a continuous probability distribution. \
dtype must be a floating point but you specified {self.dtype}",
)
- utils.check(
+ torch._check(
rate > 0.0,
lambda: f"exponential_ expects lambda > 0.0, but found lambda={rate}",
)
@@ -5409,12 +5398,12 @@
def geometric(self, p, generator=None):
assert generator is None
# TODO: fix inductor rand_like for integer, bool dtypes
- utils.check(
+ torch._check(
not utils.is_complex_dtype(self.dtype)
and not utils.is_boolean_dtype(self.dtype),
lambda: f"geometric not implemented for {self.dtype}",
)
- utils.check(
+ torch._check(
0 < p and p < 1,
lambda: f"geometric_ expects p to be in (0, 1), but got p={p}",
)
@@ -5429,13 +5418,13 @@
)
def log_normal(self, mean=1, std=2, generator=None):
assert generator is None
- utils.check(
+ torch._check(
not utils.is_complex_dtype(self.dtype)
and not utils.is_integer_dtype(self.dtype)
and not utils.is_boolean_dtype(self.dtype),
lambda: f"log_normal not implemented for {self.dtype}",
)
- utils.check(
+ torch._check(
0 < std,
lambda: f"log_normal_ expects std > 0.0, but found std={std}",
)
@@ -5451,7 +5440,7 @@
)
def normal(self, mean=0, std=1, generator=None):
assert generator is None
- utils.check(std >= 0, lambda: f"normal expects std >= 0.0, but found std {std}")
+ torch._check(std >= 0, lambda: f"normal expects std >= 0.0, but found std {std}")
normal_samples = prims.normal(
self.shape,
mean=0.0,
@@ -5465,7 +5454,7 @@
@_make_elementwise_unary_reference(ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT)
def rad2deg(self: TensorLikeType):
- utils.check(
+ torch._check(
not utils.is_complex_dtype(self.dtype),
lambda: "rad2deg is not supported for complex tensors.",
)
@@ -5475,7 +5464,7 @@
@_make_elementwise_unary_reference(ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT)
def deg2rad(self: TensorLikeType):
- utils.check(
+ torch._check(
not utils.is_complex_dtype(self.dtype),
lambda: "deg2rad is not supported for complex tensors.",
)
diff --git a/torch/_refs/_conversions.py b/torch/_refs/_conversions.py
index d0607d8..fa1ca24 100644
--- a/torch/_refs/_conversions.py
+++ b/torch/_refs/_conversions.py
@@ -4,7 +4,7 @@
# Utilities should come BEFORE this import
from torch._decomp import register_decomposition
-from torch._prims_common import check, TensorLikeType
+from torch._prims_common import TensorLikeType
from torch._prims_common.wrappers import out_wrapper
from torch._refs import _broadcast_shapes
@@ -79,14 +79,14 @@
@out_wrapper(exact_dtype=True)
def complex(real: TensorLikeType, imag: TensorLikeType) -> TensorLikeType:
allowed_dtypes = (torch.float32, torch.float64, torch.float16)
- check(
+ torch._check(
real.dtype in allowed_dtypes and imag.dtype in allowed_dtypes,
lambda: (
f"Expected both inputs to be Half, Float or Double tensors but got "
f"{real.dtype} and {imag.dtype}"
),
)
- check(
+ torch._check(
real.dtype == imag.dtype,
lambda: (
f"Expected object of scalar type {real.dtype} but got "
diff --git a/torch/_refs/fft.py b/torch/_refs/fft.py
index 54a98c2..4b22d86 100644
--- a/torch/_refs/fft.py
+++ b/torch/_refs/fft.py
@@ -6,7 +6,7 @@
import torch._prims as prims
import torch._prims_common as utils
from torch._decomp import register_decomposition
-from torch._prims_common import check, DimsType, ShapeType, TensorLikeType
+from torch._prims_common import DimsType, ShapeType, TensorLikeType
from torch._prims_common.wrappers import _maybe_convert_to_dtype, out_wrapper
__all__ = [
@@ -43,7 +43,7 @@
x: TensorLikeType, norm: NormType, signal_numel: int, forward: bool
) -> TensorLikeType:
"""Apply normalization to the un-normalized FFT result"""
- check(norm in _NORM_VALUES, lambda: f"Invalid normalization mode: {norm}")
+ torch._check(norm in _NORM_VALUES, lambda: f"Invalid normalization mode: {norm}")
if norm == "ortho":
return x * (1 / math.sqrt(signal_numel))
@@ -116,7 +116,9 @@
input = _maybe_promote_tensor_fft(input, require_complex=True)
dims = (utils.canonicalize_dim(input.ndim, dim, wrap_scalar=False),)
last_dim_size = n if n is not None else 2 * (input.shape[dim] - 1)
- check(last_dim_size >= 1, lambda: f"Invalid number of data points ({n}) specified")
+ torch._check(
+ last_dim_size >= 1, lambda: f"Invalid number of data points ({n}) specified"
+ )
if n is not None:
input = _resize_fft_input(input, dims=dims, sizes=(last_dim_size // 2 + 1,))
@@ -138,7 +140,7 @@
onesided: bool,
) -> TensorLikeType:
"""Common code for performing any real to complex FFT (rfft or ihfft)"""
- check(
+ torch._check(
not input.dtype.is_complex,
lambda: f"{func_name} expects a floating point input tensor, but got {input.dtype}",
)
@@ -162,7 +164,7 @@
forward: bool,
) -> TensorLikeType:
"""Common code for performing any complex to complex FFT (fft or ifft)"""
- check(
+ torch._check(
input.dtype.is_complex,
lambda: f"{func_name} expects a complex input tensor, but got {input.dtype}",
)
@@ -265,20 +267,20 @@
ret_dims = utils.canonicalize_dims(input_dim, dim, wrap_scalar=False)
# Check dims are unique
- check(len(set(dim)) == len(dim), lambda: "FFT dims must be unique")
+ torch._check(len(set(dim)) == len(dim), lambda: "FFT dims must be unique")
if shape is not None:
if not isinstance(shape, Sequence):
shape = (shape,)
# Has shape, might have dim
- check(
+ torch._check(
dim is None or len(dim) == len(shape),
lambda: "When given, dim and shape arguments must have the same length",
)
transform_ndim = len(shape)
- check(
+ torch._check(
transform_ndim <= input_dim,
lambda: f"Got shape with {transform_ndim} values but input tensor "
f"only has {input_dim} dimensions.",
@@ -301,7 +303,7 @@
ret_shape = tuple(input_sizes[d] for d in ret_dims)
for n in ret_shape:
- check(n > 0, lambda: f"Invalid number of data points ({n}) specified")
+ torch._check(n > 0, lambda: f"Invalid number of data points ({n}) specified")
return _ShapeAndDims(shape=ret_shape, dims=ret_dims)
@@ -323,7 +325,7 @@
forward: bool,
) -> TensorLikeType:
"""Common code for n-dimensional complex to complex FFTs (fftn or ifftn)"""
- check(
+ torch._check(
input.dtype.is_complex,
lambda: f"{function_name} expects a complex input tensor, "
f"but got {input.dtype}",
@@ -367,7 +369,7 @@
dim: Optional[DimsType] = None,
norm: NormType = None,
) -> TensorLikeType:
- check(
+ torch._check(
not input.dtype.is_complex,
lambda: f"rfftn expects a real-valued input tensor, but got {input.dtype}",
)
@@ -386,12 +388,12 @@
dim: Optional[DimsType] = None,
norm: NormType = None,
) -> TensorLikeType:
- check(
+ torch._check(
not input.dtype.is_complex,
lambda: f"ihfftn expects a real-valued input tensor, but got {input.dtype}",
)
shape, dim = _canonicalize_fft_shape_and_dim_args(input, s, dim)
- check(len(shape) > 0, lambda: "ihfftn must transform at least one axis")
+ torch._check(len(shape) > 0, lambda: "ihfftn must transform at least one axis")
input = _maybe_promote_tensor_fft(input, require_complex=False)
input = _resize_fft_input(input, dim, shape)
@@ -421,14 +423,14 @@
"""Canonicalize shape and dim arguments for n-dimensional c2r transforms,
as well as calculating the last_dim_size which is shape[dim[-1]] for the output"""
(shape, dim) = _canonicalize_fft_shape_and_dim_args(input, s, dim)
- check(len(shape) > 0, lambda: f"{fname} must transform at least one axis")
+ torch._check(len(shape) > 0, lambda: f"{fname} must transform at least one axis")
if s is None or s[-1] == -1:
last_dim_size = 2 * (input.shape[dim[-1]] - 1)
else:
last_dim_size = shape[-1]
- check(
+ torch._check(
last_dim_size >= 1,
lambda: f"Invalid number of data points ({last_dim_size}) specified",
)
diff --git a/torch/_refs/linalg/__init__.py b/torch/_refs/linalg/__init__.py
index f22926c..97fb1d9 100644
--- a/torch/_refs/linalg/__init__.py
+++ b/torch/_refs/linalg/__init__.py
@@ -11,7 +11,6 @@
import torch._refs.linalg as linalg
from torch import Tensor
from torch._prims_common import (
- check,
check_fp_or_complex,
check_is_matrix,
Dim,
@@ -29,11 +28,11 @@
Checks related to the dtype kwarg in `linalg.*norm` functions
"""
if dtype is not None:
- check(
+ torch._check(
utils.is_float_dtype(dtype) or utils.is_complex_dtype(dtype),
lambda: f"{fn_name}: dtype should be floating point or complex. Got {dtype}",
)
- check(
+ torch._check(
utils.is_complex_dtype(dtype) == utils.is_complex_dtype(x_dtype),
lambda: "{fn_name}: dtype should be {d} for {d} inputs. Got {dtype}".format(
fn_name=fn_name,
@@ -41,7 +40,7 @@
dtype=dtype,
),
)
- check(
+ torch._check(
utils.get_higher_dtype(dtype, x_dtype) == dtype,
lambda: f"{fn_name}: the dtype of the input ({x_dtype}) should be convertible "
"without narrowing to the specified dtype ({dtype})",
@@ -79,7 +78,7 @@
dim = [dim] # type: ignore[assignment]
if x.numel() == 0 and (ord < 0.0 or ord == float("inf")):
- check(
+ torch._check(
dim is not None and len(dim) != 0,
lambda: f"linalg.vector_norm cannot compute the {ord} norm on an empty tensor "
"because the operation does not have an identity",
@@ -87,7 +86,7 @@
shape = x.shape
assert dim is not None # mypy does not seem to be able to see through check?
for d in dim:
- check(
+ torch._check(
shape[d] != 0,
lambda: f"linalg.vector_norm cannot compute the {ord} norm on the "
f"dimension {d} because this dimension is empty and the "
@@ -147,8 +146,10 @@
dim = utils.canonicalize_dims(A.ndim, dim)
if isinstance(dim, Dim):
dim = (dim,) # type: ignore[assignment]
- check(len(dim) == 2, lambda: "linalg.matrix_norm: dim must be a 2-tuple. Got {dim}")
- check(
+ torch._check(
+ len(dim) == 2, lambda: "linalg.matrix_norm: dim must be a 2-tuple. Got {dim}"
+ )
+ torch._check(
dim[0] != dim[1],
lambda: "linalg.matrix_norm: dims must be different. Got ({dim[0]}, {dim[1]})",
)
@@ -157,7 +158,7 @@
if isinstance(ord, str):
# ord
- check(
+ torch._check(
ord in ("fro", "nuc"),
lambda: "linalg.matrix_norm: Order {ord} not supported.",
)
@@ -180,7 +181,7 @@
else:
# ord
abs_ord = abs(ord)
- check(
+ torch._check(
abs_ord in (2, 1, float("inf")),
lambda: "linalg.matrix_norm: Order {ord} not supported.",
)
@@ -224,12 +225,12 @@
if dim is not None:
if isinstance(dim, Dim):
dim = (dim,) # type: ignore[assignment]
- check(
+ torch._check(
len(dim) in (1, 2),
lambda: "linalg.norm: If dim is specified, it must be of length 1 or 2. Got {dim}",
)
elif ord is not None:
- check(
+ torch._check(
A.ndim in (1, 2),
lambda: "linalg.norm: If dim is not specified but ord is, the input must be 1D or 2D. Got {A.ndim}D",
)
diff --git a/torch/_refs/nn/functional/__init__.py b/torch/_refs/nn/functional/__init__.py
index be82d0a..eaa6618 100644
--- a/torch/_refs/nn/functional/__init__.py
+++ b/torch/_refs/nn/functional/__init__.py
@@ -8,7 +8,6 @@
import torch._refs as refs
from torch._decomp import register_decomposition
from torch._prims_common import (
- check,
ELEMENTWISE_TYPE_PROMOTION_KIND,
NumberType,
ShapeType,
@@ -98,7 +97,7 @@
if not training:
return self
- utils.check(
+ torch._check(
p <= 1 and p >= 0,
lambda: f"dropout probability has to be between 0 and 1, but got, {p}",
)
@@ -134,7 +133,7 @@
@wraps(fn)
def _fn(a, *args, inplace=False, **kwargs):
if inplace:
- check(
+ torch._check(
"out" not in kwargs,
lambda: "Cannot set inplace=True and pass out= at the same time",
)
@@ -193,7 +192,7 @@
if not training:
return a
- utils.check(
+ torch._check(
p <= 1 and p >= 0,
lambda: f"dropout probability has to be between 0 and 1, but got, {p}",
)
@@ -232,15 +231,15 @@
# nb. This should be factored out into a can_cast aux function
python_type = utils.dtype_to_type(a.dtype)
- check(
+ torch._check(
utils.is_weakly_lesser_type(type(input_scale), python_type),
lambda: f"input_scale argument of type {type(input_scale)} cannot be safely cast to type {python_type}!",
)
- check(
+ torch._check(
utils.is_weakly_lesser_type(type(scale), python_type),
lambda: f"scale argument of type {type(scale)} cannot be safely cast to type {python_type}!",
)
- check(
+ torch._check(
utils.is_weakly_lesser_type(type(alpha), python_type),
lambda: f"alpha argument of type {type(alpha)} cannot be safely cast to type {python_type}!",
)
@@ -276,14 +275,14 @@
"""
Reference implementation of :func:`torch.nn.functional.group_norm`.
"""
- utils.check(
+ torch._check(
input.ndim >= 2,
lambda: f"Expected at least 2 dimensions for input tensor but received {input.ndim}",
)
batch_size = input.shape[0]
num_channels = input.shape[1]
- utils.check(
+ torch._check(
num_channels % num_groups == 0,
lambda: "Expected number of channels in input to be divisible by num_groups, "
+ f"but got input of shape {input.shape} and num_groups = {num_groups}",
@@ -394,7 +393,7 @@
# deprecated. For PrimTorch, it's fine to drop support for deprecated
# behavior because it requires explicit opt in. This error is to inform
# users how to update their calls.
- check(dim is not None, lambda: "implicit dim not supported, use dim=X")
+ torch._check(dim is not None, lambda: "implicit dim not supported, use dim=X")
return torch.softmax(a=a, dim=dim, dtype=dtype) # type: ignore[call-overload]
@@ -409,7 +408,7 @@
# deprecated. For PrimTorch, it's fine to drop support for deprecated
# behavior because it requires explicit opt in. This error is to inform
# users how to update their calls.
- check(dim is not None, lambda: "implicit dim not supported, use dim=X")
+ torch._check(dim is not None, lambda: "implicit dim not supported, use dim=X")
return torch.softmax(a=-a, dim=dim, dtype=dtype) # type: ignore[call-overload]
@@ -469,7 +468,7 @@
# softshrink(x) = x - lambd if x > lambd
# = x + lambd if x < -lambd
# = 0 otherwise
- check(
+ torch._check(
lambd >= 0,
lambda: f"lambda must be greater or equal to 0, but found to be {lambd}",
)
@@ -596,7 +595,7 @@
# deprecated. For PrimTorch, it's fine to drop support for deprecated
# behavior because it requires explicit opt in. This error is to inform
# users how to update their calls.
- check(dim is not None, lambda: "implicit dim not supported, use dim=X")
+ torch._check(dim is not None, lambda: "implicit dim not supported, use dim=X")
return torch.log_softmax(a=a, dim=dim, dtype=dtype) # type: ignore[call-overload]
@@ -668,12 +667,12 @@
reduction: str,
ignore_index: int,
) -> TensorLikeType:
- utils.check(
+ torch._check(
input.ndim > 0 and input.ndim <= 3,
lambda: f"Expected input dimension to be either [1, 2, 3] but received {input.ndim}.",
)
- utils.check(
+ torch._check(
(input.ndim == 1) or (input.shape[0] == target.shape[0]),
lambda: f"Expected input batch size {input.shape[0]} to match target batch size {target.shape[0]}.",
)
@@ -693,7 +692,7 @@
(flat_target >= 0), (flat_target < num_classes)
)
class_check = torch.all(torch.logical_or(ignore_classes_mask, valid_classes_mask))
- utils.check(
+ torch._check(
isinstance(target, FakeTensor) or bool(class_check.item()),
lambda: "A target class is out-of-bounds and not the ignore index.",
)
@@ -758,7 +757,7 @@
"""
Reference implementation of torch.nn.functional.nll_loss
"""
- utils.check(
+ torch._check(
input.ndim > 0,
lambda: f"Expected input tensor to have 1 or more dimensions (got {input.ndim})",
)
@@ -796,9 +795,13 @@
# For ndim > 3, we reshape the input and target to 3-D case.
# Input (N batch-size, C classes, k-dimensions)
# Target (N batch-size, k-dimensions)
- utils.check(
+ torch._check(
input.ndim > 0 and target.ndim > 0 and target.shape[1:] == input.shape[2:],
- lambda: f"Expected target shape {out_size} but got {target.shape}",
+ lambda: (
+ "Expected input and target to both have ndim > 0 and "
+ "target.shape[1:] == input.shape[2:], but got "
+ f"target.shape {target.shape} and input.shape {input.shape}"
+ ),
)
batch_size = input.shape[0]
@@ -837,7 +840,7 @@
if type(reduction) is int:
reduction = _reduction_int_to_str(reduction)
_check_reduction_value(reduction) # type: ignore[arg-type]
- check(
+ torch._check(
delta > 0,
lambda: "huber_loss does not support non-positive values for delta.",
)
@@ -938,7 +941,7 @@
a_dim = anchor.ndim
p_dim = positive.ndim
n_dim = negative.ndim
- check(
+ torch._check(
a_dim == p_dim and p_dim == n_dim,
lambda: (
f"The anchor, positive, and negative tensors are expected to have "
@@ -1075,25 +1078,25 @@
"""
Reference implementation of torch.nn.functional.prelu
"""
- check(
+ torch._check(
isinstance(a, TensorLike),
lambda: f"prelu: Expected `a` to be tensor, but got: {type(a)}",
)
- check(
+ torch._check(
isinstance(weight, TensorLike),
lambda: f"prelu: Expected `weight` to be tensor, but got: {type(weight)}",
)
if weight.numel() != 1:
- check(a.ndim > 0, lambda: "Not allow zero-dim input tensor.")
+ torch._check(a.ndim > 0, lambda: "Not allow zero-dim input tensor.")
channel_size = a.shape[1] if a.ndim >= 2 else 1
- check(
+ torch._check(
weight.numel() == channel_size,
lambda: f"Mismatch of parameter numbers and input channel size. Found parameter numbers ="
f" {weight.numel()} and channel size = {channel_size}.",
)
- check(
+ torch._check(
weight.ndim == 0 or weight.ndim == 1,
lambda: f"prelu: Expected `weight` to be a scalar or 1D tensor, but got: "
f"ndim = {weight.ndim}",
@@ -1132,7 +1135,7 @@
)
def glu(a: TensorLikeType, dim: int = -1) -> TensorLikeType:
dim = utils.canonicalize_dims(a.ndim, dim)
- check(
+ torch._check(
a.shape[dim] % 2 == 0,
lambda: f"Halving dimension must be even, but dimension {dim} is size {a.shape[dim]}",
)
@@ -1160,8 +1163,8 @@
type_promotion_kind=utils.ELEMENTWISE_TYPE_PROMOTION_KIND.DEFAULT,
)
def pdist(a: TensorLikeType, p: float = 2) -> TensorLikeType:
- check(a.ndim == 2, lambda: f"pdist only supports 2D tensors, got: {a.ndim}D")
- check(p >= 0, lambda: "pdist only supports non-negative p values")
+ torch._check(a.ndim == 2, lambda: f"pdist only supports 2D tensors, got: {a.ndim}D")
+ torch._check(p >= 0, lambda: "pdist only supports non-negative p values")
# For p == 2 we can use an efficient implementation, but other values of p
# require creating a much bigger tensor for an intermediate step
if p == 2:
diff --git a/torch/_refs/special/__init__.py b/torch/_refs/special/__init__.py
index 4369265..048de83 100644
--- a/torch/_refs/special/__init__.py
+++ b/torch/_refs/special/__init__.py
@@ -148,7 +148,7 @@
type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT,
)
def xlog1py(a: Union[TensorLikeType, NumberType], b: Union[TensorLikeType, NumberType]):
- utils.check(
+ torch._check(
isinstance(a, TensorLike) or isinstance(b, TensorLike),
lambda: 'Expected either argument a or b to be a Tensor"',
)
diff --git a/torch/_subclasses/fake_tensor.py b/torch/_subclasses/fake_tensor.py
index 3ba091a..fe1dd93 100644
--- a/torch/_subclasses/fake_tensor.py
+++ b/torch/_subclasses/fake_tensor.py
@@ -15,7 +15,6 @@
from torch._guards import Source
from torch._ops import OpOverload
from torch._prims_common import (
- check,
elementwise_dtypes,
ELEMENTWISE_TYPE_PROMOTION_KIND,
is_boolean_dtype,
@@ -1495,7 +1494,7 @@
) = FakeTensor._find_common_device(func, args, kwargs)
if isinstance(e, FakeTensor):
- check(
+ torch._check(
e.device == common_device,
lambda: f"FakeTensor is wrapped to wrong device, found {e.device}, expected {common_device}",
)