| // Copyright (c) Facebook, Inc. and its affiliates. |
| // All rights reserved. |
| // |
| // This source code is licensed under the BSD-style license found in the |
| // LICENSE file in the root directory of this source tree. |
| |
| #include <ATen/functorch/BatchRulesHelper.h> |
| #include <ATen/WrapDimUtils.h> |
| |
| namespace at { namespace functorch { |
| |
| Tensor moveBatchDimToFront(const Tensor& tensor, optional<int64_t> maybe_batch_dim) { |
| if (!maybe_batch_dim.has_value()) { |
| return tensor; |
| } |
| if (maybe_batch_dim.value() == 0) { |
| return tensor; |
| } |
| return tensor.movedim(maybe_batch_dim.value(), 0); |
| } |
| |
| int64_t rankWithoutBatchDim(const Tensor& tensor, optional<int64_t> maybe_batch_dim) { |
| int64_t result = tensor.dim(); |
| if (maybe_batch_dim.has_value()) { |
| result -= 1; |
| } |
| return result; |
| } |
| |
| int64_t numelWithoutBatchDim(const Tensor& tensor, optional<int64_t> maybe_batch_dim) { |
| if (!maybe_batch_dim) { |
| return tensor.numel(); |
| } |
| return tensor.numel() / tensor.size(*maybe_batch_dim); |
| } |
| |
| optional<int64_t> valIfNonempty(optional<int64_t> maybe_empty, int64_t new_val) { |
| if (maybe_empty.has_value()) { |
| return new_val; |
| } |
| return nullopt; |
| } |
| |
| int64_t getPhysicalDim(const Tensor& tensor, bool has_batch_dim, int64_t logical_dim) { |
| // NB: assumes the batch dim is at the front of the tensor |
| optional<int64_t> bdim = has_batch_dim ? optional<int64_t>(0) : nullopt; |
| auto rank = rankWithoutBatchDim(tensor, bdim); |
| auto wrapped_dim = maybe_wrap_dim(logical_dim, rank); |
| if (has_batch_dim) { |
| return wrapped_dim + 1; |
| } |
| return wrapped_dim; |
| } |
| |
| VmapDimVector getPhysicalDims(const Tensor& tensor, bool has_batch_dim, IntArrayRef logical_dims) { |
| // NB: assumes the batch dim is at the front of the tensor |
| optional<int64_t> bdim = has_batch_dim ? optional<int64_t>(0) : nullopt; |
| auto rank = rankWithoutBatchDim(tensor, bdim); |
| VmapDimVector result; |
| result.reserve(logical_dims.size()); |
| for (auto d : logical_dims){ |
| if (has_batch_dim) { |
| result.push_back(maybe_wrap_dim(d, rank)+1); |
| } else { |
| result.push_back(maybe_wrap_dim(d, rank)); |
| } |
| } |
| return result; |
| } |
| |
| Tensor maybePadToLogicalRank(const Tensor& tensor, optional<int64_t> has_bdim, int64_t logical_rank) { |
| if (!has_bdim) { |
| return tensor; |
| } |
| auto tensor_logical_rank = rankWithoutBatchDim(tensor, has_bdim); |
| if (tensor_logical_rank >= logical_rank) { |
| return tensor; |
| } |
| VmapDimVector new_sizes(tensor.sizes().begin(), tensor.sizes().end()); |
| for (int64_t i = 0; i < logical_rank - tensor_logical_rank; i++) { |
| new_sizes.insert(new_sizes.begin() + 1, 1); |
| } |
| return tensor.view(new_sizes); |
| } |
| |
| void check_randomness(RandomnessType randomness, bool any_tensor_batched) { |
| TORCH_CHECK( |
| randomness != RandomnessType::Error, |
| "vmap: called random operation while in randomness error mode. Please either use the " |
| "'same' or 'different' randomness flags on vmap or perform the randomness operation out of vmap" |
| ); |
| |
| TORCH_CHECK( |
| !(randomness == RandomnessType::Same && any_tensor_batched), |
| "Vmap does not currently support same randomness with a batched tensor input. ", |
| "Please file an issue with functorch" |
| ) |
| } |
| |
| void check_randomness(RandomnessType randomness) { |
| check_randomness(randomness, false); // for ops that don't take in any tensors, don't hit same error |
| } |
| |
| Tensor reshape_dim_into(int64_t src, int64_t dst, const Tensor& x) { |
| auto x_dim = x.dim(); |
| src = maybe_wrap_dim(src, x_dim); |
| dst = maybe_wrap_dim(dst, x_dim - 1); // Returned Tensor has one fewer dim |
| VmapDimVector new_shape(x.sizes().begin(), x.sizes().end()); |
| new_shape.erase(new_shape.begin() + src); |
| new_shape[dst] *= x.sizes()[src]; |
| return at::reshape(x.movedim(src, dst), new_shape); |
| } |
| |
| Tensor reshape_dim_outof(int64_t src, int64_t size1, const Tensor& x) { |
| src = maybe_wrap_dim(src, x.dim()); |
| VmapDimVector shape(x.sizes().begin(), x.sizes().end()); |
| TORCH_INTERNAL_ASSERT(shape[src] % size1 == 0); |
| int64_t size2 = shape[src] / size1; |
| shape[src] = size1; |
| shape.insert(shape.begin() + src + 1, size2); |
| return at::reshape(x, shape); |
| } |
| |
| void vmapIncompatibleInplaceError(const char* schema_name) { |
| TORCH_CHECK(false, |
| "vmap: ", schema_name, "(self, *extra_args) is not possible because ", |
| "there exists a Tensor `other` in extra_args that has more elements ", |
| "than `self`. This happened due to `other` being vmapped over but ", |
| "`self` not being vmapped over in a vmap. ", |
| "Please try to use out-of-place operators instead of ", schema_name, ". ", |
| "If said operator is being called inside the PyTorch framework, ", |
| "please file a bug report instead."); |
| } |
| |
| static void handleScalarTypePromotion(Tensor& logical_scalar_tensor, Tensor& second) { |
| auto result_type = at::native::result_type(logical_scalar_tensor[0], second); |
| if (logical_scalar_tensor.scalar_type() != result_type) { |
| logical_scalar_tensor = logical_scalar_tensor.to(result_type); |
| } |
| if (second.scalar_type() != result_type) { |
| second = second.to(result_type); |
| } |
| } |
| |
| std::tuple<Tensor, Tensor> _binary_pointwise_helper( |
| const Tensor& tensor, optional<int64_t> tensor_batch_dim, |
| const Tensor& other, optional<int64_t> other_batch_dim, |
| bool do_type_promotion) { |
| // compute max logical rank |
| auto tensor_logical_rank = rankWithoutBatchDim(tensor, tensor_batch_dim); |
| auto other_logical_rank = rankWithoutBatchDim(other, other_batch_dim); |
| auto max_logical_rank = std::max(tensor_logical_rank, other_logical_rank); |
| |
| auto tensor_ = moveBatchDimToFront(tensor, tensor_batch_dim); |
| auto other_ = moveBatchDimToFront(other, other_batch_dim); |
| |
| // In the (0D, ND) case, type promotion semantics are different :/ |
| if (do_type_promotion) { |
| auto tensor_is_logical_scalar = (tensor_logical_rank == 0 && tensor_batch_dim.has_value()); |
| auto other_is_logical_scalar = (other_logical_rank == 0 && other_batch_dim.has_value()); |
| if (tensor_is_logical_scalar && !other_is_logical_scalar) { |
| handleScalarTypePromotion(tensor_, other_); |
| } |
| if (other_is_logical_scalar && !tensor_is_logical_scalar) { |
| handleScalarTypePromotion(other_, tensor_); |
| } |
| } |
| |
| // If the dimensions aren't aligned, we need to line them up. |
| // Tensor[B, 3] + Tensor[2, 5, 3] -> Tensor[B, 1, 1, 3] + Tensor[2, 5, 3] |
| // Note that only tensors that have a batch dim need to be modified. |
| // Tensor[B, 2, 3, 5] + Tensor[5] -> no changes needed |
| tensor_ = maybePadToLogicalRank(tensor_, tensor_batch_dim, max_logical_rank); |
| other_ = maybePadToLogicalRank(other_, other_batch_dim, max_logical_rank); |
| |
| return std::make_tuple(tensor_, other_); |
| } |
| |
| }} |