blob: 8996ca0c8dac82def7d7a35d28b71ea319614dfa [file] [log] [blame]
// Copyright (c) Facebook, Inc. and its affiliates.
// All rights reserved.
//
// This source code is licensed under the BSD-style license found in the
// LICENSE file in the root directory of this source tree.
#include <ATen/functorch/BatchRulesHelper.h>
#include <ATen/WrapDimUtils.h>
namespace at { namespace functorch {
Tensor moveBatchDimToFront(const Tensor& tensor, optional<int64_t> maybe_batch_dim) {
if (!maybe_batch_dim.has_value()) {
return tensor;
}
if (maybe_batch_dim.value() == 0) {
return tensor;
}
return tensor.movedim(maybe_batch_dim.value(), 0);
}
int64_t rankWithoutBatchDim(const Tensor& tensor, optional<int64_t> maybe_batch_dim) {
int64_t result = tensor.dim();
if (maybe_batch_dim.has_value()) {
result -= 1;
}
return result;
}
int64_t numelWithoutBatchDim(const Tensor& tensor, optional<int64_t> maybe_batch_dim) {
if (!maybe_batch_dim) {
return tensor.numel();
}
return tensor.numel() / tensor.size(*maybe_batch_dim);
}
optional<int64_t> valIfNonempty(optional<int64_t> maybe_empty, int64_t new_val) {
if (maybe_empty.has_value()) {
return new_val;
}
return nullopt;
}
int64_t getPhysicalDim(const Tensor& tensor, bool has_batch_dim, int64_t logical_dim) {
// NB: assumes the batch dim is at the front of the tensor
optional<int64_t> bdim = has_batch_dim ? optional<int64_t>(0) : nullopt;
auto rank = rankWithoutBatchDim(tensor, bdim);
auto wrapped_dim = maybe_wrap_dim(logical_dim, rank);
if (has_batch_dim) {
return wrapped_dim + 1;
}
return wrapped_dim;
}
VmapDimVector getPhysicalDims(const Tensor& tensor, bool has_batch_dim, IntArrayRef logical_dims) {
// NB: assumes the batch dim is at the front of the tensor
optional<int64_t> bdim = has_batch_dim ? optional<int64_t>(0) : nullopt;
auto rank = rankWithoutBatchDim(tensor, bdim);
VmapDimVector result;
result.reserve(logical_dims.size());
for (auto d : logical_dims){
if (has_batch_dim) {
result.push_back(maybe_wrap_dim(d, rank)+1);
} else {
result.push_back(maybe_wrap_dim(d, rank));
}
}
return result;
}
Tensor maybePadToLogicalRank(const Tensor& tensor, optional<int64_t> has_bdim, int64_t logical_rank) {
if (!has_bdim) {
return tensor;
}
auto tensor_logical_rank = rankWithoutBatchDim(tensor, has_bdim);
if (tensor_logical_rank >= logical_rank) {
return tensor;
}
VmapDimVector new_sizes(tensor.sizes().begin(), tensor.sizes().end());
for (int64_t i = 0; i < logical_rank - tensor_logical_rank; i++) {
new_sizes.insert(new_sizes.begin() + 1, 1);
}
return tensor.view(new_sizes);
}
void check_randomness(RandomnessType randomness, bool any_tensor_batched) {
TORCH_CHECK(
randomness != RandomnessType::Error,
"vmap: called random operation while in randomness error mode. Please either use the "
"'same' or 'different' randomness flags on vmap or perform the randomness operation out of vmap"
);
TORCH_CHECK(
!(randomness == RandomnessType::Same && any_tensor_batched),
"Vmap does not currently support same randomness with a batched tensor input. ",
"Please file an issue with functorch"
)
}
void check_randomness(RandomnessType randomness) {
check_randomness(randomness, false); // for ops that don't take in any tensors, don't hit same error
}
Tensor reshape_dim_into(int64_t src, int64_t dst, const Tensor& x) {
auto x_dim = x.dim();
src = maybe_wrap_dim(src, x_dim);
dst = maybe_wrap_dim(dst, x_dim - 1); // Returned Tensor has one fewer dim
VmapDimVector new_shape(x.sizes().begin(), x.sizes().end());
new_shape.erase(new_shape.begin() + src);
new_shape[dst] *= x.sizes()[src];
return at::reshape(x.movedim(src, dst), new_shape);
}
Tensor reshape_dim_outof(int64_t src, int64_t size1, const Tensor& x) {
src = maybe_wrap_dim(src, x.dim());
VmapDimVector shape(x.sizes().begin(), x.sizes().end());
TORCH_INTERNAL_ASSERT(shape[src] % size1 == 0);
int64_t size2 = shape[src] / size1;
shape[src] = size1;
shape.insert(shape.begin() + src + 1, size2);
return at::reshape(x, shape);
}
void vmapIncompatibleInplaceError(const char* schema_name) {
TORCH_CHECK(false,
"vmap: ", schema_name, "(self, *extra_args) is not possible because ",
"there exists a Tensor `other` in extra_args that has more elements ",
"than `self`. This happened due to `other` being vmapped over but ",
"`self` not being vmapped over in a vmap. ",
"Please try to use out-of-place operators instead of ", schema_name, ". ",
"If said operator is being called inside the PyTorch framework, ",
"please file a bug report instead.");
}
static void handleScalarTypePromotion(Tensor& logical_scalar_tensor, Tensor& second) {
auto result_type = at::native::result_type(logical_scalar_tensor[0], second);
if (logical_scalar_tensor.scalar_type() != result_type) {
logical_scalar_tensor = logical_scalar_tensor.to(result_type);
}
if (second.scalar_type() != result_type) {
second = second.to(result_type);
}
}
std::tuple<Tensor, Tensor> _binary_pointwise_helper(
const Tensor& tensor, optional<int64_t> tensor_batch_dim,
const Tensor& other, optional<int64_t> other_batch_dim,
bool do_type_promotion) {
// compute max logical rank
auto tensor_logical_rank = rankWithoutBatchDim(tensor, tensor_batch_dim);
auto other_logical_rank = rankWithoutBatchDim(other, other_batch_dim);
auto max_logical_rank = std::max(tensor_logical_rank, other_logical_rank);
auto tensor_ = moveBatchDimToFront(tensor, tensor_batch_dim);
auto other_ = moveBatchDimToFront(other, other_batch_dim);
// In the (0D, ND) case, type promotion semantics are different :/
if (do_type_promotion) {
auto tensor_is_logical_scalar = (tensor_logical_rank == 0 && tensor_batch_dim.has_value());
auto other_is_logical_scalar = (other_logical_rank == 0 && other_batch_dim.has_value());
if (tensor_is_logical_scalar && !other_is_logical_scalar) {
handleScalarTypePromotion(tensor_, other_);
}
if (other_is_logical_scalar && !tensor_is_logical_scalar) {
handleScalarTypePromotion(other_, tensor_);
}
}
// If the dimensions aren't aligned, we need to line them up.
// Tensor[B, 3] + Tensor[2, 5, 3] -> Tensor[B, 1, 1, 3] + Tensor[2, 5, 3]
// Note that only tensors that have a batch dim need to be modified.
// Tensor[B, 2, 3, 5] + Tensor[5] -> no changes needed
tensor_ = maybePadToLogicalRank(tensor_, tensor_batch_dim, max_logical_rank);
other_ = maybePadToLogicalRank(other_, other_batch_dim, max_logical_rank);
return std::make_tuple(tensor_, other_);
}
}}