blob: b991937cf407553a260c682d2b85e07a8762a9c2 [file] [log] [blame]
#pragma once
#include <ATen/ATen.h>
#include <ATen/core/ivalue.h>
#include <c10d/ProcessGroup.hpp>
namespace c10d {
// Broadcast many tensors to all processes in the process group.
void broadcast_coalesced(
c10::intrusive_ptr<c10d::ProcessGroup> process_group,
at::TensorList tensors,
size_t buffer_size,
int rank = 0);
// This class passes bucket contents tensor to DDP communication hook.
class GradBucket {
public:
explicit GradBucket(
size_t index,
const at::Tensor& tensor,
const std::vector<size_t>& offsets,
const std::vector<size_t>& lengths,
const std::vector<c10::IntArrayRef>& sizes_vec)
: index_(index),
tensor_(tensor),
offsets_(offsets),
lengths_(lengths),
sizes_vec_(sizes_vec) {}
// Returns the index of the bucket, which is unique across all the buckets.
size_t getIndex() const {
return index_;
}
const at::Tensor& getTensor() const {
return tensor_;
}
// Returns a mutable tensor compared with the above method.
at::Tensor& getTensorRef() {
return tensor_;
}
// Overwrites tensors at a specific index.
void setTensor(at::Tensor& tensor) {
tensor_ = tensor;
}
// Each tensor in the list that getPerParameterTensors corresponds to a
// parameter.
std::vector<at::Tensor> getPerParameterTensors() const;
// Returns whther this bucket is the last bucket to allreduce in an iteration.
bool isTheLastBucketToAllreduce() const {
return index_ == 0;
}
private:
size_t index_;
at::Tensor tensor_;
// Per-variable info in tensors_[0].
std::vector<size_t> offsets_;
std::vector<size_t> lengths_;
std::vector<c10::IntArrayRef> sizes_vec_;
};
// Base class of both `PythonCommHook` and `CppCommHook`.
// Requires implementing 1) `runHook` method that communicates gradients
// asynchronously, and 2) `parseHookResult` method that converts the hook
// result into a tensor vector.
class TORCH_PYTHON_API CommHookInterface {
public:
virtual ~CommHookInterface() {}
// Passes the input grad bucket to the registered communication hook.
// Once the tensors in the bucket are ready, kicks off the hook asynchronously
// and returns a future that holds the communication results.
virtual c10::intrusive_ptr<c10::ivalue::Future> runHook(
GradBucket& bucket) = 0;
// Returns the resulting tensors once the communication hook result is
// ready. The resulting tensors will then be copied to the grads of
// individual parameters.
virtual std::vector<at::Tensor> parseHookResult(
const c10::IValue& result) = 0;
};
// This CppCommHook interface only requires implementing runHook method that
// potentially uses a state.
// Still need TORCH_PYTHON_API instead of TORCH_API to support Windows platform.
template <typename T>
class TORCH_PYTHON_API CppCommHookInterface : public CommHookInterface {
public:
explicit CppCommHookInterface(T& state) : state_(state) {}
virtual ~CppCommHookInterface() {}
std::vector<at::Tensor> parseHookResult(const c10::IValue& result) override {
TORCH_INTERNAL_ASSERT(
result.isTensor() || result.isTensorList(),
"expected the hook result is either a Tensor or a TensorList");
if (result.isTensor()) {
return {result.toTensor()};
}
return result.toTensorVector();
}
protected:
T state_; // Not owned.
};
} // namespace c10d