blob: 5ae5476e68dd04bd8ac869a9701e262f34ca9418 [file] [log] [blame]
#include <torch/csrc/autograd/functions/comm.h>
#include <ATen/core/functional.h>
#include <torch/csrc/autograd/function.h>
#include <torch/csrc/autograd/functions/utils.h>
#include <torch/csrc/autograd/variable.h>
#include <torch/csrc/cuda/comm.h>
#include <ATen/ATen.h>
#include <ATen/cuda/CUDAContext.h>
#include <c10/util/Optional.h>
#include <cstddef>
#include <memory>
#include <vector>
namespace torch {
namespace autograd {
Scatter::Scatter(
std::vector<at::Device> devices,
c10::optional<std::vector<int64_t>> chunk_sizes,
int64_t dim,
c10::optional<std::vector<c10::optional<at::cuda::CUDAStream>>> streams,
bool unsqueeze_scalars)
: devices_(std::move(devices)),
chunk_sizes_(std::move(chunk_sizes)),
dim_(dim),
streams_(std::move(streams)),
unsqueeze_scalars_(unsqueeze_scalars) {}
Scatter::~Scatter() = default;
variable_list Scatter::apply(variable_list&& inputs) {
AT_ASSERT(inputs.size() == 1);
auto& input = inputs.front();
std::shared_ptr<Node> grad_fn;
if (compute_requires_grad(input)) {
grad_fn =
std::make_shared<Gather>(/*destination_device=*/input.device(), dim_);
grad_fn->set_next_edges(collect_next_edges(input));
}
auto device_indices = fmap(devices_, [](const at::Device& device) -> int64_t {
return device.index();
});
auto tensors = torch::cuda::scatter(
std::move(input), device_indices, chunk_sizes_, dim_, streams_);
std::vector<Variable> variables;
variables.reserve(tensors.size());
for (auto& tensor : tensors) {
AT_ASSERT(tensor.defined());
if (unsqueeze_scalars_) {
AT_ASSERT(tensor.dim() == 1 && tensor.numel() == 1);
variables.push_back(tensor[0]);
} else {
variables.push_back(std::move(tensor));
}
}
if (grad_fn) {
set_history(variables, grad_fn);
}
return variables;
}
Gather::Gather(const at::Device& destination_device, int64_t dim)
: destination_device_(destination_device), dim_(dim) {}
Gather::~Gather() = default;
variable_list Gather::apply(variable_list&& inputs) {
bool all_are_zero_dim = true;
for (const auto& input : inputs) {
TORCH_CHECK(
input.is_cuda(),
"All inputs to Gather must be CUDA tensors, got ",
input.toString());
if (input.dim() > 0) {
all_are_zero_dim = false;
}
}
const bool unsqueeze_scalars = all_are_zero_dim && dim_ == 0;
if (unsqueeze_scalars) {
TORCH_WARN(
"Was asked to gather along dimension 0, but all "
"input tensors were scalars; will instead unsqueeze "
"and return a vector.");
}
std::shared_ptr<Node> grad_fn;
// compute this before moving variables from `inputs`
if (compute_requires_grad(inputs)) {
std::vector<at::Device> source_devices;
source_devices.reserve(inputs.size());
std::vector<int64_t> input_sizes;
input_sizes.reserve(inputs.size());
for (auto& input : inputs) {
source_devices.push_back(input.device());
input_sizes.push_back(input.size(dim_));
}
grad_fn = std::make_shared<Scatter>(
std::move(source_devices),
std::move(input_sizes),
dim_,
/*streams=*/c10::nullopt,
/*unsqueeze_scalars=*/unsqueeze_scalars);
grad_fn->set_next_edges(collect_next_edges(inputs));
}
std::vector<at::Tensor> tensors;
tensors.reserve(inputs.size());
for (auto& variable : inputs) {
if (unsqueeze_scalars) {
tensors.push_back(variable.view(1));
} else {
tensors.push_back(std::move(variable));
}
}
// Disable the autograd during the actual computation
// torch::cuda::gather does not return a view or change things inplace
// so no need for extra logic here
at::Tensor variable;
{
at::AutoDispatchBelowAutograd mode;
// This is special logic for torch::cuda::gather!
const auto destination_index =
destination_device_.is_cpu() ? -1 : destination_device_.index();
variable = torch::cuda::gather(tensors, dim_, destination_index);
}
if (grad_fn) {
set_history(variable, grad_fn);
}
return {variable};
}
} // namespace autograd
} // namespace torch