blob: 5560f9a0d7963f1cd8e30b139f2958c75fb511a4 [file] [log] [blame]
#include <ATen/LegacyVmapTransforms.h>
#include <ATen/ATen.h>
#include <ATen/core/IListRef.h>
#include <c10/util/irange.h>
namespace at {
// Checks if the batch dims in `bdims` appear at the front of the tensor.
static bool areBdimsAtFrontInOrder(BatchDimsRef bdims) {
for (const auto idx : c10::irange(static_cast<int64_t>(bdims.size()))) {
if (bdims[idx].dim() != idx) {
return false;
}
}
return true;
}
// Takes a BatchedTensorImpl, permutes all of the batch dims to the front,
// and then returns a physical version of the Tensor.
static Tensor permuteBatchDimsToFront(BatchedTensorImpl* batched) {
auto bdims = batched->bdims();
const Tensor& physical_tensor = batched->value();
if (areBdimsAtFrontInOrder(bdims)) {
return physical_tensor;
}
const auto sizes = physical_tensor.sizes();
VmapDimVector permutation(sizes.size(), 0);
permutation.reserve(sizes.size());
const auto is_bdim = createBatchDimBitset(bdims);
int64_t idx = 0;
for (const auto& bdim : bdims) {
permutation[idx++] = bdim.dim();
}
for (const auto ptr : c10::irange(sizes.size())) {
if (is_bdim[ptr]) {
continue;
}
permutation[idx++] = ptr;
}
return physical_tensor.permute(permutation);
}
VmapPhysicalView MultiBatchVmapTransform::logicalToPhysical(const Tensor& logical_tensor) {
auto* batched = maybeGetBatchedImpl(logical_tensor);
TORCH_INTERNAL_ASSERT(
batched,
"logicalToPhysical(tensor) should only be passed a BatchedTensor");
return { permuteBatchDimsToFront(batched), createVmapLevelsBitset(batched->bdims()) };
}
int64_t VmapPhysicalView::numBatchDims() const {
return levels_.count();
}
int64_t VmapPhysicalView::numLogicalDims() const {
return /*physical*/tensor_.dim() - numBatchDims();
}
VmapDimVector VmapPhysicalView::getPhysicalDims(OptionalIntArrayRef opt_logical_dims) const {
auto logical_ndim = numLogicalDims();
// NB: fmap doesn't have a SmallVector variant, so we don't use it here.
VmapDimVector result;
result.reserve(logical_ndim);
if (opt_logical_dims.has_value() && !opt_logical_dims.value().empty()) {
auto logical_dims = opt_logical_dims.value();
for (auto dim : logical_dims) {
result.push_back(maybe_wrap_dim(dim, logical_ndim) + numBatchDims());
}
} else {
for (int64_t dim = 0; dim < logical_ndim; dim++) {
result.push_back(dim + numBatchDims());
}
}
return result;
}
int64_t VmapPhysicalView::getPhysicalDim(int64_t logical_dim) const {
auto logical_ndim = numLogicalDims();
return maybe_wrap_dim(logical_dim, logical_ndim) + numBatchDims();
}
VmapDimVector VmapPhysicalView::getPhysicalShape(IntArrayRef logical_shape) const {
VmapDimVector result;
result.reserve(logical_shape.size() + numBatchDims());
auto tensor_sizes = tensor_.sizes();
result.insert(result.end(), tensor_sizes.begin(), tensor_sizes.begin() + numBatchDims());
result.insert(result.end(), logical_shape.begin(), logical_shape.end());
return result;
}
static BatchDims computeFrontBatchDimsFromLevels(std::bitset<kVmapNumLevels> levels_bitset) {
BatchDims bdims;
int64_t dim = 0;
for (const auto level : c10::irange(kVmapNumLevels)) {
if (!levels_bitset[level]) {
continue;
}
bdims.emplace_back(level, dim++);
}
return bdims;
}
// Given a Tensor or a BatchedTensor, returns the underlying physical tensor
// with all vmapped dimensions permuted to the front, if they exist, and a
// bitset of vmap levels that were present in the tensor.
static std::pair<Tensor,std::bitset<kVmapNumLevels>>
getPhysicalTensorAndLevels(const Tensor& self) {
auto* batched = maybeGetBatchedImpl(self);
if (batched) {
return {permuteBatchDimsToFront(batched), createVmapLevelsBitset(batched->bdims())};
}
return {self, 0};
}
// Given a Tensor or a BatchedTensor, creates a physical view of the tensor
// such that it has a batch dimension for each level in `requested_levels`
// and `requested_example_dim` number of non-batch-dimensions.
//
// This function is useful in preparing physical views on tensors that can
// then be passed into broadcasting operations. For example, when adding
// two BatchedTensors of sizes [B0, 3] and [B0, B1, 2, 3], where the Bi are the
// batch dimensions, we must align the batch dimensions and non-batch-dimensions
// (henceforth referred to as the "example" dimensions) separately to produce
// tensors of size [B0, 1, 1, 3] and [B0, B1, 2, 3] so that they can be added.
//
// Here's a direct example of using alignBatchDimsAtFront on the above two tensors.
//
// 1) alignBatchDimsAtFront([B0, 3], requested_levels={0, 1}, requested_example_dim=2)
// returns a physical view of size [B0, 1, 1, 3] by adding an extra dimension for
// level 1 and another extra dimension to pad the example dimensions to 2.
//
// 2) alignBatchDimsAtFront([B0, B1, 2, 3], requested_levels={0, 1}, requested_example_dim=2)
// returns a physical view of size [B0, B1, 2, 3]
static Tensor alignBatchDimsAtFront(
const Tensor& self,
std::bitset<kVmapNumLevels> requested_levels,
int64_t requested_example_dim) {
auto [physical_tensor, tensor_levels] = getPhysicalTensorAndLevels(self);
TORCH_INTERNAL_ASSERT(
(tensor_levels | requested_levels) == requested_levels,
"`requested_levels` must be a superset of `self`'s levels");
auto physical_sizes = physical_tensor.sizes();
const auto tensor_example_dim = (
static_cast<int64_t>(physical_sizes.size())
- /*num_batch_dims*/static_cast<int64_t>(tensor_levels.count())
);
TORCH_INTERNAL_ASSERT(tensor_example_dim <= requested_example_dim);
if (tensor_levels == requested_levels && tensor_example_dim == requested_example_dim) {
// Optimization: no need to do another view if the physical tensor is
// already the correct shape
return physical_tensor;
}
VmapDimVector aligned_sizes(requested_levels.count() + requested_example_dim, 1);
// align the example dims (non-bdims dims) first
// aligned_sizes[-tensor_example_dim:] = tensor_sizes[-tensor_example_dim:]
std::copy(
physical_sizes.rbegin(),
physical_sizes.rbegin() + tensor_example_dim,
aligned_sizes.rbegin());
// align the bdims
int64_t level = 0;
int64_t tensor_dim = 0;
for (const auto bdim : c10::irange(requested_levels.count())) {
// Determine the level of the bdim
while (!requested_levels[level]) level++;
if (tensor_levels[level]) {
aligned_sizes[bdim] = physical_sizes[tensor_dim++];
}
level++;
}
return physical_tensor.view(aligned_sizes);
}
// The algorithm is as follows:
// 1. Figure out what all of the collective levels in `logical_tensors` is.
// 2. Move all batch dims to the front of the tensors and add extra dims
// of size 1. At this point, every tensor will have a dimension for
// each of the collective levels.
// 3. Compute the batch_sizes.
// 4. Expand each physical tensor so that they have output batch size equal
// to `batch_sizes`
VmapPhysicalViewVec
MultiBatchVmapTransform::logicalToPhysical(ITensorListRef logical_tensors) {
// Figure out all of the collective vmap levels in `logical_tensors`.
std::bitset<kVmapNumLevels> collective_levels;
for (const auto& logical_tensor : logical_tensors) {
auto* batched = maybeGetBatchedImpl(logical_tensor);
if (batched) {
collective_levels |= createVmapLevelsBitset(batched->bdims());
}
}
// Populate physical_tensors.
// This contains a list of regular (non-Batched) Tensors where all of the
// batch dims have been moved to the front of the tensor. Any previously
// non-existing batch dims get added to the tensors as new dimensions of size 1.
std::vector<Tensor> physical_tensors;
int64_t num_batch_dims = collective_levels.count();
for (const auto& logical_tensor : logical_tensors) {
auto requested_example_dim = /*logical_dim*/logical_tensor.dim();
auto physical_tensor = alignBatchDimsAtFront(
logical_tensor, collective_levels, requested_example_dim);
physical_tensors.push_back(std::move(physical_tensor));
}
// Compute batch_sizes
VmapDimVector batch_sizes(num_batch_dims, 1);
for (const auto& physical_tensor : physical_tensors) {
auto physical_sizes = physical_tensor.sizes();
for (const auto dim : c10::irange(num_batch_dims)) {
if (physical_sizes[dim] != 1) {
batch_sizes[dim] = physical_sizes[dim];
}
}
}
// Expand each physical_tensor so that it has batch sizes `batch_sizes`
VmapPhysicalViewVec result;
for (const auto& physical_tensor : physical_tensors) {
VmapDimVector expanded_size(batch_sizes.begin(), batch_sizes.end());
auto physical_sizes = physical_tensor.sizes();
expanded_size.insert(
expanded_size.end(),
physical_sizes.begin() + num_batch_dims,
physical_sizes.end());
result.emplace_back(physical_tensor.expand(expanded_size), collective_levels);
}
return result;
}
static std::pair<std::bitset<kVmapNumLevels>,int64_t>
getLevelsAndLargestLogicalDim(TensorList logical_tensors) {
TORCH_INTERNAL_ASSERT(!logical_tensors.empty());
std::bitset<kVmapNumLevels> levels;
int64_t largest_logical_dim = -1;
for (const auto& tensor : logical_tensors) {
auto* batched = maybeGetBatchedImpl(tensor);
if (batched) {
levels = levels | createVmapLevelsBitset(batched->bdims());
}
auto tensor_logical_dim = /*logical dim*/tensor.dim();
if (tensor_logical_dim > largest_logical_dim) {
largest_logical_dim = tensor_logical_dim;
}
}
return { levels, largest_logical_dim };
}
VmapPhysicalViewVec BroadcastingVmapTransform::logicalToPhysical(TensorList logical_tensors) {
TORCH_INTERNAL_ASSERT(
logical_tensors.size() == 2,
"This function has only been tested for two tensors. Please add more tests ",
"before removing this check ");
VmapPhysicalViewVec result;
auto [levels, largest_logical_dim] = getLevelsAndLargestLogicalDim(logical_tensors);
for (const auto& tensor : logical_tensors) {
// NB: It's possible that we didn't actually need to align `tensor`.
// For example, when adding two tensors of size (B, 2), and (3, 2), where
// the first Tensor is a BatchedTensor with batch dim B and the second is
// a regular Tensor, we will return views of size (B, 1, 2) and (1, 3, 2).
// However, the view on the second tensor is unnecessary: broadcasting
// semantics allow for the addition of two tensors of size (B, 1, 2) and (3, 2)!
//
// If this unnecessary view is a problem, consider optimizing it away in
// the future. This may involve creating a new type of VmapPhysicalView
auto aligned = alignBatchDimsAtFront(tensor, levels, largest_logical_dim) ;
result.emplace_back(std::move(aligned), levels);
}
return result;
}
VmapPhysicalToLogicalMap VmapPhysicalView::getPhysicalToLogicalMap() const {
return VmapPhysicalToLogicalMap(levels_);
}
Tensor VmapPhysicalToLogicalMap::apply(const Tensor& physical_tensor) const {
return makeBatched(physical_tensor, computeFrontBatchDimsFromLevels(levels_));
}
void VmapPhysicalToLogicalMap::applyInplace(std::vector<Tensor>& physical_tensors) const {
for (auto & physical_tensor : physical_tensors) {
physical_tensor = apply(physical_tensor);
}
}
} // namespace at