[pt2] add meta for `linalg_lu_factor_ex` (#101375)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/101375
Approved by: https://github.com/lezcano
diff --git a/test/test_proxy_tensor.py b/test/test_proxy_tensor.py
index 4e1e010..f7f3c9b 100644
--- a/test/test_proxy_tensor.py
+++ b/test/test_proxy_tensor.py
@@ -1443,8 +1443,6 @@
xfail('isin', ''), # aten.isin.Tensor_Tensor - couldn't find symbolic meta function/decomposition
xfail('kron', ''), # aten.size.default - couldn't find symbolic meta function/decomposition
xfail('kthvalue', ''), # aten.kthvalue.default - couldn't find symbolic meta function/decomposition
- xfail('linalg.lu_factor', ''), # aten.linalg_lu_factor_ex.default - couldn't find symbolic meta function/decomposition
- xfail('linalg.lu_factor_ex', ''), # aten.linalg_lu_factor_ex.default - couldn't find symbolic meta function/decomposition
xfail('linalg.lu_solve', ''), # aten.linalg_lu_solve.default - couldn't find symbolic meta function/decomposition
xfail('linalg.matrix_power'), # RuntimeError: Trying to call aten.size on a tensor with symbolic shape
xfail('linalg.multi_dot', ''), # aten.size.default - couldn't find symbolic meta function/decomposition
@@ -1459,7 +1457,6 @@
xfail('linalg.vander', ''), # aten.size.default - couldn't find symbolic meta function/decomposition
xfail('logaddexp2', ''), # aten.logaddexp2.default - couldn't find symbolic meta function/decomposition
xfail('logdet', ''), # aten.size.default - couldn't find symbolic meta function/decomposition
- xfail('lu', ''), # aten.linalg_lu_factor_ex.default - couldn't find symbolic meta function/decomposition
xfail('lu_solve', ''), # aten.linalg_lu_solve.default - couldn't find symbolic meta function/decomposition
xfail('lu_unpack', ''), # aten.lu_unpack.default - couldn't find symbolic meta function/decomposition
xfail('masked_select', ''), # aten.masked_select.default - couldn't find symbolic meta function/decomposition
diff --git a/torch/_meta_registrations.py b/torch/_meta_registrations.py
index 38c6f01..8fe7551 100644
--- a/torch/_meta_registrations.py
+++ b/torch/_meta_registrations.py
@@ -640,6 +640,39 @@
return P, L, U
+@register_meta([aten.linalg_lu_factor_ex.default, aten.linalg_lu_factor_ex.out])
+@out_wrapper("LU", "pivots", "info")
+def linalg_lu_factor_ex_meta(
+ A: Tensor, *, pivot: bool = True, check_errors: bool = False
+) -> Tuple[Tensor, Tensor, Tensor]:
+ check(
+ A.ndim >= 2,
+ lambda: f"torch.lu_factor: Expected tensor with 2 or more dimensions. Got size: {A.shape} instead",
+ )
+
+ sizes = list(A.shape)
+ m = sizes[-2]
+ n = sizes[-1]
+
+ LU = torch.empty_strided(
+ size=sizes,
+ stride=make_contiguous_strides_for(sizes, row_major=False),
+ dtype=A.dtype,
+ device=A.device,
+ )
+
+ # Sets sizes to the size of pivots
+ sizes.pop()
+ sizes[-1] = min(m, n)
+ pivots = A.new_empty(sizes, dtype=torch.int)
+
+ # Sets sizes to the size of info
+ sizes.pop()
+ info = A.new_empty(sizes, dtype=torch.int)
+
+ return LU, pivots, info
+
+
# parse the "mode" param in linalg_qr: return a tuple of bools (compute_q, reduced)
def _parse_qr_mode(mode: str) -> Tuple[bool, bool]:
if mode == "reduced":