blob: 8b324b58e1e0579d8ce862294d15f04b9b1dc6e1 [file] [log] [blame]
#pragma once
#include <c10d/ProcessGroup.hpp>
#include <ATen/ATen.h>
#include <c10/util/Optional.h>
#include <cstddef>
#include <memory>
#include <tuple>
#include <vector>
namespace c10d {
std::vector<std::vector<at::Tensor>> bucketTensors(
std::vector<at::Tensor>& tensors,
int64_t bucketSize,
bool fineGrained = false);
void distBroadcastCoalesced(
ProcessGroup& processGroup,
std::vector<at::Tensor>& tensors,
int64_t bufferSize,
bool fineGrained = false);
void syncParams(
ProcessGroup& processGroup,
std::vector<std::vector<at::Tensor>>& parameterData,
std::vector<std::vector<at::Tensor>>& bufferData,
const std::vector<int64_t>& devices,
int64_t broadcastBucketSize,
bool broadcastBuffers);
std::tuple<std::shared_ptr<ProcessGroup::Work>, at::Tensor> queueReduction(
ProcessGroup& processGroup,
std::vector<std::vector<at::Tensor>>& gradsBatch,
const std::vector<int64_t>& devices);
void syncReduction(
std::shared_ptr<ProcessGroup::Work>& reductionWork,
std::vector<at::Tensor>& gradsBatch,
at::Tensor& gradsBatchCoalesced);
} // namespace c10d