blob: 3799bb56df65b41bfb954d54e2eb43a0eb14f153 [file] [log] [blame]
#pragma once
#include <torch/csrc/distributed/c10d/Backend.hpp>
#include <condition_variable>
#include <memory>
#include <mutex>
#include <stdexcept>
#include <unordered_map>
#include <utility>
#include <vector>
#include <ATen/ATen.h>
#include <ATen/core/dispatch/Dispatcher.h>
#include <c10/macros/Macros.h>
#include <torch/csrc/distributed/c10d/Work.hpp>
// *************************************************************************
// PROCESS GROUP collective communication API IS BEING CHANGED BETWEEN
// versions 1.7 and 1.8.
// PLEASE DO NOT ADD ANY DEPENDENCIES.
// SEE RFC: https://github.com/pytorch/pytorch/issues/39662
// *************************************************************************
constexpr auto kProcessGroupDefaultTimeout =
std::chrono::milliseconds(30 * 60 * 1000);
namespace c10d {
// ProcessGroup is a base class that captures collective and point to
// point communication in a fixed set of processes.
//
// The functions specified in the class below describe the API alone;
// implementations are provided in subclasses.
//
// Every function that performs I/O is executed asynchronously by a
// thread pool owned by the ProcessGroup (by default). They return an
// object that can be used to wait for completion or error.
//
// The ProcessGroup can instantiate subgroups with fewer or an equal
// number of members. Implementations must take care that multiple
// process groups can be used in parallel and synchronize accordingly.
//
// The ProcessGroup assumes a fixed set of processes. If the set
// changes, existing instances must be destructed and instantiation
// and initialization must start from scratch. For members of the
// process group to find each other (referred to as rendezvous from
// hereon)
//
class TORCH_API ProcessGroup : public torch::CustomClassHolder {
public:
// ProcessGroup Options is a base struct that defines the basic options
// when constructing a ProcessGroup. Each ProcessGroup subclass should
// extend this struct and define its options if it wants to provide more
// config options (beyond basic ones defined here) to end user.
struct TORCH_API Options : torch::CustomClassHolder {
explicit Options(
std::string backend,
std::chrono::milliseconds timeout = kProcessGroupDefaultTimeout)
: timeout(timeout), backend(std::move(backend)) {}
~Options() override = default;
std::chrono::milliseconds timeout;
// backend name
const std::string backend;
};
enum BackendType {
UNDEFINED = 0,
GLOO = 1,
NCCL = 2,
UCC = 3,
MPI = 4,
CUSTOM = 5,
};
// Not used, set for backwards compatibility and only used for TypeDef in
// Ops.cpp
explicit ProcessGroup(int rank, int size);
explicit ProcessGroup(
const c10::intrusive_ptr<::c10d::Store>& store,
int rank,
int size,
c10::intrusive_ptr<Options> options);
~ProcessGroup() override;
int getRank() const {
return rank_;
}
int getSize() const {
return size_;
}
// Returns an unique opaque ID of this process group object.
int64_t getID() const {
return reinterpret_cast<std::intptr_t>(this);
}
// Returns an unique opaque ID of a backend for the specific backend type
// that can correlate with this process group's collectives.
int64_t getBackendID(BackendType backend_type) const {
return reinterpret_cast<std::intptr_t>(getBackend(backend_type).get());
}
virtual const std::string getBackendName() const {
return options_->backend;
};
BackendType getBackendType() const {
return backendType_;
};
virtual void startCoalescing(c10::DeviceType deviceType) {
// only nccl has implemented startCoalescing so only execute for nccl
// backends
auto backend = getBackend(deviceType);
backend->startCoalescing();
}
virtual c10::intrusive_ptr<Work> endCoalescing(c10::DeviceType deviceType) {
// only nccl has implemented endCoalescing so only execute for nccl
// backends
auto backend = getBackend(deviceType);
auto work = backend->endCoalescing();
return work;
}
virtual c10::intrusive_ptr<Work> broadcast(
std::vector<at::Tensor>& tensors,
const BroadcastOptions& opts = BroadcastOptions()) {
static auto op =
c10::Dispatcher::singleton()
.findSchemaOrThrow("c10d::broadcast_", "")
.typed<
std::tuple<std::vector<at::Tensor>, c10::intrusive_ptr<Work>>(
at::TensorList,
const c10::intrusive_ptr<::c10d::ProcessGroup>&,
int64_t,
int64_t,
int64_t)>();
// It's awakward to unbox the opts here and box them again in the custom C++
// op. But it's also complicated to make opts as a CustomClassHolder. Leave
// it as it is now.
return std::get<1>(op.call(
tensors,
c10::intrusive_ptr<ProcessGroup>::unsafe_reclaim_from_nonowning(this),
opts.rootRank,
opts.rootTensor,
opts.timeout.count()));
}
virtual c10::intrusive_ptr<Work> allreduce(
std::vector<at::Tensor>& tensors,
const AllreduceOptions& opts = AllreduceOptions()) {
static auto op =
c10::Dispatcher::singleton()
.findSchemaOrThrow("c10d::allreduce_", "")
.typed<
std::tuple<std::vector<at::Tensor>, c10::intrusive_ptr<Work>>(
at::TensorList,
const c10::intrusive_ptr<::c10d::ProcessGroup>&,
const c10::intrusive_ptr<::c10d::ReduceOp>&,
const c10::optional<at::Tensor>& sparse_indices,
int64_t)>();
return std::get<1>(op.call(
tensors,
c10::intrusive_ptr<ProcessGroup>::unsafe_reclaim_from_nonowning(this),
c10::make_intrusive<ReduceOp>(opts.reduceOp),
opts.sparseIndices,
opts.timeout.count()));
}
virtual c10::intrusive_ptr<Work> allreduce_coalesced(
std::vector<at::Tensor>& tensors,
const AllreduceCoalescedOptions& opts = AllreduceCoalescedOptions()) {
static auto op = c10::Dispatcher::singleton()
.findSchemaOrThrow("c10d::allreduce_coalesced_", "")
.typed<c10::intrusive_ptr<::c10d::Work>(
at::TensorList,
const c10::intrusive_ptr<::c10d::ProcessGroup>&,
const c10::intrusive_ptr<::c10d::ReduceOp>&,
int64_t)>();
return op.call(
tensors,
c10::intrusive_ptr<ProcessGroup>::unsafe_reclaim_from_nonowning(this),
c10::make_intrusive<ReduceOp>(opts.reduceOp),
opts.timeout.count());
}
virtual c10::intrusive_ptr<Work> reduce(
std::vector<at::Tensor>& tensors,
const ReduceOptions& opts = ReduceOptions()) {
static auto op = c10::Dispatcher::singleton()
.findSchemaOrThrow("c10d::reduce_", "")
.typed<c10::intrusive_ptr<::c10d::Work>(
at::TensorList,
const c10::intrusive_ptr<::c10d::ProcessGroup>&,
const c10::intrusive_ptr<::c10d::ReduceOp>&,
int64_t,
int64_t,
int64_t)>();
return op.call(
tensors,
c10::intrusive_ptr<ProcessGroup>::unsafe_reclaim_from_nonowning(this),
c10::make_intrusive<ReduceOp>(opts.reduceOp),
opts.rootRank,
opts.rootTensor,
opts.timeout.count());
}
virtual c10::intrusive_ptr<Work> allgather(
std::vector<std::vector<at::Tensor>>& outputTensors,
std::vector<at::Tensor>& inputTensors,
const AllgatherOptions& opts = AllgatherOptions()) {
static auto op = c10::Dispatcher::singleton()
.findSchemaOrThrow("c10d::allgather_", "")
.typed<std::tuple<
std::vector<std::vector<at::Tensor>>,
c10::intrusive_ptr<Work>>(
const std::vector<std::vector<at::Tensor>>&,
at::TensorList,
const c10::intrusive_ptr<::c10d::ProcessGroup>&,
int64_t)>();
return std::get<1>(op.call(
outputTensors,
inputTensors,
c10::intrusive_ptr<ProcessGroup>::unsafe_reclaim_from_nonowning(this),
opts.timeout.count()));
}
// Gathers a single tensor inputBuffer into a single buffer outputBuffer that
// is interpreted as a contiguous collection of size inputBuffer * WORLD_SIZE.
// For implementers of ProcessGroup API and advanced users only.
// Note: this function will be deprecated in near future.
virtual c10::intrusive_ptr<Work> _allgather_base(
at::Tensor& outputBuffer,
at::Tensor& inputBuffer,
const AllgatherOptions& opts = AllgatherOptions()) {
static auto op =
c10::Dispatcher::singleton()
.findSchemaOrThrow("c10d::_allgather_base_", "")
.typed<std::tuple<at::Tensor, c10::intrusive_ptr<Work>>(
at::Tensor&,
at::Tensor&,
const c10::intrusive_ptr<::c10d::ProcessGroup>&)>();
return std::get<1>(op.call(
outputBuffer,
inputBuffer,
c10::intrusive_ptr<ProcessGroup>::unsafe_reclaim_from_nonowning(this)));
}
// This function is deprecated and will be moved out of ProcessGroup to comms:
// * do not add dependencies on this function,
// * do not implement it in your ProcessGroup, implement _allgather_base
// instead.
virtual c10::intrusive_ptr<Work> allgather_coalesced(
std::vector<std::vector<at::Tensor>>& outputTensorLists,
std::vector<at::Tensor>& inputTensors,
const AllgatherOptions& opts = AllgatherOptions()) {
static auto op =
c10::Dispatcher::singleton()
.findSchemaOrThrow("c10d::allgather_coalesced_", "")
.typed<c10::intrusive_ptr<Work>(
const std::vector<std::vector<at::Tensor>>&,
const at::TensorList&,
const c10::intrusive_ptr<::c10d::ProcessGroup>&)>();
return op.call(
outputTensorLists,
inputTensors,
c10::intrusive_ptr<ProcessGroup>::unsafe_reclaim_from_nonowning(this));
}
// This function is a coalesced version of `allgather_into_tensor` (currently
// still named as `_allgather_base`). Each tensor in the vector corresponds to
// an input/output of one `allgather_into_tensor` operation.
virtual c10::intrusive_ptr<Work> allgather_into_tensor_coalesced(
std::vector<at::Tensor>& outputTensors,
std::vector<at::Tensor>& inputTensors,
const AllgatherOptions& opts = AllgatherOptions()) {
static auto op =
c10::Dispatcher::singleton()
.findSchemaOrThrow("c10d::allgather_into_tensor_coalesced_", "")
.typed<c10::intrusive_ptr<Work>(
const at::TensorList,
const at::TensorList,
const c10::intrusive_ptr<::c10d::ProcessGroup>&)>();
return op.call(
outputTensors,
inputTensors,
c10::intrusive_ptr<ProcessGroup>::unsafe_reclaim_from_nonowning(this));
}
virtual c10::intrusive_ptr<Work> gather(
std::vector<std::vector<at::Tensor>>& outputTensors,
std::vector<at::Tensor>& inputTensors,
const GatherOptions& opts = GatherOptions()) {
static auto op = c10::Dispatcher::singleton()
.findSchemaOrThrow("c10d::gather_", "")
.typed<c10::intrusive_ptr<::c10d::Work>(
const std::vector<std::vector<at::Tensor>>&,
const at::TensorList&,
const c10::intrusive_ptr<::c10d::ProcessGroup>&,
int64_t,
int64_t)>();
return op.call(
outputTensors,
inputTensors,
c10::intrusive_ptr<ProcessGroup>::unsafe_reclaim_from_nonowning(this),
opts.rootRank,
opts.timeout.count());
}
virtual c10::intrusive_ptr<Work> scatter(
std::vector<at::Tensor>& outputTensors,
std::vector<std::vector<at::Tensor>>& inputTensors,
const ScatterOptions& opts = ScatterOptions()) {
static auto op =
c10::Dispatcher::singleton()
.findSchemaOrThrow("c10d::scatter_", "")
.typed<
std::tuple<std::vector<at::Tensor>, c10::intrusive_ptr<Work>>(
const at::TensorList&,
const std::vector<std::vector<at::Tensor>>&,
const c10::intrusive_ptr<::c10d::ProcessGroup>&,
int64_t,
int64_t)>();
return std::get<1>(op.call(
outputTensors,
inputTensors,
c10::intrusive_ptr<ProcessGroup>::unsafe_reclaim_from_nonowning(this),
opts.rootRank,
opts.timeout.count()));
}
virtual c10::intrusive_ptr<Work> reduce_scatter(
std::vector<at::Tensor>& outputTensors,
std::vector<std::vector<at::Tensor>>& inputTensors,
const ReduceScatterOptions& opts = ReduceScatterOptions()) {
static auto op =
c10::Dispatcher::singleton()
.findSchemaOrThrow("c10d::reduce_scatter_", "")
.typed<
std::tuple<std::vector<at::Tensor>, c10::intrusive_ptr<Work>>(
const at::TensorList&,
const std::vector<std::vector<at::Tensor>>&,
const c10::intrusive_ptr<::c10d::ProcessGroup>&,
const c10::intrusive_ptr<::c10d::ReduceOp>&,
int64_t)>();
return std::get<1>(op.call(
outputTensors,
inputTensors,
c10::intrusive_ptr<ProcessGroup>::unsafe_reclaim_from_nonowning(this),
c10::make_intrusive<::c10d::ReduceOp>(opts.reduceOp),
opts.timeout.count()));
}
virtual c10::intrusive_ptr<Work> _reduce_scatter_base(
at::Tensor& outputBuffer,
at::Tensor& inputBuffer,
const ReduceScatterOptions& opts = ReduceScatterOptions()) {
static auto op =
c10::Dispatcher::singleton()
.findSchemaOrThrow("c10d::_reduce_scatter_base_", "")
.typed<std::tuple<at::Tensor, c10::intrusive_ptr<Work>>(
at::Tensor&,
at::Tensor&,
const c10::intrusive_ptr<::c10d::ProcessGroup>&,
const c10::intrusive_ptr<::c10d::ReduceOp>&,
int64_t)>();
return std::get<1>(op.call(
outputBuffer,
inputBuffer,
c10::intrusive_ptr<ProcessGroup>::unsafe_reclaim_from_nonowning(this),
c10::make_intrusive<::c10d::ReduceOp>(opts.reduceOp),
opts.timeout.count()));
}
// This function is a coalesced version of `reduce_scatter_tensor` (currently
// still named as `_reduce_scatter_base`). Each tensor in the vector
// corresponds to an input/output of one `reduce_scatter_tensor` operation.
virtual c10::intrusive_ptr<Work> reduce_scatter_tensor_coalesced(
std::vector<at::Tensor>& outputTensors,
std::vector<at::Tensor>& inputTensors,
const ReduceScatterOptions& opts = ReduceScatterOptions()) {
static auto op =
c10::Dispatcher::singleton()
.findSchemaOrThrow("c10d::reduce_scatter_tensor_coalesced_", "")
.typed<c10::intrusive_ptr<Work>(
const at::TensorList,
const at::TensorList,
const c10::intrusive_ptr<::c10d::ProcessGroup>&,
const c10::intrusive_ptr<::c10d::ReduceOp>&,
int64_t)>();
return op.call(
outputTensors,
inputTensors,
c10::intrusive_ptr<ProcessGroup>::unsafe_reclaim_from_nonowning(this),
c10::make_intrusive<::c10d::ReduceOp>(opts.reduceOp),
opts.timeout.count());
}
virtual c10::intrusive_ptr<Work> alltoall_base(
at::Tensor& outputBuffer,
at::Tensor& inputBuffer,
std::vector<int64_t>& outputSplitSizes,
std::vector<int64_t>& inputSplitSizes,
const AllToAllOptions& opts = AllToAllOptions()) {
static auto op = c10::Dispatcher::singleton()
.findSchemaOrThrow("c10d::alltoall_base_", "")
.typed<c10::intrusive_ptr<::c10d::Work>(
at::Tensor&,
at::Tensor&,
const c10::intrusive_ptr<::c10d::ProcessGroup>&,
std::vector<int64_t>,
std::vector<int64_t>,
int64_t)>();
return op.call(
outputBuffer,
inputBuffer,
c10::intrusive_ptr<ProcessGroup>::unsafe_reclaim_from_nonowning(this),
outputSplitSizes,
inputSplitSizes,
opts.timeout.count());
}
virtual c10::intrusive_ptr<Work> alltoall(
std::vector<at::Tensor>& outputTensors,
std::vector<at::Tensor>& inputTensors,
const AllToAllOptions& opts = AllToAllOptions()) {
static auto op =
c10::Dispatcher::singleton()
.findSchemaOrThrow("c10d::alltoall_", "")
.typed<
std::tuple<std::vector<at::Tensor>, c10::intrusive_ptr<Work>>(
const at::TensorList&,
const at::TensorList&,
const c10::intrusive_ptr<::c10d::ProcessGroup>&,
int64_t)>();
return std::get<1>(op.call(
outputTensors,
inputTensors,
c10::intrusive_ptr<ProcessGroup>::unsafe_reclaim_from_nonowning(this),
opts.timeout.count()));
}
virtual void monitoredBarrier(
const BarrierOptions& opts,
bool wait_all_ranks = false) {
static auto op = c10::Dispatcher::singleton()
.findSchemaOrThrow("c10d::monitored_barrier_", "")
.typed<void(
at::Tensor,
const c10::intrusive_ptr<::c10d::ProcessGroup>&,
const std::vector<int64_t>&,
int64_t,
bool)>();
// Default to using cpu implementation, monitored barrier is only for GLOO
at::Tensor tensor = at::empty({0}, at::TensorOptions().device(at::kCPU));
op.call(
tensor,
c10::intrusive_ptr<ProcessGroup>::unsafe_reclaim_from_nonowning(this),
opts.device_ids,
opts.timeout.count(),
wait_all_ranks);
}
// Agrees on an initial sequence number for the whole group by having rank 0
// create it and broadcast it to other ranks using the store. Only implemented
// for GLOO and NCCL backends currently.
virtual void setSequenceNumberForGroup() {
auto backendType = getBackendType();
// TODO: HACK for backend name to get sequence number for that backend.
if (backendType == ProcessGroup::BackendType::GLOO ||
backendType == ProcessGroup::BackendType::NCCL ||
backendType == ProcessGroup::BackendType::UCC) {
getDefaultBackend()->setSequenceNumberForGroup();
} else {
TORCH_CHECK(
false,
c10::str(
"ProcessGroup ",
getBackendName(),
" does not yet support sequence numbers."));
}
}
// Retrieves the current sequence number for the whole group, which should be
// in sync. If the returned number is not consistent across the group, it
// may indicate that there is some sort of collective desynchronization.
virtual uint64_t getSequenceNumberForGroup() {
auto backendType = getBackendType();
// TODO: HACK for backend name to get sequence number for that backend.
if (backendType == ProcessGroup::BackendType::GLOO ||
backendType == ProcessGroup::BackendType::NCCL ||
backendType == ProcessGroup::BackendType::UCC) {
return getDefaultBackend()->getSequenceNumberForGroup();
} else {
TORCH_CHECK(
false,
c10::str(
"ProcessGroup ",
getBackendName(),
" does not yet support sequence numbers."));
}
}
virtual c10::intrusive_ptr<Work> send(
std::vector<at::Tensor>& tensors,
int dstRank,
int tag) {
static auto op = c10::Dispatcher::singleton()
.findSchemaOrThrow("c10d::send", "")
.typed<c10::intrusive_ptr<::c10d::Work>(
at::TensorList,
const c10::intrusive_ptr<::c10d::ProcessGroup>&,
int64_t,
int64_t)>();
return op.call(
tensors,
c10::intrusive_ptr<ProcessGroup>::unsafe_reclaim_from_nonowning(this),
dstRank,
tag);
}
virtual c10::intrusive_ptr<Work> recv(
std::vector<at::Tensor>& tensors,
int srcRank,
int tag) {
static auto op = c10::Dispatcher::singleton()
.findSchemaOrThrow("c10d::recv_", "")
.typed<c10::intrusive_ptr<::c10d::Work>(
at::TensorList,
const c10::intrusive_ptr<::c10d::ProcessGroup>&,
int64_t,
int64_t)>();
return op.call(
tensors,
c10::intrusive_ptr<ProcessGroup>::unsafe_reclaim_from_nonowning(this),
srcRank,
tag);
}
virtual c10::intrusive_ptr<Work> recvAnysource(
std::vector<at::Tensor>& tensors,
int tag) {
static auto op = c10::Dispatcher::singleton()
.findSchemaOrThrow("c10d::recv_any_source_", "")
.typed<c10::intrusive_ptr<::c10d::Work>(
at::TensorList,
const c10::intrusive_ptr<::c10d::ProcessGroup>&,
int64_t)>();
return op.call(
tensors,
c10::intrusive_ptr<ProcessGroup>::unsafe_reclaim_from_nonowning(this),
tag);
}
virtual c10::intrusive_ptr<Work> barrier(
const BarrierOptions& opts = BarrierOptions()) {
static at::Tensor tensor;
// TODO: if nccl was specified then use it
auto device = opts.device;
if (device.has_value()) {
// set device tensor from argument
tensor = at::empty(
{1}, at::TensorOptions().device(device.value()).dtype(at::kByte));
} else if (backendType_ == c10d::ProcessGroup::BackendType::NCCL) {
// set cuda tensor
tensor = at::empty(
{1},
at::TensorOptions().device(at::DeviceType::CUDA).dtype(at::kByte));
} else {
// Default to using cpu implementation
tensor = at::empty(
{1},
at::TensorOptions().device(at::DeviceType::CPU).dtype(at::kByte));
}
static auto op = c10::Dispatcher::singleton()
.findSchemaOrThrow("c10d::barrier", "")
.typed<c10::intrusive_ptr<::c10d::Work>(
at::Tensor,
const c10::intrusive_ptr<::c10d::ProcessGroup>&,
const std::vector<int64_t>&,
int64_t)>();
return op.call(
tensor,
c10::intrusive_ptr<ProcessGroup>::unsafe_reclaim_from_nonowning(this),
opts.device_ids,
opts.timeout.count());
}
c10::intrusive_ptr<Options> getOptions() {
return options_;
}
bool hasBackends() {
return !deviceTypeToBackendType_.empty();
}
void setBackend(
c10::DeviceType deviceType,
BackendType backendType,
const c10::optional<c10::intrusive_ptr<Backend>>& backend) {
// TODO: should we add these entries after the backend setting succeeds?
deviceTypeToBackendType_[deviceType] = backendType;
deviceTypes_.insert(deviceType);
// if the backendType is already set then reuse it for this device
if (backendTypeToBackend_.find(backendType) !=
backendTypeToBackend_.end()) {
auto existingBackend = backendTypeToBackend_.at(backendType);
deviceTypeToBackend_[deviceType] = existingBackend;
} else {
// check if backend has value
if (backend.has_value()) {
deviceTypeToBackend_[deviceType] = backend.value();
backendTypeToBackend_[backendType] = backend.value();
}
}
}
c10::intrusive_ptr<Backend> getDefaultBackend() const {
TORCH_CHECK(
backendTypeToBackend_.find(backendType_) != backendTypeToBackend_.end(),
"Could not find the default backend type ",
backendType_,
" for Process Group with name ",
getBackendName(),
".");
return backendTypeToBackend_.at(backendType_);
}
c10::intrusive_ptr<Backend> getBackend(c10::DeviceType deviceType);
c10::intrusive_ptr<Backend> getBackend(BackendType backendType) const {
TORCH_CHECK(
backendTypeToBackend_.find(backendType) != backendTypeToBackend_.end(),
"Could not find backend type ",
backendType,
".");
return backendTypeToBackend_.at(backendType);
}
// Return device types supported by this ProcessGroup.
// Note: the return type is `Device` rather than `DeviceType` for the purpose
// of easy comparison at Python level. The `Device` will have default index
// (-1).
std::vector<c10::Device> getDeviceTypes() const {
std::vector<c10::Device> devices;
devices.reserve(deviceTypes_.size());
for (auto& dt : deviceTypes_) {
devices.push_back(c10::Device(dt));
}
return devices;
}
void registerOnCompletionHook(
std::function<void(std::shared_ptr<WorkInfo>)>&& hook) {
getDefaultBackend()->registerOnCompletionHook(std::move(hook));
}
void waitForPendingWorks() {
getDefaultBackend()->waitForPendingWorks();
}
bool hasHooks() const {
return getDefaultBackend()->hasHooks();
}
protected:
// Implementations of this interface need to call this to setup
// appropriate logging etc.
void init();
const c10::intrusive_ptr<c10d::Store> store_;
const int rank_;
const int size_;
const c10::intrusive_ptr<Options> options_;
const BackendType backendType_;
// Optional sequence number structure for matching collectives.
c10::optional<c10d::SequenceNum> sequenceNum_ = c10::nullopt;
// Debug level setting. It is parsed once when ProcessGroup is constructed and
// remains the same across use of this process group.
DebugLevel dist_debug_level_;
// Backend classes for this ProcessGroup
std::unordered_set<c10::DeviceType> deviceTypes_;
std::unordered_map<c10::DeviceType, BackendType> deviceTypeToBackendType_;
std::unordered_map<c10::DeviceType, c10::intrusive_ptr<Backend>>
deviceTypeToBackend_;
std::unordered_map<BackendType, c10::intrusive_ptr<Backend>>
backendTypeToBackend_;
};
} // namespace c10d