[pt2] add meta function for `solve_triangular` (#100829)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/100829
Approved by: https://github.com/ezyang
diff --git a/test/functorch/test_aotdispatch.py b/test/functorch/test_aotdispatch.py
index 523503b..e160051 100644
--- a/test/functorch/test_aotdispatch.py
+++ b/test/functorch/test_aotdispatch.py
@@ -2490,7 +2490,6 @@
xfail('index_fill', ''), # Cannot call sizes() on tensor with symbolic sizes/strides
xfail('kron', ''), # Cannot call sizes() on tensor with symbolic sizes/strides
xfail('kthvalue', ''), # Cannot call sizes() on tensor with symbolic sizes/strides
- xfail('linalg.cholesky_ex', ''), # could not find kernel for aten.linalg_solve_triangular.default
xfail('linalg.det', ''), # aten._linalg_det.default - couldn't find symbolic meta function/decomposition
xfail('linalg.det', 'singular'), # aten._linalg_det.default - couldn't find symbolic meta function/deco...
xfail('linalg.eigh', ''), # aten._linalg_eigh.default - couldn't find symbolic meta function/decomposition
@@ -2511,7 +2510,6 @@
xfail('linalg.slogdet', ''), # aten._linalg_slogdet.default - couldn't find symbolic meta function/decom...
xfail('linalg.solve', ''), # aten._linalg_solve_ex.default - couldn't find symbolic meta function/decomp...
xfail('linalg.solve_ex', ''), # aten._linalg_solve_ex.default - couldn't find symbolic meta function/dec...
- xfail('linalg.solve_triangular', ''), # aten.linalg_solve_triangular.default - couldn't find symbolic me...
xfail('linalg.tensorinv', ''), # Cannot call sizes() on tensor with symbolic sizes/strides
xfail('linalg.tensorsolve', ''), # Cannot call sizes() on tensor with symbolic sizes/strides
xfail('linalg.vander', ''), # Cannot call sizes() on tensor with symbolic sizes/strides
diff --git a/test/test_meta.py b/test/test_meta.py
index 4b33a98..693636b 100644
--- a/test/test_meta.py
+++ b/test/test_meta.py
@@ -598,7 +598,6 @@
torch.functional.istft : {f64, c64, c128, f32},
torch.geqrf : {f64, c64, c128, f32},
torch.linalg.householder_product : {f64, c64, c128, f32},
- torch.linalg.solve_triangular : {f64, c64, c128, f32},
torch.masked_select : {f64, i32, c128, i64, i16, f16, u8, c64, bf16, b8, i8, f32},
torch.matrix_exp : {f64, c128, c64, bf16, f32},
torch.nonzero : {f64, i32, c128, i64, i16, c32, f16, u8, c64, bf16, b8, i8, f32},
@@ -718,7 +717,6 @@
torch.histc: {i16, i32, i64, i8}, # aten::histc, aten::histc.out
torch.kthvalue: {f16}, # aten::kthvalue.values
torch.linalg.householder_product: {f32, f64}, # aten::linalg_householder_product, aten::linalg_householder_product.out
- torch.linalg.solve_triangular: {f32, f64}, # aten::linalg_solve_triangular, aten::linalg_solve_triangular.out
torch.matrix_exp: {f16}, # aten::linalg_matrix_exp
torch.median: {f16}, # aten::median, aten::median.dim_values
torch.multinomial: {f16}, # aten::multinomial, aten::multinomial.out
@@ -830,8 +828,6 @@
aten.linalg_householder_product.out : {c64, c128, f64, f32},
aten.linalg_lstsq.default : {c64, c128, f64, f32},
aten.linalg_matrix_exp.default : {c64, bf16, f32, f64, c128},
- aten.linalg_solve_triangular.default : {c64, c128, f64, f32},
- aten.linalg_solve_triangular.out : {c64, c128, f64, f32},
aten.masked_select.default : {c64, f16, i8, f64, c128, i64, bf16, f32, i32, b8, i16, u8},
aten.masked_select.out : {c64, f16, i8, f64, c128, i64, bf16, f32, i32, b8, i16, u8},
aten.nonzero.default : {c64, f16, i8, f64, c128, i64, bf16, f32, i32, c32, b8, i16, u8},
@@ -929,8 +925,6 @@
aten.linalg_householder_product.default: {f32, f64}, # aten::linalg_householder_product
aten.linalg_householder_product.out: {f32, f64}, # aten::linalg_householder_product.out
aten.linalg_matrix_exp.default: {f16}, # aten::linalg_matrix_exp
- aten.linalg_solve_triangular.default: {f32, f64}, # aten::linalg_solve_triangular
- aten.linalg_solve_triangular.out: {f32, f64}, # aten::linalg_solve_triangular.out
aten.log_sigmoid_forward.default: {bf16, f16, f64, f32},
aten.log_sigmoid_forward.output : {bf16, f16, f64, f32}, # aten::log_sigmoid_forward.output
aten.max_pool3d_with_indices.default: {bf16, f16}, # aten::max_pool3d_with_indices
diff --git a/test/test_proxy_tensor.py b/test/test_proxy_tensor.py
index 36035d2..67e45bf 100644
--- a/test/test_proxy_tensor.py
+++ b/test/test_proxy_tensor.py
@@ -1458,7 +1458,6 @@
xfail('linalg.slogdet', ''), # aten._linalg_slogdet.default - couldn't find symbolic meta function/decomposition
xfail('linalg.solve', ''), # aten._linalg_solve_ex.default - couldn't find symbolic meta function/decomposition
xfail('linalg.solve_ex', ''), # aten._linalg_solve_ex.default - couldn't find symbolic meta function/decomposition
- xfail('linalg.solve_triangular', ''), # aten.linalg_solve_triangular.default - couldn't find symbolic meta function/de...
xfail('linalg.tensorinv', ''), # aten.size.default - couldn't find symbolic meta function/decomposition
xfail('linalg.tensorsolve', ''), # aten.size.default - couldn't find symbolic meta function/decomposition
xfail('linalg.vander', ''), # aten.size.default - couldn't find symbolic meta function/decomposition
diff --git a/torch/_meta_registrations.py b/torch/_meta_registrations.py
index fdb8abd..9eafb03 100644
--- a/torch/_meta_registrations.py
+++ b/torch/_meta_registrations.py
@@ -18,7 +18,12 @@
TensorLike,
)
-from torch._prims_common.wrappers import _maybe_resize_out, _safe_copy_out, out_wrapper
+from torch._prims_common.wrappers import (
+ _maybe_resize_out,
+ _resize_output_check,
+ _safe_copy_out,
+ out_wrapper,
+)
from torch._refs import _broadcast_shapes
from torch.utils._pytree import tree_map
@@ -315,6 +320,48 @@
), f"{f_name}: A must be batches of square matrices, but they are {self.size(-2)} by {self.size(-1)} matrices"
+# Validates input shapes and devices
+# for linear solve methods (solve, cholesky_solve, lu_solve, triangular_solve)
+# From aten/src/ATen/native/LinearAlgebraUtils.h
+def linearSolveCheckInputs(
+ self: Tensor,
+ A: Tensor,
+ name: str,
+):
+ check(
+ self.device == A.device,
+ lambda: (
+ f"Expected b and A to be on the same device, but found b on "
+ f"{self.device} and A on {A.device} instead."
+ ),
+ )
+
+ check(
+ self.dtype == A.dtype,
+ lambda: (
+ f"Expected b and A to have the same dtype, but found b of type "
+ f"{self.dtype} and A of type {A.dtype} instead."
+ ),
+ )
+
+ check(
+ A.size(-1) == A.size(-2),
+ lambda: (
+ f"A must be batches of square matrices, "
+ f"but they are {A.size(-2)} by {A.size(-1)} matrices"
+ ),
+ )
+
+ check(
+ A.size(-1) == self.size(-2),
+ lambda: (
+ f"Incompatible matrix sizes for {name}: each A "
+ f"matrix is {A.size(-1)} by {A.size(-1)}"
+ f" but each b matrix is {self.size(-2)} by {self.size(-1)}"
+ ),
+ )
+
+
# From aten/src/ATen/native/LinearAlgebraUtils.h
def checkFloatingOrComplex(
t: Tensor, f_name: str, allow_low_precision_dtypes: bool = True
@@ -339,6 +386,24 @@
)
+def checkInputsSolver(
+ A: Tensor,
+ B: Tensor,
+ left: bool,
+ f_name: str,
+):
+ squareCheckInputs(A, f_name)
+ checkIsMatrix(B, f_name)
+ check(
+ A.size(-2) == B.size(-2) if left else A.size(-1) == B.size(-1),
+ lambda: (
+ f"{f_name}: Incompatible shapes of A and B for the equation "
+ f"{'AX = B' if left else 'XA = B'}"
+ f" ({A.size(-2)}x{A.size(-1)} and {B.size(-2)}x{B.size(-1)})"
+ ),
+ )
+
+
def checkUplo(uplo: str):
uplo_uppercase = uplo.upper()
assert (
@@ -483,6 +548,66 @@
return U, S, V
+def _linalg_broadcast_batch_dims(
+ arg1: Tensor, arg2: Tensor
+) -> Tuple[List[int], List[int]]:
+ # broadcast the batch dimensions of arg1 and arg2.
+ arg1_batch_sizes = arg1.shape[:-2]
+ arg2_batch_sizes = arg2.shape[:-2]
+ expand_batch_portion = _broadcast_shapes(arg1_batch_sizes, arg2_batch_sizes)
+
+ arg1_expand_size = list(expand_batch_portion)
+ arg1_expand_size += [arg1.size(-2), arg1.size(-1)]
+
+ arg2_expand_size = list(expand_batch_portion)
+ arg2_expand_size += [arg2.size(-2), arg2.size(-1)]
+ return arg1_expand_size, arg2_expand_size
+
+
+def _linalg_broadcast_batch_dims_name(
+ arg1: Tensor, arg2: Tensor, name: Optional[str]
+) -> Tuple[Tensor, Tensor]:
+ # If there's no name we assume we don't want to check the errors
+ if name:
+ linearSolveCheckInputs(arg1, arg2, name)
+
+ arg1_expand_size, arg2_expand_size = _linalg_broadcast_batch_dims(arg1, arg2)
+
+ arg1_broadcasted = (
+ arg1 if arg1_expand_size == arg1.shape else arg1.expand(arg1_expand_size)
+ )
+ arg2_broadcasted = (
+ arg2 if arg2_expand_size == arg2.shape else arg2.expand(arg2_expand_size)
+ )
+ return arg1_broadcasted, arg2_broadcasted
+
+
+@register_meta([aten.linalg_solve_triangular.default, aten.linalg_solve_triangular.out])
+def linalg_solve_triangular_meta(
+ A: Tensor,
+ B: Tensor,
+ *,
+ upper: bool,
+ left: bool = True,
+ unitriangular: bool = False,
+ out: Tensor = None,
+) -> Tensor:
+ if out is None:
+ out = A.new_empty([0])
+ assert isinstance(out, TensorLike)
+ checkInputsSolver(A, B, left, "linalg.solve_triangular")
+ B_, A_ = _linalg_broadcast_batch_dims_name(B, A, None)
+ avoid_copy_A = A_.transpose(-2, -1).is_contiguous() and A_.is_conj()
+ if avoid_copy_A:
+ out = _maybe_resize_out(out, B_.shape)
+ else:
+ # reimplementation of resize_output with result F-contig
+ if _resize_output_check(out, B_.shape):
+ out.resize_(B_.transpose(-2, -1).shape)
+ out.transpose_(-2, -1)
+ return out # type: ignore[return-value]
+
+
# From aten/src/ATen/native/LinearAlgebra.cpp
@register_meta(aten._linalg_det.default)
def _linalg_det_meta(A):
diff --git a/torch/_prims_common/wrappers.py b/torch/_prims_common/wrappers.py
index ac19d43..4f602d8 100644
--- a/torch/_prims_common/wrappers.py
+++ b/torch/_prims_common/wrappers.py
@@ -139,22 +139,29 @@
return _fn
-# TODO: handle tuples of tensors
-def _maybe_resize_out(out: TensorLikeType, shape: ShapeType):
+# Returns True if resize is necessary
+def _resize_output_check(out: TensorLikeType, shape: ShapeType):
# If the shapes are correct there's nothing to do
if utils.same_shape(out.shape, shape):
- return out
- else:
- if out.numel() != 0:
- msg = (
- f"An output with one or more elements was resized since it had shape {str(out.shape)} "
- "which does not match the required output shape {str(shape)}. "
- "This behavior is deprecated, and in a future PyTorch release outputs will not "
- "be resized unless they have zero elements. "
- "You can explicitly reuse an out tensor t by resizing it, inplace, to zero elements with t.resize_(0)."
- )
- warnings.warn(msg)
+ return False
+ if out.numel() != 0:
+ msg = (
+ f"An output with one or more elements was resized since it had shape {str(out.shape)} "
+ "which does not match the required output shape {str(shape)}. "
+ "This behavior is deprecated, and in a future PyTorch release outputs will not "
+ "be resized unless they have zero elements. "
+ "You can explicitly reuse an out tensor t by resizing it, inplace, to zero elements with t.resize_(0)."
+ )
+ warnings.warn(msg)
+ return True
+
+
+# TODO: handle tuples of tensors
+def _maybe_resize_out(out: TensorLikeType, shape: ShapeType):
+ if _resize_output_check(out, shape):
return out.resize_(shape)
+ else:
+ return out
def _safe_copy_out(