[XLA] [NFC] Do not propagate InstructionCanChangeLayout as a functor, use virtual methods instead.
PiperOrigin-RevId: 388766658
Change-Id: I92ec1eb245adf9977d125f85372936820667cc5f
diff --git a/tensorflow/compiler/xla/service/cpu/cpu_compiler.cc b/tensorflow/compiler/xla/service/cpu/cpu_compiler.cc
index 1ba869e..72dc89e 100644
--- a/tensorflow/compiler/xla/service/cpu/cpu_compiler.cc
+++ b/tensorflow/compiler/xla/service/cpu/cpu_compiler.cc
@@ -416,8 +416,7 @@
// flattened.
pipeline.AddPass<FlattenCallGraph>();
pipeline.AddPass<CpuLayoutAssignment>(
- module->mutable_entry_computation_layout(),
- LayoutAssignment::InstructionCanChangeLayout, target_machine_features);
+ module->mutable_entry_computation_layout(), target_machine_features);
pipeline.AddPass<CpuInstructionFusion>();
diff --git a/tensorflow/compiler/xla/service/cpu/cpu_layout_assignment.h b/tensorflow/compiler/xla/service/cpu/cpu_layout_assignment.h
index f4da35d..3c4fe68 100644
--- a/tensorflow/compiler/xla/service/cpu/cpu_layout_assignment.h
+++ b/tensorflow/compiler/xla/service/cpu/cpu_layout_assignment.h
@@ -30,11 +30,8 @@
public:
explicit CpuLayoutAssignment(
ComputationLayout* entry_computation_layout,
- std::function<bool(const HloInstruction*)>
- instruction_can_change_layout_func,
const TargetMachineFeatures* target_machine_features)
- : LayoutAssignment(entry_computation_layout,
- std::move(instruction_can_change_layout_func)),
+ : LayoutAssignment(entry_computation_layout),
target_machine_features_(*target_machine_features) {}
~CpuLayoutAssignment() override {}
diff --git a/tensorflow/compiler/xla/service/cpu/cpu_layout_assignment_test.cc b/tensorflow/compiler/xla/service/cpu/cpu_layout_assignment_test.cc
index 14023b3..3ce9aaf 100644
--- a/tensorflow/compiler/xla/service/cpu/cpu_layout_assignment_test.cc
+++ b/tensorflow/compiler/xla/service/cpu/cpu_layout_assignment_test.cc
@@ -54,9 +54,8 @@
[](int64_t shape_size) {
return cpu::TargetMachineFeatures::kEigenExpectedTensorAlignment;
});
- cpu::CpuLayoutAssignment layout_assignment(
- entry_computation_layout, LayoutAssignment::InstructionCanChangeLayout,
- &target_machine_features);
+ cpu::CpuLayoutAssignment layout_assignment(entry_computation_layout,
+ &target_machine_features);
EXPECT_IS_OK(layout_assignment.Run(module).status());
}
};
@@ -329,9 +328,8 @@
[](int64_t shape_size) {
return cpu::TargetMachineFeatures::kEigenExpectedTensorAlignment;
});
- cpu::CpuLayoutAssignment layout_assignment(
- &computation_layout, LayoutAssignment::InstructionCanChangeLayout,
- &target_machine_features);
+ cpu::CpuLayoutAssignment layout_assignment(&computation_layout,
+ &target_machine_features);
TF_ASSIGN_OR_RETURN(result.layout_assignment_changed_something,
layout_assignment.Run(module));
diff --git a/tensorflow/compiler/xla/service/gpu/gpu_compiler.cc b/tensorflow/compiler/xla/service/gpu/gpu_compiler.cc
index 5021cd2..b78c8d8 100644
--- a/tensorflow/compiler/xla/service/gpu/gpu_compiler.cc
+++ b/tensorflow/compiler/xla/service/gpu/gpu_compiler.cc
@@ -419,8 +419,7 @@
pipeline.AddPass<FlattenCallGraph>();
ChannelLayoutConstraints layout_constraints;
pipeline.AddPass<GpuLayoutAssignment>(
- hlo_module->mutable_entry_computation_layout(),
- LayoutAssignment::InstructionCanChangeLayout, stream_exec,
+ hlo_module->mutable_entry_computation_layout(), stream_exec,
&layout_constraints);
TF_RETURN_IF_ERROR(pipeline.Run(hlo_module).status());
}
@@ -434,8 +433,6 @@
// We try to split variadic ops with many parameters into several such ops
// to avoid exceeding the parameter space.
fusion.AddPass<VariadicOpSplitter>();
- /* TODO(b/117531509): Use LayoutAssignment::InstructionCanChangeLayout after
- * fixing the ticket. */
fusion.AddInvariantCheckerDebug<HloVerifier>(
/*layout_sensitive=*/true,
/*allow_mixed_precision=*/false,
@@ -502,8 +499,6 @@
// (b/27180329). Therefore, in that case, we set the output to be a copy of
// the parameter.
HloPassPipeline pipeline("GPU-ir-emit-prepare");
- /* TODO(b/117531509): Use LayoutAssignment::InstructionCanChangeLayout after
- * fixing the ticket. */
pipeline.AddInvariantCheckerDebug<HloVerifier>(
/*layout_sensitive=*/true,
/*allow_mixed_precision=*/false,
@@ -529,8 +524,6 @@
HloModule* hlo_module, se::StreamExecutor* stream_exec,
se::DeviceMemoryAllocator* device_allocator) {
HloPassPipeline pipeline("post-layout_assignment");
- /* TODO(b/117531509): Use LayoutAssignment::InstructionCanChangeLayout after
- * fixing the ticket. */
pipeline.AddInvariantCheckerDebug<HloVerifier>(
/*layout_sensitive=*/true,
/*allow_mixed_precision=*/false,
diff --git a/tensorflow/compiler/xla/service/gpu/gpu_layout_assignment.h b/tensorflow/compiler/xla/service/gpu/gpu_layout_assignment.h
index 1e6a180..e750453 100644
--- a/tensorflow/compiler/xla/service/gpu/gpu_layout_assignment.h
+++ b/tensorflow/compiler/xla/service/gpu/gpu_layout_assignment.h
@@ -31,12 +31,9 @@
public:
explicit GpuLayoutAssignment(
ComputationLayout* entry_computation_layout,
- std::function<bool(const HloInstruction*)>
- instruction_can_change_layout_func,
se::StreamExecutor* stream_executor,
ChannelLayoutConstraints* channel_constraints = nullptr)
: LayoutAssignment(entry_computation_layout,
- std::move(instruction_can_change_layout_func),
channel_constraints),
stream_executor_(stream_executor) {}
~GpuLayoutAssignment() override {}
diff --git a/tensorflow/compiler/xla/service/gpu/gpu_layout_assignment_test.cc b/tensorflow/compiler/xla/service/gpu/gpu_layout_assignment_test.cc
index 6ae0f1e..bc9ed9a 100644
--- a/tensorflow/compiler/xla/service/gpu/gpu_layout_assignment_test.cc
+++ b/tensorflow/compiler/xla/service/gpu/gpu_layout_assignment_test.cc
@@ -77,8 +77,7 @@
ShapeLayout(result_shape_with_layout);
GpuLayoutAssignment layout_assignment(
- &computation_layout, LayoutAssignment::InstructionCanChangeLayout,
- backend().default_stream_executor());
+ &computation_layout, backend().default_stream_executor());
EXPECT_TRUE(layout_assignment.Run(module.get()).ValueOrDie());
for (const HloInstruction* operand : add->operands()) {
@@ -166,8 +165,7 @@
}
GpuLayoutAssignment layout_assignment(
- &computation_layout, LayoutAssignment::InstructionCanChangeLayout,
- backend().default_stream_executor());
+ &computation_layout, backend().default_stream_executor());
EXPECT_TRUE(layout_assignment.Run(module.get()).ValueOrDie());
// The first operand to batchnorm should have the same layout as the
@@ -237,8 +235,7 @@
}
GpuLayoutAssignment layout_assignment(
- &computation_layout, LayoutAssignment::InstructionCanChangeLayout,
- backend().default_stream_executor());
+ &computation_layout, backend().default_stream_executor());
EXPECT_TRUE(layout_assignment.Run(module.get()).ValueOrDie());
// The first operand to batchnorm should have the same layout as the
@@ -319,8 +316,7 @@
}
GpuLayoutAssignment layout_assignment(
- &computation_layout, LayoutAssignment::InstructionCanChangeLayout,
- backend().default_stream_executor());
+ &computation_layout, backend().default_stream_executor());
EXPECT_TRUE(layout_assignment.Run(module.get()).ValueOrDie());
// The first and fourth operands to the batchnorm call should have the
@@ -355,9 +351,8 @@
ComputationLayout computation_layout(
module->entry_computation()->ComputeProgramShape(),
/*ignore_layouts=*/false);
- GpuLayoutAssignment layout_assignment(
- &computation_layout, LayoutAssignment::InstructionCanChangeLayout,
- backend().default_stream_executor());
+ GpuLayoutAssignment layout_assignment(&computation_layout,
+ backend().default_stream_executor());
EXPECT_TRUE(layout_assignment.Run(module.get()).ValueOrDie());
Shape expected_shape =
@@ -393,9 +388,8 @@
ComputationLayout computation_layout(
module->entry_computation()->ComputeProgramShape(),
/*ignore_layouts=*/false);
- GpuLayoutAssignment layout_assignment(
- &computation_layout, LayoutAssignment::InstructionCanChangeLayout,
- backend().default_stream_executor());
+ GpuLayoutAssignment layout_assignment(&computation_layout,
+ backend().default_stream_executor());
EXPECT_TRUE(layout_assignment.Run(module.get()).ValueOrDie());
Shape expected_shape = ShapeUtil::MakeShapeWithLayout(F32, {3, 2}, {1, 0});
@@ -420,9 +414,8 @@
ComputationLayout computation_layout(
module->entry_computation()->ComputeProgramShape(),
/*ignore_layouts=*/false);
- GpuLayoutAssignment layout_assignment(
- &computation_layout, LayoutAssignment::InstructionCanChangeLayout,
- backend().default_stream_executor());
+ GpuLayoutAssignment layout_assignment(&computation_layout,
+ backend().default_stream_executor());
EXPECT_TRUE(layout_assignment.Run(module.get()).ValueOrDie());
Shape expected_shape = ShapeUtil::MakeShapeWithLayout(C64, {8, 32}, {1, 0});
diff --git a/tensorflow/compiler/xla/service/interpreter/compiler.cc b/tensorflow/compiler/xla/service/interpreter/compiler.cc
index 9d7f79b..92f2243 100644
--- a/tensorflow/compiler/xla/service/interpreter/compiler.cc
+++ b/tensorflow/compiler/xla/service/interpreter/compiler.cc
@@ -89,8 +89,7 @@
pipeline.AddPass<ComparisonExpander>();
pipeline.AddPass<TriangularSolveExpander>();
pipeline.AddPass<LayoutAssignment>(
- hlo_module->mutable_entry_computation_layout(),
- LayoutAssignment::InstructionCanChangeLayout);
+ hlo_module->mutable_entry_computation_layout());
return pipeline.Run(hlo_module).status();
}
diff --git a/tensorflow/compiler/xla/service/layout_assignment.cc b/tensorflow/compiler/xla/service/layout_assignment.cc
index 870ddfa..fe2388b 100644
--- a/tensorflow/compiler/xla/service/layout_assignment.cc
+++ b/tensorflow/compiler/xla/service/layout_assignment.cc
@@ -1111,16 +1111,12 @@
LayoutAssignment::LayoutAssignment(
ComputationLayout* entry_computation_layout,
- std::function<bool(const HloInstruction*)>
- instruction_can_change_layout_func,
ChannelLayoutConstraints* channel_constraints,
bool reverse_computation_order)
: entry_computation_layout_(entry_computation_layout),
saved_entry_computation_layout_(*entry_computation_layout),
reverse_computation_order_(reverse_computation_order),
- channel_layout_constraints_(channel_constraints),
- instruction_can_change_layout_func_(
- std::move(instruction_can_change_layout_func)) {
+ channel_layout_constraints_(channel_constraints) {
if (channel_layout_constraints_ != nullptr) {
// Save a copy of the input ChannelLayoutConstraints so that we can reset it
// if we have to undo previous operations (ClearPreviousPassSideEffects()).
@@ -1138,7 +1134,7 @@
CHECK(operand->shape().IsArray());
if (!ShapeUtil::IsScalar(operand->shape()) &&
operand->shape().rank() == instruction->shape().rank() &&
- !instruction_can_change_layout_func_(instruction)) {
+ !InstructionCanChangeLayoutInstance(instruction)) {
// Propagate the result layout to the operand layout if the instruction
// requires the same layout out for the result and the operand.
//
@@ -1206,7 +1202,7 @@
if (!ShapeUtil::IsScalar(operand->shape()) &&
operand->shape().rank() == user->shape().rank() &&
- !instruction_can_change_layout_func_(user)) {
+ !InstructionCanChangeLayoutInstance(user)) {
// Assign users the same layout as the operand.
return absl::make_unique<Layout>(operand_layout);
}
@@ -1431,7 +1427,7 @@
/*mandatory=*/false));
}
}
- if (instruction_can_change_layout_func_(user) && !user->shape().IsArray()) {
+ if (InstructionCanChangeLayoutInstance(user) && !user->shape().IsArray()) {
return Status::OK();
}
@@ -1449,7 +1445,7 @@
// Propagate layouts between operands of the same instruction. This is a
// constraint on non-layout-changing instructions.
- if (!instruction_can_change_layout_func_(user)) {
+ if (!InstructionCanChangeLayoutInstance(user)) {
// Only propgate the layout of the largest concatenate operand.
if (user->opcode() == HloOpcode::kConcatenate) {
for (int64_t operand_no = 0; operand_no < user->operand_count();
@@ -1597,7 +1593,7 @@
if (IsAtMostRank1(operand->shape())) {
continue;
}
- if (!instruction_can_change_layout_func_(instruction)) {
+ if (!InstructionCanChangeLayoutInstance(instruction)) {
// Copy the layout to the operand.
if (buffer.IsArray() && operand->shape().IsArray() &&
operand->shape().rank() ==
@@ -2466,6 +2462,11 @@
}
}
+bool LayoutAssignment::InstructionCanChangeLayoutInstance(
+ const HloInstruction* instruction) {
+ return InstructionCanChangeLayout(instruction);
+}
+
/* static */
bool LayoutAssignment::IsAtMostRank1(const Shape& shape) {
if (shape.IsArray()) {
diff --git a/tensorflow/compiler/xla/service/layout_assignment.h b/tensorflow/compiler/xla/service/layout_assignment.h
index b62eec3..178b2c3 100644
--- a/tensorflow/compiler/xla/service/layout_assignment.h
+++ b/tensorflow/compiler/xla/service/layout_assignment.h
@@ -308,11 +308,6 @@
// entry_computation_layout is modified to populate a layout for the result in
// the case that no particular layout is requested.
//
- // instruction_can_change_layout_func is a function object that determines
- // whether an instruction can change layouts. An instruction not being able to
- // change layout means that it requires operands with the same rank as the
- // output to have the same layout as the output.
- //
// channel_constraints is both an input and output. Any sends or recvs that
// are present in channel_constraints will be laid out as constrained. Any
// unconstrained sends or recvs will be laid out as locally optimal and their
@@ -322,8 +317,6 @@
// within any module passed to `Run`.
explicit LayoutAssignment(
ComputationLayout* entry_computation_layout,
- std::function<bool(const HloInstruction*)>
- instruction_can_change_layout_func = InstructionCanChangeLayout,
ChannelLayoutConstraints* channel_constraints = nullptr,
bool reverse_computation_order = false);
~LayoutAssignment() override {}
@@ -400,6 +393,11 @@
const Layout& operand_layout, const HloInstruction* user,
int64_t operand_no);
+ // Convenient wrapper for InstructionCanChangeLayout which can be overridden
+ // in subclasses.
+ virtual bool InstructionCanChangeLayoutInstance(
+ const HloInstruction* instruction);
+
private:
// Initializes the layout assignment object for a new Run() call.
Status Init();
diff --git a/tensorflow/compiler/xla/service/layout_assignment_test.cc b/tensorflow/compiler/xla/service/layout_assignment_test.cc
index 374af0a..f0874bc 100644
--- a/tensorflow/compiler/xla/service/layout_assignment_test.cc
+++ b/tensorflow/compiler/xla/service/layout_assignment_test.cc
@@ -54,7 +54,7 @@
void AssignLayouts(HloModule* m, ComputationLayout* entry_computation_layout,
ChannelLayoutConstraints* channel_constraints = nullptr) {
LayoutAssignment layout_assignment(
- entry_computation_layout, LayoutAssignment::InstructionCanChangeLayout,
+ entry_computation_layout,
/*channel_constraints=*/channel_constraints);
EXPECT_IS_OK(layout_assignment.Run(m).status());
}
@@ -1329,9 +1329,8 @@
->mutable_result_layout()
->SetToDefaultLayout();
}
- LayoutAssignment layout_assignment(
- m->mutable_entry_computation_layout(),
- LayoutAssignment::InstructionCanChangeLayout, channel_constraints);
+ LayoutAssignment layout_assignment(m->mutable_entry_computation_layout(),
+ channel_constraints);
return layout_assignment.Run(m).status();
}
@@ -1524,7 +1523,7 @@
std::cerr << computation_layout.ToString();
ChannelLayoutConstraints channel_constraints;
LayoutAssignment layout_assignment(
- &computation_layout, LayoutAssignment::InstructionCanChangeLayout,
+ &computation_layout,
/*channel_constraints=*/&channel_constraints,
/* reverse_computation_order = */ true);
EXPECT_IS_OK(layout_assignment.Run(m.get()).status());