blob: b5d8cb40916194e864befa35e9ddbe48b37bb402 [file] [log] [blame]
/* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "tensorflow/compiler/xla/service/memory_space_assignment.h"
#include "tensorflow/compiler/xla/service/hlo_computation.h"
#include "tensorflow/compiler/xla/service/hlo_matchers.h"
#include "tensorflow/compiler/xla/tests/hlo_test_base.h"
namespace xla {
namespace {
namespace op = xla::testing::opcode_matchers;
class MemorySpaceAssignmentTest : public HloTestBase {
protected:
// We use the following two memory space values to describe the default (slow
// and large) and alternate (fast and small) memory spaces.
const int64 kDefaultMemorySpace = 0;
const int64 kAlternateMemorySpace = 1;
std::unique_ptr<PresetAssignments> AssignMemorySpace(HloModule* module) {
auto size_fn = [](const BufferValue& buffer) {
return ShapeUtil::ByteSizeOf(buffer.shape(), /*pointer_size=*/8);
};
auto is_allowed_in_alternate_mem = [](const HloValue& value) {
// Check if the value belongs to the entry computation.
HloInstruction* instruction = value.instruction();
HloComputation* computation = instruction->parent();
bool in_entry_computation =
(computation == computation->parent()->entry_computation());
if (in_entry_computation &&
instruction->opcode() == HloOpcode::kParameter) {
return false;
}
return true;
};
return std::move(MemorySpaceAssignment::Run(
module, kAlternateMemorySpace,
/*max_size_in_bytes=*/128,
/*min_prefetch_interval=*/2,
/*max_prefetch_interval=*/10,
/*alternate_memory_space_alignment_in_bytes=*/8,
size_fn, is_allowed_in_alternate_mem)
.ValueOrDie());
}
};
TEST_F(MemorySpaceAssignmentTest, ParameterOnly) {
// A module consisting of a single parameter. Inputs/outputs are currently
// excluded from memory space assignment.
HloComputation::Builder builder(TestName());
Shape shape = ShapeUtil::MakeShape(F32, {2, 3});
HloInstruction* p0 =
builder.AddInstruction(HloInstruction::CreateParameter(0, shape, "p0"));
auto module = CreateNewVerifiedModule();
HloComputation* computation = module->AddEntryComputation(builder.Build());
HloSchedule schedule(module.get());
schedule.set_sequence(computation, {p0});
TF_CHECK_OK(module->set_schedule(schedule));
AssignMemorySpace(module.get());
EXPECT_THAT(p0, op::ShapeWithLayout(shape));
}
TEST_F(MemorySpaceAssignmentTest, Simple) {
// A simple module with a few simple instructions. Expect this to be
// transformed with CopyStart and CopyDone instructions inserted after inputs
// and before outputs.
HloComputation::Builder builder(TestName());
Shape shape = ShapeUtil::MakeShape(F32, {2, 3});
HloInstruction* p0 =
builder.AddInstruction(HloInstruction::CreateParameter(0, shape, "p0"));
HloInstruction* p1 =
builder.AddInstruction(HloInstruction::CreateParameter(1, shape, "p1"));
HloInstruction* add = builder.AddInstruction(
HloInstruction::CreateBinary(shape, HloOpcode::kAdd, p0, p1));
HloInstruction* sub = builder.AddInstruction(
HloInstruction::CreateBinary(shape, HloOpcode::kSubtract, p0, p1));
HloInstruction* mul = builder.AddInstruction(
HloInstruction::CreateBinary(shape, HloOpcode::kMultiply, add, sub));
auto module = CreateNewVerifiedModule();
HloComputation* computation = module->AddEntryComputation(builder.Build());
HloSchedule schedule(module.get());
schedule.set_sequence(computation, {p0, p1, add, sub, mul});
TF_CHECK_OK(module->set_schedule(schedule));
auto preset_assignments = AssignMemorySpace(module.get());
// Inputs and outputs are currently placed in the default memory. Everything
// else should be in the alternate memory.
Shape shape_in_alternate_mem = ShapeUtil::MakeShapeWithLayout(
F32, {2, 3},
/*minor_to_major=*/{1, 0}, /*tiles=*/{}, /*element_size_in_bits=*/0,
kAlternateMemorySpace);
EXPECT_THAT(p0, op::ShapeWithLayout(shape));
EXPECT_THAT(p1, op::ShapeWithLayout(shape));
EXPECT_THAT(mul, op::ShapeWithLayout(shape));
EXPECT_THAT(add, op::ShapeWithLayout(shape_in_alternate_mem));
EXPECT_THAT(sub, op::ShapeWithLayout(shape_in_alternate_mem));
// Make sure the preset assignments is sane.
EXPECT_THAT(preset_assignments->chunks().size(), 2);
EXPECT_THAT(preset_assignments->sizes().size(), 1);
}
TEST_F(MemorySpaceAssignmentTest, NegateChain) {
// The negate chain is long enough for asynchronous copy to be inserted
// between p1 and add.
HloComputation::Builder builder(TestName());
Shape shape = ShapeUtil::MakeShape(F32, {2, 3});
HloInstruction* p0 =
builder.AddInstruction(HloInstruction::CreateParameter(0, shape, "p0"));
HloInstruction* p1 =
builder.AddInstruction(HloInstruction::CreateParameter(1, shape, "p1"));
HloInstruction* negate0 = builder.AddInstruction(
HloInstruction::CreateUnary(shape, HloOpcode::kNegate, p0));
HloInstruction* negate1 = builder.AddInstruction(
HloInstruction::CreateUnary(shape, HloOpcode::kNegate, negate0));
HloInstruction* negate2 = builder.AddInstruction(
HloInstruction::CreateUnary(shape, HloOpcode::kNegate, negate1));
HloInstruction* negate3 = builder.AddInstruction(
HloInstruction::CreateUnary(shape, HloOpcode::kNegate, negate2));
HloInstruction* negate4 = builder.AddInstruction(
HloInstruction::CreateUnary(shape, HloOpcode::kNegate, negate3));
HloInstruction* negate5 = builder.AddInstruction(
HloInstruction::CreateUnary(shape, HloOpcode::kNegate, negate4));
HloInstruction* negate6 = builder.AddInstruction(
HloInstruction::CreateUnary(shape, HloOpcode::kNegate, negate5));
HloInstruction* add = builder.AddInstruction(
HloInstruction::CreateBinary(shape, HloOpcode::kAdd, negate6, p1));
auto module = CreateNewVerifiedModule();
HloComputation* computation = module->AddEntryComputation(builder.Build());
HloSchedule schedule(module.get());
schedule.set_sequence(computation, {p0, p1, negate0, negate1, negate2,
negate3, negate4, negate5, negate6, add});
TF_CHECK_OK(module->set_schedule(schedule));
AssignMemorySpace(module.get());
EXPECT_THAT(add, op::Add(op::Negate(), op::AsyncCopy(kAlternateMemorySpace,
kDefaultMemorySpace,
op::Parameter(1))));
// Parameters are in the default memory space.
EXPECT_THAT(p0, op::ShapeWithLayout(shape));
EXPECT_THAT(p1, op::ShapeWithLayout(shape));
// Negate instructions are in the alternate memory space (1).
Shape shape_in_alternate_mem = ShapeUtil::MakeShapeWithLayout(
F32, {2, 3},
/*minor_to_major=*/{1, 0}, /*tiles=*/{}, /*element_size_in_bits=*/0,
kAlternateMemorySpace);
EXPECT_THAT(negate0, op::ShapeWithLayout(shape_in_alternate_mem));
EXPECT_THAT(negate1, op::ShapeWithLayout(shape_in_alternate_mem));
EXPECT_THAT(negate2, op::ShapeWithLayout(shape_in_alternate_mem));
EXPECT_THAT(negate3, op::ShapeWithLayout(shape_in_alternate_mem));
EXPECT_THAT(negate4, op::ShapeWithLayout(shape_in_alternate_mem));
EXPECT_THAT(negate5, op::ShapeWithLayout(shape_in_alternate_mem));
EXPECT_THAT(negate6, op::ShapeWithLayout(shape_in_alternate_mem));
// Ensure the CopyStart/CopyDone schedules.
const HloInstructionSequence& sequence =
module->schedule().sequence(computation);
EXPECT_THAT(sequence.instructions()[0], op::Parameter(0));
EXPECT_THAT(sequence.instructions()[1], op::Parameter(1));
EXPECT_THAT(sequence.instructions()[2], op::CopyStart());
EXPECT_THAT(sequence.instructions()[10], op::CopyDone());
}
TEST_F(MemorySpaceAssignmentTest, EvictAndPrefetch) {
HloComputation::Builder builder(TestName());
Shape shape = ShapeUtil::MakeShape(F32, {2, 3});
HloInstruction* p0 =
builder.AddInstruction(HloInstruction::CreateParameter(0, shape, "p0"));
HloInstruction* p1 =
builder.AddInstruction(HloInstruction::CreateParameter(1, shape, "p1"));
HloInstruction* tanh = builder.AddInstruction(
HloInstruction::CreateUnary(shape, HloOpcode::kTanh, p0));
// tanh should be placed in the alternate memory since there isn't much
// contention in the beginning. However, tanh has another consumer at the end.
// So it should be kicked out to default memory and prefetched back in.
// The graph below is meant to increase the contention to force
// eviction/prefetch behavior.
HloInstruction* a = builder.AddInstruction(
HloInstruction::CreateBinary(shape, HloOpcode::kAdd, p0, tanh));
HloInstruction* b = builder.AddInstruction(
HloInstruction::CreateBinary(shape, HloOpcode::kSubtract, p0, p1));
HloInstruction* c = builder.AddInstruction(
HloInstruction::CreateBinary(shape, HloOpcode::kMultiply, p0, p1));
HloInstruction* d = builder.AddInstruction(
HloInstruction::CreateBinary(shape, HloOpcode::kSubtract, p0, p1));
HloInstruction* e = builder.AddInstruction(
HloInstruction::CreateBinary(shape, HloOpcode::kMultiply, a, b));
HloInstruction* f = builder.AddInstruction(
HloInstruction::CreateBinary(shape, HloOpcode::kMultiply, a, c));
HloInstruction* g = builder.AddInstruction(
HloInstruction::CreateBinary(shape, HloOpcode::kMultiply, a, d));
HloInstruction* h = builder.AddInstruction(
HloInstruction::CreateBinary(shape, HloOpcode::kMultiply, b, c));
HloInstruction* i = builder.AddInstruction(
HloInstruction::CreateBinary(shape, HloOpcode::kMultiply, b, d));
HloInstruction* j = builder.AddInstruction(
HloInstruction::CreateBinary(shape, HloOpcode::kMultiply, c, d));
HloInstruction* k = builder.AddInstruction(
HloInstruction::CreateBinary(shape, HloOpcode::kAdd, e, f));
HloInstruction* l = builder.AddInstruction(
HloInstruction::CreateBinary(shape, HloOpcode::kAdd, g, h));
HloInstruction* m = builder.AddInstruction(
HloInstruction::CreateBinary(shape, HloOpcode::kAdd, i, j));
HloInstruction* n = builder.AddInstruction(
HloInstruction::CreateBinary(shape, HloOpcode::kAdd, k, l));
HloInstruction* o = builder.AddInstruction(
HloInstruction::CreateBinary(shape, HloOpcode::kAdd, n, m));
// tanh is being used at the root instruction, and this should be prefetched.
HloInstruction* add = builder.AddInstruction(
HloInstruction::CreateBinary(shape, HloOpcode::kAdd, o, tanh));
auto module = CreateNewVerifiedModule();
HloComputation* computation = module->AddEntryComputation(builder.Build());
HloSchedule schedule(module.get());
schedule.set_sequence(computation, {p0, p1, tanh, a, b, c, d, e, f, g, h, i,
j, k, l, m, n, o, add});
TF_CHECK_OK(module->set_schedule(schedule));
AssignMemorySpace(module.get());
EXPECT_THAT(
add,
op::Add(op::Add(),
op::AsyncCopy(kAlternateMemorySpace, kDefaultMemorySpace,
op::AsyncCopy(kDefaultMemorySpace,
kAlternateMemorySpace, op::Tanh()))));
}
TEST_F(MemorySpaceAssignmentTest, While) {
auto module = CreateNewVerifiedModule();
Shape shape = ShapeUtil::MakeShape(xla::F32, {2, 3});
Shape scalar_shape = ShapeUtil::MakeShape(xla::F32, {});
Shape tuple_shape = ShapeUtil::MakeTupleShape({shape, scalar_shape});
auto cond_builder = HloComputation::Builder("WhileCond");
// Tuple param: 24 bytes (each elem has 8 byte pointer, 4 byte element)
HloInstruction* cond_param = cond_builder.AddInstruction(
HloInstruction::CreateParameter(0, tuple_shape, "cond_param"));
HloInstruction* cond_iter = cond_builder.AddInstruction(
HloInstruction::CreateGetTupleElement(scalar_shape, cond_param, 1));
HloInstruction* cond_limit = cond_builder.AddInstruction(
HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(50.f)));
// Free cond_param[] (16 bytes), Alloc PRED[] (1 byte)
HloInstruction* cond_lt = cond_builder.AddInstruction(
HloInstruction::CreateCompare(ShapeUtil::MakeShape(PRED, {}), cond_iter,
cond_limit, ComparisonDirection::kLt));
HloComputation* cond_computation =
module->AddEmbeddedComputation(cond_builder.Build());
auto body_builder = HloComputation::Builder("WhileBody");
// Tuple param: 24 bytes (each elem has 8 byte pointer, 4 byte element)
HloInstruction* body_param = body_builder.AddInstruction(
HloInstruction::CreateParameter(0, tuple_shape, "body_param"));
HloInstruction* body_iter = body_builder.AddInstruction(
HloInstruction::CreateGetTupleElement(scalar_shape, body_param, 1));
HloInstruction* body_data = body_builder.AddInstruction(
HloInstruction::CreateGetTupleElement(shape, body_param, 0));
HloInstruction* body_iter_increment = body_builder.AddInstruction(
HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(1.f)));
HloInstruction* body_iter_next =
body_builder.AddInstruction(HloInstruction::CreateBinary(
scalar_shape, HloOpcode::kAdd, body_iter, body_iter_increment));
HloInstruction* body_data_increment =
body_builder.AddInstruction(HloInstruction::CreateConstant(
LiteralUtil::CreateR2<float>({{1.f, 2.f, 3.f}, {4.f, 5.f, 6.f}})));
HloInstruction* body_data_mul =
body_builder.AddInstruction(HloInstruction::CreateBinary(
shape, HloOpcode::kMultiply, body_data, body_data));
HloInstruction* body_data_add =
body_builder.AddInstruction(HloInstruction::CreateBinary(
shape, HloOpcode::kAdd, body_data, body_data_increment));
HloInstruction* body_data_next =
body_builder.AddInstruction(HloInstruction::CreateBinary(
shape, HloOpcode::kAdd, body_data_add, body_data_mul));
HloInstruction* body_out = body_builder.AddInstruction(
HloInstruction::CreateTuple({body_data_next, body_iter_next}));
HloComputation* body_computation =
module->AddEmbeddedComputation(body_builder.Build());
auto builder = HloComputation::Builder(TestName());
HloInstruction* data = builder.AddInstruction(
HloInstruction::CreateParameter(0, shape, "param_iter"));
HloInstruction* iter = builder.AddInstruction(
HloInstruction::CreateParameter(1, scalar_shape, "param_data"));
HloInstruction* tuple =
builder.AddInstruction(HloInstruction::CreateTuple({data, iter}));
HloInstruction* while_op = builder.AddInstruction(HloInstruction::CreateWhile(
tuple_shape, cond_computation, body_computation, tuple));
HloComputation* entry_computation =
module->AddEntryComputation(builder.Build());
HloSchedule schedule(module.get());
schedule.set_sequence(cond_computation,
{cond_param, cond_iter, cond_limit, cond_lt});
schedule.set_sequence(body_computation,
{body_param, body_iter, body_data, body_iter_increment,
body_iter_next, body_data_increment, body_data_mul,
body_data_add, body_data_next, body_out});
schedule.set_sequence(entry_computation, {iter, data, tuple, while_op});
TF_CHECK_OK(module->set_schedule(schedule));
AssignMemorySpace(module.get());
// Ensure the tuple value and buffers used in the while instruction are
// exempted from using the alternate memory. However, body_data_mul is
// independent and can be safely be placed in the alternate memory.
EXPECT_THAT(tuple, op::ShapeWithLayout(tuple_shape));
EXPECT_THAT(data, op::ShapeWithLayout(shape));
EXPECT_THAT(iter, op::ShapeWithLayout(scalar_shape));
EXPECT_THAT(body_data, op::ShapeWithLayout(shape));
EXPECT_THAT(body_iter, op::ShapeWithLayout(scalar_shape));
EXPECT_THAT(cond_iter, op::ShapeWithLayout(scalar_shape));
Shape shape_in_alternate_mem = ShapeUtil::MakeShapeWithLayout(
F32, {2, 3},
/*minor_to_major=*/{1, 0}, /*tiles=*/{}, /*element_size_in_bits=*/0,
kAlternateMemorySpace);
EXPECT_THAT(body_data_mul, op::ShapeWithLayout(shape_in_alternate_mem));
}
TEST_F(MemorySpaceAssignmentTest, Tuple) {
HloComputation::Builder builder(TestName());
Shape shape = ShapeUtil::MakeShape(F32, {2, 3});
Shape inner_tuple_shape = ShapeUtil::MakeTupleShape({shape});
Shape tuple_shape =
ShapeUtil::MakeTupleShape({shape, shape, inner_tuple_shape});
HloInstruction* p = builder.AddInstruction(
HloInstruction::CreateParameter(0, tuple_shape, "p"));
HloInstruction* p0 = builder.AddInstruction(
HloInstruction::CreateGetTupleElement(shape, p, 0));
HloInstruction* negate0 = builder.AddInstruction(
HloInstruction::CreateUnary(shape, HloOpcode::kNegate, p0));
HloInstruction* negate1 = builder.AddInstruction(
HloInstruction::CreateUnary(shape, HloOpcode::kNegate, negate0));
HloInstruction* negate2 = builder.AddInstruction(
HloInstruction::CreateUnary(shape, HloOpcode::kNegate, negate1));
HloInstruction* negate3 = builder.AddInstruction(
HloInstruction::CreateUnary(shape, HloOpcode::kNegate, negate2));
HloInstruction* negate4 = builder.AddInstruction(
HloInstruction::CreateUnary(shape, HloOpcode::kNegate, negate3));
HloInstruction* negate5 = builder.AddInstruction(
HloInstruction::CreateUnary(shape, HloOpcode::kNegate, negate4));
HloInstruction* negate6 = builder.AddInstruction(
HloInstruction::CreateUnary(shape, HloOpcode::kNegate, negate5));
HloInstruction* p1 = builder.AddInstruction(
HloInstruction::CreateGetTupleElement(shape, p, 1));
HloInstruction* add = builder.AddInstruction(
HloInstruction::CreateBinary(shape, HloOpcode::kAdd, negate6, p1));
HloInstruction* p2 = builder.AddInstruction(
HloInstruction::CreateGetTupleElement(inner_tuple_shape, p, 2));
HloInstruction* p2_0 = builder.AddInstruction(
HloInstruction::CreateGetTupleElement(shape, p2, 0));
HloInstruction* mul = builder.AddInstruction(
HloInstruction::CreateBinary(shape, HloOpcode::kMultiply, add, p2_0));
auto module = CreateNewVerifiedModule();
HloComputation* computation = module->AddEntryComputation(builder.Build());
HloSchedule schedule(module.get());
schedule.set_sequence(
computation, {p, p0, negate0, negate1, negate2, negate3, negate4, negate5,
negate6, p1, add, p2, p2_0, mul});
TF_CHECK_OK(module->set_schedule(schedule));
AssignMemorySpace(module.get());
EXPECT_THAT(
mul,
op::Multiply(op::Add(op::Negate(), op::AsyncCopy(kAlternateMemorySpace,
kDefaultMemorySpace,
op::GetTupleElement())),
op::AsyncCopy(kAlternateMemorySpace, kDefaultMemorySpace,
op::GetTupleElement(op::GetTupleElement()))));
}
TEST_F(MemorySpaceAssignmentTest, Bitcast) {
// Bitcasts can cause the position in the alternate memory to appear multiple
// times in the preset assignments. This test ensure the preset assignments
// refer to unique positions.
HloComputation::Builder builder(TestName());
Shape shape = ShapeUtil::MakeShape(F32, {2, 3});
HloInstruction* p0 =
builder.AddInstruction(HloInstruction::CreateParameter(0, shape, "p0"));
HloInstruction* p1 =
builder.AddInstruction(HloInstruction::CreateParameter(1, shape, "p1"));
HloInstruction* negate = builder.AddInstruction(
HloInstruction::CreateUnary(shape, HloOpcode::kNegate, p0));
HloInstruction* bitcast =
builder.AddInstruction(HloInstruction::CreateBitcast(shape, negate));
HloInstruction* add = builder.AddInstruction(
HloInstruction::CreateBinary(shape, HloOpcode::kAdd, bitcast, p1));
auto module = CreateNewVerifiedModule();
HloComputation* computation = module->AddEntryComputation(builder.Build());
HloSchedule schedule(module.get());
schedule.set_sequence(computation, {p0, p1, negate, bitcast, add});
TF_CHECK_OK(module->set_schedule(schedule));
auto preset_assignments = AssignMemorySpace(module.get());
// Ensure the positions are unique. Note that we're using a std::set instead
// of absl::flat_hash_set because we can make use of HloPosition's comparator
// logic instead of providing a hasher.
std::set<HloPosition> positions_in_preset_assignments;
for (auto& position_and_chunk : preset_assignments->chunks()) {
HloPosition position = position_and_chunk.first;
EXPECT_EQ(positions_in_preset_assignments.find(position),
positions_in_preset_assignments.end());
positions_in_preset_assignments.insert(position);
}
}
} // namespace
} // namespace xla