| // Copyright (c) Facebook, Inc. and its affiliates. |
| // All rights reserved. |
| // |
| // This source code is licensed under the BSD-style license found in the |
| // LICENSE file in the root directory of this source tree. |
| |
| #include <ATen/functorch/BatchRulesHelper.h> |
| |
| namespace at { namespace functorch { |
| |
| typedef std::tuple<Tensor, optional<int64_t>> oneOutput; |
| typedef std::tuple<Tensor, optional<int64_t>, Tensor, optional<int64_t>> twoOutputs; |
| typedef std::tuple<Tensor, optional<int64_t>, Tensor, optional<int64_t>, Tensor, optional<int64_t>> threeOutputs; |
| typedef std::tuple<Tensor, optional<int64_t>, Tensor, optional<int64_t>, Tensor, optional<int64_t>, Tensor, optional<int64_t>> fourOutputs; |
| |
| // Note [Batching rules for matmul-like operators] |
| // at::matmul doesn't "de-expand" arguments to get better performance (maybe |
| // it should). In the batching rules for matmul-like operators (dot, mv, mm), |
| // we should be careful not to expand any unnecessary dimensions. i.e., if |
| // only one of the two arguments is a BatchedTensor, then we should try |
| // not to expand batch dimensions onto the other arg. |
| |
| std::tuple<Tensor, optional<int64_t>> dot_batch_rule(const Tensor& A, optional<int64_t> A_bdim, const Tensor& B, optional<int64_t> B_bdim) { |
| TORCH_CHECK(A.dim() - A_bdim.has_value() == 1 && B.dim() - B_bdim.has_value() == 1, "Got wrong shapes for dot"); |
| auto A_ = moveBatchDimToFront(A, A_bdim); |
| auto B_ = moveBatchDimToFront(B, B_bdim); |
| if (A_bdim && B_bdim) { |
| return std::make_tuple(at::matmul(A_.unsqueeze(-2), B_.unsqueeze(-1)).squeeze(-1).squeeze(-1), 0); |
| } else { |
| return std::make_tuple(at::matmul(A_, B_.t()), 0); |
| } |
| } |
| |
| // NB: I wrote this like this because we *might* want its for a future matmul |
| // batch rule that isn't decomposed... |
| // "tv" = tensor @ vector |
| static std::tuple<Tensor, optional<int64_t>> tv_batch_rule( |
| const Tensor& self, optional<int64_t> self_bdim, |
| const Tensor& other, optional<int64_t> other_bdim) { |
| if (self_bdim && other_bdim) { |
| // See Note [Batching rules for matmul-like operators] |
| // B...OI, BI -> ...BOI, BI1 -> ...BO1 -> ...BO |
| auto self_ = at::movedim(self, *self_bdim, -3); |
| auto other_ = moveBatchDimToFront(other, other_bdim); |
| other_ = other_.unsqueeze(-1); |
| auto result = at::matmul(self_, other_).squeeze(-1); |
| auto result_bdim = result.dim() - 2; |
| return std::make_tuple( std::move(result), result_bdim ); |
| } |
| else if (self_bdim && !other_bdim) { |
| // B...OI, I -> B...O |
| auto self_ = moveBatchDimToFront(self, self_bdim); |
| return std::make_tuple( at::matmul(self_, other), 0 ); |
| } |
| else if (!self_bdim && other_bdim) { |
| // ...OI, BI -> ...OI, IB -> OB |
| auto other_ = at::movedim(other, *other_bdim, -1); |
| auto result = at::matmul(self, other_); |
| return std::make_tuple( std::move(result), 1 ); |
| } |
| TORCH_INTERNAL_ASSERT(false, "can't get here"); |
| } |
| |
| static std::tuple<Tensor, optional<int64_t>> mv_batch_rule( |
| const Tensor& self, optional<int64_t> self_bdim, |
| const Tensor& other, optional<int64_t> other_bdim) { |
| auto self_logical_rank = rankWithoutBatchDim(self, self_bdim); |
| auto other_logical_rank = rankWithoutBatchDim(other, other_bdim); |
| TORCH_CHECK(self_logical_rank == 2 && other_logical_rank == 1, |
| "Shape mismatch: ", |
| "Got incorrect dims for mv(a, b). a has dim ", self_logical_rank, |
| "and b has dim ", other_logical_rank, |
| "but expected them to have dim 2 and dim 1"); |
| return tv_batch_rule(self, self_bdim, other, other_bdim); |
| } |
| |
| static std::tuple<Tensor, optional<int64_t>> mm_batch_rule( |
| const Tensor& self, optional<int64_t> self_bdim, |
| const Tensor& other, optional<int64_t> other_bdim) { |
| auto self_logical_rank = rankWithoutBatchDim(self, self_bdim); |
| auto other_logical_rank = rankWithoutBatchDim(other, other_bdim); |
| TORCH_CHECK(self_logical_rank == 2 && other_logical_rank == 2, |
| "Shape mismatch: Got incorrect dims for mm(a, b). " |
| "a has dim ", self_logical_rank, |
| "and b has dim ", other_logical_rank, |
| "but expected them to have dim 2 and dim 2"); |
| auto self_ = moveBatchDimToFront(self, self_bdim); |
| auto other_ = moveBatchDimToFront(other, other_bdim); |
| return std::make_tuple( at::matmul(self_, other_), 0 ); |
| } |
| |
| static std::tuple<Tensor, optional<int64_t>> bmm_batch_rule( |
| const Tensor& self, optional<int64_t> self_bdim, |
| const Tensor& other, optional<int64_t> other_bdim) { |
| auto self_logical_rank = rankWithoutBatchDim(self, self_bdim); |
| auto other_logical_rank = rankWithoutBatchDim(other, other_bdim); |
| TORCH_CHECK(self_logical_rank == 3 && other_logical_rank == 3, |
| "Shape mismatch: Got incorrect dims for bmm(a, b). " |
| "a has dim ", self_logical_rank, |
| "and b has dim ", other_logical_rank, |
| "but expected them to have dim 3 and dim 3"); |
| auto self_ = moveBatchDimToFront(self, self_bdim); |
| auto other_ = moveBatchDimToFront(other, other_bdim); |
| return std::make_tuple( at::matmul(self_, other_), 0 ); |
| } |
| |
| // AFAICT, nothing here can be batched. So we decompose :) |
| Tensor addmv_decomp( |
| const Tensor& input, const Tensor& mat, const Tensor& vec, const Scalar& beta, const Scalar& alpha) { |
| Tensor out = at::mv(mat, vec); |
| if (!alpha.equal(1)) { |
| out = alpha * out; |
| } |
| if (!beta.equal(0)) { |
| out = beta * input + out; |
| } |
| return out; |
| } |
| |
| Tensor addbmm_decomp( |
| const Tensor& input, const Tensor& batch1, const Tensor& batch2, const Scalar& beta, const Scalar& alpha) { |
| Tensor out = at::bmm(batch1, batch2).sum(0); |
| if (!alpha.equal(1)) { |
| out = alpha * out; |
| } |
| if (!beta.equal(0)) { |
| out = beta * input + out; |
| } |
| return out; |
| } |
| |
| Tensor baddbmm_decomp( |
| const Tensor& input, const Tensor& batch1, const Tensor& batch2, const Scalar& beta, const Scalar& alpha) { |
| Tensor out = at::bmm(batch1, batch2); |
| if (!alpha.equal(1)) { |
| out = alpha * out; |
| } |
| if (!beta.equal(0)) { |
| out = beta * input + out; |
| } |
| return out; |
| } |
| |
| Tensor linear_decomp( |
| const Tensor& input, const Tensor& weight, |
| const c10::optional<Tensor>& bias_opt) { |
| auto result = input.matmul(weight.t()); |
| if (bias_opt) { |
| // NB: It's too much work to figure out how to actually fuse the bias so |
| // we're not going to. |
| // TODO: if the result isn't batched but bias is, then we need to do the following. |
| // Otherwise, it can just be in-place. We should write a more nuanced |
| // decomposition rule |
| return result.add(*bias_opt); |
| } |
| return result; |
| } |
| |
| Tensor addmm_decomp(const Tensor& self, const Tensor& mat1, const Tensor& mat2, const Scalar& beta, const Scalar& alpha) { |
| // Decomposition that is probably not very fast... |
| return at::add(self * beta, at::mm(mat1, mat2), alpha); |
| } |
| |
| void _linalg_check_errors_batch_rule(const Tensor& info, optional<int64_t> info_bdim, c10::string_view api_name, bool is_matrix) { |
| auto info_ = moveBatchDimToFront(info, info_bdim); |
| // Not a matrix means this is a batch of matrices |
| at::_linalg_check_errors(info_, api_name, false); |
| } |
| |
| std::tuple<Tensor, c10::optional<int64_t>> |
| householder_product_batch_rule(const Tensor &input, c10::optional<int64_t> input_bdim, |
| const Tensor &tau, c10::optional<int64_t> tau_bdim) |
| { |
| auto input_ = moveBatchDimToFront(input, input_bdim); |
| auto tau_ = moveBatchDimToFront(tau, tau_bdim); |
| |
| auto batch_size = get_bdim_size2(input, input_bdim, tau, tau_bdim); |
| |
| input_ = ensure_has_bdim(input_, input_bdim.has_value(), batch_size); |
| tau_ = ensure_has_bdim(tau_, tau_bdim.has_value(), batch_size); |
| return std::make_tuple(at::linalg_householder_product(input_, tau_), 0); |
| } |
| |
| template <char const *op_name, typename A, A a, typename C> |
| struct LinalgCheckMatrixUnaryRuleHelper; |
| |
| template <char const *op_name, typename F, F Func, typename A, typename... T> |
| struct LinalgCheckMatrixUnaryRuleHelper<op_name, F, Func, typelist<A, T...>> { |
| static inline Tensor check_and_reshape_input(const Tensor& tensor, optional<int64_t> batch_dim) { |
| TORCH_CHECK(rankWithoutBatchDim(tensor, batch_dim) >= 2, op_name, ": The input tensor A must have at least 2 dimensions."); |
| return moveBatchDimToFront(tensor, batch_dim); |
| } |
| |
| static oneOutput apply_one( |
| const Tensor& tensor, |
| optional<int64_t> batch_dim, |
| T... extra_args) { |
| const auto tensor_ = check_and_reshape_input(tensor, batch_dim); |
| return std::make_tuple(Func(tensor_, std::forward<T>(extra_args)...), 0); |
| } |
| |
| static twoOutputs apply_two( |
| const Tensor& tensor, |
| optional<int64_t> batch_dim, |
| T... extra_args) { |
| const auto tensor_ = check_and_reshape_input(tensor, batch_dim); |
| const auto res = Func(tensor_, std::forward<T>(extra_args)...); |
| return std::make_tuple(std::get<0>(res), 0, std::get<1>(res), 0); |
| } |
| |
| static threeOutputs apply_three( |
| const Tensor& tensor, |
| optional<int64_t> batch_dim, |
| T... extra_args) { |
| const auto tensor_ = check_and_reshape_input(tensor, batch_dim); |
| const auto res = Func(tensor_, std::forward<T>(extra_args)...); |
| return std::make_tuple(std::get<0>(res), 0, std::get<1>(res), 0, std::get<2>(res), 0); |
| } |
| |
| static fourOutputs apply_four( |
| const Tensor& tensor, |
| optional<int64_t> batch_dim, |
| T... extra_args) { |
| const auto tensor_ = check_and_reshape_input(tensor, batch_dim); |
| const auto res = Func(tensor_, std::forward<T>(extra_args)...); |
| return std::make_tuple(std::get<0>(res), 0, std::get<1>(res), 0, std::get<2>(res), 0, std::get<3>(res), 0); |
| } |
| }; |
| |
| template <char const *op_name, typename A, A a, typename C> |
| struct LinalgCheckMatrixBinaryRuleHelper; |
| |
| template <char const *op_name, typename F, F Func, typename A, typename B, typename... T> |
| struct LinalgCheckMatrixBinaryRuleHelper<op_name, F, Func, typelist<A, B, T...>> { |
| static inline std::tuple<Tensor, Tensor> check_inputs_and_reshape_inputs( |
| const Tensor& first, optional<int64_t> first_bdim, |
| const Tensor& second, optional<int64_t> second_bdim) { |
| TORCH_CHECK(rankWithoutBatchDim(first, first_bdim) >= 2, |
| op_name, ": The input tensor A must have at least 2 dimensions."); |
| TORCH_CHECK(rankWithoutBatchDim(second, second_bdim) >= 2, |
| op_name, ": The input tensor B must have at least 2 dimensions."); |
| return _binary_pointwise_helper(first, first_bdim, second, second_bdim, false); |
| } |
| |
| static oneOutput apply_one( |
| const Tensor& first, optional<int64_t> first_bdim, |
| const Tensor& second, optional<int64_t> second_bdim, |
| T... extra_args) { |
| const auto tensor_other = check_inputs_and_reshape_inputs(first, first_bdim, second, second_bdim); |
| const auto tensor_ = std::get<0>(tensor_other); |
| const auto other_ = std::get<1>(tensor_other); |
| return std::make_tuple(Func(tensor_, other_, std::forward<T>(extra_args)...), 0); |
| } |
| |
| static twoOutputs apply_two( |
| const Tensor& first, optional<int64_t> first_bdim, |
| const Tensor& second, optional<int64_t> second_bdim, |
| T... extra_args) { |
| const auto tensor_other = check_inputs_and_reshape_inputs(first, first_bdim, second, second_bdim); |
| const auto tensor_ = std::get<0>(tensor_other); |
| const auto other_ = std::get<1>(tensor_other); |
| const auto res = Func(tensor_, other_, std::forward<T>(extra_args)...); |
| return std::make_tuple(std::get<0>(res), 0, std::get<1>(res), 0); |
| } |
| }; |
| |
| static void expect_at_least_rank( |
| const Tensor& tensor, |
| optional<int64_t> tensor_bdim, |
| int64_t expected_rank, |
| const char* name) { |
| auto rank = rankWithoutBatchDim(tensor, tensor_bdim); |
| TORCH_CHECK(rank >= expected_rank, |
| name, " should have at least ", expected_rank, " dimensions, but has ", |
| rank, " dimensions instead."); |
| } |
| |
| oneOutput linalg_lu_solve_batch_rule( |
| const Tensor& LU, optional<int64_t> LU_bdim, |
| const Tensor& pivots, optional<int64_t> pivots_bdim, |
| const Tensor& B, optional<int64_t> B_bdim, |
| bool left, bool adjoint) { |
| const auto LU_min_rank = 2; |
| const auto pivots_min_rank = 1; |
| const auto B_min_rank = 2; |
| |
| expect_at_least_rank(LU, LU_bdim, LU_min_rank, "LU"); |
| expect_at_least_rank(pivots, pivots_bdim, pivots_min_rank, "pivots"); |
| expect_at_least_rank(B, B_bdim, B_min_rank, "B"); |
| |
| auto LU_ = moveBatchDimToFront(LU, LU_bdim); |
| auto pivots_ = moveBatchDimToFront(pivots, pivots_bdim); |
| auto B_ = moveBatchDimToFront(B, B_bdim); |
| |
| // LU and pivots's first {N-2} (for LU), {N-1} (for pivots) dimensions must match |
| // So if only one of them is being vmapped over, we must expand out that dimension. |
| if (LU_bdim.has_value() ^ pivots_bdim.has_value()) { |
| auto bdim_size = get_bdim_size2(LU, LU_bdim, pivots, pivots_bdim); |
| LU_ = ensure_has_bdim(LU_, LU_bdim.has_value(), bdim_size); |
| pivots_ = ensure_has_bdim(pivots_, pivots_bdim.has_value(), bdim_size); |
| pivots_bdim = 0; |
| LU_bdim = 0; |
| } |
| |
| // Now, {LU, pivots} and B's first dimensions are allowed to broadcast. |
| // The rest of the logic handles that. |
| const auto LU_num_batch_dims = rankWithoutBatchDim(LU_, LU_bdim) - LU_min_rank; |
| const auto pivots_num_batch_dims = rankWithoutBatchDim(pivots_, pivots_bdim) - pivots_min_rank; |
| const auto B_num_batch_dims = rankWithoutBatchDim(B_, B_bdim) - B_min_rank; |
| const auto max_num_batch_dims = std::max(std::max(LU_num_batch_dims, pivots_num_batch_dims), B_num_batch_dims); |
| |
| LU_ = maybePadToLogicalRank(LU_, LU_bdim, max_num_batch_dims + LU_min_rank); |
| pivots_ = maybePadToLogicalRank(pivots_, pivots_bdim, max_num_batch_dims + pivots_min_rank); |
| B_ = maybePadToLogicalRank(B_, B_bdim, max_num_batch_dims + B_min_rank); |
| |
| const auto result = at::linalg_lu_solve(LU_, pivots_, B_, left, adjoint); |
| return std::make_tuple(result, 0); |
| } |
| |
| oneOutput cholesky_solve_batch_rule( |
| const Tensor& self, c10::optional<int64_t> self_bdim, |
| const Tensor& A, c10::optional<int64_t> A_bdim, |
| bool upper) { |
| TORCH_CHECK(rankWithoutBatchDim(self, self_bdim) >= 2, |
| "b should have at least 2 dimensions, but has ", self.dim(), " dimensions instead"); |
| TORCH_CHECK(rankWithoutBatchDim(A, A_bdim) >= 2, |
| "u should have at least 2 dimensions, but has ", A.dim(), " dimensions instead"); |
| |
| const auto tensor_other = _binary_pointwise_helper(self, self_bdim, A, A_bdim, /*do_type_promotion=*/false); |
| const auto tensor_ = std::get<0>(tensor_other); |
| const auto other_ = std::get<1>(tensor_other); |
| return std::make_tuple(at::cholesky_solve(tensor_, other_, upper), 0); |
| } |
| |
| threeOutputs linalg_lu_factor_ex_batch_rule( |
| const Tensor& A, c10::optional<int64_t> A_bdim, bool pivot, bool check_errors) { |
| TORCH_CHECK(rankWithoutBatchDim(A, A_bdim) >= 2, "torch.lu_factor: Expected tensor with 2 or more dimensions. Got size: ", A.sizes(), " instead"); |
| const auto A_ = moveBatchDimToFront(A, A_bdim); |
| const auto res = at::linalg_lu_factor_ex(A_, pivot, check_errors); |
| return std::make_tuple(std::get<0>(res), 0, std::get<1>(res), 0, std::get<2>(res), 0); |
| } |
| |
| oneOutput matrix_exp_batch_rule(const Tensor& self, c10::optional<int64_t> self_bdim) { |
| TORCH_CHECK(rankWithoutBatchDim(self, self_bdim) >= 2, "torch.matrix_exp: The input tensor A must have at least 2 dimensions."); |
| const auto self_ = moveBatchDimToFront(self, self_bdim).contiguous(); // seems to be a bug |
| return std::make_tuple(at::matrix_exp(self_), 0); |
| } |
| |
| fourOutputs solve_ex_batch_rule( |
| const Tensor& A, optional<int64_t> A_bdim, |
| const Tensor& B, optional<int64_t> B_bdim, |
| bool left, bool check_errors) { |
| auto batch_size = get_bdim_size2(A, A_bdim, B, B_bdim); |
| const auto A_logical_rank = rankWithoutBatchDim(A, A_bdim); |
| const auto B_logical_rank = rankWithoutBatchDim(B, B_bdim); |
| const auto max_logical_rank = std::max(A_logical_rank, B_logical_rank); |
| |
| TORCH_CHECK(A_logical_rank >= 2, |
| "linalg.solve: The input tensor A must have at least 2 dimensions."); |
| |
| int b_logical_rank = max_logical_rank; |
| if (A_logical_rank > B_logical_rank) { // vector case: B was a vector or batched vector |
| // not accurate but matches linalg error message |
| TORCH_CHECK(B_logical_rank >= 1, "linalg.solve: The input tensor B must have at least 2 dimensions."); |
| b_logical_rank = max_logical_rank - 1; |
| } else { // matrix case: A and B are both matrices or batches of matrices |
| TORCH_CHECK(B_logical_rank >= 2, "linalg.solve: The input tensor B must have at least 2 dimensions."); |
| } |
| |
| // basically binary pointwise helper but if B was a vector incoming, we must pad it to be 1 dim smaller than A |
| auto A_ = moveBatchDimToFront(A, A_bdim); |
| auto B_ = moveBatchDimToFront(B, B_bdim); |
| A_ = maybePadToLogicalRank(A_, A_bdim, max_logical_rank); |
| B_ = maybePadToLogicalRank(B_, B_bdim, b_logical_rank); |
| |
| A_ = ensure_has_bdim(A_, A_bdim.has_value(), batch_size); |
| B_ = ensure_has_bdim(B_, B_bdim.has_value(), batch_size); |
| |
| // NOTE [ solve_ex Batch Rule Contiguity ] |
| // A determines whether or not linalg_solve takes an optimized path. We need the check on A_ to match the one run on |
| // A as BatchedTensor since it might have been saved by autograd (specifically by the jvp) and the autograd behvaior |
| // differs based on whether or not the optimized path was taken |
| const auto batched_A_was_contiguous = A_bdim.has_value() ? at::select(A, *A_bdim, 0).is_contiguous() : A.is_contiguous(); |
| if (batched_A_was_contiguous && !A.is_complex()) { |
| A_ = A_.contiguous(); |
| } |
| const auto res = _linalg_solve_ex(A_, B_, left, check_errors); |
| return std::make_tuple(std::get<0>(res), 0, std::get<1>(res), 0, std::get<2>(res), 0, std::get<3>(res), 0); |
| } |
| |
| oneOutput cross_batch_rule(const Tensor& self, c10::optional<int64_t> self_bdim, |
| const Tensor& other, c10::optional<int64_t> other_bdim, const int64_t dim) { |
| const auto batch_size = get_bdim_size2(self, self_bdim, other, other_bdim); |
| |
| const auto self_other_bundled = _binary_pointwise_helper(self, self_bdim, other, other_bdim, false); |
| |
| // might be a bug but still doesn't incur an extra perf hit because input would be expanded in cross' broadcast |
| const auto self_ = ensure_has_bdim(std::get<0>(self_other_bundled), self_bdim.has_value(), batch_size); |
| // needed because of same bug since other_.conj() is used as input to cross in backwards formula |
| const auto other_ = ensure_has_bdim(std::get<1>(self_other_bundled), other_bdim.has_value(), batch_size); |
| |
| const auto dim_ = getPhysicalDim(self_, true, dim); |
| |
| return std::make_tuple(linalg_cross(self_, other_, dim_), 0); |
| } |
| |
| c10::optional<int64_t> batch_dim_if_not_empty(const Tensor& t) { |
| if (t.dim() == 1 && t.size(0) == 0) { |
| return c10::optional<int64_t>(); |
| } |
| return c10::optional<int64_t>(0); |
| } |
| |
| fourOutputs linalg_lstsq_batch_rule( |
| const Tensor& self, c10::optional<int64_t> self_bdim, const Tensor& b, c10::optional<int64_t> b_bdim, |
| c10::optional<double> rcond, c10::optional<c10::string_view> driver) { |
| TORCH_CHECK(rankWithoutBatchDim(self, self_bdim) >= 2, "torch.linalg.lstsq: input must have at least 2 dimensions."); |
| TORCH_CHECK(rankWithoutBatchDim(b, b_bdim) >= 1, "torch.linalg.lstsq: other must have at least 1 dimension."); |
| |
| const auto batch_size = get_bdim_size2(self, self_bdim, b, b_bdim); |
| const auto tensor_other = _binary_pointwise_helper(self, self_bdim, b, b_bdim, /*do_type_promotion=*/false); |
| |
| // because of ambiguity with vector case, lstsq can broadcast [1, 2] -> [batch_size, 2] but not [2] -> [batch_size, 2] |
| // so could unsqueeze if there's no bdim or just ensure_has_bdim |
| const auto self_ = ensure_has_bdim(std::get<0>(tensor_other), self_bdim.has_value(), batch_size); |
| const auto b_ = ensure_has_bdim(std::get<1>(tensor_other), b_bdim.has_value(), batch_size); |
| |
| Tensor res, res_1, res_2, res_3; |
| std::tie(res, res_1, res_2, res_3) = at::linalg_lstsq(self_, b_, rcond, driver); |
| |
| // everything but the 0th output are only sometimes computed. When they aren't, they're empty tensors without a bdim |
| const auto res_1_bdim = batch_dim_if_not_empty(res_1); |
| const auto res_2_bdim = batch_dim_if_not_empty(res_2); |
| const auto res_3_bdim = batch_dim_if_not_empty(res_3); |
| return std::make_tuple(res, 0, res_1, res_1_bdim, res_2, res_2_bdim, res_3, res_3_bdim); |
| } |
| |
| template<typename F> |
| std::tuple<Tensor, c10::optional<int64_t>> |
| atol_rtol_tensor_batch_rule( |
| F Func, const Tensor& input, optional<int64_t> input_bdim, |
| const optional<Tensor>& atol, const optional<int64_t> atol_bdim, |
| const optional<Tensor>& rtol, const optional<int64_t> rtol_bdim, bool hermitian, char const *op_name) { |
| auto input_logical_rank = rankWithoutBatchDim(input, input_bdim); |
| |
| TORCH_CHECK(input_logical_rank >= 2, |
| op_name, ": The input tensor input must have at least 2 dimensions."); |
| |
| // atol and rtol's dims must be broadcastable to the number of batch dims of input |
| // which is input's dim - 2 (input represents a batch of matrices, so 2 is for the matrix dimensions) |
| const auto input_logical_num_bdims = input_logical_rank - 2; |
| const int64_t atol_logical_num_bdims = atol.has_value() ? rankWithoutBatchDim(*atol, atol_bdim) : 0; |
| const int64_t rtol_logical_num_bdims = rtol.has_value() ? rankWithoutBatchDim(*rtol, rtol_bdim) : 0; |
| const auto max_logical_bdims = std::max({input_logical_num_bdims, atol_logical_num_bdims, rtol_logical_num_bdims}); |
| |
| auto input_ = moveBatchDimToFront(input, input_bdim); |
| auto atol_ = atol.has_value() ? moveBatchDimToFront(*atol, atol_bdim) : atol; |
| auto rtol_ = rtol.has_value() ? moveBatchDimToFront(*rtol, rtol_bdim) : rtol; |
| |
| // pad all inputs to have the same number of (non-vmap) batch dimensions |
| input_ = maybePadToLogicalRank(input_, input_bdim, max_logical_bdims + 2); |
| atol_ = atol_.has_value() ? maybePadToLogicalRank(*atol_, atol_bdim, max_logical_bdims) : atol_; |
| rtol_ = rtol_.has_value() ? maybePadToLogicalRank(*rtol_, rtol_bdim, max_logical_bdims) : rtol_; |
| |
| return std::make_tuple(Func(input_, atol_, rtol_, hermitian), 0); |
| } |
| |
| std::tuple<Tensor, c10::optional<int64_t>> |
| matrix_rank_atol_rtol_tensor_batch_rule( |
| const Tensor& input, c10::optional<int64_t> input_bdim, const optional<Tensor>& atol, |
| const c10::optional<int64_t> atol_bdim, const optional<Tensor>& rtol, |
| const c10::optional<int64_t> rtol_bdim, bool hermitian) { |
| return atol_rtol_tensor_batch_rule(ATEN_FN2(linalg_matrix_rank, atol_rtol_tensor), input, input_bdim, atol, atol_bdim, rtol, rtol_bdim, hermitian, "torch.linalg.matrix_rank"); |
| } |
| |
| std::tuple<Tensor, c10::optional<int64_t>> |
| pinv_batch_rule( |
| const Tensor& input, c10::optional<int64_t> input_bdim, const optional<Tensor>& atol, |
| const c10::optional<int64_t> atol_bdim, const optional<Tensor>& rtol, |
| const c10::optional<int64_t> rtol_bdim, bool hermitian) { |
| return atol_rtol_tensor_batch_rule(ATEN_FN2(linalg_pinv, atol_rtol_tensor), input, input_bdim, atol, atol_bdim, rtol, rtol_bdim, hermitian, "linalg.pinv"); |
| } |
| |
| std::tuple<Tensor,optional<int64_t>> |
| matrix_rank_atol_rtol_float_batch_rule( |
| const Tensor& input, optional<int64_t> input_bdim, optional<double> atol, optional<double> rtol, bool hermitian) { |
| TORCH_CHECK(rankWithoutBatchDim(input, input_bdim) >= 2, |
| "torch.linalg.matrix_rank: The input tensor input must have at least 2 dimensions."); |
| return std::make_tuple(linalg_matrix_rank(moveBatchDimToFront(input, input_bdim), atol, rtol, hermitian), 0); |
| } |
| |
| #define LINALG_CHECK_MATRIX_UNARY_BATCH_RULE(fn, num_out) SINGLE_ARG(\ |
| LinalgCheckMatrixUnaryRuleHelper<\ |
| func_string_##fn,\ |
| decltype(&ATEN_FN(fn)),\ |
| &ATEN_FN(fn),\ |
| c10::guts::function_traits<decltype(ATEN_FN(fn))>::parameter_types>::apply_##num_out) |
| |
| #define LINALG_CHECK_MATRIX_UNARY_BATCH_RULE2(fn, overload, num_out) SINGLE_ARG(\ |
| LinalgCheckMatrixUnaryRuleHelper<\ |
| func_string_##fn_##overload,\ |
| decltype(&ATEN_FN2(fn, overload)),\ |
| &ATEN_FN2(fn, overload),\ |
| c10::guts::function_traits<decltype(ATEN_FN2(fn, overload))>::parameter_types>::apply_##num_out) |
| |
| #define LINALG_CHECK_MATRIX_BINARY_BATCH_RULE(fn, num_out) SINGLE_ARG(\ |
| LinalgCheckMatrixBinaryRuleHelper<\ |
| func_string_##fn,\ |
| decltype(&ATEN_FN(fn)),\ |
| &ATEN_FN(fn),\ |
| c10::guts::function_traits<decltype(ATEN_FN(fn))>::parameter_types>::apply_##num_out) |
| |
| |
| // Define string constants with the function names. These will be used as template parameters |
| // C++ doesn't let us use string literals as template parameters, so we have to declare them as consts first |
| // What is going on with these macros? |
| // - clang-5 seems to require the constexpr |
| // - windows compiles with or without the constexpr, but the constexpr causes test problems |
| // - as a result we have some macro guards. |
| #if defined(_MSC_VER) |
| #define LINALG_STRING_CONST(fn, op_name) \ |
| const char func_string_##fn[] = #op_name;\ |
| |
| #define LINALG_STRING_CONST2(fn, overload, op_name) \ |
| const char func_string_##fn_##overload[] = #op_name;\ |
| |
| #else |
| #define LINALG_STRING_CONST(fn, op_name) \ |
| constexpr const char func_string_##fn[] = #op_name;\ |
| |
| #define LINALG_STRING_CONST2(fn, overload, op_name) \ |
| constexpr const char func_string_##fn_##overload[] = #op_name;\ |
| |
| #endif |
| |
| #define LINALG_CHECK_MATRIX_UNARY_ONE_OUT(fn, op_name) \ |
| LINALG_STRING_CONST(fn, op_name);\ |
| TORCH_LIBRARY_IMPL(aten, FuncTorchBatched, m) {\ |
| VMAP_SUPPORT(fn, LINALG_CHECK_MATRIX_UNARY_BATCH_RULE(fn, one));\ |
| } |
| |
| #define LINALG_CHECK_MATRIX_UNARY_ONE_OUT2(fn, overload, op_name) \ |
| LINALG_STRING_CONST2(fn, overload, op_name);\ |
| TORCH_LIBRARY_IMPL(aten, FuncTorchBatched, m) {\ |
| VMAP_SUPPORT2(fn, overload, LINALG_CHECK_MATRIX_UNARY_BATCH_RULE2(fn, overload, one));\ |
| } |
| |
| #define LINALG_CHECK_MATRIX_UNARY_TWO_OUT(fn, op_name) \ |
| LINALG_STRING_CONST(fn, op_name);\ |
| TORCH_LIBRARY_IMPL(aten, FuncTorchBatched, m) {\ |
| VMAP_SUPPORT(fn, LINALG_CHECK_MATRIX_UNARY_BATCH_RULE(fn, two));\ |
| } |
| |
| #define LINALG_CHECK_MATRIX_UNARY_THREE_OUT(fn, op_name) \ |
| LINALG_STRING_CONST(fn, op_name);\ |
| TORCH_LIBRARY_IMPL(aten, FuncTorchBatched, m) {\ |
| VMAP_SUPPORT(fn, LINALG_CHECK_MATRIX_UNARY_BATCH_RULE(fn, three));\ |
| } |
| |
| #define LINALG_CHECK_MATRIX_UNARY_FOUR_OUT(fn, op_name) \ |
| LINALG_STRING_CONST(fn, op_name);\ |
| TORCH_LIBRARY_IMPL(aten, FuncTorchBatched, m) {\ |
| VMAP_SUPPORT(fn, LINALG_CHECK_MATRIX_UNARY_BATCH_RULE(fn, four));\ |
| } |
| |
| #define LINALG_CHECK_MATRIX_BINARY_ONE_OUT(fn, op_name) \ |
| LINALG_STRING_CONST(fn, op_name);\ |
| TORCH_LIBRARY_IMPL(aten, FuncTorchBatched, m) {\ |
| VMAP_SUPPORT(fn, LINALG_CHECK_MATRIX_BINARY_BATCH_RULE(fn, one));\ |
| } |
| |
| #define LINALG_CHECK_MATRIX_BINARY_TWO_OUT(fn, op_name) \ |
| LINALG_STRING_CONST(fn, op_name);\ |
| TORCH_LIBRARY_IMPL(aten, FuncTorchBatched, m) {\ |
| VMAP_SUPPORT(fn, LINALG_CHECK_MATRIX_BINARY_BATCH_RULE(fn, two));\ |
| } |
| |
| // These need to be outside. String constant must be declared outside of a macro to be used as template param |
| LINALG_CHECK_MATRIX_UNARY_ONE_OUT(cholesky, cholesky); |
| LINALG_CHECK_MATRIX_UNARY_ONE_OUT(cholesky_inverse, cholesky_inverse); |
| LINALG_CHECK_MATRIX_UNARY_TWO_OUT(linalg_cholesky_ex, linalg.cholesky); |
| LINALG_CHECK_MATRIX_UNARY_TWO_OUT(linalg_eig, linalg.eig); |
| LINALG_CHECK_MATRIX_UNARY_TWO_OUT(linalg_inv_ex, linalg.inv_ex); |
| LINALG_CHECK_MATRIX_UNARY_THREE_OUT(linalg_ldl_factor_ex, torch.linalg.ldl_factor_ex); |
| LINALG_CHECK_MATRIX_UNARY_ONE_OUT(linalg_matrix_power, linalg.matrix_power); |
| LINALG_CHECK_MATRIX_UNARY_ONE_OUT(linalg_pinv, linalg.pinv); |
| LINALG_CHECK_MATRIX_UNARY_ONE_OUT2(linalg_pinv, atol_rtol_float, linalg.pinv); |
| LINALG_CHECK_MATRIX_UNARY_TWO_OUT(linalg_qr, linalg.qr); |
| LINALG_CHECK_MATRIX_UNARY_TWO_OUT(linalg_slogdet, linalg.slogdet); |
| LINALG_CHECK_MATRIX_BINARY_ONE_OUT(linalg_solve_triangular, linalg.solve_triangular); |
| |
| LINALG_CHECK_MATRIX_UNARY_TWO_OUT(geqrf, geqrf); |
| LINALG_CHECK_MATRIX_UNARY_ONE_OUT(logdet, logdet); |
| LINALG_CHECK_MATRIX_UNARY_TWO_OUT(symeig, symeig); |
| LINALG_CHECK_MATRIX_BINARY_TWO_OUT(triangular_solve, triangular_solve); |
| LINALG_CHECK_MATRIX_UNARY_THREE_OUT(_linalg_det, linalg.det); |
| LINALG_CHECK_MATRIX_UNARY_TWO_OUT(_linalg_eigh, linalg.eigh); |
| LINALG_CHECK_MATRIX_UNARY_FOUR_OUT(_linalg_slogdet, linalg.slogdet); |
| LINALG_CHECK_MATRIX_UNARY_THREE_OUT(_linalg_svd, linalg.svd); |
| |
| TORCH_LIBRARY_IMPL(aten, FuncTorchBatched, m) { |
| VMAP_SUPPORT(bmm, bmm_batch_rule); |
| m.impl("addmv", addmv_decomp); |
| m.impl("addmm", addmm_decomp); |
| m.impl("addbmm", addbmm_decomp); |
| m.impl("baddbmm", baddbmm_decomp); |
| VMAP_SUPPORT(dot, dot_batch_rule); |
| VMAP_SUPPORT(mv, mv_batch_rule); |
| VMAP_SUPPORT(mm, mm_batch_rule); |
| m.impl("linear", linear_decomp); |
| VMAP_SUPPORT(linalg_lu_solve, linalg_lu_solve_batch_rule); |
| VMAP_SUPPORT(linalg_householder_product, householder_product_batch_rule); |
| VMAP_SUPPORT(cholesky_solve, cholesky_solve_batch_rule); // custom dim error |
| VMAP_SUPPORT(linalg_lstsq, linalg_lstsq_batch_rule); // custom errors and sometimes empty return |
| VMAP_SUPPORT(linalg_lu_factor_ex, linalg_lu_factor_ex_batch_rule); |
| VMAP_SUPPORT(linalg_matrix_exp, matrix_exp_batch_rule); |
| VMAP_SUPPORT(_linalg_solve_ex, solve_ex_batch_rule); |
| VMAP_SUPPORT(linalg_cross, cross_batch_rule); |
| VMAP_SUPPORT2(linalg_matrix_rank, atol_rtol_tensor, matrix_rank_atol_rtol_tensor_batch_rule); |
| VMAP_SUPPORT2(linalg_matrix_rank, atol_rtol_float, matrix_rank_atol_rtol_float_batch_rule); |
| VMAP_SUPPORT2(linalg_pinv, atol_rtol_tensor, pinv_batch_rule); |
| |
| VMAP_SUPPORT(_linalg_check_errors, _linalg_check_errors_batch_rule); |
| } |
| }} |