[XLA] [NFC] Do not propagate InstructionCanChangeLayout as a functor, use virtual methods instead.

PiperOrigin-RevId: 388766658
Change-Id: I92ec1eb245adf9977d125f85372936820667cc5f
diff --git a/tensorflow/compiler/xla/service/cpu/cpu_compiler.cc b/tensorflow/compiler/xla/service/cpu/cpu_compiler.cc
index 1ba869e..72dc89e 100644
--- a/tensorflow/compiler/xla/service/cpu/cpu_compiler.cc
+++ b/tensorflow/compiler/xla/service/cpu/cpu_compiler.cc
@@ -416,8 +416,7 @@
   // flattened.
   pipeline.AddPass<FlattenCallGraph>();
   pipeline.AddPass<CpuLayoutAssignment>(
-      module->mutable_entry_computation_layout(),
-      LayoutAssignment::InstructionCanChangeLayout, target_machine_features);
+      module->mutable_entry_computation_layout(), target_machine_features);
 
   pipeline.AddPass<CpuInstructionFusion>();
 
diff --git a/tensorflow/compiler/xla/service/cpu/cpu_layout_assignment.h b/tensorflow/compiler/xla/service/cpu/cpu_layout_assignment.h
index f4da35d..3c4fe68 100644
--- a/tensorflow/compiler/xla/service/cpu/cpu_layout_assignment.h
+++ b/tensorflow/compiler/xla/service/cpu/cpu_layout_assignment.h
@@ -30,11 +30,8 @@
  public:
   explicit CpuLayoutAssignment(
       ComputationLayout* entry_computation_layout,
-      std::function<bool(const HloInstruction*)>
-          instruction_can_change_layout_func,
       const TargetMachineFeatures* target_machine_features)
-      : LayoutAssignment(entry_computation_layout,
-                         std::move(instruction_can_change_layout_func)),
+      : LayoutAssignment(entry_computation_layout),
         target_machine_features_(*target_machine_features) {}
   ~CpuLayoutAssignment() override {}
 
diff --git a/tensorflow/compiler/xla/service/cpu/cpu_layout_assignment_test.cc b/tensorflow/compiler/xla/service/cpu/cpu_layout_assignment_test.cc
index 14023b3..3ce9aaf 100644
--- a/tensorflow/compiler/xla/service/cpu/cpu_layout_assignment_test.cc
+++ b/tensorflow/compiler/xla/service/cpu/cpu_layout_assignment_test.cc
@@ -54,9 +54,8 @@
         [](int64_t shape_size) {
           return cpu::TargetMachineFeatures::kEigenExpectedTensorAlignment;
         });
-    cpu::CpuLayoutAssignment layout_assignment(
-        entry_computation_layout, LayoutAssignment::InstructionCanChangeLayout,
-        &target_machine_features);
+    cpu::CpuLayoutAssignment layout_assignment(entry_computation_layout,
+                                               &target_machine_features);
     EXPECT_IS_OK(layout_assignment.Run(module).status());
   }
 };
@@ -329,9 +328,8 @@
       [](int64_t shape_size) {
         return cpu::TargetMachineFeatures::kEigenExpectedTensorAlignment;
       });
-  cpu::CpuLayoutAssignment layout_assignment(
-      &computation_layout, LayoutAssignment::InstructionCanChangeLayout,
-      &target_machine_features);
+  cpu::CpuLayoutAssignment layout_assignment(&computation_layout,
+                                             &target_machine_features);
   TF_ASSIGN_OR_RETURN(result.layout_assignment_changed_something,
                       layout_assignment.Run(module));
 
diff --git a/tensorflow/compiler/xla/service/gpu/gpu_compiler.cc b/tensorflow/compiler/xla/service/gpu/gpu_compiler.cc
index 5021cd2..b78c8d8 100644
--- a/tensorflow/compiler/xla/service/gpu/gpu_compiler.cc
+++ b/tensorflow/compiler/xla/service/gpu/gpu_compiler.cc
@@ -419,8 +419,7 @@
     pipeline.AddPass<FlattenCallGraph>();
     ChannelLayoutConstraints layout_constraints;
     pipeline.AddPass<GpuLayoutAssignment>(
-        hlo_module->mutable_entry_computation_layout(),
-        LayoutAssignment::InstructionCanChangeLayout, stream_exec,
+        hlo_module->mutable_entry_computation_layout(), stream_exec,
         &layout_constraints);
     TF_RETURN_IF_ERROR(pipeline.Run(hlo_module).status());
   }
@@ -434,8 +433,6 @@
     // We try to split variadic ops with many parameters into several such ops
     // to avoid exceeding the parameter space.
     fusion.AddPass<VariadicOpSplitter>();
-    /* TODO(b/117531509): Use LayoutAssignment::InstructionCanChangeLayout after
-     * fixing the ticket. */
     fusion.AddInvariantCheckerDebug<HloVerifier>(
         /*layout_sensitive=*/true,
         /*allow_mixed_precision=*/false,
@@ -502,8 +499,6 @@
   // (b/27180329). Therefore, in that case, we set the output to be a copy of
   // the parameter.
   HloPassPipeline pipeline("GPU-ir-emit-prepare");
-  /* TODO(b/117531509): Use LayoutAssignment::InstructionCanChangeLayout after
-   * fixing the ticket. */
   pipeline.AddInvariantCheckerDebug<HloVerifier>(
       /*layout_sensitive=*/true,
       /*allow_mixed_precision=*/false,
@@ -529,8 +524,6 @@
     HloModule* hlo_module, se::StreamExecutor* stream_exec,
     se::DeviceMemoryAllocator* device_allocator) {
   HloPassPipeline pipeline("post-layout_assignment");
-  /* TODO(b/117531509): Use LayoutAssignment::InstructionCanChangeLayout after
-   * fixing the ticket. */
   pipeline.AddInvariantCheckerDebug<HloVerifier>(
       /*layout_sensitive=*/true,
       /*allow_mixed_precision=*/false,
diff --git a/tensorflow/compiler/xla/service/gpu/gpu_layout_assignment.h b/tensorflow/compiler/xla/service/gpu/gpu_layout_assignment.h
index 1e6a180..e750453 100644
--- a/tensorflow/compiler/xla/service/gpu/gpu_layout_assignment.h
+++ b/tensorflow/compiler/xla/service/gpu/gpu_layout_assignment.h
@@ -31,12 +31,9 @@
  public:
   explicit GpuLayoutAssignment(
       ComputationLayout* entry_computation_layout,
-      std::function<bool(const HloInstruction*)>
-          instruction_can_change_layout_func,
       se::StreamExecutor* stream_executor,
       ChannelLayoutConstraints* channel_constraints = nullptr)
       : LayoutAssignment(entry_computation_layout,
-                         std::move(instruction_can_change_layout_func),
                          channel_constraints),
         stream_executor_(stream_executor) {}
   ~GpuLayoutAssignment() override {}
diff --git a/tensorflow/compiler/xla/service/gpu/gpu_layout_assignment_test.cc b/tensorflow/compiler/xla/service/gpu/gpu_layout_assignment_test.cc
index 6ae0f1e..bc9ed9a 100644
--- a/tensorflow/compiler/xla/service/gpu/gpu_layout_assignment_test.cc
+++ b/tensorflow/compiler/xla/service/gpu/gpu_layout_assignment_test.cc
@@ -77,8 +77,7 @@
             ShapeLayout(result_shape_with_layout);
 
         GpuLayoutAssignment layout_assignment(
-            &computation_layout, LayoutAssignment::InstructionCanChangeLayout,
-            backend().default_stream_executor());
+            &computation_layout, backend().default_stream_executor());
         EXPECT_TRUE(layout_assignment.Run(module.get()).ValueOrDie());
 
         for (const HloInstruction* operand : add->operands()) {
@@ -166,8 +165,7 @@
       }
 
       GpuLayoutAssignment layout_assignment(
-          &computation_layout, LayoutAssignment::InstructionCanChangeLayout,
-          backend().default_stream_executor());
+          &computation_layout, backend().default_stream_executor());
       EXPECT_TRUE(layout_assignment.Run(module.get()).ValueOrDie());
 
       // The first operand to batchnorm should have the same layout as the
@@ -237,8 +235,7 @@
       }
 
       GpuLayoutAssignment layout_assignment(
-          &computation_layout, LayoutAssignment::InstructionCanChangeLayout,
-          backend().default_stream_executor());
+          &computation_layout, backend().default_stream_executor());
       EXPECT_TRUE(layout_assignment.Run(module.get()).ValueOrDie());
 
       // The first operand to batchnorm should have the same layout as the
@@ -319,8 +316,7 @@
         }
 
         GpuLayoutAssignment layout_assignment(
-            &computation_layout, LayoutAssignment::InstructionCanChangeLayout,
-            backend().default_stream_executor());
+            &computation_layout, backend().default_stream_executor());
         EXPECT_TRUE(layout_assignment.Run(module.get()).ValueOrDie());
 
         // The first and fourth operands to the batchnorm call should have the
@@ -355,9 +351,8 @@
   ComputationLayout computation_layout(
       module->entry_computation()->ComputeProgramShape(),
       /*ignore_layouts=*/false);
-  GpuLayoutAssignment layout_assignment(
-      &computation_layout, LayoutAssignment::InstructionCanChangeLayout,
-      backend().default_stream_executor());
+  GpuLayoutAssignment layout_assignment(&computation_layout,
+                                        backend().default_stream_executor());
   EXPECT_TRUE(layout_assignment.Run(module.get()).ValueOrDie());
 
   Shape expected_shape =
@@ -393,9 +388,8 @@
   ComputationLayout computation_layout(
       module->entry_computation()->ComputeProgramShape(),
       /*ignore_layouts=*/false);
-  GpuLayoutAssignment layout_assignment(
-      &computation_layout, LayoutAssignment::InstructionCanChangeLayout,
-      backend().default_stream_executor());
+  GpuLayoutAssignment layout_assignment(&computation_layout,
+                                        backend().default_stream_executor());
   EXPECT_TRUE(layout_assignment.Run(module.get()).ValueOrDie());
 
   Shape expected_shape = ShapeUtil::MakeShapeWithLayout(F32, {3, 2}, {1, 0});
@@ -420,9 +414,8 @@
   ComputationLayout computation_layout(
       module->entry_computation()->ComputeProgramShape(),
       /*ignore_layouts=*/false);
-  GpuLayoutAssignment layout_assignment(
-      &computation_layout, LayoutAssignment::InstructionCanChangeLayout,
-      backend().default_stream_executor());
+  GpuLayoutAssignment layout_assignment(&computation_layout,
+                                        backend().default_stream_executor());
   EXPECT_TRUE(layout_assignment.Run(module.get()).ValueOrDie());
 
   Shape expected_shape = ShapeUtil::MakeShapeWithLayout(C64, {8, 32}, {1, 0});
diff --git a/tensorflow/compiler/xla/service/interpreter/compiler.cc b/tensorflow/compiler/xla/service/interpreter/compiler.cc
index 9d7f79b..92f2243 100644
--- a/tensorflow/compiler/xla/service/interpreter/compiler.cc
+++ b/tensorflow/compiler/xla/service/interpreter/compiler.cc
@@ -89,8 +89,7 @@
   pipeline.AddPass<ComparisonExpander>();
   pipeline.AddPass<TriangularSolveExpander>();
   pipeline.AddPass<LayoutAssignment>(
-      hlo_module->mutable_entry_computation_layout(),
-      LayoutAssignment::InstructionCanChangeLayout);
+      hlo_module->mutable_entry_computation_layout());
 
   return pipeline.Run(hlo_module).status();
 }
diff --git a/tensorflow/compiler/xla/service/layout_assignment.cc b/tensorflow/compiler/xla/service/layout_assignment.cc
index 870ddfa..fe2388b 100644
--- a/tensorflow/compiler/xla/service/layout_assignment.cc
+++ b/tensorflow/compiler/xla/service/layout_assignment.cc
@@ -1111,16 +1111,12 @@
 
 LayoutAssignment::LayoutAssignment(
     ComputationLayout* entry_computation_layout,
-    std::function<bool(const HloInstruction*)>
-        instruction_can_change_layout_func,
     ChannelLayoutConstraints* channel_constraints,
     bool reverse_computation_order)
     : entry_computation_layout_(entry_computation_layout),
       saved_entry_computation_layout_(*entry_computation_layout),
       reverse_computation_order_(reverse_computation_order),
-      channel_layout_constraints_(channel_constraints),
-      instruction_can_change_layout_func_(
-          std::move(instruction_can_change_layout_func)) {
+      channel_layout_constraints_(channel_constraints) {
   if (channel_layout_constraints_ != nullptr) {
     // Save a copy of the input ChannelLayoutConstraints so that we can reset it
     // if we have to undo previous operations (ClearPreviousPassSideEffects()).
@@ -1138,7 +1134,7 @@
   CHECK(operand->shape().IsArray());
   if (!ShapeUtil::IsScalar(operand->shape()) &&
       operand->shape().rank() == instruction->shape().rank() &&
-      !instruction_can_change_layout_func_(instruction)) {
+      !InstructionCanChangeLayoutInstance(instruction)) {
     // Propagate the result layout to the operand layout if the instruction
     // requires the same layout out for the result and the operand.
     //
@@ -1206,7 +1202,7 @@
 
   if (!ShapeUtil::IsScalar(operand->shape()) &&
       operand->shape().rank() == user->shape().rank() &&
-      !instruction_can_change_layout_func_(user)) {
+      !InstructionCanChangeLayoutInstance(user)) {
     // Assign users the same layout as the operand.
     return absl::make_unique<Layout>(operand_layout);
   }
@@ -1431,7 +1427,7 @@
           /*mandatory=*/false));
     }
   }
-  if (instruction_can_change_layout_func_(user) && !user->shape().IsArray()) {
+  if (InstructionCanChangeLayoutInstance(user) && !user->shape().IsArray()) {
     return Status::OK();
   }
 
@@ -1449,7 +1445,7 @@
 
   // Propagate layouts between operands of the same instruction. This is a
   // constraint on non-layout-changing instructions.
-  if (!instruction_can_change_layout_func_(user)) {
+  if (!InstructionCanChangeLayoutInstance(user)) {
     // Only propgate the layout of the largest concatenate operand.
     if (user->opcode() == HloOpcode::kConcatenate) {
       for (int64_t operand_no = 0; operand_no < user->operand_count();
@@ -1597,7 +1593,7 @@
     if (IsAtMostRank1(operand->shape())) {
       continue;
     }
-    if (!instruction_can_change_layout_func_(instruction)) {
+    if (!InstructionCanChangeLayoutInstance(instruction)) {
       // Copy the layout to the operand.
       if (buffer.IsArray() && operand->shape().IsArray() &&
           operand->shape().rank() ==
@@ -2466,6 +2462,11 @@
   }
 }
 
+bool LayoutAssignment::InstructionCanChangeLayoutInstance(
+    const HloInstruction* instruction) {
+  return InstructionCanChangeLayout(instruction);
+}
+
 /* static */
 bool LayoutAssignment::IsAtMostRank1(const Shape& shape) {
   if (shape.IsArray()) {
diff --git a/tensorflow/compiler/xla/service/layout_assignment.h b/tensorflow/compiler/xla/service/layout_assignment.h
index b62eec3..178b2c3 100644
--- a/tensorflow/compiler/xla/service/layout_assignment.h
+++ b/tensorflow/compiler/xla/service/layout_assignment.h
@@ -308,11 +308,6 @@
   // entry_computation_layout is modified to populate a layout for the result in
   // the case that no particular layout is requested.
   //
-  // instruction_can_change_layout_func is a function object that determines
-  // whether an instruction can change layouts. An instruction not being able to
-  // change layout means that it requires operands with the same rank as the
-  // output to have the same layout as the output.
-  //
   // channel_constraints is both an input and output. Any sends or recvs that
   // are present in channel_constraints will be laid out as constrained. Any
   // unconstrained sends or recvs will be laid out as locally optimal and their
@@ -322,8 +317,6 @@
   // within any module passed to `Run`.
   explicit LayoutAssignment(
       ComputationLayout* entry_computation_layout,
-      std::function<bool(const HloInstruction*)>
-          instruction_can_change_layout_func = InstructionCanChangeLayout,
       ChannelLayoutConstraints* channel_constraints = nullptr,
       bool reverse_computation_order = false);
   ~LayoutAssignment() override {}
@@ -400,6 +393,11 @@
       const Layout& operand_layout, const HloInstruction* user,
       int64_t operand_no);
 
+  // Convenient wrapper for InstructionCanChangeLayout which can be overridden
+  // in subclasses.
+  virtual bool InstructionCanChangeLayoutInstance(
+      const HloInstruction* instruction);
+
  private:
   // Initializes the layout assignment object for a new Run() call.
   Status Init();
diff --git a/tensorflow/compiler/xla/service/layout_assignment_test.cc b/tensorflow/compiler/xla/service/layout_assignment_test.cc
index 374af0a..f0874bc 100644
--- a/tensorflow/compiler/xla/service/layout_assignment_test.cc
+++ b/tensorflow/compiler/xla/service/layout_assignment_test.cc
@@ -54,7 +54,7 @@
   void AssignLayouts(HloModule* m, ComputationLayout* entry_computation_layout,
                      ChannelLayoutConstraints* channel_constraints = nullptr) {
     LayoutAssignment layout_assignment(
-        entry_computation_layout, LayoutAssignment::InstructionCanChangeLayout,
+        entry_computation_layout,
         /*channel_constraints=*/channel_constraints);
     EXPECT_IS_OK(layout_assignment.Run(m).status());
   }
@@ -1329,9 +1329,8 @@
         ->mutable_result_layout()
         ->SetToDefaultLayout();
   }
-  LayoutAssignment layout_assignment(
-      m->mutable_entry_computation_layout(),
-      LayoutAssignment::InstructionCanChangeLayout, channel_constraints);
+  LayoutAssignment layout_assignment(m->mutable_entry_computation_layout(),
+                                     channel_constraints);
   return layout_assignment.Run(m).status();
 }
 
@@ -1524,7 +1523,7 @@
   std::cerr << computation_layout.ToString();
   ChannelLayoutConstraints channel_constraints;
   LayoutAssignment layout_assignment(
-      &computation_layout, LayoutAssignment::InstructionCanChangeLayout,
+      &computation_layout,
       /*channel_constraints=*/&channel_constraints,
       /* reverse_computation_order = */ true);
   EXPECT_IS_OK(layout_assignment.Run(m.get()).status());