Change rank type: int -> std::uint32_t; Minor fixes
diff --git a/torch/lib/THD/CMakeLists.txt b/torch/lib/THD/CMakeLists.txt
index 30cf269..01d5aa1 100644
--- a/torch/lib/THD/CMakeLists.txt
+++ b/torch/lib/THD/CMakeLists.txt
@@ -72,7 +72,7 @@
FILE(GLOB_RECURSE test_cpp "test/*.cpp")
IF(NOT MPI_FOUND)
- LIST(REMOVE_ITEM base_cpp "${CMAKE_CURRENT_SOURCE_DIR}/base/channels/data_channel/DataChannelMPI.cpp")
+ LIST(REMOVE_ITEM base_cpp "${CMAKE_CURRENT_SOURCE_DIR}/base/data_channels/DataChannelMPI.cpp")
LIST(REMOVE_ITEM test_cpp "${CMAKE_CURRENT_SOURCE_DIR}/test/data_channel_mpi_smoke.cpp")
ENDIF()
diff --git a/torch/lib/THD/base/ChannelUtils.cpp b/torch/lib/THD/base/ChannelUtils.cpp
index de7c83c..f24c8a3 100644
--- a/torch/lib/THD/base/ChannelUtils.cpp
+++ b/torch/lib/THD/base/ChannelUtils.cpp
@@ -7,6 +7,7 @@
#include <netinet/tcp.h>
#include <sys/poll.h>
#include <unistd.h>
+#include <cstring>
#include <memory>
#include <string>
#include <thread>
@@ -34,10 +35,10 @@
} // anonymous namespace
-std::tuple<int, std::uint16_t> listen(std::uint16_t port) {
+std::tuple<int, port_type> listen(port_type port) {
struct addrinfo hints, *res = NULL;
- memset(&hints, 0x00, sizeof(hints));
+ std::memset(&hints, 0x00, sizeof(hints));
hints.ai_flags = AI_PASSIVE;
hints.ai_family = AF_UNSPEC; // either IPv4 or IPv6
hints.ai_socktype = SOCK_STREAM; // TCP
@@ -80,16 +81,16 @@
struct sockaddr_in addr;
socklen_t addr_len = sizeof(addr);
SYSCHECK(::getsockname(socket, reinterpret_cast<struct sockaddr*>(&addr), &addr_len))
- std::uint16_t listen_port = ntohs(addr.sin_port);
+ port_type listen_port = ntohs(addr.sin_port);
return std::make_tuple(socket, listen_port);
}
-int connect(const std::string& address, std::uint16_t port, bool wait) {
+int connect(const std::string& address, port_type port, bool wait) {
struct addrinfo hints, *res = NULL;
- memset(&hints, 0x00, sizeof(hints));
+ std::memset(&hints, 0x00, sizeof(hints));
hints.ai_flags = AI_NUMERICSERV; // specifies that port (service) is numeric
hints.ai_family = AF_UNSPEC; // either IPv4 or IPv6
hints.ai_socktype = SOCK_STREAM; // TCP
@@ -179,10 +180,10 @@
return std::make_tuple(socket, std::string(address));
}
-std::tuple<std::uint16_t, std::uint32_t> load_master_env() {
- std::uint16_t port = convertToPort(std::stoul(getenv(MASTER_PORT_ENV)));
+std::tuple<port_type, rank_type> load_master_env() {
+ auto port = convertToPort(std::stoul(getenv(MASTER_PORT_ENV)));
- std::uint32_t world_size = std::stoul(getenv(WORLD_SIZE_ENV));
+ rank_type world_size = std::stoul(getenv(WORLD_SIZE_ENV));
if (world_size == 0)
throw std::domain_error(std::string(WORLD_SIZE_ENV) + " env variable cannot be 0");
@@ -190,18 +191,18 @@
}
-std::tuple<std::string, std::uint16_t> load_worker_env() {
+std::tuple<std::string, port_type> load_worker_env() {
std::string full_address = std::string(getenv(MASTER_ADDR_ENV));
auto found_pos = full_address.rfind(":");
if (found_pos == std::string::npos)
throw std::domain_error("invalid master address, usage: IP:PORT | HOSTNAME:PORT");
std::string str_port = full_address.substr(found_pos + 1);
- std::uint16_t port = convertToPort(std::stoul(str_port));
+ auto port = convertToPort(std::stoul(str_port));
return std::make_tuple(full_address.substr(0, found_pos), port);
}
-std::uint32_t load_rank_env() {
+rank_type load_rank_env() {
return convertToRank(std::stol(getenv(RANK_ENV)));
}
diff --git a/torch/lib/THD/base/ChannelUtils.hpp b/torch/lib/THD/base/ChannelUtils.hpp
index 5e4bf95..de412ef 100644
--- a/torch/lib/THD/base/ChannelUtils.hpp
+++ b/torch/lib/THD/base/ChannelUtils.hpp
@@ -4,19 +4,23 @@
#include <sys/types.h>
#include <cstdlib>
#include <cstdint>
+#include <limits>
#include <string>
#include <system_error>
#include <tuple>
namespace thd {
+using rank_type = std::uint32_t;
+using port_type = std::uint16_t;
+
#define SYSCHECK(expr) { \
errno = 0; (expr); \
if (errno != 0) throw std::system_error(errno, std::system_category()); \
}
template<typename T>
-void send_bytes(int socket, const T* buffer, std::size_t length)
+void send_bytes(int socket, const T* buffer, std::size_t length, bool more_data = false)
{
std::size_t bytes_to_send = sizeof(T) * length;
if (bytes_to_send == 0)
@@ -25,9 +29,16 @@
auto bytes = reinterpret_cast<const std::uint8_t*>(buffer);
std::uint8_t *current_bytes = const_cast<std::uint8_t*>(bytes);
+ int flags = 0;
+#ifdef MSG_MORE
+ if (more_data) { // there is more data to send
+ flags |= MSG_MORE;
+ }
+#endif
+
while (bytes_to_send > 0) {
ssize_t bytes_sent;
- SYSCHECK(bytes_sent = ::send(socket, current_bytes, bytes_to_send, 0))
+ SYSCHECK(bytes_sent = ::send(socket, current_bytes, bytes_to_send, flags))
if (bytes_sent == 0)
throw std::system_error(EBADMSG, std::system_category());
@@ -58,26 +69,26 @@
}
}
-inline std::uint16_t convertToPort(long port) {
- if ((port < 0) || (port >= UINT16_MAX))
+inline port_type convertToPort(long port) {
+ if ((port < 0) || (port >= std::numeric_limits<port_type>::max()))
throw std::domain_error("invalid port (value out of range)");
- return static_cast<std::uint16_t>(port);
+ return static_cast<port_type>(port);
}
-inline std::uint32_t convertToRank(long rank, long min = 0) {
- if ((rank < min) || (rank >= UINT32_MAX))
+inline rank_type convertToRank(long rank, long min = 0) {
+ if ((rank < min) || (rank >= std::numeric_limits<rank_type>::max()))
throw std::domain_error("invalid rank (value out of range)");
- return static_cast<std::uint32_t>(rank);
+ return static_cast<rank_type>(rank);
}
-std::tuple<int, std::uint16_t> listen(std::uint16_t port = 0);
-int connect(const std::string& address, std::uint16_t port, bool wait = true);
+std::tuple<int, port_type> listen(port_type port = 0);
+int connect(const std::string& address, port_type port, bool wait = true);
std::tuple<int, std::string> accept(int listen_socket, int timeout = -1);
-std::tuple<std::uint16_t, std::uint32_t> load_master_env();
-std::tuple<std::string, std::uint16_t> load_worker_env();
-std::uint32_t load_rank_env();
+std::tuple<port_type, rank_type> load_master_env();
+std::tuple<std::string, port_type> load_worker_env();
+rank_type load_rank_env();
} // namespace thd
diff --git a/torch/lib/THD/base/DataChannel.cpp b/torch/lib/THD/base/DataChannel.cpp
index 3790fde..a6a9c0a 100644
--- a/torch/lib/THD/base/DataChannel.cpp
+++ b/torch/lib/THD/base/DataChannel.cpp
@@ -31,7 +31,7 @@
throw std::logic_error("cannot create empty group");
sort(ranks.begin(), ranks.end());
- if (ranks.front() < 0 || ranks.back() > max_rank) {
+ if (ranks.back() > max_rank) {
throw std::out_of_range(
"array of ranks contains invalid rank, "
"all ranks should be in range: [0, " + std::to_string(max_rank) + "]"
@@ -96,7 +96,7 @@
auto DataChannel::Group::getGlobalRank(rank_type group_rank) const -> std::pair<rank_type, bool> {
- if (group_rank < 0 || group_rank >= _new2old.size())
+ if (group_rank >= _new2old.size())
return std::make_pair(0, false);
return std::make_pair(_new2old[group_rank], true);
diff --git a/torch/lib/THD/base/DataChannel.hpp b/torch/lib/THD/base/DataChannel.hpp
index a925f26..eb8f0f7 100644
--- a/torch/lib/THD/base/DataChannel.hpp
+++ b/torch/lib/THD/base/DataChannel.hpp
@@ -1,6 +1,7 @@
#pragma once
#include "ChannelType.h"
+#include "ChannelUtils.hpp"
#include "DataChannel.h"
#include "Scalar.hpp"
@@ -13,6 +14,7 @@
namespace thd {
struct DataChannel {
+
struct Request {
Request() {};
virtual ~Request() {};
@@ -28,39 +30,37 @@
virtual bool init() = 0;
- virtual int getRank() = 0;
- virtual int getNumProcesses() = 0;
+ virtual rank_type getRank() = 0;
+ virtual rank_type getNumProcesses() = 0;
virtual void allGather(std::vector<thpp::Tensor*>& output, thpp::Tensor& input,
THDGroup group_id = THDGroupWORLD) = 0;
virtual void gather(std::vector<thpp::Tensor*>& output, thpp::Tensor& input,
- int dst_rank, THDGroup group_id = THDGroupWORLD) = 0;
+ rank_type dst_rank, THDGroup group_id = THDGroupWORLD) = 0;
virtual void scatter(std::vector<thpp::Tensor*>& input, thpp::Tensor& output,
- int src_rank, THDGroup group_id = THDGroupWORLD) = 0;
+ rank_type src_rank, THDGroup group_id = THDGroupWORLD) = 0;
virtual void allReduce(thpp::Tensor& data, THDReduceOp operation,
THDGroup group_id = THDGroupWORLD) = 0;
virtual void reduce(thpp::Tensor& data, THDReduceOp operation,
- int dst_rank, THDGroup group_id = THDGroupWORLD) = 0;
- virtual void broadcast(thpp::Tensor& data, int src_rank,
+ rank_type dst_rank, THDGroup group_id = THDGroupWORLD) = 0;
+ virtual void broadcast(thpp::Tensor& data, rank_type src_rank,
THDGroup group_id = THDGroupWORLD) = 0;
- virtual void send(const Scalar& value, int src_rank) = 0;
- virtual void send(thpp::Tensor& data, int dst_rank) = 0;
- virtual void receive(Scalar& value, int src_rank) = 0;
+ virtual void send(const Scalar& value, rank_type src_rank) = 0;
+ virtual void send(thpp::Tensor& data, rank_type dst_rank) = 0;
+ virtual void receive(Scalar& value, rank_type src_rank) = 0;
virtual void receive(thpp::Tensor& data) = 0; // receive from any source
- virtual void receive(thpp::Tensor& data, int src_rank) = 0;
- virtual Request* isend(thpp::Tensor& data, int dst_rank) = 0;
- virtual Request* ireceive(thpp::Tensor& data, int src_rank) = 0;
+ virtual void receive(thpp::Tensor& data, rank_type src_rank) = 0;
+ virtual Request* isend(thpp::Tensor& data, rank_type dst_rank) = 0;
+ virtual Request* ireceive(thpp::Tensor& data, rank_type src_rank) = 0;
virtual void barrier(THDGroup group_id = THDGroupWORLD) = 0;
- virtual THDGroup newGroup(const std::vector<int>& ranks) = 0;
+ virtual THDGroup newGroup(const std::vector<rank_type>& ranks) = 0;
static DataChannel* newChannel(THDChannelType type);
protected:
struct Group {
- using rank_type = int;
-
Group();
/*
* Constructs `Group` from provided `ranks` and checks if all ranks are
diff --git a/torch/lib/THD/base/data_channels/DataChannelMPI.cpp b/torch/lib/THD/base/data_channels/DataChannelMPI.cpp
index fb6c99e..3d9d0ec 100644
--- a/torch/lib/THD/base/data_channels/DataChannelMPI.cpp
+++ b/torch/lib/THD/base/data_channels/DataChannelMPI.cpp
@@ -112,12 +112,16 @@
bool DataChannelMPI::init() {
MPI_Init(NULL, NULL);
- MPI_Comm_size(MPI_COMM_WORLD, &_num_processes);
- MPI_Comm_rank(MPI_COMM_WORLD, &_rank);
+ int rank, num_processes;
+ MPI_Comm_size(MPI_COMM_WORLD, &num_processes);
+ MPI_Comm_rank(MPI_COMM_WORLD, &rank);
- std::vector<int> ranks;
+ _rank = convertToRank(rank);
+ _num_processes = convertToRank(num_processes);
+
+ std::vector<rank_type> ranks;
ranks.reserve(_num_processes);
- for (size_t rank = 0; rank < _num_processes; ++rank)
+ for (rank_type rank = 0; rank < _num_processes; ++rank)
ranks.push_back(rank);
_groups.insert({
@@ -128,12 +132,12 @@
}
-int DataChannelMPI::getRank() {
+rank_type DataChannelMPI::getRank() {
return _rank;
}
-int DataChannelMPI::getNumProcesses() {
+rank_type DataChannelMPI::getNumProcesses() {
return _num_processes;
}
@@ -167,7 +171,7 @@
void DataChannelMPI::gather(std::vector<thpp::Tensor*>& output,
- thpp::Tensor& input, int dst_rank,
+ thpp::Tensor& input, rank_type dst_rank,
THDGroup group_id) {
/*
* Output vector size is 0 for _rank != dst_rank.
@@ -189,7 +193,7 @@
assertTensorEqual(*out_tensor, input, "gather");
}
- auto group_dst_rank = group_pair.second.mustGetGroupRank(dst_rank);
+ rank_type group_dst_rank = group_pair.second.mustGetGroupRank(dst_rank);
std::uint64_t tensor_bytes = input.elementSize() * input.numel();
std::uint64_t all_tensors_bytes = tensor_bytes * output.size();
std::unique_ptr<std::uint8_t[]> tmp_data(new std::uint8_t[all_tensors_bytes]);
@@ -207,7 +211,7 @@
void DataChannelMPI::scatter(std::vector<thpp::Tensor*>& input,
thpp::Tensor& output,
- int src_rank, THDGroup group_id) {
+ rank_type src_rank, THDGroup group_id) {
/*
* Input vector size is 0 for _rank != dst_rank.
*/
@@ -228,7 +232,7 @@
assertTensorEqual(*in_tensor, output, "scatter");
}
- auto group_src_rank = group_pair.second.mustGetGroupRank(src_rank);
+ rank_type group_src_rank = group_pair.second.mustGetGroupRank(src_rank);
std::uint64_t tensor_bytes = output.elementSize() * output.numel();
std::uint64_t all_tensors_bytes = tensor_bytes * input.size();
std::unique_ptr<std::uint8_t[]> tmp_data(new std::uint8_t[all_tensors_bytes]);
@@ -259,8 +263,8 @@
}
-void DataChannelMPI::reduce(thpp::Tensor& data, THDReduceOp operation, int dst_rank,
- THDGroup group_id) {
+void DataChannelMPI::reduce(thpp::Tensor& data, THDReduceOp operation,
+ rank_type dst_rank, THDGroup group_id) {
const auto& group_pair = _groups.at(group_id);
const auto& comm = group_pair.first;
if (comm == MPI_COMM_NULL)
@@ -279,7 +283,7 @@
}
-void DataChannelMPI::_broadcastPack(thpp::Tensor& data, int src_rank,
+void DataChannelMPI::_broadcastPack(thpp::Tensor& data, rank_type src_rank,
MPI_Comm comm) const {
std::uint64_t tensor_bytes = data.elementSize() * data.numel();
MPI_Bcast(&tensor_bytes, 1, MPI_UINT64_T, src_rank, comm);
@@ -288,31 +292,33 @@
}
-void DataChannelMPI::_broadcastUnpack(thpp::Tensor& data, int src_rank,
+void DataChannelMPI::_broadcastUnpack(thpp::Tensor& data, rank_type src_rank,
MPI_Comm comm) const {
std::uint64_t tensor_bytes;
MPI_Bcast(&tensor_bytes, 1, MPI_UINT64_T, src_rank, comm);
- std::unique_ptr<std::uint8_t[]> bytes(new std::uint8_t[tensor_bytes]);
- MPI_Bcast(bytes.get(), tensor_bytes, MPI_UINT8_T, src_rank, comm);
-
std::uint64_t actual_tensor_bytes = data.elementSize() * data.numel();
- if (actual_tensor_bytes != tensor_bytes) {
+ if (actual_tensor_bytes == tensor_bytes) {
+ MPI_Bcast(data.data(), tensor_bytes, MPI_UINT8_T, src_rank, comm);
+ } else {
+ // receive invalid data
+ std::unique_ptr<std::uint8_t[]> bytes(new std::uint8_t[tensor_bytes]);
+ MPI_Bcast(bytes.get(), tensor_bytes, MPI_UINT8_T, src_rank, comm);
throw std::logic_error("tensor sizes does not match");
}
- std::memcpy(data.data(), bytes.get(), tensor_bytes);
+
}
-void DataChannelMPI::broadcast(thpp::Tensor& data, int src_rank,
+void DataChannelMPI::broadcast(thpp::Tensor& data, rank_type src_rank,
THDGroup group_id) {
const auto& group_pair = _groups.at(group_id);
const auto& comm = group_pair.first;
if (comm == MPI_COMM_NULL)
return;
- auto group_src_rank = group_pair.second.mustGetGroupRank(src_rank);
+ rank_type group_src_rank = group_pair.second.mustGetGroupRank(src_rank);
if (src_rank == _rank) {
_broadcastPack(data, group_src_rank, comm);
} else {
@@ -321,7 +327,7 @@
}
-void DataChannelMPI::send(const Scalar& data, int dst_rank) {
+void DataChannelMPI::send(const Scalar& data, rank_type dst_rank) {
std::uint64_t scalar_bytes = data.elementSize();
MPI_Send(&scalar_bytes, 1, MPI_UINT64_T, dst_rank, 0, MPI_COMM_WORLD);
MPI_Send(reinterpret_cast<const std::uint8_t*>(data.data()), scalar_bytes,
@@ -329,7 +335,7 @@
}
-void DataChannelMPI::send(thpp::Tensor& data, int dst_rank) {
+void DataChannelMPI::send(thpp::Tensor& data, rank_type dst_rank) {
if (!data.isContiguous())
throw std::logic_error("tensor to send is not contiguous");
@@ -340,20 +346,22 @@
}
-void DataChannelMPI::receive(Scalar& data, int src_rank) {
+void DataChannelMPI::receive(Scalar& data, rank_type src_rank) {
std::uint64_t scalar_bytes;
MPI_Recv(&scalar_bytes, 1, MPI_UINT64_T, src_rank, 0, MPI_COMM_WORLD,
- MPI_STATUS_IGNORE);
-
- std::unique_ptr<std::uint8_t[]> bytes(new std::uint8_t[scalar_bytes]);
- MPI_Recv(bytes.get(), scalar_bytes, MPI_UINT8_T, src_rank, 0,
- MPI_COMM_WORLD, MPI_STATUS_IGNORE);
+ MPI_STATUS_IGNORE);
std::uint64_t actual_scalar_bytes = data.elementSize();
- if (actual_scalar_bytes != scalar_bytes)
- throw std::logic_error("scalar sizes does not match");
-
- memcpy(data.data(), bytes.get(), scalar_bytes);
+ if (actual_scalar_bytes == scalar_bytes) {
+ MPI_Recv(data.data(), scalar_bytes, MPI_UINT8_T, src_rank, 0, MPI_COMM_WORLD,
+ MPI_STATUS_IGNORE);
+ } else {
+ // receive invalid data
+ std::unique_ptr<std::uint8_t[]> bytes(new std::uint8_t[scalar_bytes]);
+ MPI_Recv(bytes.get(), scalar_bytes, MPI_UINT8_T, src_rank, 0, MPI_COMM_WORLD,
+ MPI_STATUS_IGNORE);
+ throw std::logic_error("tensor sizes does not match");
+ }
}
@@ -362,39 +370,42 @@
throw std::logic_error("tensor to receive is not contiguous");
std::uint64_t tensor_bytes;
- MPI_Status status;
MPI_Recv(&tensor_bytes, 1, MPI_UINT64_T, MPI_ANY_SOURCE, 0,
- MPI_COMM_WORLD, &status);
-
- std::unique_ptr<std::uint8_t[]> bytes(new std::uint8_t[tensor_bytes]);
- MPI_Recv(bytes.get(), tensor_bytes, MPI_UINT8_T, status.MPI_SOURCE, 0,
- MPI_COMM_WORLD, MPI_STATUS_IGNORE);
+ MPI_COMM_WORLD, MPI_STATUS_IGNORE);
std::uint64_t actual_tensor_bytes = data.elementSize() * data.numel();
- if (actual_tensor_bytes != tensor_bytes)
+ if (actual_tensor_bytes == tensor_bytes) {
+ MPI_Recv(data.data(), tensor_bytes, MPI_UINT8_T, MPI_ANY_SOURCE, 0, MPI_COMM_WORLD,
+ MPI_STATUS_IGNORE);
+ } else {
+ // receive invalid data
+ std::unique_ptr<std::uint8_t[]> bytes(new std::uint8_t[tensor_bytes]);
+ MPI_Recv(bytes.get(), tensor_bytes, MPI_UINT8_T, MPI_ANY_SOURCE, 0, MPI_COMM_WORLD,
+ MPI_STATUS_IGNORE);
throw std::logic_error("tensor sizes does not match");
-
- memcpy(data.data(), bytes.get(), tensor_bytes);
+ }
}
-void DataChannelMPI::receive(thpp::Tensor& data, int src_rank) {
+void DataChannelMPI::receive(thpp::Tensor& data, rank_type src_rank) {
if (!data.isContiguous())
throw std::logic_error("tensor to receive is not contiguous");
std::uint64_t tensor_bytes;
MPI_Recv(&tensor_bytes, 1, MPI_UINT64_T, src_rank, 0, MPI_COMM_WORLD,
- MPI_STATUS_IGNORE);
-
- std::unique_ptr<std::uint8_t[]> bytes(new std::uint8_t[tensor_bytes]);
- MPI_Recv(bytes.get(), tensor_bytes, MPI_UINT8_T, src_rank, 0, MPI_COMM_WORLD,
- MPI_STATUS_IGNORE);
+ MPI_STATUS_IGNORE);
std::uint64_t actual_tensor_bytes = data.elementSize() * data.numel();
- if (actual_tensor_bytes != tensor_bytes)
+ if (actual_tensor_bytes == tensor_bytes) {
+ MPI_Recv(data.data(), tensor_bytes, MPI_UINT8_T, src_rank, 0, MPI_COMM_WORLD,
+ MPI_STATUS_IGNORE);
+ } else {
+ // receive invalid data
+ std::unique_ptr<std::uint8_t[]> bytes(new std::uint8_t[tensor_bytes]);
+ MPI_Recv(bytes.get(), tensor_bytes, MPI_UINT8_T, src_rank, 0, MPI_COMM_WORLD,
+ MPI_STATUS_IGNORE);
throw std::logic_error("tensor sizes does not match");
-
- memcpy(data.data(), bytes.get(), tensor_bytes);
+ }
}
@@ -408,7 +419,7 @@
DataChannelMPI::RequestMPI* DataChannelMPI::isend(thpp::Tensor& data,
- int dst_rank) {
+ rank_type dst_rank) {
if (!data.isContiguous())
throw std::logic_error("tensor to send is not contiguous");
@@ -436,7 +447,7 @@
DataChannelMPI::RequestMPI* DataChannelMPI::ireceive(thpp::Tensor& data,
- int src_rank) {
+ rank_type src_rank) {
/*
* This function does **NOT** perform length and size checking. It assumes that
* someone is using this very carefully.
@@ -467,12 +478,13 @@
return request;
}
-THDGroup DataChannelMPI::newGroup(const std::vector<int>& ranks) {
+THDGroup DataChannelMPI::newGroup(const std::vector<rank_type>& ranks) {
MPI_Group world_group;
MPI_Comm_group(MPI_COMM_WORLD, &world_group);
MPI_Group ranks_group;
- MPI_Group_incl(world_group, ranks.size(), ranks.data(), &ranks_group);
+ std::vector<int> int_ranks(ranks.begin(), ranks.end());
+ MPI_Group_incl(world_group, int_ranks.size(), int_ranks.data(), &ranks_group);
MPI_Comm new_comm;
MPI_Comm_create_group(MPI_COMM_WORLD, ranks_group, 0, &new_comm);
@@ -489,11 +501,11 @@
std::unique_ptr<int[]> all_mapping_ranks(new int[2 * size]);
MPI_Allgather(&mapping_ranks, 2, MPI_INT, all_mapping_ranks.get(), 2,
- MPI_INT, new_comm);
+ MPI_INT, new_comm);
// this vector maps new ranks to ranks in COMM_WORLD (global ranks)
- std::vector<int> new_ranks(size);
- for (size_t i = 0; i < 2 * size; i += 2)
+ std::vector<rank_type> new_ranks(size);
+ for (std::size_t i = 0; i < 2 * size; i += 2)
new_ranks[all_mapping_ranks[i]] = all_mapping_ranks[i + 1];
new_group = DataChannel::Group(new_ranks, _num_processes - 1);
diff --git a/torch/lib/THD/base/data_channels/DataChannelMPI.hpp b/torch/lib/THD/base/data_channels/DataChannelMPI.hpp
index df5d547..0f333b7 100644
--- a/torch/lib/THD/base/data_channels/DataChannelMPI.hpp
+++ b/torch/lib/THD/base/data_channels/DataChannelMPI.hpp
@@ -34,38 +34,38 @@
bool init() override;
- int getRank() override;
- int getNumProcesses() override;
+ rank_type getRank() override;
+ rank_type getNumProcesses() override;
void allGather(std::vector<thpp::Tensor*>& output, thpp::Tensor& input,
THDGroup group_id = THDGroupWORLD) override;
void gather(std::vector<thpp::Tensor*>& output, thpp::Tensor& input,
- int dst_rank, THDGroup group_id = THDGroupWORLD) override;
+ rank_type dst_rank, THDGroup group_id = THDGroupWORLD) override;
void scatter(std::vector<thpp::Tensor*>& input, thpp::Tensor& output,
- int src_rank, THDGroup group_id = THDGroupWORLD) override;
+ rank_type src_rank, THDGroup group_id = THDGroupWORLD) override;
void allReduce(thpp::Tensor& data, THDReduceOp operation,
THDGroup group_id = THDGroupWORLD) override;
- void reduce(thpp::Tensor& data, THDReduceOp operation, int dst_rank,
+ void reduce(thpp::Tensor& data, THDReduceOp operation, rank_type dst_rank,
THDGroup group_id = THDGroupWORLD) override;
- void broadcast(thpp::Tensor& data, int src_rank,
+ void broadcast(thpp::Tensor& data, rank_type src_rank,
THDGroup group_id = THDGroupWORLD) override;
- void send(const Scalar& data, int dst_rank) override;
- void send(thpp::Tensor& data, int dst_rank) override;
- void receive(Scalar& data, int src_rank) override;
+ void send(const Scalar& data, rank_type dst_rank) override;
+ void send(thpp::Tensor& data, rank_type dst_rank) override;
+ void receive(Scalar& data, rank_type src_rank) override;
void receive(thpp::Tensor& data) override;
- void receive(thpp::Tensor& data, int src_rank) override;
- RequestMPI* isend(thpp::Tensor& data, int dst_rank) override;
- RequestMPI* ireceive(thpp::Tensor& data, int src_rank) override;
+ void receive(thpp::Tensor& data, rank_type src_rank) override;
+ RequestMPI* isend(thpp::Tensor& data, rank_type dst_rank) override;
+ RequestMPI* ireceive(thpp::Tensor& data, rank_type src_rank) override;
void barrier(THDGroup group_id = THDGroupWORLD) override;
- THDGroup newGroup(const std::vector<int>& ranks) override;
+ THDGroup newGroup(const std::vector<rank_type>& ranks) override;
private:
- void _broadcastPack(thpp::Tensor& data, int src_rank, MPI_Comm comm) const;
- void _broadcastUnpack(thpp::Tensor& data, int src_rank, MPI_Comm comm) const;
+ void _broadcastPack(thpp::Tensor& data, rank_type src_rank, MPI_Comm comm) const;
+ void _broadcastUnpack(thpp::Tensor& data, rank_type src_rank, MPI_Comm comm) const;
- int _rank; // Current process' rank
- int _num_processes; // Number of processes in network
+ rank_type _rank; // Current process' rank
+ rank_type _num_processes; // Number of processes in network
// Existing groups of processes with assigned MPI communicator
// and corresponding group ids
diff --git a/torch/lib/THD/base/data_channels/DataChannelTCP.cpp b/torch/lib/THD/base/data_channels/DataChannelTCP.cpp
index e7c6bc3..b558d85 100644
--- a/torch/lib/THD/base/data_channels/DataChannelTCP.cpp
+++ b/torch/lib/THD/base/data_channels/DataChannelTCP.cpp
@@ -1,5 +1,4 @@
#include "DataChannelTCP.hpp"
-#include "../ChannelUtils.hpp"
#include <sys/poll.h>
#include <unistd.h>
@@ -14,113 +13,32 @@
#include <system_error>
-<<<<<<< HEAD:torch/lib/THD/base/data_channels/DataChannelTCP.cpp
-=======
-
-#ifndef MSG_MORE // OS X does not have this optimalization option
-#define MSG_MORE 0
-#endif
-
-#define SYSCHECK(expr) { \
- errno = 0; (expr); \
- if (errno != 0) throw std::system_error(errno, std::system_category()); \
-}
-
->>>>>>> Tweaks, fixes, cleanup in DataChannelTCP:torch/lib/THD/base/channels/DataChannelTCP.cpp
namespace thd {
namespace {
-constexpr int MASTER_RANK = 0;
-<<<<<<< HEAD:torch/lib/THD/base/data_channels/DataChannelTCP.cpp
-=======
-constexpr int LISTEN_QUEUE_SIZE = 64;
+constexpr rank_type MASTER_RANK = 0;
-template<typename T>
-void send_bytes(int socket, const T* buffer, std::size_t length,
- bool more_data = false)
-{
- std::size_t bytes_to_send = sizeof(T) * length;
- if (bytes_to_send == 0)
- return;
-
- int flags = 0;
- if (more_data) { // there is more data to send
- flags |= MSG_MORE;
- }
-
- auto bytes = reinterpret_cast<const std::uint8_t*>(buffer);
- std::uint8_t *current_bytes = const_cast<std::uint8_t*>(bytes);
-
- while (bytes_to_send > 0) {
- ssize_t bytes_sent;
- SYSCHECK(bytes_sent = ::send(socket, current_bytes, bytes_to_send, flags))
- if (bytes_sent == 0)
- throw std::system_error(EBADMSG, std::system_category());
-
- bytes_to_send -= bytes_sent;
- current_bytes += bytes_sent;
- }
-}
-
-
-template<typename T>
-void recv_bytes(int socket, T* buffer, std::size_t length)
-{
- std::size_t bytes_to_receive = sizeof(T) * length;
- if (bytes_to_receive == 0)
- return;
-
- auto bytes = reinterpret_cast<std::uint8_t*>(buffer);
- std::uint8_t *current_bytes = bytes;
-
- while (bytes_to_receive > 0) {
- ssize_t bytes_received;
- SYSCHECK(bytes_received = ::recv(socket, current_bytes, bytes_to_receive, 0))
- if (bytes_received == 0)
- throw std::system_error(EBADMSG, std::system_category());
-
- bytes_to_receive -= bytes_received;
- current_bytes += bytes_received;
- }
-}
-
-
-inline bool validatePort(int port) {
- return (port > 0 && port < 65536);
-}
->>>>>>> Tweaks, fixes, cleanup in DataChannelTCP:torch/lib/THD/base/channels/DataChannelTCP.cpp
-
-
-inline int log2ceil(std::uint32_t value) {
- int dim = 0;
+inline std::uint32_t log2ceil(std::uint32_t value) {
+ std::uint32_t dim = 0;
#if defined(__GNUC__)
if (value <= 1)
return 0;
dim = 32 - __builtin_clz(value - 1);
#else
- for (int size = 1; size < value; ++dim, size <<= 1) /* empty */;
+ for (std::uint32_t size = 1; size < value; ++dim, size <<= 1) /* empty */;
#endif // defined(__GNUC__)
return dim;
}
-<<<<<<< HEAD:torch/lib/THD/base/data_channels/DataChannelTCP.cpp
-=======
// Finds nearest power-of-two less than or equal to `value`.
template<typename T>
-inline int pow2(T value) {
- T pof2 = 1;
+inline std::uint64_t pow2(T value) {
+ std::uint64_t pof2 = 1;
while (pof2 <= value) { pof2 <<= 1; }
pof2 >>= 1;
return pof2;
}
-void setSocketNoDelay(int socket) {
- int flag = 1;
- socklen_t optlen = sizeof(flag);
- SYSCHECK(setsockopt(socket, IPPROTO_TCP, TCP_NODELAY, (char*)&flag, optlen));
-}
-
->>>>>>> Tweaks, fixes, cleanup in DataChannelTCP:torch/lib/THD/base/channels/DataChannelTCP.cpp
} // namespace
@@ -156,7 +74,7 @@
_rank = load_rank_env();
if (_rank == MASTER_RANK) { // MASTER
- std::uint32_t processes_number;
+ rank_type processes_number;
std::tie(_port, processes_number) = load_master_env();
_processes.resize(processes_number);
@@ -168,7 +86,7 @@
};
} else { // WORKER
std::string address;
- std::uint16_t port;
+ port_type port;
std::tie(address, port) = load_worker_env();
// add master
@@ -202,34 +120,28 @@
std::tie(_socket, _port) = listen();
-<<<<<<< HEAD:torch/lib/THD/base/data_channels/DataChannelTCP.cpp
- send_bytes<std::uint32_t>(master_socket, &_rank, 1);
- send_bytes<std::uint16_t>(master_socket, &_port, 1); // send listening port to master
-=======
- std::uint32_t p_rank = (std::uint32_t)_rank;
- std::uint16_t p_port = (std::uint16_t)_port;
- send_bytes<std::uint32_t>(master_socket, &p_rank, 1, true);
- send_bytes<std::uint16_t>(master_socket, &p_port, 1); // send listening port to master
->>>>>>> Tweaks, fixes, cleanup in DataChannelTCP:torch/lib/THD/base/channels/DataChannelTCP.cpp
+ send_bytes<rank_type>(master_socket, &_rank, 1, true);
+ send_bytes<port_type>(master_socket, &_port, 1); // send listening port to master
- std::uint32_t processes_number;
- recv_bytes<std::uint32_t>(master_socket, &processes_number, 1);
+ rank_type processes_number;
+ recv_bytes<rank_type>(master_socket, &processes_number, 1);
_processes.resize(processes_number);
// get all metadata of other processes in network
processes_number--; // exclude master
while (processes_number > 0) {
- std::uint32_t p_rank, p_address_len;
- std::uint16_t p_port;
+ std::uint32_t p_address_len;
+ rank_type p_rank;
+ port_type p_port;
- recv_bytes<std::uint32_t>(master_socket, &p_rank, 1); // get process rank
+ recv_bytes<rank_type>(master_socket, &p_rank, 1); // get process rank
recv_bytes<std::uint32_t>(master_socket, &p_address_len, 1); // get process address length
// get process address
std::unique_ptr<char[]> tmp_address(new char[p_address_len + 1]);
recv_bytes<char>(master_socket, tmp_address.get(), p_address_len);
- recv_bytes<std::uint16_t>(master_socket, &p_port, 1); // get process port
+ recv_bytes<port_type>(master_socket, &p_port, 1); // get process port
_processes[p_rank] = {
.rank = p_rank,
@@ -249,22 +161,21 @@
* trying to connect.
*/
- for (std::uint32_t r = 1; r < _rank; ++r) {
+ for (rank_type r = 1; r < _rank; ++r) {
auto& process = _processes[r];
process.socket = connect(process.address, process.port);
// send rank to tell to the accepting process who we are
- std::uint32_t p_rank = static_cast<std::uint32_t>(_rank);
- send_bytes<std::uint32_t>(process.socket, &p_rank, 1);
+ send_bytes<rank_type>(process.socket, &_rank, 1);
}
- for (std::uint32_t i = _rank + 1; i < _processes.size(); ++i) {
+ for (rank_type i = _rank + 1; i < _processes.size(); ++i) {
int socket;
std::tie(socket, std::ignore) = accept(_socket, _timeout);
// get rank of process we have just accepted
- std::uint32_t p_rank;
- recv_bytes<std::uint32_t>(socket, &p_rank, 1);
+ rank_type p_rank;
+ recv_bytes<rank_type>(socket, &p_rank, 1);
_processes[p_rank].socket = socket;
}
@@ -281,16 +192,16 @@
std::tie(_socket, std::ignore) = listen(_port);
// wait for all workers to connect
- int workers = _processes.size() - 1;
+ std::size_t workers = _processes.size() - 1;
while (workers > 0) {
std::string p_address;
int p_socket;
std::tie(p_socket, p_address) = accept(_socket, _timeout);
- std::uint32_t p_rank;
- std::uint16_t p_port;
- recv_bytes<std::uint32_t>(p_socket, &p_rank, 1);
- recv_bytes<std::uint16_t>(p_socket, &p_port, 1);
+ rank_type p_rank;
+ port_type p_port;
+ recv_bytes<rank_type>(p_socket, &p_rank, 1);
+ recv_bytes<port_type>(p_socket, &p_port, 1);
if (p_rank >= _processes.size()) {
throw std::out_of_range(
@@ -320,17 +231,17 @@
for (const auto& worker : _processes) {
if (worker.rank == _rank) continue;
- std::uint32_t processes_number = _processes.size();
- send_bytes<std::uint32_t>(worker.socket, &processes_number, 1, true);
+ rank_type processes_number = _processes.size();
+ send_bytes<rank_type>(worker.socket, &processes_number, 1, true);
for (auto& process : _processes) {
if (process.rank == _rank) continue;
std::uint32_t proc_address_length = process.address.size();
- send_bytes<std::uint32_t>(worker.socket, &process.rank, 1, true);
+ send_bytes<rank_type>(worker.socket, &process.rank, 1, true);
send_bytes<std::uint32_t>(worker.socket, &proc_address_length, 1, true);
send_bytes<char>(worker.socket, process.address.data(), proc_address_length, true);
- send_bytes<std::uint16_t>(worker.socket, &(process.port), 1);
+ send_bytes<port_type>(worker.socket, &process.port, 1);
}
}
@@ -345,9 +256,9 @@
bool DataChannelTCP::init() {
bool ok = (_rank == MASTER_RANK ? initMaster() : initWorker());
if (ok) {
- std::vector<int> ranks;
+ std::vector<rank_type> ranks;
ranks.reserve(_processes.size());
- for (size_t rank = 0; rank < _processes.size(); ++rank)
+ for (rank_type rank = 0; rank < _processes.size(); ++rank)
ranks.push_back(rank);
_groups.insert({
@@ -360,12 +271,12 @@
}
-int DataChannelTCP::getRank() {
+rank_type DataChannelTCP::getRank() {
return _rank;
}
-int DataChannelTCP::getNumProcesses() {
+rank_type DataChannelTCP::getNumProcesses() {
return _processes.size();
}
@@ -413,7 +324,7 @@
void DataChannelTCP::gather(std::vector<thpp::Tensor*>& output,
- thpp::Tensor& input, int dst_rank, THDGroup group_id) {
+ thpp::Tensor& input, rank_type dst_rank, THDGroup group_id) {
const auto& group = _groups.at(group_id);
bool exists;
@@ -445,7 +356,7 @@
void DataChannelTCP::scatter(std::vector<thpp::Tensor*>& input,
- thpp::Tensor& output, int src_rank,
+ thpp::Tensor& output, rank_type src_rank,
THDGroup group_id) {
const auto& group = _groups.at(group_id);
bool exists;
@@ -502,7 +413,7 @@
return;
std::uint64_t tensor_bytes = data.elementSize() * data.numel();
- auto tmp_tensor = data.clone();
+ auto tmp_tensor = std::unique_ptr<thpp::Tensor>(data.clone());
auto pof2 = pow2(group.size());
int rem = group.size() - pof2;
@@ -543,8 +454,6 @@
}
}
- delete tmp_tensor;
-
if (group_rank < 2 * rem) {
if (group_rank % 2) {
send(data, group.mustGetGlobalRank(group_rank - 1));
@@ -556,7 +465,7 @@
void DataChannelTCP::reduce(thpp::Tensor& data, THDReduceOp operation,
- int dst_rank, THDGroup group_id) {
+ rank_type dst_rank, THDGroup group_id) {
/*
* Idea of this algorithm is similar to broadcast but with reversed
* order and direction of communication.
@@ -572,9 +481,9 @@
auto group_dst_rank = group.mustGetGroupRank(dst_rank);
int dim = log2ceil(group.size());
- rank_type virtual_rank = ((group.size() - group_dst_rank) + group_rank) % group.size();
+ rank_type virtual_rank = (group_rank + group.size() - group_dst_rank) % group.size();
long long mask = 0;
- auto result_tensor = data.clone();
+ auto result_tensor = std::unique_ptr<thpp::Tensor>(data.clone());
for (int k = 0; k <= dim - 1; mask ^= (1 << k), ++k) {
if ((virtual_rank & mask) == 0) {
@@ -594,12 +503,10 @@
if (_rank == dst_rank)
std::memcpy(data.data(), result_tensor->data(), data.elementSize() * data.numel());
-
- delete result_tensor;
}
-void DataChannelTCP::broadcast(thpp::Tensor& data, int src_rank,
+void DataChannelTCP::broadcast(thpp::Tensor& data, rank_type src_rank,
THDGroup group_id) {
/*
* General idea of this algorithm is to send data in `d` dimensional
@@ -621,7 +528,7 @@
auto group_src_rank = group.mustGetGroupRank(src_rank);
int dim = log2ceil(group.size());
- rank_type virtual_rank = ((group.size() - group_src_rank) + group_rank) % group.size();
+ rank_type virtual_rank = (group_rank + group.size() - group_src_rank) % group.size();
long long mask = (1 << dim) - 1;
for (int k = dim - 1; k >= 0; --k) {
@@ -642,7 +549,7 @@
}
-void DataChannelTCP::send(const Scalar& data, int dst_rank) {
+void DataChannelTCP::send(const Scalar& data, rank_type dst_rank) {
auto request = _send_worker.push([this, &data, dst_rank]{
this->_send(data, dst_rank);
});
@@ -650,7 +557,7 @@
}
-void DataChannelTCP::send(thpp::Tensor& data, int dst_rank) {
+void DataChannelTCP::send(thpp::Tensor& data, rank_type dst_rank) {
auto request = _send_worker.push([this, &data, dst_rank]{
this->_send(data, dst_rank);
});
@@ -658,7 +565,7 @@
}
-void DataChannelTCP::receive(Scalar& data, int src_rank) {
+void DataChannelTCP::receive(Scalar& data, rank_type src_rank) {
auto request = _receive_worker.push([this, &data, src_rank]{
this->_receive(data, src_rank);
});
@@ -701,7 +608,7 @@
}
-void DataChannelTCP::receive(thpp::Tensor& data, int src_rank) {
+void DataChannelTCP::receive(thpp::Tensor& data, rank_type src_rank) {
auto request = _receive_worker.push([this, &data, src_rank]{
this->_receive(data, src_rank);
});
@@ -710,7 +617,7 @@
DataChannelTCP::RequestTCP* DataChannelTCP::isend(thpp::Tensor& data,
- int dst_rank) {
+ rank_type dst_rank) {
std::shared_ptr<thpp::Tensor> copy_tensor(data.clone_shallow());
auto request = _send_worker.push([this, copy_tensor, dst_rank]{
this->_send(*copy_tensor, dst_rank);
@@ -720,7 +627,7 @@
DataChannelTCP::RequestTCP* DataChannelTCP::ireceive(thpp::Tensor& data,
- int src_rank) {
+ rank_type src_rank) {
std::shared_ptr<thpp::Tensor> copy_tensor(data.clone_shallow());
auto request = _receive_worker.push([this, copy_tensor, src_rank]{
this->_receive(*copy_tensor, src_rank);
@@ -765,7 +672,7 @@
}
-THDGroup DataChannelTCP::newGroup(const std::vector<int>& ranks) {
+THDGroup DataChannelTCP::newGroup(const std::vector<rank_type>& ranks) {
auto new_group = DataChannel::Group(ranks, _processes.size() - 1);
THDGroup new_group_id = static_cast<THDGroup>(_groups.size());
@@ -774,15 +681,12 @@
}
-void DataChannelTCP::_send(const Scalar& data, int dst_rank) {
+void DataChannelTCP::_send(const Scalar& data, rank_type dst_rank) {
/*
* We have to check if dst_rank is positive to properly use `.at` function in vector.
* Not checking that can result in int overflow and strange errors.
*/
- if (dst_rank < 0)
- throw std::out_of_range("destination rank is invalid (< 0)");
-
const auto& process_dst = _processes.at(dst_rank);
if (process_dst.rank == _rank)
throw std::logic_error("cannot send scalar to process with same rank");
@@ -800,15 +704,12 @@
}
-void DataChannelTCP::_send(thpp::Tensor& data, int dst_rank) {
+void DataChannelTCP::_send(thpp::Tensor& data, rank_type dst_rank) {
/*
* We have to check if dst_rank is positive to properly use `.at` function in vector.
* Not checking that can result in int overflow and strange errors.
*/
- if (dst_rank < 0)
- throw std::out_of_range("destination rank is invalid (< 0)");
-
const auto& process_dst = _processes.at(dst_rank);
if (process_dst.rank == _rank)
throw std::logic_error("cannot send tensor to process with same rank");
@@ -829,15 +730,12 @@
}
-void DataChannelTCP::_receive(Scalar& data, int src_rank) {
+void DataChannelTCP::_receive(Scalar& data, rank_type src_rank) {
/*
* We have to check if src_rank is positive to properly use `.at` function in vector.
* Not checking that can result in int overflow and strange errors.
*/
- if (src_rank < 0)
- throw std::out_of_range("source rank is invalid (< 0)");
-
const auto& process_src = _processes.at(src_rank);
if (process_src.rank == _rank)
throw std::logic_error("cannot receive scalar from process with same rank");
@@ -862,15 +760,12 @@
}
-void DataChannelTCP::_receive(thpp::Tensor& data, int src_rank) {
+void DataChannelTCP::_receive(thpp::Tensor& data, rank_type src_rank) {
/*
* We have to check if src_rank is positive to properly use `.at` function in vector.
* Not checking that can result in int overflow and strange errors.
*/
- if (src_rank < 0)
- throw std::out_of_range("source rank is invalid (< 0)");
-
const auto& process_src = _processes.at(src_rank);
if (process_src.rank == _rank)
throw std::logic_error("cannot receive tensor from process with same rank");
diff --git a/torch/lib/THD/base/data_channels/DataChannelTCP.hpp b/torch/lib/THD/base/data_channels/DataChannelTCP.hpp
index 432d376..8df2870 100644
--- a/torch/lib/THD/base/data_channels/DataChannelTCP.hpp
+++ b/torch/lib/THD/base/data_channels/DataChannelTCP.hpp
@@ -14,7 +14,6 @@
namespace thd {
struct DataChannelTCP : DataChannel {
- using rank_type = DataChannel::Group::rank_type;
struct RequestTCP : DataChannel::Request {
RequestTCP(QueueWorker::Request&& request);
@@ -33,56 +32,56 @@
bool init() override;
- int getRank() override;
- int getNumProcesses() override;
+ rank_type getRank() override;
+ rank_type getNumProcesses() override;
void allGather(std::vector<thpp::Tensor*>& output, thpp::Tensor& input,
THDGroup group_id = THDGroupWORLD) override;
void gather(std::vector<thpp::Tensor*>& output, thpp::Tensor& input,
- int dst_rank, THDGroup group_id = THDGroupWORLD) override;
+ rank_type dst_rank, THDGroup group_id = THDGroupWORLD) override;
void scatter(std::vector<thpp::Tensor*>& input, thpp::Tensor& output,
- int src_rank, THDGroup group_id = THDGroupWORLD) override;
+ rank_type src_rank, THDGroup group_id = THDGroupWORLD) override;
void allReduce(thpp::Tensor& data, THDReduceOp operation,
THDGroup group_id = THDGroupWORLD) override;
- void reduce(thpp::Tensor& data, THDReduceOp operation, int dst_rank,
+ void reduce(thpp::Tensor& data, THDReduceOp operation, rank_type dst_rank,
THDGroup group_id = THDGroupWORLD) override;
- void broadcast(thpp::Tensor& data, int src_id,
+ void broadcast(thpp::Tensor& data, rank_type src_id,
THDGroup group_id = THDGroupWORLD) override;
- void send(const Scalar& data, int dst_id) override;
- void send(thpp::Tensor& data, int dst_id) override;
- void receive(Scalar& data, int src_id) override;
+ void send(const Scalar& data, rank_type dst_id) override;
+ void send(thpp::Tensor& data, rank_type dst_id) override;
+ void receive(Scalar& data, rank_type src_id) override;
void receive(thpp::Tensor& data) override;
- void receive(thpp::Tensor& data, int src_id) override;
- RequestTCP* isend(thpp::Tensor& data, int dst_rank) override;
- RequestTCP* ireceive(thpp::Tensor& data, int src_rank) override;
+ void receive(thpp::Tensor& data, rank_type src_id) override;
+ RequestTCP* isend(thpp::Tensor& data, rank_type dst_rank) override;
+ RequestTCP* ireceive(thpp::Tensor& data, rank_type src_rank) override;
void barrier(THDGroup group_id = THDGroupWORLD) override;
- THDGroup newGroup(const std::vector<int>& ranks) override;
+ THDGroup newGroup(const std::vector<rank_type>& ranks) override;
private:
// Defines process to which master or worker is connected
struct Process {
- std::uint32_t rank;
+ rank_type rank;
std::string address;
- std::uint16_t port;
+ port_type port;
int socket;
};
bool initMaster();
bool initWorker();
- void _send(const Scalar& data, int dst_id);
- void _send(thpp::Tensor& data, int dst_id);
- void _receive(Scalar& data, int src_id);
- void _receive(thpp::Tensor& data, int src_id);
+ void _send(const Scalar& data, rank_type dst_id);
+ void _send(thpp::Tensor& data, rank_type dst_id);
+ void _receive(Scalar& data, rank_type src_id);
+ void _receive(thpp::Tensor& data, rank_type src_id);
void _reduce(thpp::Tensor& result, thpp::Tensor& data,
THDReduceOp operation) const;
- std::uint32_t _rank; // Rank of current process, range: [0.._processes.size()-1]
+ rank_type _rank; // Rank of current process, range: [0.._processes.size()-1]
int _socket; // Socket on which process is listening
- std::uint16_t _port; // Port on which process is listening
+ port_type _port; // Port on which process is listening
int _timeout; // Accept waiting timeout in milliseconds (it is optional, default = infinity)
std::vector<Process> _processes; // Other processes in network
diff --git a/torch/lib/THD/master_worker/common/CommandChannel.cpp b/torch/lib/THD/master_worker/common/CommandChannel.cpp
index d07151b..4fb5ac1 100644
--- a/torch/lib/THD/master_worker/common/CommandChannel.cpp
+++ b/torch/lib/THD/master_worker/common/CommandChannel.cpp
@@ -19,7 +19,7 @@
auto& bytes = msg.get()->bytes();
std::uint64_t msg_length = static_cast<std::uint64_t>(bytes.length());
- send_bytes<std::uint64_t>(socket, &msg_length, 1);
+ send_bytes<std::uint64_t>(socket, &msg_length, 1, true);
send_bytes<std::uint8_t>(
socket,
reinterpret_cast<const std::uint8_t*>(bytes.data()),
@@ -44,7 +44,7 @@
MasterCommandChannel::MasterCommandChannel()
: _rank(0)
{
- std::uint32_t world_size;
+ rank_type world_size;
std::tie(_port, world_size) = load_master_env();
_sockets.resize(world_size);
@@ -62,10 +62,10 @@
std::tie(_sockets[0], std::ignore) = listen(_port);
int socket;
- std::uint32_t rank;
+ rank_type rank;
for (std::size_t i = 1; i < _sockets.size(); ++i) {
std::tie(socket, std::ignore) = accept(_sockets[0]);
- recv_bytes<std::uint32_t>(socket, &rank, 1);
+ recv_bytes<rank_type>(socket, &rank, 1);
_sockets.at(rank) = socket;
}
@@ -114,7 +114,7 @@
bool WorkerCommandChannel::init() {
_socket = connect(_master_addr, _master_port);
- send_bytes<std::uint32_t>(_socket, &_rank, 1); // send rank
+ send_bytes<rank_type>(_socket, &_rank, 1); // send rank
std::uint8_t confirm_byte;
recv_bytes<std::uint8_t>(_socket, &confirm_byte, 1);
diff --git a/torch/lib/THD/master_worker/common/CommandChannel.hpp b/torch/lib/THD/master_worker/common/CommandChannel.hpp
index 212f8fa..3d24ea5 100644
--- a/torch/lib/THD/master_worker/common/CommandChannel.hpp
+++ b/torch/lib/THD/master_worker/common/CommandChannel.hpp
@@ -18,10 +18,10 @@
void sendMessage(std::unique_ptr<rpc::RPCMessage> msg, int rank);
private:
- std::uint32_t _rank;
+ rank_type _rank;
std::vector<int> _sockets;
- std::uint16_t _port;
+ port_type _port;
};
struct WorkerCommandChannel {
@@ -34,11 +34,11 @@
void sendMessage(std::unique_ptr<rpc::RPCMessage> msg);
private:
- std::uint32_t _rank;
+ rank_type _rank;
int _socket;
std::string _master_addr;
- std::uint16_t _master_port;
+ port_type _master_port;
};
} // namespace thd
diff --git a/torch/lib/THD/master_worker/master/THDStorage.cpp b/torch/lib/THD/master_worker/master/THDStorage.cpp
index 57c68d4..09812ff 100644
--- a/torch/lib/THD/master_worker/master/THDStorage.cpp
+++ b/torch/lib/THD/master_worker/master/THDStorage.cpp
@@ -8,6 +8,7 @@
#include <THPP/Traits.hpp>
+#include <cstring>
#include <memory>
#include "master_worker/master/generic/THDStorage.cpp"
diff --git a/torch/lib/THD/master_worker/master/THDTensor.cpp b/torch/lib/THD/master_worker/master/THDTensor.cpp
index f30f5bf..691f453 100644
--- a/torch/lib/THD/master_worker/master/THDTensor.cpp
+++ b/torch/lib/THD/master_worker/master/THDTensor.cpp
@@ -8,6 +8,7 @@
#include <THPP/Traits.hpp>
+#include <cstring>
#include <memory>
#include "master_worker/master/generic/THDTensor.cpp"
diff --git a/torch/lib/THD/process_group/Collectives.cpp b/torch/lib/THD/process_group/Collectives.cpp
index cdfc9e8..984bb7a 100644
--- a/torch/lib/THD/process_group/Collectives.cpp
+++ b/torch/lib/THD/process_group/Collectives.cpp
@@ -1,16 +1,17 @@
#include "Collectives.hpp"
#include "General.hpp"
+#include "../base/ChannelUtils.hpp"
#include <vector>
using namespace thd;
int THDGetRank() {
- return dataChannel->getRank();
+ return static_cast<int>(dataChannel->getRank());
}
int THDGetNumProcesses() {
- return dataChannel->getNumProcesses();
+ return static_cast<int>(dataChannel->getNumProcesses());
}
void THDAllReduce(THDTensorDescriptor* desc, THDReduceOp operation, THDGroup group) {
@@ -19,23 +20,38 @@
void THDReduce(THDTensorDescriptor* desc, THDReduceOp operation,
int dst_rank, THDGroup group) {
- dataChannel->reduce(*desc, operation, dst_rank, group);
+ if (dst_rank < 0)
+ throw std::domain_error("dst_rank should not be negative");
+
+ dataChannel->reduce(*desc, operation, static_cast<rank_type>(dst_rank), group);
}
void THDBroadcast(THDTensorDescriptor* desc, int src_rank, THDGroup group) {
- dataChannel->broadcast(*desc, src_rank, group);
+ if (src_rank < 0)
+ throw std::domain_error("src_rank should not be negative");
+
+ dataChannel->broadcast(*desc, static_cast<rank_type>(src_rank), group);
}
THDRequest* THDIsend(THDTensorDescriptor* desc, int dst_rank) {
- return dataChannel->isend(*desc, dst_rank);
+ if (dst_rank < 0)
+ throw std::domain_error("dst_rank should not be negative");
+
+ return dataChannel->isend(*desc, static_cast<rank_type>(dst_rank));
}
THDRequest* THDIrecv(THDTensorDescriptor* desc, int src_rank) {
- return dataChannel->ireceive(*desc, src_rank);
+ if (src_rank < 0)
+ throw std::domain_error("src_rank should not be negative");
+
+ return dataChannel->ireceive(*desc, static_cast<rank_type>(src_rank));
}
void THDSend(THDTensorDescriptor* desc, int dst_rank) {
- dataChannel->send(*desc, dst_rank);
+ if (dst_rank < 0)
+ throw std::domain_error("dst_rank should not be negative");
+
+ dataChannel->send(*desc, static_cast<rank_type>(dst_rank));
}
void THDRecvAnySource(THDTensorDescriptor* desc) {
@@ -43,7 +59,10 @@
}
void THDRecv(THDTensorDescriptor* desc, int src_rank) {
- dataChannel->receive(*desc, src_rank);
+ if (src_rank < 0)
+ throw std::domain_error("src_rank should not be negative");
+
+ dataChannel->receive(*desc, static_cast<rank_type>(src_rank));
}
void THDAllGather(THDTensorDescriptor** output, size_t len,
@@ -53,8 +72,11 @@
}
void THDGatherSend(THDTensorDescriptor* input, int dst_rank, THDGroup group) {
+ if (dst_rank < 0)
+ throw std::domain_error("dst_rank should not be negative");
+
std::vector<thpp::Tensor*> v_output;
- dataChannel->gather(v_output, *input, dst_rank, group);
+ dataChannel->gather(v_output, *input, static_cast<rank_type>(dst_rank), group);
}
void THDGatherRecv(THDTensorDescriptor** output, size_t len,
@@ -70,8 +92,11 @@
}
void THDScatterRecv(THDTensorDescriptor* output, int src_rank, THDGroup group) {
+ if (src_rank < 0)
+ throw std::domain_error("src_rank should not be negative");
+
std::vector<thpp::Tensor*> v_input;
- dataChannel->scatter(v_input, *output, src_rank, group);
+ dataChannel->scatter(v_input, *output, static_cast<rank_type>(src_rank), group);
}
void THDBarrier(THDGroup group) {
@@ -79,7 +104,14 @@
}
THDGroup THDNewGroup(const int *ranks, size_t len) {
- std::vector<int> v_ranks(ranks, ranks + len);
+ std::vector<rank_type> v_ranks(len);
+ for (std::size_t i = 0; i < len; ++i) {
+ if (ranks[i] < 0)
+ throw std::domain_error("ranks should not be negative");
+
+ v_ranks[i] = ranks[i];
+ }
+
return dataChannel->newGroup(v_ranks);
}
diff --git a/torch/lib/THD/test/command_channel_smoke.cpp b/torch/lib/THD/test/command_channel_smoke.cpp
index e31e453..47e1a19 100644
--- a/torch/lib/THD/test/command_channel_smoke.cpp
+++ b/torch/lib/THD/test/command_channel_smoke.cpp
@@ -5,6 +5,7 @@
#include <cerrno>
#include <cstdlib>
#include <exception>
+#include <mutex>
#include <string>
#include <system_error>
#include <thread>
@@ -16,8 +17,8 @@
void init_worker(const int& rank, const std::string& master_addr) {
g_mutex.lock();
- setenv("RANK", std::to_string(rank).data(), 1);
- setenv("MASTER_ADDR", master_addr.data(), 1);
+ setenv(RANK_ENV, std::to_string(rank).data(), 1);
+ setenv(MASTER_ADDR_ENV, master_addr.data(), 1);
auto channel = std::make_shared<thd::WorkerCommandChannel>(); // reads all env variable
g_mutex.unlock();
@@ -47,9 +48,9 @@
void init_master(int world_size, const std::string& master_port) {
g_mutex.lock();
- setenv("WORLD_SIZE", std::to_string(world_size).data(), 1);
- setenv("RANK", "0", 1);
- setenv("MASTER_PORT", master_port.data(), 1);
+ setenv(WORLD_SIZE_ENV, std::to_string(world_size).data(), 1);
+ setenv(RANK_ENV, "0", 1);
+ setenv(MASTER_PORT_ENV, master_port.data(), 1);
auto channel = std::make_shared<thd::MasterCommandChannel>(); // reads all env variable
g_mutex.unlock();
diff --git a/torch/lib/THD/test/data_channel_collectives.cpp b/torch/lib/THD/test/data_channel_collectives.cpp
index 652b9fc..9f9f533 100644
--- a/torch/lib/THD/test/data_channel_collectives.cpp
+++ b/torch/lib/THD/test/data_channel_collectives.cpp
@@ -2,6 +2,7 @@
#ifdef WITH_MPI
#include "../base/data_channels/DataChannelMPI.hpp"
#endif // WITH_MPI
+#include "../base/ChannelEnvVars.hpp"
#include "TestUtils.hpp"
#include <THPP/tensors/THTensor.hpp>
@@ -80,12 +81,12 @@
void _test_reduce_helper(std::shared_ptr<thd::DataChannel> data_channel,
THDReduceOp op_type, long init_value, long expected_value) {
if (data_channel->getRank() == 0) {
- auto long_tensor = buildTensor<long>({1, 2, 3, 4, 5}, init_value);
- data_channel->reduce(*long_tensor, op_type, 0);
- ASSERT_TENSOR_VALUE(long, *long_tensor, expected_value)
+ auto int_tensor = buildTensor<int>({1, 2, 3, 4, 5}, init_value);
+ data_channel->reduce(*int_tensor, op_type, 0);
+ ASSERT_TENSOR_VALUE(int, *int_tensor, expected_value)
} else {
- auto long_tensor = buildTensor<long>({1, 2, 3, 4, 5}, data_channel->getRank());
- data_channel->reduce(*long_tensor, op_type, 0);
+ auto int_tensor = buildTensor<int>({1, 2, 3, 4, 5}, data_channel->getRank());
+ data_channel->reduce(*int_tensor, op_type, 0);
}
}
@@ -275,7 +276,7 @@
////////////
void test_broadcast_group(std::shared_ptr<thd::DataChannel> data_channel,
- THDGroup group, std::vector<int> group_ranks) {
+ THDGroup group, std::vector<thd::rank_type> group_ranks) {
if (contains(group_ranks, data_channel->getRank())) {
auto int_tensor = buildTensor({1, 2, 3, 4, 5}, -1);
if (data_channel->getRank() == group_ranks[0])
@@ -291,7 +292,7 @@
}
void test_reduce_group(std::shared_ptr<thd::DataChannel> data_channel,
- THDGroup group, std::vector<int> group_ranks) {
+ THDGroup group, std::vector<thd::rank_type> group_ranks) {
if (contains(group_ranks, data_channel->getRank())) {
auto int_tensor = buildTensor({1, 2, 3, 4, 5}, 10);
data_channel->reduce(*int_tensor, THDReduceOp::THDReduceSUM, group_ranks[0], group);
@@ -308,7 +309,7 @@
}
void test_allReduce_group(std::shared_ptr<thd::DataChannel> data_channel,
- THDGroup group, std::vector<int> group_ranks) {
+ THDGroup group, std::vector<thd::rank_type> group_ranks) {
if (contains(group_ranks, data_channel->getRank())) {
auto int_tensor = buildTensor({1, 2, 3, 4, 5, 6, 7, 100}, 10);
data_channel->allReduce(*int_tensor, THDReduceOp::THDReduceSUM, group);
@@ -321,7 +322,7 @@
}
void test_scatter_group(std::shared_ptr<thd::DataChannel> data_channel,
- THDGroup group, std::vector<int> group_ranks) {
+ THDGroup group, std::vector<thd::rank_type> group_ranks) {
std::vector<std::shared_ptr<thpp::IntTensor>> tensors;
std::vector<thpp::Tensor*> raw_tensors;
if (contains(group_ranks, data_channel->getRank())) {
@@ -344,7 +345,7 @@
void test_gather_group(std::shared_ptr<thd::DataChannel> data_channel,
- THDGroup group, std::vector<int> group_ranks) {
+ THDGroup group, std::vector<thd::rank_type> group_ranks) {
std::vector<std::shared_ptr<thpp::IntTensor>> tensors;
std::vector<thpp::Tensor*> raw_tensors;
if (contains(group_ranks, data_channel->getRank())) {
@@ -369,7 +370,7 @@
}
void test_allGather_group(std::shared_ptr<thd::DataChannel> data_channel,
- THDGroup group, std::vector<int> group_ranks) {
+ THDGroup group, std::vector<thd::rank_type> group_ranks) {
std::vector<std::shared_ptr<thpp::IntTensor>> tensors;
std::vector<thpp::Tensor*> raw_tensors;
if (contains(group_ranks, data_channel->getRank())) {
@@ -390,7 +391,7 @@
}
void test_barrier_group(std::shared_ptr<thd::DataChannel> data_channel,
- THDGroup group, std::vector<int> group_ranks) {
+ THDGroup group, std::vector<thd::rank_type> group_ranks) {
if (contains(group_ranks, data_channel->getRank())) {
for (int i = 0; i < group_ranks.size(); ++i) {
if (data_channel->getRank() == group_ranks[i]) {
@@ -555,7 +556,7 @@
test_irecv(data_channel);
test_interlaces(data_channel);
- std::vector<int> group_ranks = {1, 2};
+ std::vector<thd::rank_type> group_ranks = {1, 2};
THDGroup group = data_channel->newGroup(group_ranks);
test_broadcast_group(data_channel, group, group_ranks);
test_reduce_group(data_channel, group, group_ranks);
@@ -575,9 +576,9 @@
void init_tcp_master(int workers) {
g_mutex.lock();
- setenv("WORLD_SIZE", std::to_string((workers + 1)).data(), 1);
- setenv("RANK", "0", 1);
- setenv("MASTER_PORT", std::to_string(MASTER_PORT).data(), 1);
+ setenv(thd::WORLD_SIZE_ENV, std::to_string((workers + 1)).data(), 1);
+ setenv(thd::RANK_ENV, "0", 1);
+ setenv(thd::MASTER_PORT_ENV, std::to_string(MASTER_PORT).data(), 1);
auto masterChannel = std::make_shared<thd::DataChannelTCP>(); // reads all env variable
g_mutex.unlock();
@@ -593,8 +594,8 @@
void init_tcp_worker(unsigned int id, int workers) {
g_mutex.lock();
- setenv("RANK", std::to_string(id).data(), 1);
- setenv("MASTER_ADDR", std::string("127.0.0.1:" + std::to_string(MASTER_PORT)).data(), 1);
+ setenv(thd::RANK_ENV, std::to_string(id).data(), 1);
+ setenv(thd::MASTER_ADDR_ENV, std::string("127.0.0.1:" + std::to_string(MASTER_PORT)).data(), 1);
auto worker_channel = std::make_shared<thd::DataChannelTCP>(); // reads all env variable
g_mutex.unlock();
diff --git a/torch/lib/THD/test/data_channel_tcp_accept_timeout.cpp b/torch/lib/THD/test/data_channel_tcp_accept_timeout.cpp
index d1c44e5..f85eac0 100644
--- a/torch/lib/THD/test/data_channel_tcp_accept_timeout.cpp
+++ b/torch/lib/THD/test/data_channel_tcp_accept_timeout.cpp
@@ -1,4 +1,5 @@
#include "../base/data_channels/DataChannelTCP.hpp"
+#include "../base/ChannelEnvVars.hpp"
#include "TestUtils.hpp"
#include <cassert>
@@ -11,9 +12,9 @@
void master()
{
- setenv("WORLD_SIZE", std::to_string((WORKERS_NUM + 1)).data(), 1);
- setenv("RANK", "0", 1);
- setenv("MASTER_PORT", std::to_string(MASTER_PORT).data(), 1);
+ setenv(thd::WORLD_SIZE_ENV, std::to_string((WORKERS_NUM + 1)).data(), 1);
+ setenv(thd::RANK_ENV, "0", 1);
+ setenv(thd::MASTER_PORT_ENV, std::to_string(MASTER_PORT).data(), 1);
auto masterChannel = std::make_shared<thd::DataChannelTCP>(2000); // timeout after 2s
ASSERT_THROWS(std::exception, masterChannel->init())
diff --git a/torch/lib/THD/test/data_channel_tcp_slow_master.cpp b/torch/lib/THD/test/data_channel_tcp_slow_master.cpp
index a4d2b97..efadc37 100644
--- a/torch/lib/THD/test/data_channel_tcp_slow_master.cpp
+++ b/torch/lib/THD/test/data_channel_tcp_slow_master.cpp
@@ -1,4 +1,5 @@
#include "../base/data_channels/DataChannelTCP.hpp"
+#include "../base/ChannelEnvVars.hpp"
#include "TestUtils.hpp"
#include <THPP/tensors/THTensor.hpp>
@@ -18,9 +19,9 @@
void master()
{
g_mutex.lock();
- setenv("WORLD_SIZE", std::to_string((WORKERS_NUM + 1)).data(), 1);
- setenv("RANK", "0", 1);
- setenv("MASTER_PORT", std::to_string(MASTER_PORT).data(), 1);
+ setenv(thd::WORLD_SIZE_ENV, std::to_string((WORKERS_NUM + 1)).data(), 1);
+ setenv(thd::RANK_ENV, "0", 1);
+ setenv(thd::MASTER_PORT_ENV, std::to_string(MASTER_PORT).data(), 1);
auto masterChannel = std::make_shared<thd::DataChannelTCP>(); // reads all env variable
g_mutex.unlock();
@@ -41,8 +42,8 @@
void worker(int id)
{
g_mutex.lock();
- setenv("RANK", std::to_string(id).data(), 1);
- setenv("MASTER_ADDR", std::string("127.0.0.1:" + std::to_string(MASTER_PORT)).data(), 1);
+ setenv(thd::RANK_ENV, std::to_string(id).data(), 1);
+ setenv(thd::MASTER_ADDR_ENV, std::string("127.0.0.1:" + std::to_string(MASTER_PORT)).data(), 1);
auto workerChannel = std::make_shared<thd::DataChannelTCP>(); // reads all env variable
g_mutex.unlock();
diff --git a/torch/lib/THD/test/data_channel_tcp_smoke.cpp b/torch/lib/THD/test/data_channel_tcp_smoke.cpp
index 258cbed..348f690 100644
--- a/torch/lib/THD/test/data_channel_tcp_smoke.cpp
+++ b/torch/lib/THD/test/data_channel_tcp_smoke.cpp
@@ -1,4 +1,5 @@
#include "../base/data_channels/DataChannelTCP.hpp"
+#include "../base/ChannelEnvVars.hpp"
#include <THPP/tensors/THTensor.hpp>
@@ -17,9 +18,9 @@
void master()
{
g_mutex.lock();
- setenv("WORLD_SIZE", std::to_string((WORKERS_NUM + 1)).data(), 1);
- setenv("RANK", "0", 1);
- setenv("MASTER_PORT", std::to_string(MASTER_PORT).data(), 1);
+ setenv(thd::WORLD_SIZE_ENV, std::to_string((WORKERS_NUM + 1)).data(), 1);
+ setenv(thd::RANK_ENV, "0", 1);
+ setenv(thd::MASTER_PORT_ENV, std::to_string(MASTER_PORT).data(), 1);
auto masterChannel = std::make_shared<thd::DataChannelTCP>(); // reads all env variable
g_mutex.unlock();
@@ -36,8 +37,8 @@
void worker(int id)
{
g_mutex.lock();
- setenv("RANK", std::to_string(id).data(), 1);
- setenv("MASTER_ADDR", std::string("127.0.0.1:" + std::to_string(MASTER_PORT)).data(), 1);
+ setenv(thd::RANK_ENV, std::to_string(id).data(), 1);
+ setenv(thd::MASTER_ADDR_ENV, std::string("127.0.0.1:" + std::to_string(MASTER_PORT)).data(), 1);
auto workerChannel = std::make_shared<thd::DataChannelTCP>(); // reads all env variable
g_mutex.unlock();
diff --git a/torch/lib/THD/test/tensor_smoke.cpp b/torch/lib/THD/test/tensor_smoke.cpp
index 69d0576..3d6ea9e 100644
--- a/torch/lib/THD/test/tensor_smoke.cpp
+++ b/torch/lib/THD/test/tensor_smoke.cpp
@@ -2,14 +2,13 @@
#include <cassert>
#include <typeinfo>
-// #include "../base/tensors/THTensor.hpp"
#include <THPP/tensors/THTensor.hpp>
using namespace std;
int main() {
- thpp::FloatTensor *tensor = new thpp::THTensor<float>();
+ thpp::FloatTensor *tensor = new thpp::THTensor<float>();
thpp::FloatTensor *tensor2 = new thpp::THTensor<float>();
assert(tensor->nDim() == 0);
@@ -32,7 +31,7 @@
bool thrown = false;
try {
- thpp::IntTensor &a = dynamic_cast<thpp::IntTensor&>(*tensor);
+ thpp::IntTensor &a = dynamic_cast<thpp::IntTensor&>(*tensor);
} catch(std::bad_cast &e) {
thrown = true;
}