Strenghten preconditions of linalg.cross (#83798)

This makes `linalg.cross` array API complaint (https://github.com/data-apis/array-api/issues/415) and fixes a few bugs.

Fixes https://github.com/pytorch/pytorch/issues/77629
Fixes https://github.com/pytorch/pytorch/issues/83756
Pull Request resolved: https://github.com/pytorch/pytorch/pull/83798
Approved by: https://github.com/mruberry
diff --git a/aten/src/ATen/TensorMeta.h b/aten/src/ATen/TensorMeta.h
index 9712461..07631c3 100644
--- a/aten/src/ATen/TensorMeta.h
+++ b/aten/src/ATen/TensorMeta.h
@@ -71,6 +71,7 @@
 struct TORCH_API MetaBase {
   virtual const Tensor& maybe_get_output(int64_t output_idx) = 0;
 
+  // Note: [set_output_*]
   // See: https://github.com/pytorch/pytorch/issues/69813
   // Whenever defining the output properties in the META function of a
   // structured kernel (what was usually done with `set_output`), use one of
diff --git a/aten/src/ATen/native/Cross.cpp b/aten/src/ATen/native/Cross.cpp
index 4b3e43d..9b268c6 100644
--- a/aten/src/ATen/native/Cross.cpp
+++ b/aten/src/ATen/native/Cross.cpp
@@ -9,17 +9,20 @@
 namespace at {
 namespace meta {
 
-TORCH_PRECOMPUTE_META_FUNC(linalg_cross)
-(const Tensor & input, const Tensor & other, const int64_t dimension) {
-  auto out_size = infer_size(input.sizes(), other.sizes());
-  Tensor input_broadcasted = input.expand(out_size);
-  Tensor other_broadcasted = other.expand(out_size);
+TORCH_META_FUNC(linalg_cross)
+(const Tensor & input, const Tensor & other, int64_t dim) {
+  auto x_d = input.dim();
+  auto y_d = other.dim();
+  // This is to avoid things like
+  // linalg.cross(torch.randn(2, 3), torch.randn(5, 2, 3), dim=2)
+  TORCH_CHECK(x_d == y_d, "linalg.cross: inputs must have the same number of dimensions.");
+  TORCH_CHECK(input.size(dim) == 3 && other.size(dim) == 3, "linalg.cross: inputs dimension ", dim, " must have length 3. Got ", input.size(dim), " and ", other.size(dim));
 
-  int64_t dim = maybe_wrap_dim(dimension, input.dim()); // default dim = -1
-  TORCH_CHECK(input_broadcasted.size(dim) == 3, "dimension ", dimension, " does not have size 3");
+  // Broadcast the batch dimension of input and other.
+  // Since the non-batch dimensions agree, this is the same as broadcast all the inputs
+  auto out_size = infer_size(input.sizes(), other.sizes());
 
   set_output_raw_strided(0, out_size, {}, input.options());
-  return TORCH_PRECOMPUTE_STRUCT(linalg_cross)().set_dim(dim);
 }
 
 }
@@ -56,8 +59,9 @@
 
 
 TORCH_IMPL_FUNC(linalg_cross_out)
-(const Tensor & input, const Tensor & other, const int64_t dim, const Tensor & out) {
-  auto out_size = infer_size(input.sizes(), other.sizes());
+(const Tensor & input, const Tensor & other, int64_t dim, const Tensor & out) {
+  dim = maybe_wrap_dim(dim, input.dim());
+  auto out_size = out.sizes();
   Tensor input_broadcasted = input.expand(out_size);
   Tensor other_broadcasted = other.expand(out_size);
 
diff --git a/aten/src/ATen/native/native_functions.yaml b/aten/src/ATen/native/native_functions.yaml
index 78b7071..daab8f0 100644
--- a/aten/src/ATen/native/native_functions.yaml
+++ b/aten/src/ATen/native/native_functions.yaml
@@ -12155,8 +12155,6 @@
 - func: linalg_cross.out(Tensor self, Tensor other, *, int dim=-1, Tensor(a!) out) -> Tensor(a!)
   python_module: linalg
   structured: True
-  precomputed:
-  - dim -> int dim
   dispatch:
     CPU, CUDA: linalg_cross_out
 
diff --git a/functorch/test/test_ops.py b/functorch/test/test_ops.py
index da38270..2bb19ba 100644
--- a/functorch/test/test_ops.py
+++ b/functorch/test/test_ops.py
@@ -958,6 +958,7 @@
         xfail('svd_lowrank', ''),
         xfail('pca_lowrank', ''),
         xfail('clamp'),
+        xfail('cross'),  # The defaults of this op are *very* weird. No wonder it doesn't work
         # something weird happening with channels_last
         xfail('bfloat16'),
         xfail('double'),
diff --git a/functorch/test/test_vmap.py b/functorch/test/test_vmap.py
index 7f00909..05dab7b 100644
--- a/functorch/test/test_vmap.py
+++ b/functorch/test/test_vmap.py
@@ -3222,6 +3222,7 @@
         xfail('eye', ''),  # non-tensor input
         xfail('broadcast_shapes', ''),  # test runner can't handle non-Tensor ops
         xfail('sparse.sampled_addmm'),  # sparse
+        xfail('cross'),  # The default value of dim in op is *very* weird. No wonder it doesn't work
         xfail('svd', device_type='cuda'),  # not unique, see test_linalg_svd for manual test
         xfail('linalg.svd', device_type='cuda'),  # not unique, see test_linalg_svd for manual test
         skip('linalg.eigh', ''),  # not unique, see test_linalg_eigh for manual test
diff --git a/test/test_linalg.py b/test/test_linalg.py
index 8df810d..655150c 100644
--- a/test/test_linalg.py
+++ b/test/test_linalg.py
@@ -4504,48 +4504,6 @@
         torch.linalg.cross(x, y, dim=1, out=res2)
         self.assertEqual(res1, res2)
 
-        # non contiguous case 1
-        x = torch.rand((4, 4, 4, 3), dtype=dtype,
-                       device=device).contiguous(memory_format=torch.channels_last)  # non-contiguous
-        y = torch.rand((4, 4, 4, 3), dtype=dtype,
-                       device=device).contiguous(memory_format=torch.channels_last)  # non-contiguous
-        np_expected_ref = np.cross(x.cpu().numpy(), y.cpu().numpy(), axis=-1)
-        res = torch.linalg.cross(x, y, dim=-1)
-        # numpy reference compared to torch result
-        self.assertEqual(res.cpu().numpy(), np_expected_ref)
-
-        # non contiguous case 2
-        x = torch.rand(1, 3, 2, dtype=dtype, device=device)  # contiguous
-        y = torch.rand(1, 3, 4, dtype=dtype, device=device).permute(2, 1, 0)  # non-contiguous
-        np_expected_ref = np.cross(x.cpu().numpy(), y.cpu().numpy(), axis=1)
-        res = torch.linalg.cross(x, y, dim=1)
-        # numpy reference compared to torch result
-        self.assertEqual(res.cpu().numpy(), np_expected_ref)
-
-        # non contiguous case 3
-        x = torch.rand(2, 3, 1, dtype=dtype, device=device).permute(2, 1, 0)  # non-contiguous
-        y = torch.rand(1, 3, 4, dtype=dtype, device=device).permute(2, 1, 0)  # non-contiguous
-        np_expected_ref = np.cross(x.cpu().numpy(), y.cpu().numpy(), axis=1)
-        res = torch.linalg.cross(x, y, dim=1)
-        # numpy reference compared to torch result
-        self.assertEqual(res.cpu().numpy(), np_expected_ref)
-
-        # non contiguous case 4
-        x = torch.randn(12, 3, device=device, dtype=dtype)[::2, :]  # non-contiguous
-        y = torch.randn(18, 3, device=device, dtype=dtype)[::3, :]  # non-contiguous
-        np_expected_ref = np.cross(x.cpu().numpy(), y.cpu().numpy(), axis=1)
-        res = torch.linalg.cross(x, y, dim=1)
-        # numpy reference compared to torch result
-        self.assertEqual(res.cpu().numpy(), np_expected_ref)
-
-        # non contiguous case 5
-        x = torch.randn(1, device=device, dtype=dtype)  # contiguous
-        y = torch.randn(6, device=device, dtype=dtype)[::2]  # non-contiguous
-        np_expected_ref = np.cross(x.expand(3).cpu().numpy(), y.cpu().numpy())
-        res = torch.linalg.cross(x, y)
-        # numpy reference compared to torch result
-        self.assertEqual(res.cpu().numpy(), np_expected_ref)
-
     @dtypes(torch.float32, torch.complex64)
     def test_cross_with_and_without_dim(self, device, dtype):
         x = torch.rand(100, 3, dtype=dtype, device=device)
@@ -4566,46 +4524,6 @@
         self.assertEqual(res1, res2)
         self.assertEqual(res1, res3)
 
-    def test_cross_errors(self, device):
-        self.assertRaisesRegex(
-            RuntimeError, "must match the size of tensor",
-            lambda: torch.cross(torch.rand(100, 3, device=device), torch.rand(100, 3, 10, device=device)))
-        self.assertRaisesRegex(
-            RuntimeError, "must match the size of tensor",
-            lambda: torch.cross(torch.rand(5, 3, device=device), torch.rand(3, 5, device=device)))
-        self.assertRaisesRegex(
-            RuntimeError, "no dimension of size 3 in input",
-            lambda: torch.cross(torch.rand(5, 4, device=device), torch.rand(5, 4, device=device)))
-        self.assertRaisesRegex(
-            RuntimeError, "dimension 0 does not have size 3",
-            lambda: torch.cross(torch.rand(5, 4, 3, device=device), torch.rand(5, 4, 3, device=device), dim=0))
-        self.assertRaisesRegex(
-            RuntimeError, "dimension -1 does not have size 3",
-            lambda: torch.cross(torch.rand(5, 3, 4, device=device), torch.rand(5, 3, 4, device=device), dim=-1))
-        self.assertRaisesRegex(
-            IndexError, "Dimension out of range",
-            lambda: torch.cross(torch.rand(5, 3, 4, device=device), torch.rand(5, 3, 4, device=device), dim=-5))
-
-    def test_linalg_cross_errors(self, device):
-        self.assertRaisesRegex(
-            RuntimeError, "dimension -1 does not have size 3",
-            lambda: torch.linalg.cross(torch.rand(5, 3, 4, device=device), torch.rand(5, 3, 4, device=device)))
-        self.assertRaisesRegex(
-            RuntimeError, "must match the size of tensor",
-            lambda: torch.linalg.cross(torch.rand(100, 3, device=device), torch.rand(100, 3, 10, device=device)))
-        self.assertRaisesRegex(
-            RuntimeError, "must match the size of tensor",
-            lambda: torch.linalg.cross(torch.rand(5, 3, device=device), torch.rand(3, 5, device=device)))
-        self.assertRaisesRegex(
-            RuntimeError, "dimension 0 does not have size 3",
-            lambda: torch.linalg.cross(torch.rand(5, 4, 3, device=device), torch.rand(5, 4, 3, device=device), dim=0))
-        self.assertRaisesRegex(
-            RuntimeError, "dimension -1 does not have size 3",
-            lambda: torch.linalg.cross(torch.rand(5, 3, 4, device=device), torch.rand(5, 3, 4, device=device), dim=-1))
-        self.assertRaisesRegex(
-            IndexError, "Dimension out of range",
-            lambda: torch.linalg.cross(torch.rand(5, 3, 4, device=device), torch.rand(5, 3, 4, device=device), dim=-5))
-
     def test_renorm(self, device):
         m1 = torch.randn(20, 20, device=device)  # big enough to exercise vectorized path
         res1 = torch.tensor((), device=device)
diff --git a/torch/linalg/__init__.py b/torch/linalg/__init__.py
index bf82cc2..b070628 100644
--- a/torch/linalg/__init__.py
+++ b/torch/linalg/__init__.py
@@ -28,8 +28,7 @@
 
 Supports input of float, double, cfloat and cdouble dtypes. Also supports batches
 of vectors, for which it computes the product along the dimension :attr:`dim`.
-In this case, the output has the same batch dimensions as the inputs broadcast to
-a common shape.
+It broadcasts over the batch dimensions.
 
 Args:
     input (Tensor): the first input tensor.
@@ -39,9 +38,6 @@
 Keyword args:
     out (Tensor, optional): the output tensor. Ignored if `None`. Default: `None`.
 
-Raises:
-    RuntimeError: If after broadcasting :attr:`input`\ `.size(\ `:attr:`dim`\ `) != 3`
-                  or :attr:`other`\ `.size(\ `:attr:`dim`\ `) != 3`.
 Example:
     >>> a = torch.randn(4, 3)
     >>> a
diff --git a/torch/testing/_internal/opinfo/definitions/linalg.py b/torch/testing/_internal/opinfo/definitions/linalg.py
index b5f5f0d..fe58719 100644
--- a/torch/testing/_internal/opinfo/definitions/linalg.py
+++ b/torch/testing/_internal/opinfo/definitions/linalg.py
@@ -118,32 +118,36 @@
 
 
 def sample_inputs_cross(op_info, device, dtype, requires_grad, **kwargs):
-    yield SampleInput(
-        make_tensor((S, 3), device=device, dtype=dtype, requires_grad=requires_grad),
-        args=(
-            make_tensor(
-                (S, 3), device=device, dtype=dtype, requires_grad=requires_grad
-            ),
-        ),
+    make_arg = partial(
+        make_tensor, dtype=dtype, device=device, requires_grad=requires_grad
     )
+    yield SampleInput(make_arg((S, 3)), args=(make_arg((S, 3)),))
     yield SampleInput(
-        make_tensor((S, 3, S), device=device, dtype=dtype, requires_grad=requires_grad),
-        args=(
-            make_tensor(
-                (S, 3, S), device=device, dtype=dtype, requires_grad=requires_grad
-            ),
-        ),
-        kwargs={"dim": 1},
+        make_arg((S, 3, S)), args=(make_arg((S, 3, S)),), kwargs=dict(dim=1)
     )
-    yield SampleInput(
-        make_tensor((S, 3), device=device, dtype=dtype, requires_grad=requires_grad),
-        args=(
-            make_tensor(
-                (S, 3), device=device, dtype=dtype, requires_grad=requires_grad
-            ),
-        ),
-        kwargs={"dim": -1},
+    yield SampleInput(make_arg((1, 3)), args=(make_arg((S, 3)),), kwargs=dict(dim=-1))
+
+
+def error_inputs_cross(op_info, device, **kwargs):
+    make_arg = partial(make_tensor, device=device, dtype=torch.float32)
+
+    sample = SampleInput(input=make_arg((S, 3)), args=(make_arg((S, 1)),))
+    err = "inputs dimension -1 must have length 3"
+    yield ErrorInput(sample, error_regex=err, error_type=RuntimeError)
+
+    sample = SampleInput(input=make_arg((5, S, 3)), args=(make_arg((S, 3)),))
+    err = "inputs must have the same number of dimensions"
+    yield ErrorInput(sample, error_regex=err, error_type=RuntimeError)
+
+    sample = SampleInput(input=make_arg((S, 2)), args=(make_arg((S, 2)),))
+    err = "must have length 3"
+    yield ErrorInput(sample, error_regex=err, error_type=RuntimeError)
+
+    sample = SampleInput(
+        input=make_arg((S, 2)), args=(make_arg((S, 2)),), kwargs=dict(dim=2)
     )
+    err = "Dimension out of range"
+    yield ErrorInput(sample, error_regex=err, error_type=IndexError)
 
 
 def sample_inputs_householder_product(op_info, device, dtype, requires_grad, **kwargs):
@@ -1225,6 +1229,7 @@
         dtypesIfCUDA=all_types_and_complex_and(torch.half),
         aten_name="linalg_cross",
         sample_inputs_func=sample_inputs_cross,
+        error_inputs_func=error_inputs_cross,
         supports_out=True,
         supports_fwgrad_bwgrad=True,
         supports_forward_ad=True,