blob: 5d6d0c816402797e825f10425642ad66fc8a4919 [file] [log] [blame]
/* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "tensorflow/compiler/xla/service/memory_space_assignment.h"
#include "tensorflow/compiler/xla/service/hlo_computation.h"
#include "tensorflow/compiler/xla/service/hlo_matchers.h"
#include "tensorflow/compiler/xla/tests/hlo_test_base.h"
namespace xla {
namespace {
namespace op = xla::testing::opcode_matchers;
class MemorySpaceAssignmentTest : public HloTestBase {
protected:
// We use the following two memory space values to describe the default (slow
// and large) and alternate (fast and small) memory spaces.
const int64 kDefaultMemorySpace = 0;
const int64 kAlternateMemorySpace = 1;
void AssignMemorySpace(HloModule* module) {
auto size_fn = [](const BufferValue& buffer) {
return ShapeUtil::ByteSizeOf(buffer.shape(), /*pointer_size=*/8);
};
auto is_allowed_in_alternate_mem = [](const HloValue& value) {
// Check if the value belongs to the entry computation.
HloInstruction* instruction = value.instruction();
HloComputation* computation = instruction->parent();
bool in_entry_computation =
(computation == computation->parent()->entry_computation());
if (in_entry_computation &&
instruction->opcode() == HloOpcode::kParameter) {
return false;
}
return true;
};
ASSERT_IS_OK(MemorySpaceAssignment::Run(
module, kAlternateMemorySpace, /*max_size_in_bytes=*/128,
/*min_prefetch_interval=*/2,
/*max_prefetch_interval=*/10,
/*alternate_memory_space_alignment_in_bytes=*/8, size_fn,
is_allowed_in_alternate_mem)
.status());
}
};
TEST_F(MemorySpaceAssignmentTest, ParameterOnly) {
// A module consisting of a single parameter. Inputs/outputs are currently
// excluded from memory space assignment.
HloComputation::Builder builder(TestName());
Shape shape = ShapeUtil::MakeShape(F32, {2, 3});
HloInstruction* p0 =
builder.AddInstruction(HloInstruction::CreateParameter(0, shape, "p0"));
auto module = CreateNewVerifiedModule();
HloComputation* computation = module->AddEntryComputation(builder.Build());
HloSchedule schedule(module.get());
schedule.set_sequence(computation, {p0});
TF_CHECK_OK(module->set_schedule(schedule));
AssignMemorySpace(module.get());
EXPECT_THAT(p0, op::ShapeWithLayout(shape));
}
TEST_F(MemorySpaceAssignmentTest, Simple) {
// A simple module with a few simple instructions. Expect this to be
// transformed with CopyStart and CopyDone instructions inserted after inputs
// and before outputs.
HloComputation::Builder builder(TestName());
Shape shape = ShapeUtil::MakeShape(F32, {2, 3});
HloInstruction* p0 =
builder.AddInstruction(HloInstruction::CreateParameter(0, shape, "p0"));
HloInstruction* p1 =
builder.AddInstruction(HloInstruction::CreateParameter(1, shape, "p1"));
HloInstruction* add = builder.AddInstruction(
HloInstruction::CreateBinary(shape, HloOpcode::kAdd, p0, p1));
HloInstruction* sub = builder.AddInstruction(
HloInstruction::CreateBinary(shape, HloOpcode::kSubtract, p0, p1));
HloInstruction* mul = builder.AddInstruction(
HloInstruction::CreateBinary(shape, HloOpcode::kMultiply, add, sub));
auto module = CreateNewVerifiedModule();
HloComputation* computation = module->AddEntryComputation(builder.Build());
HloSchedule schedule(module.get());
schedule.set_sequence(computation, {p0, p1, add, sub, mul});
TF_CHECK_OK(module->set_schedule(schedule));
AssignMemorySpace(module.get());
// Inputs and outputs are currently placed in the default memory. Everything
// else should be in the alternate memory.
Shape shape_in_alternate_mem = ShapeUtil::MakeShapeWithLayout(
F32, {2, 3},
/*minor_to_major=*/{1, 0}, /*tiles=*/{}, /*element_size_in_bits=*/0,
kAlternateMemorySpace);
EXPECT_THAT(p0, op::ShapeWithLayout(shape));
EXPECT_THAT(p1, op::ShapeWithLayout(shape));
EXPECT_THAT(mul, op::ShapeWithLayout(shape));
EXPECT_THAT(add, op::ShapeWithLayout(shape_in_alternate_mem));
EXPECT_THAT(sub, op::ShapeWithLayout(shape_in_alternate_mem));
}
TEST_F(MemorySpaceAssignmentTest, NegateChain) {
// The negate chain is long enough for asynchronous copy to be inserted
// between p1 and add.
HloComputation::Builder builder(TestName());
Shape shape = ShapeUtil::MakeShape(F32, {2, 3});
HloInstruction* p0 =
builder.AddInstruction(HloInstruction::CreateParameter(0, shape, "p0"));
HloInstruction* p1 =
builder.AddInstruction(HloInstruction::CreateParameter(1, shape, "p1"));
HloInstruction* negate0 = builder.AddInstruction(
HloInstruction::CreateUnary(shape, HloOpcode::kNegate, p0));
HloInstruction* negate1 = builder.AddInstruction(
HloInstruction::CreateUnary(shape, HloOpcode::kNegate, negate0));
HloInstruction* negate2 = builder.AddInstruction(
HloInstruction::CreateUnary(shape, HloOpcode::kNegate, negate1));
HloInstruction* negate3 = builder.AddInstruction(
HloInstruction::CreateUnary(shape, HloOpcode::kNegate, negate2));
HloInstruction* negate4 = builder.AddInstruction(
HloInstruction::CreateUnary(shape, HloOpcode::kNegate, negate3));
HloInstruction* negate5 = builder.AddInstruction(
HloInstruction::CreateUnary(shape, HloOpcode::kNegate, negate4));
HloInstruction* negate6 = builder.AddInstruction(
HloInstruction::CreateUnary(shape, HloOpcode::kNegate, negate5));
HloInstruction* add = builder.AddInstruction(
HloInstruction::CreateBinary(shape, HloOpcode::kAdd, negate6, p1));
auto module = CreateNewVerifiedModule();
HloComputation* computation = module->AddEntryComputation(builder.Build());
HloSchedule schedule(module.get());
schedule.set_sequence(computation, {p0, p1, negate0, negate1, negate2,
negate3, negate4, negate5, negate6, add});
TF_CHECK_OK(module->set_schedule(schedule));
AssignMemorySpace(module.get());
EXPECT_THAT(add, op::Add(op::Negate(), op::AsyncCopy(kAlternateMemorySpace,
kDefaultMemorySpace,
op::Parameter(1))));
// Parameters are in the default memory space.
EXPECT_THAT(p0, op::ShapeWithLayout(shape));
EXPECT_THAT(p1, op::ShapeWithLayout(shape));
// Negate instructions are in the alternate memory space (1).
Shape shape_in_alternate_mem = ShapeUtil::MakeShapeWithLayout(
F32, {2, 3},
/*minor_to_major=*/{1, 0}, /*tiles=*/{}, /*element_size_in_bits=*/0,
kAlternateMemorySpace);
EXPECT_THAT(negate0, op::ShapeWithLayout(shape_in_alternate_mem));
EXPECT_THAT(negate1, op::ShapeWithLayout(shape_in_alternate_mem));
EXPECT_THAT(negate2, op::ShapeWithLayout(shape_in_alternate_mem));
EXPECT_THAT(negate3, op::ShapeWithLayout(shape_in_alternate_mem));
EXPECT_THAT(negate4, op::ShapeWithLayout(shape_in_alternate_mem));
EXPECT_THAT(negate5, op::ShapeWithLayout(shape_in_alternate_mem));
EXPECT_THAT(negate6, op::ShapeWithLayout(shape_in_alternate_mem));
// Ensure the CopyStart/CopyDone schedules.
const HloInstructionSequence& sequence =
module->schedule().sequence(computation);
EXPECT_THAT(sequence.instructions()[0], op::Parameter(0));
EXPECT_THAT(sequence.instructions()[1], op::Parameter(1));
EXPECT_THAT(sequence.instructions()[2], op::CopyStart());
EXPECT_THAT(sequence.instructions()[10], op::CopyDone());
}
TEST_F(MemorySpaceAssignmentTest, EvictAndPrefetch) {
HloComputation::Builder builder(TestName());
Shape shape = ShapeUtil::MakeShape(F32, {2, 3});
HloInstruction* p0 =
builder.AddInstruction(HloInstruction::CreateParameter(0, shape, "p0"));
HloInstruction* p1 =
builder.AddInstruction(HloInstruction::CreateParameter(1, shape, "p1"));
HloInstruction* tanh = builder.AddInstruction(
HloInstruction::CreateUnary(shape, HloOpcode::kTanh, p0));
// tanh should be placed in the alternate memory since there isn't much
// contention in the beginning. However, tanh has another consumer at the end.
// So it should be kicked out to default memory and prefetched back in.
// The graph below is meant to increase the contention to force
// eviction/prefetch behavior.
HloInstruction* a = builder.AddInstruction(
HloInstruction::CreateBinary(shape, HloOpcode::kAdd, p0, tanh));
HloInstruction* b = builder.AddInstruction(
HloInstruction::CreateBinary(shape, HloOpcode::kSubtract, p0, p1));
HloInstruction* c = builder.AddInstruction(
HloInstruction::CreateBinary(shape, HloOpcode::kMultiply, p0, p1));
HloInstruction* d = builder.AddInstruction(
HloInstruction::CreateBinary(shape, HloOpcode::kSubtract, p0, p1));
HloInstruction* e = builder.AddInstruction(
HloInstruction::CreateBinary(shape, HloOpcode::kMultiply, a, b));
HloInstruction* f = builder.AddInstruction(
HloInstruction::CreateBinary(shape, HloOpcode::kMultiply, a, c));
HloInstruction* g = builder.AddInstruction(
HloInstruction::CreateBinary(shape, HloOpcode::kMultiply, a, d));
HloInstruction* h = builder.AddInstruction(
HloInstruction::CreateBinary(shape, HloOpcode::kMultiply, b, c));
HloInstruction* i = builder.AddInstruction(
HloInstruction::CreateBinary(shape, HloOpcode::kMultiply, b, d));
HloInstruction* j = builder.AddInstruction(
HloInstruction::CreateBinary(shape, HloOpcode::kMultiply, c, d));
HloInstruction* k = builder.AddInstruction(
HloInstruction::CreateBinary(shape, HloOpcode::kAdd, e, f));
HloInstruction* l = builder.AddInstruction(
HloInstruction::CreateBinary(shape, HloOpcode::kAdd, g, h));
HloInstruction* m = builder.AddInstruction(
HloInstruction::CreateBinary(shape, HloOpcode::kAdd, i, j));
HloInstruction* n = builder.AddInstruction(
HloInstruction::CreateBinary(shape, HloOpcode::kAdd, k, l));
HloInstruction* o = builder.AddInstruction(
HloInstruction::CreateBinary(shape, HloOpcode::kAdd, n, m));
// tanh is being used at the root instruction, and this should be prefetched.
HloInstruction* add = builder.AddInstruction(
HloInstruction::CreateBinary(shape, HloOpcode::kAdd, o, tanh));
auto module = CreateNewVerifiedModule();
HloComputation* computation = module->AddEntryComputation(builder.Build());
HloSchedule schedule(module.get());
schedule.set_sequence(computation, {p0, p1, tanh, a, b, c, d, e, f, g, h, i,
j, k, l, m, n, o, add});
TF_CHECK_OK(module->set_schedule(schedule));
AssignMemorySpace(module.get());
EXPECT_THAT(
add,
op::Add(op::Add(),
op::AsyncCopy(kAlternateMemorySpace, kDefaultMemorySpace,
op::AsyncCopy(kDefaultMemorySpace,
kAlternateMemorySpace, op::Tanh()))));
}
TEST_F(MemorySpaceAssignmentTest, While) {
auto module = CreateNewVerifiedModule();
Shape shape = ShapeUtil::MakeShape(xla::F32, {2, 3});
Shape scalar_shape = ShapeUtil::MakeShape(xla::F32, {});
Shape tuple_shape = ShapeUtil::MakeTupleShape({shape, scalar_shape});
auto cond_builder = HloComputation::Builder("WhileCond");
// Tuple param: 24 bytes (each elem has 8 byte pointer, 4 byte element)
HloInstruction* cond_param = cond_builder.AddInstruction(
HloInstruction::CreateParameter(0, tuple_shape, "cond_param"));
HloInstruction* cond_iter = cond_builder.AddInstruction(
HloInstruction::CreateGetTupleElement(scalar_shape, cond_param, 1));
HloInstruction* cond_limit = cond_builder.AddInstruction(
HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(50.f)));
// Free cond_param[] (16 bytes), Alloc PRED[] (1 byte)
HloInstruction* cond_lt = cond_builder.AddInstruction(
HloInstruction::CreateCompare(ShapeUtil::MakeShape(PRED, {}), cond_iter,
cond_limit, ComparisonDirection::kLt));
HloComputation* cond_computation =
module->AddEmbeddedComputation(cond_builder.Build());
auto body_builder = HloComputation::Builder("WhileBody");
// Tuple param: 24 bytes (each elem has 8 byte pointer, 4 byte element)
HloInstruction* body_param = body_builder.AddInstruction(
HloInstruction::CreateParameter(0, tuple_shape, "body_param"));
HloInstruction* body_iter = body_builder.AddInstruction(
HloInstruction::CreateGetTupleElement(scalar_shape, body_param, 1));
HloInstruction* body_data = body_builder.AddInstruction(
HloInstruction::CreateGetTupleElement(shape, body_param, 0));
HloInstruction* body_iter_increment = body_builder.AddInstruction(
HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(1.f)));
HloInstruction* body_iter_next =
body_builder.AddInstruction(HloInstruction::CreateBinary(
scalar_shape, HloOpcode::kAdd, body_iter, body_iter_increment));
HloInstruction* body_data_increment =
body_builder.AddInstruction(HloInstruction::CreateConstant(
LiteralUtil::CreateR2<float>({{1.f, 2.f, 3.f}, {4.f, 5.f, 6.f}})));
HloInstruction* body_data_mul =
body_builder.AddInstruction(HloInstruction::CreateBinary(
shape, HloOpcode::kMultiply, body_data, body_data));
HloInstruction* body_data_add =
body_builder.AddInstruction(HloInstruction::CreateBinary(
shape, HloOpcode::kAdd, body_data, body_data_increment));
HloInstruction* body_data_next =
body_builder.AddInstruction(HloInstruction::CreateBinary(
shape, HloOpcode::kAdd, body_data_add, body_data_mul));
HloInstruction* body_out = body_builder.AddInstruction(
HloInstruction::CreateTuple({body_data_next, body_iter_next}));
HloComputation* body_computation =
module->AddEmbeddedComputation(body_builder.Build());
auto builder = HloComputation::Builder(TestName());
HloInstruction* data = builder.AddInstruction(
HloInstruction::CreateParameter(0, shape, "param_iter"));
HloInstruction* iter = builder.AddInstruction(
HloInstruction::CreateParameter(1, scalar_shape, "param_data"));
HloInstruction* tuple =
builder.AddInstruction(HloInstruction::CreateTuple({data, iter}));
HloInstruction* while_op = builder.AddInstruction(HloInstruction::CreateWhile(
tuple_shape, cond_computation, body_computation, tuple));
HloComputation* entry_computation =
module->AddEntryComputation(builder.Build());
HloSchedule schedule(module.get());
schedule.set_sequence(cond_computation,
{cond_param, cond_iter, cond_limit, cond_lt});
schedule.set_sequence(body_computation,
{body_param, body_iter, body_data, body_iter_increment,
body_iter_next, body_data_increment, body_data_mul,
body_data_add, body_data_next, body_out});
schedule.set_sequence(entry_computation, {iter, data, tuple, while_op});
TF_CHECK_OK(module->set_schedule(schedule));
AssignMemorySpace(module.get());
// Ensure the tuple value and buffers used in the while instruction are
// exempted from using the alternate memory. However, body_data_mul is
// independent and can be safely be placed in the alternate memory.
EXPECT_THAT(tuple, op::ShapeWithLayout(tuple_shape));
EXPECT_THAT(data, op::ShapeWithLayout(shape));
EXPECT_THAT(iter, op::ShapeWithLayout(scalar_shape));
EXPECT_THAT(body_data, op::ShapeWithLayout(shape));
EXPECT_THAT(body_iter, op::ShapeWithLayout(scalar_shape));
EXPECT_THAT(cond_iter, op::ShapeWithLayout(scalar_shape));
Shape shape_in_alternate_mem = ShapeUtil::MakeShapeWithLayout(
F32, {2, 3},
/*minor_to_major=*/{1, 0}, /*tiles=*/{}, /*element_size_in_bits=*/0,
kAlternateMemorySpace);
EXPECT_THAT(body_data_mul, op::ShapeWithLayout(shape_in_alternate_mem));
}
} // namespace
} // namespace xla