add `OpInfo` for `torch.linalg.tensorinv` (#62326)
Summary:
Fixes https://github.com/pytorch/pytorch/issues/53739.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/62326
Reviewed By: H-Huang
Differential Revision: D30136376
Pulled By: zou3519
fbshipit-source-id: 04ec9450e8866667649af401c7559b96ddc91491
diff --git a/aten/src/ATen/native/LinearAlgebra.cpp b/aten/src/ATen/native/LinearAlgebra.cpp
index bbb6fce..10576a0 100644
--- a/aten/src/ATen/native/LinearAlgebra.cpp
+++ b/aten/src/ATen/native/LinearAlgebra.cpp
@@ -2651,12 +2651,9 @@
shape_ind_end.insert(shape_ind_end.cend(), shape_start_ind.cbegin(), shape_start_ind.cend());
// If the reshaped self is not invertible catch this error
- Tensor result;
- try {
- result = at::inverse(self.reshape({prod_ind_end, prod_ind_end}));
- } catch (...) {
- TORCH_CHECK(false, "Failed to invert the input tensor, because it is singular.");
- }
+ Tensor result, info;
+ std::tie(result, info) = at::linalg_inv_ex(self.reshape({prod_ind_end, prod_ind_end}), /*check_errors=*/false);
+ TORCH_CHECK(info.item<int64_t>() == 0, "Failed to invert the input tensor, because it is singular.");
return result.reshape(shape_ind_end);
}
diff --git a/torch/testing/_internal/common_methods_invocations.py b/torch/testing/_internal/common_methods_invocations.py
index b281c5e..f06d3ce 100644
--- a/torch/testing/_internal/common_methods_invocations.py
+++ b/torch/testing/_internal/common_methods_invocations.py
@@ -4986,6 +4986,22 @@
SampleInput(make_input(low=1), kwargs=dict(threshold=1)),
]
+def sample_inputs_tensorinv(op_info, device, dtype, requires_grad, **kwargs):
+ def make_input():
+ input = make_fullrank_matrices_with_distinct_singular_values(12, 12, device=device, dtype=dtype)
+ return input.requires_grad_(requires_grad)
+
+ # lhs / rhs shape can have any number of dimensions as long as their product equals 12
+ shapes = [
+ ((2, 2, 3), (12, 1)),
+ ((4, 3), (6, 1, 2)),
+ ]
+
+ return [
+ SampleInput(make_input().reshape(*shape_lhs, *shape_rhs), kwargs=dict(ind=len(shape_lhs)))
+ for shape_lhs, shape_rhs in shapes
+ ]
+
def sample_inputs_mse_loss(op_info, device, dtype, requires_grad, **kwargs):
_make_tensor = partial(make_tensor, device=device, dtype=dtype, requires_grad=requires_grad)
@@ -8674,6 +8690,19 @@
),
),
OpInfo(
+ "linalg.tensorinv",
+ ref=np.linalg.tensorinv,
+ dtypes=floating_and_complex_types(),
+ skips=(
+ # RuntimeError: aliasOp != torch::jit::getOperatorAliasMap().end()
+ # INTERNAL ASSERT FAILED at "../torch/csrc/jit/passes/utils/check_alias_annotation.cpp":159,
+ # please report a bug to PyTorch.
+ SkipInfo('TestJit', 'test_variant_consistency_jit', dtypes=(torch.float32,)),
+ ),
+ sample_inputs_func=sample_inputs_tensorinv,
+ supports_forward_ad=True,
+ ),
+ OpInfo(
"nn.functional.mse_loss",
ref=reference_mse_loss,
sample_inputs_func=sample_inputs_mse_loss,