add `OpInfo` for `torch.linalg.tensorinv` (#62326)

Summary:
Fixes https://github.com/pytorch/pytorch/issues/53739.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/62326

Reviewed By: H-Huang

Differential Revision: D30136376

Pulled By: zou3519

fbshipit-source-id: 04ec9450e8866667649af401c7559b96ddc91491
diff --git a/aten/src/ATen/native/LinearAlgebra.cpp b/aten/src/ATen/native/LinearAlgebra.cpp
index bbb6fce..10576a0 100644
--- a/aten/src/ATen/native/LinearAlgebra.cpp
+++ b/aten/src/ATen/native/LinearAlgebra.cpp
@@ -2651,12 +2651,9 @@
   shape_ind_end.insert(shape_ind_end.cend(), shape_start_ind.cbegin(), shape_start_ind.cend());
 
   // If the reshaped self is not invertible catch this error
-  Tensor result;
-  try {
-    result = at::inverse(self.reshape({prod_ind_end, prod_ind_end}));
-  } catch (...) {
-    TORCH_CHECK(false, "Failed to invert the input tensor, because it is singular.");
-  }
+  Tensor result, info;
+  std::tie(result, info) = at::linalg_inv_ex(self.reshape({prod_ind_end, prod_ind_end}), /*check_errors=*/false);
+  TORCH_CHECK(info.item<int64_t>() == 0, "Failed to invert the input tensor, because it is singular.");
 
   return result.reshape(shape_ind_end);
 }
diff --git a/torch/testing/_internal/common_methods_invocations.py b/torch/testing/_internal/common_methods_invocations.py
index b281c5e..f06d3ce 100644
--- a/torch/testing/_internal/common_methods_invocations.py
+++ b/torch/testing/_internal/common_methods_invocations.py
@@ -4986,6 +4986,22 @@
         SampleInput(make_input(low=1), kwargs=dict(threshold=1)),
     ]
 
+def sample_inputs_tensorinv(op_info, device, dtype, requires_grad, **kwargs):
+    def make_input():
+        input = make_fullrank_matrices_with_distinct_singular_values(12, 12, device=device, dtype=dtype)
+        return input.requires_grad_(requires_grad)
+
+    # lhs / rhs shape can have any number of dimensions as long as their product equals 12
+    shapes = [
+        ((2, 2, 3), (12, 1)),
+        ((4, 3), (6, 1, 2)),
+    ]
+
+    return [
+        SampleInput(make_input().reshape(*shape_lhs, *shape_rhs), kwargs=dict(ind=len(shape_lhs)))
+        for shape_lhs, shape_rhs in shapes
+    ]
+
 def sample_inputs_mse_loss(op_info, device, dtype, requires_grad, **kwargs):
     _make_tensor = partial(make_tensor, device=device, dtype=dtype, requires_grad=requires_grad)
 
@@ -8674,6 +8690,19 @@
         ),
     ),
     OpInfo(
+        "linalg.tensorinv",
+        ref=np.linalg.tensorinv,
+        dtypes=floating_and_complex_types(),
+        skips=(
+            # RuntimeError: aliasOp != torch::jit::getOperatorAliasMap().end()
+            # INTERNAL ASSERT FAILED at "../torch/csrc/jit/passes/utils/check_alias_annotation.cpp":159,
+            # please report a bug to PyTorch.
+            SkipInfo('TestJit', 'test_variant_consistency_jit', dtypes=(torch.float32,)),
+        ),
+        sample_inputs_func=sample_inputs_tensorinv,
+        supports_forward_ad=True,
+    ),
+    OpInfo(
         "nn.functional.mse_loss",
         ref=reference_mse_loss,
         sample_inputs_func=sample_inputs_mse_loss,