[pt2] add meta for `linalg_lu_factor_ex` (#101375)

Pull Request resolved: https://github.com/pytorch/pytorch/pull/101375
Approved by: https://github.com/lezcano
diff --git a/test/test_proxy_tensor.py b/test/test_proxy_tensor.py
index 4e1e010..f7f3c9b 100644
--- a/test/test_proxy_tensor.py
+++ b/test/test_proxy_tensor.py
@@ -1443,8 +1443,6 @@
     xfail('isin', ''),  # aten.isin.Tensor_Tensor - couldn't find symbolic meta function/decomposition
     xfail('kron', ''),  # aten.size.default - couldn't find symbolic meta function/decomposition
     xfail('kthvalue', ''),  # aten.kthvalue.default - couldn't find symbolic meta function/decomposition
-    xfail('linalg.lu_factor', ''),  # aten.linalg_lu_factor_ex.default - couldn't find symbolic meta function/decomposition
-    xfail('linalg.lu_factor_ex', ''),  # aten.linalg_lu_factor_ex.default - couldn't find symbolic meta function/decomposition
     xfail('linalg.lu_solve', ''),  # aten.linalg_lu_solve.default - couldn't find symbolic meta function/decomposition
     xfail('linalg.matrix_power'),  # RuntimeError: Trying to call aten.size on a tensor with symbolic shape
     xfail('linalg.multi_dot', ''),  # aten.size.default - couldn't find symbolic meta function/decomposition
@@ -1459,7 +1457,6 @@
     xfail('linalg.vander', ''),  # aten.size.default - couldn't find symbolic meta function/decomposition
     xfail('logaddexp2', ''),  # aten.logaddexp2.default - couldn't find symbolic meta function/decomposition
     xfail('logdet', ''),  # aten.size.default - couldn't find symbolic meta function/decomposition
-    xfail('lu', ''),  # aten.linalg_lu_factor_ex.default - couldn't find symbolic meta function/decomposition
     xfail('lu_solve', ''),  # aten.linalg_lu_solve.default - couldn't find symbolic meta function/decomposition
     xfail('lu_unpack', ''),  # aten.lu_unpack.default - couldn't find symbolic meta function/decomposition
     xfail('masked_select', ''),  # aten.masked_select.default - couldn't find symbolic meta function/decomposition
diff --git a/torch/_meta_registrations.py b/torch/_meta_registrations.py
index 38c6f01..8fe7551 100644
--- a/torch/_meta_registrations.py
+++ b/torch/_meta_registrations.py
@@ -640,6 +640,39 @@
     return P, L, U
 
 
+@register_meta([aten.linalg_lu_factor_ex.default, aten.linalg_lu_factor_ex.out])
+@out_wrapper("LU", "pivots", "info")
+def linalg_lu_factor_ex_meta(
+    A: Tensor, *, pivot: bool = True, check_errors: bool = False
+) -> Tuple[Tensor, Tensor, Tensor]:
+    check(
+        A.ndim >= 2,
+        lambda: f"torch.lu_factor: Expected tensor with 2 or more dimensions. Got size: {A.shape} instead",
+    )
+
+    sizes = list(A.shape)
+    m = sizes[-2]
+    n = sizes[-1]
+
+    LU = torch.empty_strided(
+        size=sizes,
+        stride=make_contiguous_strides_for(sizes, row_major=False),
+        dtype=A.dtype,
+        device=A.device,
+    )
+
+    # Sets sizes to the size of pivots
+    sizes.pop()
+    sizes[-1] = min(m, n)
+    pivots = A.new_empty(sizes, dtype=torch.int)
+
+    # Sets sizes to the size of info
+    sizes.pop()
+    info = A.new_empty(sizes, dtype=torch.int)
+
+    return LU, pivots, info
+
+
 # parse the "mode" param in linalg_qr: return a tuple of bools (compute_q, reduced)
 def _parse_qr_mode(mode: str) -> Tuple[bool, bool]:
     if mode == "reduced":