AllGather algorithm [CPU]

Summary: Allgather ring CPU implementation. Its does |buffers| x |contextSize| passes.

Reviewed By: pietern

Differential Revision: D4723809

fbshipit-source-id: ffd8366ac7e1746555474e173143d33cee497822
diff --git a/gloo/allgather_ring.h b/gloo/allgather_ring.h
new file mode 100644
index 0000000..f2f2faa
--- /dev/null
+++ b/gloo/allgather_ring.h
@@ -0,0 +1,124 @@
+/**
+ * Copyright (c) 2017-present, Facebook, Inc.
+ * All rights reserved.
+ *
+ * This source code is licensed under the BSD-style license found in the
+ * LICENSE file in the root directory of this source tree. An additional grant
+ * of patent rights can be found in the PATENTS file in the same directory.
+ */
+
+#pragma once
+
+#include <stddef.h>
+#include <string.h>
+
+#include "gloo/algorithm.h"
+#include "gloo/context.h"
+
+namespace gloo {
+
+// AllgatherRing is similar to MPI_Allgather where all processes receive the
+// buffers (inPtrs) from all other processes.
+// The caller needs to pass the receive buffers as a vector of memory pointers
+// (outPtrs) of size equal to the context size where the send buffers of the
+// process with rank = k will be written to outPtrs[k] contiguously.
+template <typename T>
+class AllgatherRing : public Algorithm {
+ public:
+  AllgatherRing(
+      const std::shared_ptr<Context>& context,
+      const std::vector<T*>& inPtrs,
+      std::vector<T*> outPtrs,
+      int count)
+      : Algorithm(context),
+        inPtrs_(inPtrs),
+        outPtrs_(outPtrs),
+        count_(count),
+        bytes_(count * sizeof(T)),
+        leftPair_(this->getLeftPair()),
+        rightPair_(this->getRightPair()) {
+    inbox_ = static_cast<T*>(malloc(bytes_));
+    outbox_ = static_cast<T*>(malloc(bytes_));
+
+    auto slot = this->context_->nextSlot();
+
+    sendDataBuf_ = rightPair_->createSendBuffer(slot, outbox_, bytes_);
+    recvDataBuf_ = leftPair_->createRecvBuffer(slot, inbox_, bytes_);
+
+    auto notificationSlot = this->context_->nextSlot();
+    sendNotificationBuf_ =
+        leftPair_->createSendBuffer(notificationSlot, &dummy_, sizeof(dummy_));
+    recvNotificationBuf_ =
+        rightPair_->createRecvBuffer(notificationSlot, &dummy_, sizeof(dummy_));
+  }
+
+  virtual ~AllgatherRing() {
+    if (inbox_ != nullptr) {
+      free(inbox_);
+    }
+
+    if (outbox_ != nullptr) {
+      free(outbox_);
+    }
+  }
+
+  void run() {
+    const int rank = this->contextRank_;
+    const int numRounds = this->contextSize_ - 1;
+
+    // Copy local buffer.
+    for (int i = 0; i < inPtrs_.size(); i++) {
+      memcpy(outPtrs_[rank] + i * count_, inPtrs_[i], bytes_);
+    }
+
+    // We send input buffers in order.
+    for (int i = 0; i < inPtrs_.size(); i++) {
+      memcpy(outbox_, inPtrs_[i], bytes_);
+      for (int round = 0; round < numRounds; round++) {
+        // Send data in the outbox buffer and wait to receive from left.
+        sendDataBuf_->send();
+        recvDataBuf_->waitRecv();
+
+        // Nodes receive data from the left node in every round and forward it
+        // to the right node.
+        int inRank = (numRounds - round + rank) % this->contextSize_;
+
+        // Copy received buffer inplace.
+        memcpy(outPtrs_[inRank] + i * count_, inbox_, bytes_);
+
+        // Forward received buffer to the right.
+        if (round < (numRounds - 1)) {
+          memcpy(outbox_, inbox_, bytes_);
+        }
+
+        // Send notification to node on the left that this node is ready for an
+        // inbox write.
+        sendNotificationBuf_->send();
+
+        // Wait for notification from node on the right.
+        recvNotificationBuf_->waitRecv();
+      }
+    }
+  }
+
+ private:
+  const std::vector<T*> inPtrs_;
+  std::vector<T*> outPtrs_;
+  const int count_;
+  const int bytes_;
+
+  std::unique_ptr<transport::Pair>& leftPair_;
+  std::unique_ptr<transport::Pair>& rightPair_;
+
+  T* inbox_;
+  T* outbox_;
+
+  std::unique_ptr<transport::Buffer> sendDataBuf_;
+  std::unique_ptr<transport::Buffer> recvDataBuf_;
+
+  int dummy_;
+  std::unique_ptr<transport::Buffer> sendNotificationBuf_;
+  std::unique_ptr<transport::Buffer> recvNotificationBuf_;
+};
+
+}  // namespace gloo
diff --git a/gloo/test/allgather_test.cc b/gloo/test/allgather_test.cc
new file mode 100644
index 0000000..8c7af4f
--- /dev/null
+++ b/gloo/test/allgather_test.cc
@@ -0,0 +1,79 @@
+/**
+ * Copyright (c) 2017-present, Facebook, Inc.
+ * All rights reserved.
+ *
+ * This source code is licensed under the BSD-style license found in the
+ * LICENSE file in the root directory of this source tree. An additional grant
+ * of patent rights can be found in the PATENTS file in the same directory.
+ */
+
+#include <functional>
+#include <thread>
+#include <vector>
+
+#include "gloo/allgather_ring.h"
+#include "gloo/test/base_test.h"
+
+namespace gloo {
+namespace test {
+namespace {
+
+// Test parameterization.
+using Param = std::tuple<int, int>;
+
+// Test fixture.
+class AllgatherTest : public BaseTest,
+                      public ::testing::WithParamInterface<Param> {};
+
+TEST_P(AllgatherTest, TwoPointer) {
+  auto contextSize = std::get<0>(GetParam());
+  auto dataSize = std::get<1>(GetParam());
+
+  spawn(contextSize, [&](std::shared_ptr<Context> context) {
+
+    Fixture inFixture(context, 2, dataSize);
+    inFixture.assignValues();
+
+    Fixture outFixture(context, contextSize, 2 * dataSize);
+
+    AllgatherRing<float> algorithm(
+        context,
+        inFixture.getFloatPointers(),
+        outFixture.getFloatPointers(),
+        dataSize);
+
+    algorithm.run();
+
+    auto stride = contextSize * 2;
+    for (int i = 0; i < contextSize; ++i) {
+      auto val = i * 2;
+      for (int j = 0; j < dataSize; j++) {
+        float exp = j * stride + val;
+        ASSERT_EQ(outFixture.getFloatPointers()[i][j], exp)
+            << "Mismatch at index [" << i << ", " << j << "]";
+        ASSERT_EQ(outFixture.getFloatPointers()[i][j + dataSize], exp + 1)
+            << "Mismatch at index [" << i << ", " << j + dataSize << "]";
+      }
+    }
+  });
+}
+
+std::vector<int> genMemorySizes() {
+  std::vector<int> v;
+  v.push_back(sizeof(float));
+  v.push_back(100);
+  v.push_back(1000);
+  v.push_back(10000);
+  return v;
+}
+
+INSTANTIATE_TEST_CASE_P(
+    AllgatherRing,
+    AllgatherTest,
+    ::testing::Combine(
+        ::testing::Range(2, 16),
+        ::testing::ValuesIn(genMemorySizes())));
+
+} // namespace
+} // namespace test
+} // namespace gloo