Strenghten preconditions of linalg.cross (#83798)
This makes `linalg.cross` array API complaint (https://github.com/data-apis/array-api/issues/415) and fixes a few bugs.
Fixes https://github.com/pytorch/pytorch/issues/77629
Fixes https://github.com/pytorch/pytorch/issues/83756
Pull Request resolved: https://github.com/pytorch/pytorch/pull/83798
Approved by: https://github.com/mruberry
diff --git a/aten/src/ATen/TensorMeta.h b/aten/src/ATen/TensorMeta.h
index 9712461..07631c3 100644
--- a/aten/src/ATen/TensorMeta.h
+++ b/aten/src/ATen/TensorMeta.h
@@ -71,6 +71,7 @@
struct TORCH_API MetaBase {
virtual const Tensor& maybe_get_output(int64_t output_idx) = 0;
+ // Note: [set_output_*]
// See: https://github.com/pytorch/pytorch/issues/69813
// Whenever defining the output properties in the META function of a
// structured kernel (what was usually done with `set_output`), use one of
diff --git a/aten/src/ATen/native/Cross.cpp b/aten/src/ATen/native/Cross.cpp
index 4b3e43d..9b268c6 100644
--- a/aten/src/ATen/native/Cross.cpp
+++ b/aten/src/ATen/native/Cross.cpp
@@ -9,17 +9,20 @@
namespace at {
namespace meta {
-TORCH_PRECOMPUTE_META_FUNC(linalg_cross)
-(const Tensor & input, const Tensor & other, const int64_t dimension) {
- auto out_size = infer_size(input.sizes(), other.sizes());
- Tensor input_broadcasted = input.expand(out_size);
- Tensor other_broadcasted = other.expand(out_size);
+TORCH_META_FUNC(linalg_cross)
+(const Tensor & input, const Tensor & other, int64_t dim) {
+ auto x_d = input.dim();
+ auto y_d = other.dim();
+ // This is to avoid things like
+ // linalg.cross(torch.randn(2, 3), torch.randn(5, 2, 3), dim=2)
+ TORCH_CHECK(x_d == y_d, "linalg.cross: inputs must have the same number of dimensions.");
+ TORCH_CHECK(input.size(dim) == 3 && other.size(dim) == 3, "linalg.cross: inputs dimension ", dim, " must have length 3. Got ", input.size(dim), " and ", other.size(dim));
- int64_t dim = maybe_wrap_dim(dimension, input.dim()); // default dim = -1
- TORCH_CHECK(input_broadcasted.size(dim) == 3, "dimension ", dimension, " does not have size 3");
+ // Broadcast the batch dimension of input and other.
+ // Since the non-batch dimensions agree, this is the same as broadcast all the inputs
+ auto out_size = infer_size(input.sizes(), other.sizes());
set_output_raw_strided(0, out_size, {}, input.options());
- return TORCH_PRECOMPUTE_STRUCT(linalg_cross)().set_dim(dim);
}
}
@@ -56,8 +59,9 @@
TORCH_IMPL_FUNC(linalg_cross_out)
-(const Tensor & input, const Tensor & other, const int64_t dim, const Tensor & out) {
- auto out_size = infer_size(input.sizes(), other.sizes());
+(const Tensor & input, const Tensor & other, int64_t dim, const Tensor & out) {
+ dim = maybe_wrap_dim(dim, input.dim());
+ auto out_size = out.sizes();
Tensor input_broadcasted = input.expand(out_size);
Tensor other_broadcasted = other.expand(out_size);
diff --git a/aten/src/ATen/native/native_functions.yaml b/aten/src/ATen/native/native_functions.yaml
index 78b7071..daab8f0 100644
--- a/aten/src/ATen/native/native_functions.yaml
+++ b/aten/src/ATen/native/native_functions.yaml
@@ -12155,8 +12155,6 @@
- func: linalg_cross.out(Tensor self, Tensor other, *, int dim=-1, Tensor(a!) out) -> Tensor(a!)
python_module: linalg
structured: True
- precomputed:
- - dim -> int dim
dispatch:
CPU, CUDA: linalg_cross_out
diff --git a/functorch/test/test_ops.py b/functorch/test/test_ops.py
index da38270..2bb19ba 100644
--- a/functorch/test/test_ops.py
+++ b/functorch/test/test_ops.py
@@ -958,6 +958,7 @@
xfail('svd_lowrank', ''),
xfail('pca_lowrank', ''),
xfail('clamp'),
+ xfail('cross'), # The defaults of this op are *very* weird. No wonder it doesn't work
# something weird happening with channels_last
xfail('bfloat16'),
xfail('double'),
diff --git a/functorch/test/test_vmap.py b/functorch/test/test_vmap.py
index 7f00909..05dab7b 100644
--- a/functorch/test/test_vmap.py
+++ b/functorch/test/test_vmap.py
@@ -3222,6 +3222,7 @@
xfail('eye', ''), # non-tensor input
xfail('broadcast_shapes', ''), # test runner can't handle non-Tensor ops
xfail('sparse.sampled_addmm'), # sparse
+ xfail('cross'), # The default value of dim in op is *very* weird. No wonder it doesn't work
xfail('svd', device_type='cuda'), # not unique, see test_linalg_svd for manual test
xfail('linalg.svd', device_type='cuda'), # not unique, see test_linalg_svd for manual test
skip('linalg.eigh', ''), # not unique, see test_linalg_eigh for manual test
diff --git a/test/test_linalg.py b/test/test_linalg.py
index 8df810d..655150c 100644
--- a/test/test_linalg.py
+++ b/test/test_linalg.py
@@ -4504,48 +4504,6 @@
torch.linalg.cross(x, y, dim=1, out=res2)
self.assertEqual(res1, res2)
- # non contiguous case 1
- x = torch.rand((4, 4, 4, 3), dtype=dtype,
- device=device).contiguous(memory_format=torch.channels_last) # non-contiguous
- y = torch.rand((4, 4, 4, 3), dtype=dtype,
- device=device).contiguous(memory_format=torch.channels_last) # non-contiguous
- np_expected_ref = np.cross(x.cpu().numpy(), y.cpu().numpy(), axis=-1)
- res = torch.linalg.cross(x, y, dim=-1)
- # numpy reference compared to torch result
- self.assertEqual(res.cpu().numpy(), np_expected_ref)
-
- # non contiguous case 2
- x = torch.rand(1, 3, 2, dtype=dtype, device=device) # contiguous
- y = torch.rand(1, 3, 4, dtype=dtype, device=device).permute(2, 1, 0) # non-contiguous
- np_expected_ref = np.cross(x.cpu().numpy(), y.cpu().numpy(), axis=1)
- res = torch.linalg.cross(x, y, dim=1)
- # numpy reference compared to torch result
- self.assertEqual(res.cpu().numpy(), np_expected_ref)
-
- # non contiguous case 3
- x = torch.rand(2, 3, 1, dtype=dtype, device=device).permute(2, 1, 0) # non-contiguous
- y = torch.rand(1, 3, 4, dtype=dtype, device=device).permute(2, 1, 0) # non-contiguous
- np_expected_ref = np.cross(x.cpu().numpy(), y.cpu().numpy(), axis=1)
- res = torch.linalg.cross(x, y, dim=1)
- # numpy reference compared to torch result
- self.assertEqual(res.cpu().numpy(), np_expected_ref)
-
- # non contiguous case 4
- x = torch.randn(12, 3, device=device, dtype=dtype)[::2, :] # non-contiguous
- y = torch.randn(18, 3, device=device, dtype=dtype)[::3, :] # non-contiguous
- np_expected_ref = np.cross(x.cpu().numpy(), y.cpu().numpy(), axis=1)
- res = torch.linalg.cross(x, y, dim=1)
- # numpy reference compared to torch result
- self.assertEqual(res.cpu().numpy(), np_expected_ref)
-
- # non contiguous case 5
- x = torch.randn(1, device=device, dtype=dtype) # contiguous
- y = torch.randn(6, device=device, dtype=dtype)[::2] # non-contiguous
- np_expected_ref = np.cross(x.expand(3).cpu().numpy(), y.cpu().numpy())
- res = torch.linalg.cross(x, y)
- # numpy reference compared to torch result
- self.assertEqual(res.cpu().numpy(), np_expected_ref)
-
@dtypes(torch.float32, torch.complex64)
def test_cross_with_and_without_dim(self, device, dtype):
x = torch.rand(100, 3, dtype=dtype, device=device)
@@ -4566,46 +4524,6 @@
self.assertEqual(res1, res2)
self.assertEqual(res1, res3)
- def test_cross_errors(self, device):
- self.assertRaisesRegex(
- RuntimeError, "must match the size of tensor",
- lambda: torch.cross(torch.rand(100, 3, device=device), torch.rand(100, 3, 10, device=device)))
- self.assertRaisesRegex(
- RuntimeError, "must match the size of tensor",
- lambda: torch.cross(torch.rand(5, 3, device=device), torch.rand(3, 5, device=device)))
- self.assertRaisesRegex(
- RuntimeError, "no dimension of size 3 in input",
- lambda: torch.cross(torch.rand(5, 4, device=device), torch.rand(5, 4, device=device)))
- self.assertRaisesRegex(
- RuntimeError, "dimension 0 does not have size 3",
- lambda: torch.cross(torch.rand(5, 4, 3, device=device), torch.rand(5, 4, 3, device=device), dim=0))
- self.assertRaisesRegex(
- RuntimeError, "dimension -1 does not have size 3",
- lambda: torch.cross(torch.rand(5, 3, 4, device=device), torch.rand(5, 3, 4, device=device), dim=-1))
- self.assertRaisesRegex(
- IndexError, "Dimension out of range",
- lambda: torch.cross(torch.rand(5, 3, 4, device=device), torch.rand(5, 3, 4, device=device), dim=-5))
-
- def test_linalg_cross_errors(self, device):
- self.assertRaisesRegex(
- RuntimeError, "dimension -1 does not have size 3",
- lambda: torch.linalg.cross(torch.rand(5, 3, 4, device=device), torch.rand(5, 3, 4, device=device)))
- self.assertRaisesRegex(
- RuntimeError, "must match the size of tensor",
- lambda: torch.linalg.cross(torch.rand(100, 3, device=device), torch.rand(100, 3, 10, device=device)))
- self.assertRaisesRegex(
- RuntimeError, "must match the size of tensor",
- lambda: torch.linalg.cross(torch.rand(5, 3, device=device), torch.rand(3, 5, device=device)))
- self.assertRaisesRegex(
- RuntimeError, "dimension 0 does not have size 3",
- lambda: torch.linalg.cross(torch.rand(5, 4, 3, device=device), torch.rand(5, 4, 3, device=device), dim=0))
- self.assertRaisesRegex(
- RuntimeError, "dimension -1 does not have size 3",
- lambda: torch.linalg.cross(torch.rand(5, 3, 4, device=device), torch.rand(5, 3, 4, device=device), dim=-1))
- self.assertRaisesRegex(
- IndexError, "Dimension out of range",
- lambda: torch.linalg.cross(torch.rand(5, 3, 4, device=device), torch.rand(5, 3, 4, device=device), dim=-5))
-
def test_renorm(self, device):
m1 = torch.randn(20, 20, device=device) # big enough to exercise vectorized path
res1 = torch.tensor((), device=device)
diff --git a/torch/linalg/__init__.py b/torch/linalg/__init__.py
index bf82cc2..b070628 100644
--- a/torch/linalg/__init__.py
+++ b/torch/linalg/__init__.py
@@ -28,8 +28,7 @@
Supports input of float, double, cfloat and cdouble dtypes. Also supports batches
of vectors, for which it computes the product along the dimension :attr:`dim`.
-In this case, the output has the same batch dimensions as the inputs broadcast to
-a common shape.
+It broadcasts over the batch dimensions.
Args:
input (Tensor): the first input tensor.
@@ -39,9 +38,6 @@
Keyword args:
out (Tensor, optional): the output tensor. Ignored if `None`. Default: `None`.
-Raises:
- RuntimeError: If after broadcasting :attr:`input`\ `.size(\ `:attr:`dim`\ `) != 3`
- or :attr:`other`\ `.size(\ `:attr:`dim`\ `) != 3`.
Example:
>>> a = torch.randn(4, 3)
>>> a
diff --git a/torch/testing/_internal/opinfo/definitions/linalg.py b/torch/testing/_internal/opinfo/definitions/linalg.py
index b5f5f0d..fe58719 100644
--- a/torch/testing/_internal/opinfo/definitions/linalg.py
+++ b/torch/testing/_internal/opinfo/definitions/linalg.py
@@ -118,32 +118,36 @@
def sample_inputs_cross(op_info, device, dtype, requires_grad, **kwargs):
- yield SampleInput(
- make_tensor((S, 3), device=device, dtype=dtype, requires_grad=requires_grad),
- args=(
- make_tensor(
- (S, 3), device=device, dtype=dtype, requires_grad=requires_grad
- ),
- ),
+ make_arg = partial(
+ make_tensor, dtype=dtype, device=device, requires_grad=requires_grad
)
+ yield SampleInput(make_arg((S, 3)), args=(make_arg((S, 3)),))
yield SampleInput(
- make_tensor((S, 3, S), device=device, dtype=dtype, requires_grad=requires_grad),
- args=(
- make_tensor(
- (S, 3, S), device=device, dtype=dtype, requires_grad=requires_grad
- ),
- ),
- kwargs={"dim": 1},
+ make_arg((S, 3, S)), args=(make_arg((S, 3, S)),), kwargs=dict(dim=1)
)
- yield SampleInput(
- make_tensor((S, 3), device=device, dtype=dtype, requires_grad=requires_grad),
- args=(
- make_tensor(
- (S, 3), device=device, dtype=dtype, requires_grad=requires_grad
- ),
- ),
- kwargs={"dim": -1},
+ yield SampleInput(make_arg((1, 3)), args=(make_arg((S, 3)),), kwargs=dict(dim=-1))
+
+
+def error_inputs_cross(op_info, device, **kwargs):
+ make_arg = partial(make_tensor, device=device, dtype=torch.float32)
+
+ sample = SampleInput(input=make_arg((S, 3)), args=(make_arg((S, 1)),))
+ err = "inputs dimension -1 must have length 3"
+ yield ErrorInput(sample, error_regex=err, error_type=RuntimeError)
+
+ sample = SampleInput(input=make_arg((5, S, 3)), args=(make_arg((S, 3)),))
+ err = "inputs must have the same number of dimensions"
+ yield ErrorInput(sample, error_regex=err, error_type=RuntimeError)
+
+ sample = SampleInput(input=make_arg((S, 2)), args=(make_arg((S, 2)),))
+ err = "must have length 3"
+ yield ErrorInput(sample, error_regex=err, error_type=RuntimeError)
+
+ sample = SampleInput(
+ input=make_arg((S, 2)), args=(make_arg((S, 2)),), kwargs=dict(dim=2)
)
+ err = "Dimension out of range"
+ yield ErrorInput(sample, error_regex=err, error_type=IndexError)
def sample_inputs_householder_product(op_info, device, dtype, requires_grad, **kwargs):
@@ -1225,6 +1229,7 @@
dtypesIfCUDA=all_types_and_complex_and(torch.half),
aten_name="linalg_cross",
sample_inputs_func=sample_inputs_cross,
+ error_inputs_func=error_inputs_cross,
supports_out=True,
supports_fwgrad_bwgrad=True,
supports_forward_ad=True,