[pt2] add meta function for `solve_triangular` (#100829)

Pull Request resolved: https://github.com/pytorch/pytorch/pull/100829
Approved by: https://github.com/ezyang
diff --git a/test/functorch/test_aotdispatch.py b/test/functorch/test_aotdispatch.py
index 523503b..e160051 100644
--- a/test/functorch/test_aotdispatch.py
+++ b/test/functorch/test_aotdispatch.py
@@ -2490,7 +2490,6 @@
     xfail('index_fill', ''),  # Cannot call sizes() on tensor with symbolic sizes/strides
     xfail('kron', ''),  # Cannot call sizes() on tensor with symbolic sizes/strides
     xfail('kthvalue', ''),  # Cannot call sizes() on tensor with symbolic sizes/strides
-    xfail('linalg.cholesky_ex', ''),  # could not find kernel for aten.linalg_solve_triangular.default
     xfail('linalg.det', ''),  # aten._linalg_det.default - couldn't find symbolic meta function/decomposition
     xfail('linalg.det', 'singular'),  # aten._linalg_det.default - couldn't find symbolic meta function/deco...
     xfail('linalg.eigh', ''),  # aten._linalg_eigh.default - couldn't find symbolic meta function/decomposition
@@ -2511,7 +2510,6 @@
     xfail('linalg.slogdet', ''),  # aten._linalg_slogdet.default - couldn't find symbolic meta function/decom...
     xfail('linalg.solve', ''),  # aten._linalg_solve_ex.default - couldn't find symbolic meta function/decomp...
     xfail('linalg.solve_ex', ''),  # aten._linalg_solve_ex.default - couldn't find symbolic meta function/dec...
-    xfail('linalg.solve_triangular', ''),  # aten.linalg_solve_triangular.default - couldn't find symbolic me...
     xfail('linalg.tensorinv', ''),  # Cannot call sizes() on tensor with symbolic sizes/strides
     xfail('linalg.tensorsolve', ''),  # Cannot call sizes() on tensor with symbolic sizes/strides
     xfail('linalg.vander', ''),  # Cannot call sizes() on tensor with symbolic sizes/strides
diff --git a/test/test_meta.py b/test/test_meta.py
index 4b33a98..693636b 100644
--- a/test/test_meta.py
+++ b/test/test_meta.py
@@ -598,7 +598,6 @@
     torch.functional.istft : {f64, c64, c128, f32},
     torch.geqrf : {f64, c64, c128, f32},
     torch.linalg.householder_product : {f64, c64, c128, f32},
-    torch.linalg.solve_triangular : {f64, c64, c128, f32},
     torch.masked_select : {f64, i32, c128, i64, i16, f16, u8, c64, bf16, b8, i8, f32},
     torch.matrix_exp : {f64, c128, c64, bf16, f32},
     torch.nonzero : {f64, i32, c128, i64, i16, c32, f16, u8, c64, bf16, b8, i8, f32},
@@ -718,7 +717,6 @@
     torch.histc: {i16, i32, i64, i8},  # aten::histc, aten::histc.out
     torch.kthvalue: {f16},  # aten::kthvalue.values
     torch.linalg.householder_product: {f32, f64},  # aten::linalg_householder_product, aten::linalg_householder_product.out
-    torch.linalg.solve_triangular: {f32, f64},  # aten::linalg_solve_triangular, aten::linalg_solve_triangular.out
     torch.matrix_exp: {f16},  # aten::linalg_matrix_exp
     torch.median: {f16},  # aten::median, aten::median.dim_values
     torch.multinomial: {f16},  # aten::multinomial, aten::multinomial.out
@@ -830,8 +828,6 @@
     aten.linalg_householder_product.out : {c64, c128, f64, f32},
     aten.linalg_lstsq.default : {c64, c128, f64, f32},
     aten.linalg_matrix_exp.default : {c64, bf16, f32, f64, c128},
-    aten.linalg_solve_triangular.default : {c64, c128, f64, f32},
-    aten.linalg_solve_triangular.out : {c64, c128, f64, f32},
     aten.masked_select.default : {c64, f16, i8, f64, c128, i64, bf16, f32, i32, b8, i16, u8},
     aten.masked_select.out : {c64, f16, i8, f64, c128, i64, bf16, f32, i32, b8, i16, u8},
     aten.nonzero.default : {c64, f16, i8, f64, c128, i64, bf16, f32, i32, c32, b8, i16, u8},
@@ -929,8 +925,6 @@
     aten.linalg_householder_product.default: {f32, f64},  # aten::linalg_householder_product
     aten.linalg_householder_product.out: {f32, f64},  # aten::linalg_householder_product.out
     aten.linalg_matrix_exp.default: {f16},  # aten::linalg_matrix_exp
-    aten.linalg_solve_triangular.default: {f32, f64},  # aten::linalg_solve_triangular
-    aten.linalg_solve_triangular.out: {f32, f64},  # aten::linalg_solve_triangular.out
     aten.log_sigmoid_forward.default: {bf16, f16, f64, f32},
     aten.log_sigmoid_forward.output : {bf16, f16, f64, f32},  # aten::log_sigmoid_forward.output
     aten.max_pool3d_with_indices.default: {bf16, f16},  # aten::max_pool3d_with_indices
diff --git a/test/test_proxy_tensor.py b/test/test_proxy_tensor.py
index 36035d2..67e45bf 100644
--- a/test/test_proxy_tensor.py
+++ b/test/test_proxy_tensor.py
@@ -1458,7 +1458,6 @@
     xfail('linalg.slogdet', ''),  # aten._linalg_slogdet.default - couldn't find symbolic meta function/decomposition
     xfail('linalg.solve', ''),  # aten._linalg_solve_ex.default - couldn't find symbolic meta function/decomposition
     xfail('linalg.solve_ex', ''),  # aten._linalg_solve_ex.default - couldn't find symbolic meta function/decomposition
-    xfail('linalg.solve_triangular', ''),  # aten.linalg_solve_triangular.default - couldn't find symbolic meta function/de...
     xfail('linalg.tensorinv', ''),  # aten.size.default - couldn't find symbolic meta function/decomposition
     xfail('linalg.tensorsolve', ''),  # aten.size.default - couldn't find symbolic meta function/decomposition
     xfail('linalg.vander', ''),  # aten.size.default - couldn't find symbolic meta function/decomposition
diff --git a/torch/_meta_registrations.py b/torch/_meta_registrations.py
index fdb8abd..9eafb03 100644
--- a/torch/_meta_registrations.py
+++ b/torch/_meta_registrations.py
@@ -18,7 +18,12 @@
     TensorLike,
 )
 
-from torch._prims_common.wrappers import _maybe_resize_out, _safe_copy_out, out_wrapper
+from torch._prims_common.wrappers import (
+    _maybe_resize_out,
+    _resize_output_check,
+    _safe_copy_out,
+    out_wrapper,
+)
 from torch._refs import _broadcast_shapes
 
 from torch.utils._pytree import tree_map
@@ -315,6 +320,48 @@
     ), f"{f_name}: A must be batches of square matrices, but they are {self.size(-2)} by {self.size(-1)} matrices"
 
 
+# Validates input shapes and devices
+# for linear solve methods (solve, cholesky_solve, lu_solve, triangular_solve)
+# From aten/src/ATen/native/LinearAlgebraUtils.h
+def linearSolveCheckInputs(
+    self: Tensor,
+    A: Tensor,
+    name: str,
+):
+    check(
+        self.device == A.device,
+        lambda: (
+            f"Expected b and A to be on the same device, but found b on "
+            f"{self.device} and A on {A.device} instead."
+        ),
+    )
+
+    check(
+        self.dtype == A.dtype,
+        lambda: (
+            f"Expected b and A to have the same dtype, but found b of type "
+            f"{self.dtype} and A of type {A.dtype} instead."
+        ),
+    )
+
+    check(
+        A.size(-1) == A.size(-2),
+        lambda: (
+            f"A must be batches of square matrices, "
+            f"but they are {A.size(-2)} by {A.size(-1)} matrices"
+        ),
+    )
+
+    check(
+        A.size(-1) == self.size(-2),
+        lambda: (
+            f"Incompatible matrix sizes for {name}: each A "
+            f"matrix is {A.size(-1)} by {A.size(-1)}"
+            f" but each b matrix is {self.size(-2)} by {self.size(-1)}"
+        ),
+    )
+
+
 # From aten/src/ATen/native/LinearAlgebraUtils.h
 def checkFloatingOrComplex(
     t: Tensor, f_name: str, allow_low_precision_dtypes: bool = True
@@ -339,6 +386,24 @@
     )
 
 
+def checkInputsSolver(
+    A: Tensor,
+    B: Tensor,
+    left: bool,
+    f_name: str,
+):
+    squareCheckInputs(A, f_name)
+    checkIsMatrix(B, f_name)
+    check(
+        A.size(-2) == B.size(-2) if left else A.size(-1) == B.size(-1),
+        lambda: (
+            f"{f_name}: Incompatible shapes of A and B for the equation "
+            f"{'AX = B' if left else 'XA = B'}"
+            f" ({A.size(-2)}x{A.size(-1)} and {B.size(-2)}x{B.size(-1)})"
+        ),
+    )
+
+
 def checkUplo(uplo: str):
     uplo_uppercase = uplo.upper()
     assert (
@@ -483,6 +548,66 @@
     return U, S, V
 
 
+def _linalg_broadcast_batch_dims(
+    arg1: Tensor, arg2: Tensor
+) -> Tuple[List[int], List[int]]:
+    # broadcast the batch dimensions of arg1 and arg2.
+    arg1_batch_sizes = arg1.shape[:-2]
+    arg2_batch_sizes = arg2.shape[:-2]
+    expand_batch_portion = _broadcast_shapes(arg1_batch_sizes, arg2_batch_sizes)
+
+    arg1_expand_size = list(expand_batch_portion)
+    arg1_expand_size += [arg1.size(-2), arg1.size(-1)]
+
+    arg2_expand_size = list(expand_batch_portion)
+    arg2_expand_size += [arg2.size(-2), arg2.size(-1)]
+    return arg1_expand_size, arg2_expand_size
+
+
+def _linalg_broadcast_batch_dims_name(
+    arg1: Tensor, arg2: Tensor, name: Optional[str]
+) -> Tuple[Tensor, Tensor]:
+    # If there's no name we assume we don't want to check the errors
+    if name:
+        linearSolveCheckInputs(arg1, arg2, name)
+
+    arg1_expand_size, arg2_expand_size = _linalg_broadcast_batch_dims(arg1, arg2)
+
+    arg1_broadcasted = (
+        arg1 if arg1_expand_size == arg1.shape else arg1.expand(arg1_expand_size)
+    )
+    arg2_broadcasted = (
+        arg2 if arg2_expand_size == arg2.shape else arg2.expand(arg2_expand_size)
+    )
+    return arg1_broadcasted, arg2_broadcasted
+
+
+@register_meta([aten.linalg_solve_triangular.default, aten.linalg_solve_triangular.out])
+def linalg_solve_triangular_meta(
+    A: Tensor,
+    B: Tensor,
+    *,
+    upper: bool,
+    left: bool = True,
+    unitriangular: bool = False,
+    out: Tensor = None,
+) -> Tensor:
+    if out is None:
+        out = A.new_empty([0])
+    assert isinstance(out, TensorLike)
+    checkInputsSolver(A, B, left, "linalg.solve_triangular")
+    B_, A_ = _linalg_broadcast_batch_dims_name(B, A, None)
+    avoid_copy_A = A_.transpose(-2, -1).is_contiguous() and A_.is_conj()
+    if avoid_copy_A:
+        out = _maybe_resize_out(out, B_.shape)
+    else:
+        # reimplementation of resize_output with result F-contig
+        if _resize_output_check(out, B_.shape):
+            out.resize_(B_.transpose(-2, -1).shape)
+            out.transpose_(-2, -1)
+    return out  # type: ignore[return-value]
+
+
 # From aten/src/ATen/native/LinearAlgebra.cpp
 @register_meta(aten._linalg_det.default)
 def _linalg_det_meta(A):
diff --git a/torch/_prims_common/wrappers.py b/torch/_prims_common/wrappers.py
index ac19d43..4f602d8 100644
--- a/torch/_prims_common/wrappers.py
+++ b/torch/_prims_common/wrappers.py
@@ -139,22 +139,29 @@
         return _fn
 
 
-# TODO: handle tuples of tensors
-def _maybe_resize_out(out: TensorLikeType, shape: ShapeType):
+# Returns True if resize is necessary
+def _resize_output_check(out: TensorLikeType, shape: ShapeType):
     # If the shapes are correct there's nothing to do
     if utils.same_shape(out.shape, shape):
-        return out
-    else:
-        if out.numel() != 0:
-            msg = (
-                f"An output with one or more elements was resized since it had shape {str(out.shape)} "
-                "which does not match the required output shape {str(shape)}. "
-                "This behavior is deprecated, and in a future PyTorch release outputs will not "
-                "be resized unless they have zero elements. "
-                "You can explicitly reuse an out tensor t by resizing it, inplace, to zero elements with t.resize_(0)."
-            )
-            warnings.warn(msg)
+        return False
+    if out.numel() != 0:
+        msg = (
+            f"An output with one or more elements was resized since it had shape {str(out.shape)} "
+            "which does not match the required output shape {str(shape)}. "
+            "This behavior is deprecated, and in a future PyTorch release outputs will not "
+            "be resized unless they have zero elements. "
+            "You can explicitly reuse an out tensor t by resizing it, inplace, to zero elements with t.resize_(0)."
+        )
+        warnings.warn(msg)
+    return True
+
+
+# TODO: handle tuples of tensors
+def _maybe_resize_out(out: TensorLikeType, shape: ShapeType):
+    if _resize_output_check(out, shape):
         return out.resize_(shape)
+    else:
+        return out
 
 
 def _safe_copy_out(