blob: 372a3664e2d35a6c381cb7fbe069c3e0741569d3 [file] [log] [blame]
/* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "absl/strings/str_replace.h"
#include "absl/types/span.h"
#include "tensorflow/compiler/xla/literal.h"
#include "tensorflow/compiler/xla/primitive_util.h"
#include "tensorflow/compiler/xla/service/gpu/nccl_all_reduce_thunk.h"
#include "tensorflow/compiler/xla/shape_util.h"
#include "tensorflow/compiler/xla/test.h"
#include "tensorflow/compiler/xla/test_helpers.h"
#include "tensorflow/compiler/xla/tests/hlo_test_base.h"
#include "tensorflow/compiler/xla/tests/test_macros.h"
#include "tensorflow/core/lib/core/blocking_counter.h"
#include "tensorflow/core/lib/core/status_test_util.h"
#include "tensorflow/core/lib/core/threadpool.h"
// Tests cross-GPU operatons.
//
// This test requires at least four GPUs. For instructions on running this
// within Google, see go/multi-gpu-unit-test.
namespace xla {
namespace {
using ::testing::IsEmpty;
using ::testing::UnorderedElementsAre;
class CollectiveOpsTest : public HloTestBase {
protected:
std::unique_ptr<HloModule> MakeCrsModule(
int64 num_elems, std::vector<std::vector<int64>> replica_groups,
const HloModuleConfig& config, std::string op = "add",
std::string datatype = "f32") {
const char* kTemplate = R"(
HloModule test
apply_op {
x = DATATYPE[] parameter(0)
y = DATATYPE[] parameter(1)
ROOT apply_op = DATATYPE[] OP(x, y)
}
ENTRY test_computation {
p = DATATYPE[NUM_ELEMS] parameter(0)
ROOT crs = DATATYPE[NUM_ELEMS] all-reduce(p), replica_groups=REPLICA_GROUPS, to_apply=apply_op
}
)";
std::vector<string> replica_group_strs;
for (const auto& g : replica_groups) {
replica_group_strs.push_back(
absl::StrFormat("{%s}", absl::StrJoin(g, ",")));
}
return ParseAndReturnVerifiedModule(
absl::StrReplaceAll(
kTemplate,
{{"NUM_ELEMS", absl::StrCat(num_elems)},
{"REPLICA_GROUPS",
absl::StrFormat("{%s}",
absl::StrJoin(replica_group_strs, ", "))},
{"OP", op},
{"DATATYPE", datatype}}),
config)
.ValueOrDie();
}
template <typename LiteralType>
void TestTwoReplicasOneOperand(std::string op,
std::vector<LiteralType> input_value,
std::vector<LiteralType> expected_value) {
std::string dtype = primitive_util::LowercasePrimitiveTypeName(
primitive_util::NativeToPrimitiveType<LiteralType>());
auto config = GetModuleConfigForTest();
config.set_replica_count(2);
auto module = MakeCrsModule(/*num_elems=*/3, /*replica_groups=*/{}, config,
/*op=*/op, /*datatype=*/dtype);
auto literal = LiteralUtil::CreateR1<LiteralType>(input_value);
auto expected = LiteralUtil::CreateR1<LiteralType>(expected_value);
TF_ASSERT_OK_AND_ASSIGN(
std::vector<Literal> results,
ExecuteReplicated(std::move(module), {&literal}, /*num_replicas=*/2,
/*use_threads=*/true));
EXPECT_TRUE(LiteralTestUtil::NearOrEqual(expected, results[0],
ErrorSpec{1e-5, 1e-5}));
EXPECT_TRUE(LiteralTestUtil::NearOrEqual(expected, results[1],
ErrorSpec{1e-5, 1e-5}));
}
template <typename LiteralType>
void TestAllOps() {
auto cast = [&](int value) { return static_cast<LiteralType>(value); };
std::vector<LiteralType> input_value = {cast(1), cast(2), cast(3)};
TestTwoReplicasOneOperand<LiteralType>(
"add",
/*input_value=*/input_value,
/*expected_value=*/{cast(2), cast(4), cast(6)});
TestTwoReplicasOneOperand<LiteralType>(
"multiply",
/*input_value=*/input_value,
/*expected_value=*/{cast(1), cast(4), cast(9)});
TestTwoReplicasOneOperand<LiteralType>(
"maximum",
/*input_value=*/input_value,
/*expected_value=*/{cast(1), cast(2), cast(3)});
TestTwoReplicasOneOperand<LiteralType>(
"minimum",
/*input_value=*/input_value,
/*expected_value=*/{cast(1), cast(2), cast(3)});
}
};
// Returns the non-empty subsets of {0, 1, ..., n}. For example,
// PowerSetOfIota(3) = {{0}, {1}, {2}, {0,1}, {0,2}, {1,2}, {0,1,2}}.
std::vector<std::vector<int64>> PowerSetOfIota(int64 n) {
std::vector<std::vector<int64>> power_set;
for (int64 i = 1; i < (1 << n); ++i) {
power_set.emplace_back();
for (int64 j = 0; j < n; ++j) {
if (i & (1 << j)) {
power_set.back().push_back(j);
}
}
}
return power_set;
}
// Makes a DeviceAssignment assigning replica-id i to devices[i].
DeviceAssignment MakeDeviceAssn(std::vector<int64> devices) {
DeviceAssignment assn(/*replica_count=*/devices.size(),
/*computation_count=*/1);
for (int64 i = 0; i < devices.size(); ++i) {
assn(i, 0) = devices[i];
}
return assn;
}
// Shorter alias for this function.
absl::flat_hash_set<int> OpenNcclChannels() {
return gpu::NcclAllReduceThunk::DevicesWithOpenNcclChannels();
}
template <typename T>
static Eigen::half ToHalf(T value) {
return static_cast<Eigen::half>(value);
}
XLA_TEST_F(CollectiveOpsTest, AllReduceTwoReplicasOneOperand_int8) {
TestAllOps<int8>();
}
XLA_TEST_F(CollectiveOpsTest, AllReduceTwoReplicasOneOperand_uint8) {
TestAllOps<uint8>();
}
XLA_TEST_F(CollectiveOpsTest, AllReduceTwoReplicasOneOperand_uint32) {
TestAllOps<uint32>();
}
XLA_TEST_F(CollectiveOpsTest, AllReduceTwoReplicasOneOperand_int32) {
TestAllOps<int32>();
}
XLA_TEST_F(CollectiveOpsTest, AllReduceTwoReplicasOneOperand_int64) {
TestAllOps<int64>();
}
XLA_TEST_F(CollectiveOpsTest, AllReduceTwoReplicasOneOperand_uint64) {
TestAllOps<uint64>();
}
XLA_TEST_F(CollectiveOpsTest, AllReduceTwoReplicasOneOperand_float32) {
TestAllOps<float>();
}
XLA_TEST_F(CollectiveOpsTest, AllReduceTwoReplicasOneOperand_double) {
TestAllOps<double>();
}
XLA_TEST_F(CollectiveOpsTest, AllReduceTwoReplicasOneOperand_half) {
TestAllOps<Eigen::half>();
}
// Tries all-to-all operations across all 2^kNumDevices - 1 combinations of
// devices in sequence.
XLA_TEST_F(CollectiveOpsTest, AllReduce_AllCombinations) {
const int64 kNumDevices = 4;
const int64 kNumElems = 1024;
for (std::vector<int64> devices : PowerSetOfIota(kNumDevices)) {
SCOPED_TRACE(absl::StrFormat("Running on devices {%s}",
absl::StrJoin(devices, ", ")));
DeviceAssignment device_assn = MakeDeviceAssn(devices);
auto config = GetModuleConfigForTest();
config.set_replica_count(devices.size());
config.set_static_device_assignment(device_assn);
auto module = MakeCrsModule(kNumElems, /*replica_groups=*/{}, config);
std::vector<float> input_vec(kNumElems);
absl::c_iota(input_vec, 0);
auto input_literal = LiteralUtil::CreateR1<float>(input_vec);
TF_ASSERT_OK_AND_ASSIGN(
std::vector<Literal> results,
ExecuteReplicated(std::move(module), {&input_literal},
/*num_replicas=*/devices.size(), &device_assn,
/*run_hlo_passes=*/true, /*use_threads=*/true));
}
}
// Check that the NCCL data structures in our all-reduce implementation are
// cached as we expect.
XLA_TEST_F(CollectiveOpsTest, DISABLED_ON_CPU(AllReduce_NcclChannelCaching)) {
const int64 kNumElems = 1024;
std::vector<float> input_vec(kNumElems);
absl::c_iota(input_vec, 0);
auto input_literal = LiteralUtil::CreateR1<float>(input_vec);
// Initially no NCCL channels should be open.
EXPECT_THAT(OpenNcclChannels(), IsEmpty());
// Create three Executables, touching devices {0,1}, {1,2}, and {0,1,2}.
struct ExecutableInfo {
std::unique_ptr<Executable> executable;
DeviceAssignment device_assn;
HloRunner::ReplicatedExecuteOptions opts;
};
std::vector<ExecutableInfo> executables;
for (const auto& devices :
std::vector<std::vector<int64>>{{0, 1}, {1, 2}, {0, 1, 2}}) {
executables.emplace_back();
auto& e = executables.back();
e.device_assn = MakeDeviceAssn(devices);
auto config = GetModuleConfigForTest();
config.set_replica_count(devices.size());
config.set_static_device_assignment(e.device_assn);
auto module = MakeCrsModule(kNumElems, /*replica_groups=*/{}, config);
e.executable =
test_runner_
.CreateExecutable(std::move(module), /*run_hlo_passes=*/true)
.ValueOrDie();
e.opts.num_replicas = devices.size();
e.opts.use_threads = true;
e.opts.arguments.push_back(&input_literal);
}
auto run_executable = [&](int64 i) {
auto& e = executables[i];
TF_ASSERT_OK(
test_runner_
.ExecuteReplicated(e.executable.get(), e.opts, &e.device_assn)
.status());
};
// Compiling executables above shouldn't cause us to open any channels.
EXPECT_THAT(OpenNcclChannels(), IsEmpty());
// Run the executables and check that channels are opened as we expect.
run_executable(0);
EXPECT_THAT(OpenNcclChannels(), UnorderedElementsAre(0, 1));
run_executable(2);
EXPECT_THAT(OpenNcclChannels(), UnorderedElementsAre(0, 1, 2));
run_executable(1);
EXPECT_THAT(OpenNcclChannels(), UnorderedElementsAre(0, 1, 2));
// Tear down the executables and check that channels are closed as we expect.
// Note that after we tear down an executable *all* the nccl channels may go
// away, so we rerun all of the executables that haven't been torn down.
executables[2].executable.reset();
run_executable(0);
run_executable(1);
EXPECT_THAT(OpenNcclChannels(), UnorderedElementsAre(0, 1, 2));
executables[0].executable.reset();
run_executable(1);
EXPECT_THAT(OpenNcclChannels(), UnorderedElementsAre(1, 2));
executables[1].executable.reset();
EXPECT_THAT(OpenNcclChannels(), IsEmpty());
}
// Runs the same executable many times concurrently. The all-reduces should not
// conflict with one another.
XLA_TEST_F(CollectiveOpsTest, AllReduce_ManyConcurrentAllReduces) {
const int64 kNumElems = 1024;
const int64 kNumThreads = 200;
const int64 kRunsPerThread = 10;
auto config = GetModuleConfigForTest();
config.set_replica_count(2);
auto executable =
test_runner_
.CreateExecutable(
MakeCrsModule(kNumElems, /*replica_groups=*/{}, config),
/*run_hlo_passes=*/true)
.ValueOrDie();
std::vector<int64> devices = {0, 1};
auto device_assn = MakeDeviceAssn(devices);
std::vector<float> input_vec(kNumElems);
absl::c_iota(input_vec, 0);
auto input_literal = LiteralUtil::CreateR1<float>(input_vec);
HloRunner::ReplicatedExecuteOptions opts;
opts.num_replicas = devices.size();
opts.use_threads = true;
opts.arguments.push_back(&input_literal);
tensorflow::BlockingCounter done(kNumThreads * kRunsPerThread);
tensorflow::thread::ThreadPool pool(tensorflow::Env::Default(), TestName(),
kNumThreads);
for (int64 i = 0; i < kNumThreads * kRunsPerThread; ++i) {
pool.Schedule([&] {
TF_ASSERT_OK(
test_runner_.ExecuteReplicated(executable.get(), opts, &device_assn)
.status());
done.DecrementCount();
});
}
done.Wait();
}
// Runs an all-reduce with three partitions:
// {0}, {1,2}, {3}
// meaning, the all-reduce is a nop for devices 0 and 3, and only devices 1 and
// 2 actually exchange data with each other.
XLA_TEST_F(CollectiveOpsTest, AllReduce_ThreeReplicaGroups) {
// Test a prime number so it's not all powers of 2.
const int64 kNumElems = 137;
auto config = GetModuleConfigForTest();
config.set_replica_count(4);
auto module = MakeCrsModule(/*num_elems=*/kNumElems,
/*replica_groups=*/{{0}, {1, 2}, {3}}, config);
std::vector<float> input_vec(kNumElems);
absl::c_iota(input_vec, 0);
auto input_literal = LiteralUtil::CreateR1<float>(input_vec);
TF_ASSERT_OK_AND_ASSIGN(
std::vector<Literal> results,
ExecuteReplicated(std::move(module), {&input_literal}, /*num_replicas=*/4,
/*use_threads=*/true));
ASSERT_EQ(results.size(), 4);
std::vector<float> input_vec_doubled;
for (float n : input_vec) {
input_vec_doubled.push_back(n * 2);
}
auto input_literal_doubled = LiteralUtil::CreateR1<float>(input_vec_doubled);
EXPECT_TRUE(LiteralTestUtil::Equal(input_literal, results[0]));
EXPECT_TRUE(LiteralTestUtil::Equal(input_literal_doubled, results[1]));
EXPECT_TRUE(LiteralTestUtil::Equal(input_literal_doubled, results[2]));
EXPECT_TRUE(LiteralTestUtil::Equal(input_literal, results[3]));
}
XLA_TEST_F(CollectiveOpsTest, ReplicaId) {
const char* const kModuleStr = R"(
HloModule test
ENTRY test_computation {
ROOT id = u32[] replica-id()
}
)";
const int64 kNumReplicas = 4;
auto config = GetModuleConfigForTest();
config.set_replica_count(kNumReplicas);
TF_ASSERT_OK_AND_ASSIGN(auto module,
ParseAndReturnVerifiedModule(kModuleStr));
TF_ASSERT_OK_AND_ASSIGN(std::vector<Literal> results,
ExecuteReplicated(std::move(module), {}, kNumReplicas,
/*use_threads=*/true));
ASSERT_EQ(results.size(), kNumReplicas);
for (uint32 i = 0; i < kNumReplicas; ++i) {
EXPECT_TRUE(LiteralTestUtil::Equal(LiteralUtil::CreateR0(i), results[i]));
}
}
XLA_TEST_F(CollectiveOpsTest, DISABLED_ON_CPU(CollectivePermute_Simple)) {
const char* const kModuleStr = R"(
HloModule test
ENTRY test_computation {
replica = u32[] replica-id()
ten = u32[] constant(10)
sum = u32[] add(replica, ten)
p = u32[2] broadcast(sum), dimensions={}
ROOT permute = u32[2] collective-permute(p), source_target_pairs={{1,0}, {0,1}, {2,2}}
}
)";
const int64 kNumReplicas = 4;
auto config = GetModuleConfigForTest();
config.set_replica_count(kNumReplicas);
TF_ASSERT_OK_AND_ASSIGN(auto module,
ParseAndReturnVerifiedModule(kModuleStr, config));
TF_ASSERT_OK_AND_ASSIGN(std::vector<Literal> results,
ExecuteReplicated(std::move(module), {}, kNumReplicas,
/*use_threads=*/true));
ASSERT_EQ(results.size(), kNumReplicas);
EXPECT_TRUE(LiteralTestUtil::Equal(LiteralUtil::CreateR1<uint32>({11, 11}),
results[0]));
EXPECT_TRUE(LiteralTestUtil::Equal(LiteralUtil::CreateR1<uint32>({10, 10}),
results[1]));
EXPECT_TRUE(LiteralTestUtil::Equal(LiteralUtil::CreateR1<uint32>({12, 12}),
results[2]));
// Nothing writes to replica 3, so it is memzero'ed.
EXPECT_TRUE(LiteralTestUtil::Equal(LiteralUtil::CreateR1<uint32>({0, 0}),
results[3]));
}
} // namespace
} // namespace xla