[XLA] Allow CustomCall to specify aliasing buffers

PiperOrigin-RevId: 333445659
Change-Id: I444cc5f33793de365ad66cfd828f6e4eeb489065
diff --git a/tensorflow/compiler/mlir/xla/ir/mlir_hlo_builder.cc b/tensorflow/compiler/mlir/xla/ir/mlir_hlo_builder.cc
index ac5e01a..3da8935 100644
--- a/tensorflow/compiler/mlir/xla/ir/mlir_hlo_builder.cc
+++ b/tensorflow/compiler/mlir/xla/ir/mlir_hlo_builder.cc
@@ -135,12 +135,16 @@
     const string& call_target_name, absl::Span<const XlaOp> operands,
     const Shape& shape, const string& opaque,
     absl::optional<absl::Span<const Shape>> operand_shapes_with_layout,
-    bool has_side_effect) {
+    bool has_side_effect,
+    absl::Span<const std::pair<ShapeIndex, std::pair<int64, ShapeIndex>>>
+        output_operand_aliasing) {
   if (operand_shapes_with_layout.has_value())
     return Unimplemented(
         "CustomCall doesn't support operands shapes with layout");
   TF_ASSIGN_OR_RETURN(mlir::Type ty, ConvertShapeToType<mlir::RankedTensorType>(
                                          shape, builder_));
+  TF_RET_CHECK(output_operand_aliasing.empty())
+      << "MLIR CustomCallOp does not support output_operand_aliasing yet";
   auto op = builder_.create<mlir::mhlo::CustomCallOp>(
       loc_, ty, GetValues(operands), builder_.getStringAttr(call_target_name),
       /*has_side_effect=*/builder_.getBoolAttr(has_side_effect),
diff --git a/tensorflow/compiler/mlir/xla/ir/mlir_hlo_builder.h b/tensorflow/compiler/mlir/xla/ir/mlir_hlo_builder.h
index 00b7aa4..59b4bc7 100644
--- a/tensorflow/compiler/mlir/xla/ir/mlir_hlo_builder.h
+++ b/tensorflow/compiler/mlir/xla/ir/mlir_hlo_builder.h
@@ -135,7 +135,9 @@
       const string& call_target_name, absl::Span<const XlaOp> operands,
       const Shape& shape, const string& opaque,
       absl::optional<absl::Span<const Shape>> operand_shapes_with_layout,
-      bool has_side_effect) override;
+      bool has_side_effect,
+      absl::Span<const std::pair<ShapeIndex, std::pair<int64, ShapeIndex>>>
+          output_operand_aliasing) override;
 
   StatusOr<XlaOp> ReduceInternal(
       const Shape& shape, absl::Span<const XlaOp> all_operands,
diff --git a/tensorflow/compiler/xla/client/xla_builder.cc b/tensorflow/compiler/xla/client/xla_builder.cc
index 3e2a4eb..c7bbf9f 100644
--- a/tensorflow/compiler/xla/client/xla_builder.cc
+++ b/tensorflow/compiler/xla/client/xla_builder.cc
@@ -1707,7 +1707,9 @@
     const string& call_target_name, absl::Span<const XlaOp> operands,
     const Shape& shape, const string& opaque,
     absl::optional<absl::Span<const Shape>> operand_shapes_with_layout,
-    bool has_side_effect) {
+    bool has_side_effect,
+    absl::Span<const std::pair<ShapeIndex, std::pair<int64, ShapeIndex>>>
+        output_operand_aliasing) {
   return ReportErrorOrReturn([&]() -> StatusOr<XlaOp> {
     if (absl::StartsWith(call_target_name, "$")) {
       return InvalidArgument(
@@ -1739,7 +1741,8 @@
       }
     }
     return CustomCallInternal(call_target_name, operands, shape, opaque,
-                              operand_shapes_with_layout, has_side_effect);
+                              operand_shapes_with_layout, has_side_effect,
+                              output_operand_aliasing);
   });
 }
 
@@ -1747,7 +1750,9 @@
     const string& call_target_name, absl::Span<const XlaOp> operands,
     const Shape& shape, const string& opaque,
     absl::optional<absl::Span<const Shape>> operand_shapes_with_layout,
-    bool has_side_effect) {
+    bool has_side_effect,
+    absl::Span<const std::pair<ShapeIndex, std::pair<int64, ShapeIndex>>>
+        output_operand_aliasing) {
   HloInstructionProto instr;
   *instr.mutable_shape() = shape.ToProto();
   instr.set_custom_call_target(call_target_name);
@@ -1759,6 +1764,16 @@
     }
   }
   instr.set_custom_call_has_side_effect(has_side_effect);
+  for (const auto& pair : output_operand_aliasing) {
+    auto aliasing = instr.add_custom_call_output_operand_aliasing();
+    aliasing->set_operand_index(pair.second.first);
+    for (int64 index : pair.second.second) {
+      aliasing->add_operand_shape_index(index);
+    }
+    for (int64 index : pair.first) {
+      aliasing->add_output_shape_index(index);
+    }
+  }
   return AddInstruction(std::move(instr), HloOpcode::kCustomCall, operands);
 }
 
@@ -1766,7 +1781,9 @@
     const string& call_target_name, absl::Span<const XlaOp> operands,
     const XlaComputation& computation, const Shape& shape, const string& opaque,
     absl::optional<absl::Span<const Shape>> operand_shapes_with_layout,
-    bool has_side_effect) {
+    bool has_side_effect,
+    absl::Span<const std::pair<ShapeIndex, std::pair<int64, ShapeIndex>>>
+        output_operand_aliasing) {
   return ReportErrorOrReturn([&]() -> StatusOr<XlaOp> {
     HloInstructionProto instr;
     if (absl::StartsWith(call_target_name, "$")) {
@@ -1804,6 +1821,16 @@
       }
     }
     AddCalledComputation(computation, &instr);
+    for (const auto& pair : output_operand_aliasing) {
+      auto aliasing = instr.add_custom_call_output_operand_aliasing();
+      aliasing->set_operand_index(pair.second.first);
+      for (int64 index : pair.second.second) {
+        aliasing->add_operand_shape_index(index);
+      }
+      for (int64 index : pair.first) {
+        aliasing->add_output_shape_index(index);
+      }
+    }
     return AddInstruction(std::move(instr), HloOpcode::kCustomCall, operands);
   });
 }
@@ -3861,31 +3888,39 @@
   return builder->Call(computation, operands);
 }
 
-XlaOp CustomCall(XlaBuilder* builder, const string& call_target_name,
-                 absl::Span<const XlaOp> operands, const Shape& shape,
-                 const string& opaque, bool has_side_effect) {
+XlaOp CustomCall(
+    XlaBuilder* builder, const string& call_target_name,
+    absl::Span<const XlaOp> operands, const Shape& shape, const string& opaque,
+    bool has_side_effect,
+    absl::Span<const std::pair<ShapeIndex, std::pair<int64, ShapeIndex>>>
+        output_operand_aliasing) {
   return builder->CustomCall(call_target_name, operands, shape, opaque,
                              /*operand_shapes_with_layout=*/absl::nullopt,
-                             has_side_effect);
+                             has_side_effect, output_operand_aliasing);
 }
 
-XlaOp CustomCallWithComputation(XlaBuilder* builder,
-                                const string& call_target_name,
-                                absl::Span<const XlaOp> operands,
-                                const XlaComputation& computation,
-                                const Shape& shape, const string& opaque,
-                                bool has_side_effect) {
-  return builder->CustomCall(
-      call_target_name, operands, computation, shape, opaque,
-      /*operand_shapes_with_layout=*/absl::nullopt, has_side_effect);
+XlaOp CustomCallWithComputation(
+    XlaBuilder* builder, const string& call_target_name,
+    absl::Span<const XlaOp> operands, const XlaComputation& computation,
+    const Shape& shape, const string& opaque, bool has_side_effect,
+    absl::Span<const std::pair<ShapeIndex, std::pair<int64, ShapeIndex>>>
+        output_operand_aliasing) {
+  return builder->CustomCall(call_target_name, operands, computation, shape,
+                             opaque,
+                             /*operand_shapes_with_layout=*/absl::nullopt,
+                             has_side_effect, output_operand_aliasing);
 }
 
-XlaOp CustomCallWithLayout(XlaBuilder* builder, const string& call_target_name,
-                           absl::Span<const XlaOp> operands, const Shape& shape,
-                           absl::Span<const Shape> operand_shapes_with_layout,
-                           const string& opaque, bool has_side_effect) {
+XlaOp CustomCallWithLayout(
+    XlaBuilder* builder, const string& call_target_name,
+    absl::Span<const XlaOp> operands, const Shape& shape,
+    absl::Span<const Shape> operand_shapes_with_layout, const string& opaque,
+    bool has_side_effect,
+    absl::Span<const std::pair<ShapeIndex, std::pair<int64, ShapeIndex>>>
+        output_operand_aliasing) {
   return builder->CustomCall(call_target_name, operands, shape, opaque,
-                             operand_shapes_with_layout, has_side_effect);
+                             operand_shapes_with_layout, has_side_effect,
+                             output_operand_aliasing);
 }
 
 XlaOp Complex(const XlaOp lhs, const XlaOp rhs,
diff --git a/tensorflow/compiler/xla/client/xla_builder.h b/tensorflow/compiler/xla/client/xla_builder.h
index cd9809c..55bcd86 100644
--- a/tensorflow/compiler/xla/client/xla_builder.h
+++ b/tensorflow/compiler/xla/client/xla_builder.h
@@ -593,7 +593,9 @@
       const string& call_target_name, absl::Span<const XlaOp> operands,
       const Shape& shape_with_layout, const string& opaque,
       absl::optional<absl::Span<const Shape>> operand_shapes_with_layout,
-      bool has_side_effect);
+      bool has_side_effect,
+      absl::Span<const std::pair<ShapeIndex, std::pair<int64, ShapeIndex>>>
+          output_operand_aliasing);
 
   // Internal version of CustomCall without computation that doesn't do op
   // specific error handling and expects arguments to be legal. CustomCall
@@ -602,14 +604,18 @@
       const string& call_target_name, absl::Span<const XlaOp> operands,
       const Shape& shape_with_layout, const string& opaque,
       absl::optional<absl::Span<const Shape>> operand_shapes_with_layout,
-      bool has_side_effect);
+      bool has_side_effect,
+      absl::Span<const std::pair<ShapeIndex, std::pair<int64, ShapeIndex>>>
+          output_operand_aliasing);
 
   XlaOp CustomCall(
       const string& call_target_name, absl::Span<const XlaOp> operands,
       const XlaComputation& computation, const Shape& shape_with_layout,
       const string& opaque,
       absl::optional<absl::Span<const Shape>> operand_shapes_with_layout,
-      bool has_side_effect);
+      bool has_side_effect,
+      absl::Span<const std::pair<ShapeIndex, std::pair<int64, ShapeIndex>>>
+          output_operand_aliasing);
 
   XlaOp Reduce(XlaOp operand, XlaOp init_value,
                const XlaComputation& computation,
@@ -1058,18 +1064,25 @@
                       const string& outfeed_config);
   friend XlaOp Call(XlaBuilder* builder, const XlaComputation& computation,
                     absl::Span<const XlaOp> operands);
-  friend XlaOp CustomCall(XlaBuilder* builder, const string& call_target_name,
-                          absl::Span<const XlaOp> operands, const Shape& shape,
-                          const string& opaque, bool has_side_effect);
+  friend XlaOp CustomCall(
+      XlaBuilder* builder, const string& call_target_name,
+      absl::Span<const XlaOp> operands, const Shape& shape,
+      const string& opaque, bool has_side_effect,
+      absl::Span<const std::pair<ShapeIndex, std::pair<int64, ShapeIndex>>>
+          output_operand_aliasing);
   friend XlaOp CustomCallWithComputation(
       XlaBuilder* builder, const string& call_target_name,
       absl::Span<const XlaOp> operands, const XlaComputation& computation,
-      const Shape& shape, const string& opaque, bool has_side_effect);
+      const Shape& shape, const string& opaque, bool has_side_effect,
+      absl::Span<const std::pair<ShapeIndex, std::pair<int64, ShapeIndex>>>
+          output_operand_aliasing);
   friend XlaOp CustomCallWithLayout(
       XlaBuilder* builder, const string& call_target_name,
       absl::Span<const XlaOp> operands, const Shape& shape_with_layout,
       absl::Span<const Shape> operand_shapes_with_layout, const string& opaque,
-      bool has_side_effect);
+      bool has_side_effect,
+      absl::Span<const std::pair<ShapeIndex, std::pair<int64, ShapeIndex>>>
+          output_operand_aliasing);
   friend XlaOp Complex(XlaOp real, XlaOp imag,
                        absl::Span<const int64> broadcast_dimensions);
   friend XlaOp Conj(XlaOp operand);
@@ -1805,30 +1818,39 @@
 // backend, a call instruction is emitted which targets a symbol with the name
 // |call_target_name|.  |call_target_name| and |opaque| can arbitrary strings,
 // but |call_target_name| should be short as it may be used in labels. |opaque|
-// can encode arbitrarily large amounts of information.
-XlaOp CustomCall(XlaBuilder* builder, const string& call_target_name,
-                 absl::Span<const XlaOp> operands, const Shape& shape,
-                 const string& opaque = "", bool has_side_effect = false);
+// can encode arbitrarily large amounts of information. |has_side_effect|
+// specifies whether the instruction can have side effects.
+// |output_operand_aliasing| specifies a list of output/operand buffer pairs
+// that alias each other, where the output buffer is represented as a
+// ShapeIndex, and the operand buffer is represented as the operand index and
+// the ShapeIndex.
+XlaOp CustomCall(
+    XlaBuilder* builder, const string& call_target_name,
+    absl::Span<const XlaOp> operands, const Shape& shape,
+    const string& opaque = "", bool has_side_effect = false,
+    absl::Span<const std::pair<ShapeIndex, std::pair<int64, ShapeIndex>>>
+        output_operand_aliasing = {});
 
 // Overload which constructs a custom call that applies an Xla computation.
-XlaOp CustomCallWithComputation(XlaBuilder* builder,
-                                const string& call_target_name,
-                                absl::Span<const XlaOp> operands,
-                                const XlaComputation& computation,
-                                const Shape& shape, const string& opaque = "",
-                                bool has_side_effect = false);
+XlaOp CustomCallWithComputation(
+    XlaBuilder* builder, const string& call_target_name,
+    absl::Span<const XlaOp> operands, const XlaComputation& computation,
+    const Shape& shape, const string& opaque = "", bool has_side_effect = false,
+    absl::Span<const std::pair<ShapeIndex, std::pair<int64, ShapeIndex>>>
+        output_operand_aliasing = {});
 
 // Overload which constructs a custom call with fixed layouts. The operands will
 // have the layouts specified by |operand_shapes_with_layout| when provided to
 // external code, and the external code is expected to produce a result with the
 // layout specified by |shape_with_layout|. All shapes in |shape_with_layout|
 // and |operand_shapes_with_layout| must have layouts.
-XlaOp CustomCallWithLayout(XlaBuilder* builder, const string& call_target_name,
-                           absl::Span<const XlaOp> operands,
-                           const Shape& shape_with_layout,
-                           absl::Span<const Shape> operand_shapes_with_layout,
-                           const string& opaque = "",
-                           bool has_side_effect = false);
+XlaOp CustomCallWithLayout(
+    XlaBuilder* builder, const string& call_target_name,
+    absl::Span<const XlaOp> operands, const Shape& shape_with_layout,
+    absl::Span<const Shape> operand_shapes_with_layout,
+    const string& opaque = "", bool has_side_effect = false,
+    absl::Span<const std::pair<ShapeIndex, std::pair<int64, ShapeIndex>>>
+        output_operand_aliasing = {});
 
 // The following methods enqueue element-wise binary arithmetic operations
 // onto the computation. The shapes of the operands have to match unless one
diff --git a/tensorflow/compiler/xla/service/BUILD b/tensorflow/compiler/xla/service/BUILD
index 170f774..b9aff29 100644
--- a/tensorflow/compiler/xla/service/BUILD
+++ b/tensorflow/compiler/xla/service/BUILD
@@ -3492,6 +3492,8 @@
     hdrs = ["memory_space_assignment_utils.h"],
     deps = [
         ":heap_simulator",
+        ":hlo",
+        ":hlo_casting_utils",
     ],
 )
 
diff --git a/tensorflow/compiler/xla/service/hlo.proto b/tensorflow/compiler/xla/service/hlo.proto
index c3a7b3a..ac94b2e 100644
--- a/tensorflow/compiler/xla/service/hlo.proto
+++ b/tensorflow/compiler/xla/service/hlo.proto
@@ -35,7 +35,7 @@
 option cc_enable_arenas = true;
 
 // Serialization of HloInstruction.
-// Next ID: 74
+// Next ID: 75
 message HloInstructionProto {
   reserved 10;
   reserved "parameter_name";
@@ -232,6 +232,11 @@
   // kCustomCall.
   bool custom_call_has_side_effect = 65;
 
+  // A list of CustomCallOutputOperandAliasing pairs that specifies aliasing
+  // buffers between output and operands for kCustomCall.
+  repeated xla.CustomCallOutputOperandAliasing
+      custom_call_output_operand_aliasing = 74;
+
   // The delta value for kRngGetAndUpdateState.
   int64 delta = 66;
 
diff --git a/tensorflow/compiler/xla/service/hlo_dataflow_analysis.cc b/tensorflow/compiler/xla/service/hlo_dataflow_analysis.cc
index 72899ff..bc1063f 100644
--- a/tensorflow/compiler/xla/service/hlo_dataflow_analysis.cc
+++ b/tensorflow/compiler/xla/service/hlo_dataflow_analysis.cc
@@ -432,6 +432,23 @@
   return changed;
 }
 
+bool HloDataflowAnalysis::UpdateCustomCallValueSet(
+    HloInstruction* custom_call) {
+  CHECK_EQ(custom_call->opcode(), HloOpcode::kCustomCall);
+  bool changed = false;
+  for (const auto& aliasing : Cast<HloCustomCallInstruction>(custom_call)
+                                  ->output_to_operand_aliasing()) {
+    const HloValueSet& operand_value_set = GetValueSet(
+        custom_call->operand(aliasing.second.first), aliasing.second.second);
+    HloValueSet& value_set = GetValueSet(custom_call, aliasing.first);
+    if (value_set != operand_value_set) {
+      value_set = operand_value_set;
+      changed = true;
+    }
+  }
+  return changed;
+}
+
 bool HloDataflowAnalysis::UpdateCopyStartValueSet(HloInstruction* copy_start) {
   CHECK_EQ(copy_start->opcode(), HloOpcode::kCopyStart);
   bool changed = false;
@@ -757,6 +774,8 @@
       return UpdateAddDependencyValueSet(instruction);
     case HloOpcode::kBitcast:
       return UpdateBitcastValueSet(instruction);
+    case HloOpcode::kCustomCall:
+      return UpdateCustomCallValueSet(instruction);
     case HloOpcode::kSetDimensionSize:
       return UpdateSetDimensionSizeValueSet(instruction);
     case HloOpcode::kDomain:
@@ -1018,6 +1037,22 @@
           define_value_at(/*index=*/{1});
           define_value_at(/*index=*/{2});
           break;
+        case HloOpcode::kCustomCall: {
+          absl::flat_hash_set<ShapeIndex> aliasing_indices;
+          for (const auto& aliasing :
+               Cast<HloCustomCallInstruction>(instruction)
+                   ->output_to_operand_aliasing()) {
+            aliasing_indices.insert(aliasing.first);
+          }
+          ShapeUtil::ForEachSubshape(
+              instruction->shape(),
+              [&](const Shape& /*subshape*/, const ShapeIndex& index) {
+                if (!aliasing_indices.contains(index)) {
+                  define_value_at(index);
+                }
+              });
+          break;
+        }
         default:
           define_all_values();
           break;
diff --git a/tensorflow/compiler/xla/service/hlo_dataflow_analysis.h b/tensorflow/compiler/xla/service/hlo_dataflow_analysis.h
index ffa307d..c3aad04 100644
--- a/tensorflow/compiler/xla/service/hlo_dataflow_analysis.h
+++ b/tensorflow/compiler/xla/service/hlo_dataflow_analysis.h
@@ -216,6 +216,7 @@
   bool UpdateCallValueSet(HloInstruction* call);
   bool UpdateConditionalValueSet(HloInstruction* conditional);
   bool UpdateCopyValueSet(HloInstruction* copy);
+  bool UpdateCustomCallValueSet(HloInstruction* custom_call);
   bool UpdateDomainValueSet(HloInstruction* domain);
   bool UpdateGetTupleElementValueSet(HloInstruction* gte);
   bool UpdateParameterValueSet(HloInstruction* parameter);
diff --git a/tensorflow/compiler/xla/service/hlo_instruction.cc b/tensorflow/compiler/xla/service/hlo_instruction.cc
index 251261a..41488dc 100644
--- a/tensorflow/compiler/xla/service/hlo_instruction.cc
+++ b/tensorflow/compiler/xla/service/hlo_instruction.cc
@@ -568,6 +568,19 @@
           std::max(static_cast<int64>(proto.batch_group_count()), int64{1}));
       custom_call_instr->set_custom_call_has_side_effect(
           proto.custom_call_has_side_effect());
+      std::vector<std::pair<ShapeIndex, std::pair<int64, ShapeIndex>>>
+          output_to_operand_aliasing;
+      for (const auto& aliasing : proto.custom_call_output_operand_aliasing()) {
+        output_to_operand_aliasing.emplace_back(
+            ShapeIndex(aliasing.output_shape_index().begin(),
+                       aliasing.output_shape_index().end()),
+            std::pair<int64, ShapeIndex>{
+                aliasing.operand_index(),
+                ShapeIndex(aliasing.operand_shape_index().begin(),
+                           aliasing.operand_shape_index().end())});
+      }
+      custom_call_instr->set_output_to_operand_aliasing(
+          std::move(output_to_operand_aliasing));
       break;
     }
     case HloOpcode::kPad:
diff --git a/tensorflow/compiler/xla/service/hlo_instructions.cc b/tensorflow/compiler/xla/service/hlo_instructions.cc
index c4c31db..45b2d88 100644
--- a/tensorflow/compiler/xla/service/hlo_instructions.cc
+++ b/tensorflow/compiler/xla/service/hlo_instructions.cc
@@ -2395,6 +2395,16 @@
     }
   }
   proto.set_custom_call_has_side_effect(custom_call_has_side_effect_);
+  for (const auto& pair : output_to_operand_aliasing_) {
+    auto aliasing = proto.add_custom_call_output_operand_aliasing();
+    aliasing->set_operand_index(pair.second.first);
+    for (int64 index : pair.first) {
+      aliasing->add_output_shape_index(index);
+    }
+    for (int64 index : pair.second.second) {
+      aliasing->add_operand_shape_index(index);
+    }
+  }
   return proto;
 }
 
@@ -2432,6 +2442,16 @@
   if (custom_call_has_side_effect_) {
     extra.push_back("custom_call_has_side_effect=true");
   }
+  if (!output_to_operand_aliasing_.empty()) {
+    std::vector<string> pair_strings;
+    for (const auto& pair : output_to_operand_aliasing_) {
+      pair_strings.push_back(StrCat(pair.first.ToString(), ": (",
+                                    pair.second.first, ", ",
+                                    pair.second.second.ToString(), ")"));
+    }
+    extra.push_back(StrCat("output_to_operand_aliasing={",
+                           StrJoin(pair_strings, ", "), "}"));
+  }
   return extra;
 }
 
@@ -2475,6 +2495,10 @@
       casted_other.custom_call_has_side_effect()) {
     return false;
   }
+  if (output_to_operand_aliasing_ !=
+      casted_other.output_to_operand_aliasing()) {
+    return false;
+  }
   // Note: backend_config comparison is done in Identical, which is the
   // intended/exposed way to compare computations, and so not repeated here.
   return custom_call_target_ == casted_other.custom_call_target_;
@@ -2499,6 +2523,7 @@
   cloned->set_feature_group_count(feature_group_count_);
   cloned->set_batch_group_count(batch_group_count_);
   cloned->set_custom_call_has_side_effect(custom_call_has_side_effect_);
+  cloned->set_output_to_operand_aliasing(output_to_operand_aliasing_);
   return std::move(cloned);
 }
 
diff --git a/tensorflow/compiler/xla/service/hlo_instructions.h b/tensorflow/compiler/xla/service/hlo_instructions.h
index 821849b..88e8743 100644
--- a/tensorflow/compiler/xla/service/hlo_instructions.h
+++ b/tensorflow/compiler/xla/service/hlo_instructions.h
@@ -1430,6 +1430,20 @@
     CHECK(layout_constrained());
     return operand_shapes_with_layout_;
   }
+  // Gets a list of output/operand buffer pairs that alias each other, where the
+  // output buffer is represented as a ShapeIndex, and the operand buffer is
+  // represented as the operand index and the ShapeIndex. By default this list
+  // is empty.
+  const std::vector<std::pair<ShapeIndex, std::pair<int64, ShapeIndex>>>&
+  output_to_operand_aliasing() const {
+    return output_to_operand_aliasing_;
+  }
+  // Sets the list of output/operand buffer pairs that alias each other.
+  void set_output_to_operand_aliasing(
+      std::vector<std::pair<ShapeIndex, std::pair<int64, ShapeIndex>>>
+          aliasing) {
+    output_to_operand_aliasing_ = std::move(aliasing);
+  }
 
  private:
   std::vector<string> ExtraAttributesToStringImpl(
@@ -1458,6 +1472,10 @@
   std::vector<Shape> operand_shapes_with_layout_;
   // Whether this custom call has a side-effect.
   bool custom_call_has_side_effect_;
+  // A list of output/operand buffer pairs that alias each other. See comment of
+  // output_to_operand_aliasing().
+  std::vector<std::pair<ShapeIndex, std::pair<int64, ShapeIndex>>>
+      output_to_operand_aliasing_;
 };
 
 class HloPadInstruction : public HloInstruction {
diff --git a/tensorflow/compiler/xla/service/hlo_parser.cc b/tensorflow/compiler/xla/service/hlo_parser.cc
index e2bbda3..37bdeaa 100644
--- a/tensorflow/compiler/xla/service/hlo_parser.cc
+++ b/tensorflow/compiler/xla/service/hlo_parser.cc
@@ -212,6 +212,7 @@
     kEnum,
     kRandomAlgorithm,
     kAliasing,
+    kInstructionAliasing,
   };
 
   struct AttrConfig {
@@ -346,6 +347,12 @@
   // fails.
   bool ParseAliasing(AliasingData* data);
 
+  // Parses the per-instruction aliasing information from string `s`, returns
+  // `false` if it fails.
+  bool ParseInstructionOutputOperandAliasing(
+      std::vector<std::pair<ShapeIndex, std::pair<int64, ShapeIndex>>>*
+          aliasing_output_operand_pairs);
+
   bool ParseShapeIndex(ShapeIndex* out);
 
   // Returns true if the current token is the beginning of a shape.
@@ -598,6 +605,58 @@
   return true;
 }
 
+bool HloParserImpl::ParseInstructionOutputOperandAliasing(
+    std::vector<std::pair<ShapeIndex, std::pair<int64, ShapeIndex>>>*
+        aliasing_output_operand_pairs) {
+  if (!ParseToken(
+          TokKind::kLbrace,
+          "Expects '{' at the start of instruction aliasing description")) {
+    return false;
+  }
+
+  while (lexer_.GetKind() != TokKind::kRbrace) {
+    ShapeIndex out;
+    if (!ParseShapeIndex(&out)) {
+      return false;
+    }
+    std::string errmsg =
+        "Expected format: <output_shape_index>: (<operand_index>, "
+        "<operand_shape_index>)";
+    if (!ParseToken(TokKind::kColon, errmsg)) {
+      return false;
+    }
+
+    if (!ParseToken(TokKind::kLparen, errmsg)) {
+      return false;
+    }
+    int64 operand_index;
+    ParseInt64(&operand_index);
+    if (!ParseToken(TokKind::kComma, errmsg)) {
+      return false;
+    }
+    ShapeIndex operand_shape_index;
+    if (!ParseShapeIndex(&operand_shape_index)) {
+      return false;
+    }
+
+    aliasing_output_operand_pairs->emplace_back(
+        out, std::pair<int64, ShapeIndex>{operand_index, operand_shape_index});
+    if (!ParseToken(TokKind::kRparen, errmsg)) {
+      return false;
+    }
+
+    if (!EatIfPresent(TokKind::kComma)) {
+      break;
+    }
+  }
+  if (!ParseToken(
+          TokKind::kRbrace,
+          "Expects '}' at the end of instruction aliasing description")) {
+    return false;
+  }
+  return true;
+}
+
 // ::= 'HloModule' name computations
 bool HloParserImpl::ParseHloModule(HloModule* module) {
   if (lexer_.GetKind() != TokKind::kw_HloModule) {
@@ -1777,6 +1836,8 @@
       optional<std::vector<Shape>> operand_layout_constraints;
       optional<bool> custom_call_has_side_effect;
       optional<HloComputation*> to_apply;
+      optional<std::vector<std::pair<ShapeIndex, std::pair<int64, ShapeIndex>>>>
+          output_to_operand_aliasing;
       attrs["custom_call_target"] = {/*required=*/true, AttrTy::kString,
                                      &custom_call_target};
       attrs["window"] = {/*required=*/false, AttrTy::kWindow, &window};
@@ -1792,6 +1853,9 @@
                                               &custom_call_has_side_effect};
       attrs["to_apply"] = {/*required=*/false, AttrTy::kHloComputation,
                            &to_apply};
+      attrs["output_to_operand_aliasing"] = {/*required=*/false,
+                                             AttrTy::kInstructionAliasing,
+                                             &output_to_operand_aliasing};
       if (!ParseOperands(&operands) || !ParseAttributes(attrs)) {
         return false;
       }
@@ -1861,6 +1925,10 @@
         custom_call_instr->set_custom_call_has_side_effect(
             *custom_call_has_side_effect);
       }
+      if (output_to_operand_aliasing.has_value()) {
+        custom_call_instr->set_output_to_operand_aliasing(
+            std::move(*output_to_operand_aliasing));
+      }
       break;
     }
     case HloOpcode::kDot: {
@@ -3223,6 +3291,19 @@
             ->emplace(aliasing_data);
         return true;
       }
+      case AttrTy::kInstructionAliasing: {
+        std::vector<std::pair<ShapeIndex, std::pair<int64, ShapeIndex>>>
+            aliasing_output_operand_pairs;
+        if (!ParseInstructionOutputOperandAliasing(
+                &aliasing_output_operand_pairs)) {
+          return false;
+        }
+        static_cast<optional<
+            std::vector<std::pair<ShapeIndex, std::pair<int64, ShapeIndex>>>>*>(
+            attr_out_ptr)
+            ->emplace(std::move(aliasing_output_operand_pairs));
+        return true;
+      }
     }
   }();
   if (!success) {
diff --git a/tensorflow/compiler/xla/service/hlo_parser_test.cc b/tensorflow/compiler/xla/service/hlo_parser_test.cc
index 620e67c..d220d73 100644
--- a/tensorflow/compiler/xla/service/hlo_parser_test.cc
+++ b/tensorflow/compiler/xla/service/hlo_parser_test.cc
@@ -993,6 +993,19 @@
 
 )"
 },
+// CustomCallWithAliasing
+{
+"CustomCallWithAliasing",
+R"(HloModule CustomCallWithAliasing
+
+ENTRY %CustomCallWithAliasing (p0: (f32[2,2], f32[42,2,3]), p1: f32[123,4]) -> (f32[123,4], f32[2,2], f32[1,2,3]) {
+  %p0 = (f32[2,2]{0,1}, f32[42,2,3]{0,1,2}) parameter(0)
+  %p1 = f32[123,4]{0,1} parameter(1)
+  ROOT %custom-call = (f32[123,4]{0,1}, f32[2,2]{0,1}, f32[1,2,3]{0,1,2}) custom-call((f32[2,2]{0,1}, f32[42,2,3]{0,1,2}) %p0, f32[123,4]{0,1} %p1), custom_call_target="baz", output_to_operand_aliasing={{0}: (1, {}), {1}: (0, {0})}
+}
+
+)"
+},
 // Parse c64 literal
 {
 "ParseC64Literal",
diff --git a/tensorflow/compiler/xla/service/hlo_verifier.cc b/tensorflow/compiler/xla/service/hlo_verifier.cc
index b3603e4..4be0c52 100644
--- a/tensorflow/compiler/xla/service/hlo_verifier.cc
+++ b/tensorflow/compiler/xla/service/hlo_verifier.cc
@@ -801,6 +801,28 @@
       TF_RET_CHECK(LayoutUtil::HasLayout(operand_shape_with_layout));
     }
   }
+  for (const auto& pair : custom_call->output_to_operand_aliasing()) {
+    TF_RET_CHECK(pair.second.first < custom_call->operand_count())
+        << "Invalid aliasing operand index.";
+    TF_RET_CHECK(ShapeUtil::IndexIsValid(
+        custom_call->operand(pair.second.first)->shape(), pair.second.second))
+        << "Invalid aliasing operand shape index.";
+    TF_RET_CHECK(ShapeUtil::IndexIsValid(custom_call->shape(), pair.first))
+        << "Invalid aliasing output shape index.";
+    const Shape& output_subshape =
+        ShapeUtil::GetSubshape(custom_call->shape(), pair.first);
+    const Shape& operand_subshape = ShapeUtil::GetSubshape(
+        custom_call->operand(pair.second.first)->shape(), pair.second.second);
+    if (layout_sensitive_) {
+      TF_RET_CHECK(operand_subshape == output_subshape)
+          << "Different aliasing shapes: " << operand_subshape.ToString()
+          << " vs " << output_subshape.ToString();
+    } else {
+      TF_RET_CHECK(ShapeUtil::Compatible(output_subshape, operand_subshape))
+          << "Different aliasing shapes: " << operand_subshape.ToString()
+          << " vs " << output_subshape.ToString();
+    }
+  }
   return Status::OK();
 }
 
diff --git a/tensorflow/compiler/xla/service/memory_space_assignment_utils.cc b/tensorflow/compiler/xla/service/memory_space_assignment_utils.cc
index 0c44ae0..aad943aa 100644
--- a/tensorflow/compiler/xla/service/memory_space_assignment_utils.cc
+++ b/tensorflow/compiler/xla/service/memory_space_assignment_utils.cc
@@ -15,6 +15,9 @@
 
 #include "tensorflow/compiler/xla/service/memory_space_assignment_utils.h"
 
+#include "tensorflow/compiler/xla/service/hlo_casting_utils.h"
+#include "tensorflow/compiler/xla/service/hlo_instructions.h"
+
 namespace xla {
 
 bool MemorySpaceAssignmentUtils::IsValueAllowedInAlternateMemory(
@@ -87,6 +90,17 @@
         return false;
       }
     }
+    if (auto* custom_call =
+            DynCast<HloCustomCallInstruction>(position.instruction)) {
+      for (const auto& pair : custom_call->output_to_operand_aliasing()) {
+        if (position.index == pair.first) {
+          VLOG(4) << "Keeping value " << value->ToShortString()
+                  << " in default mem because it is a custom-call output that "
+                     "aliases an operand buffer.";
+          return false;
+        }
+      }
+    }
   }
 
   return true;
diff --git a/tensorflow/compiler/xla/xla_data.proto b/tensorflow/compiler/xla/xla_data.proto
index d334f87..2d311dd 100644
--- a/tensorflow/compiler/xla/xla_data.proto
+++ b/tensorflow/compiler/xla/xla_data.proto
@@ -691,3 +691,11 @@
   // unknown-trip-count.
   KnownTripCount known_trip_count = 1;
 }
+
+// Specifies a pair of output/operand buffers for kCustomCall that alias each
+// other.
+message CustomCallOutputOperandAliasing {
+  repeated int64 output_shape_index = 1;
+  int64 operand_index = 2;
+  repeated int64 operand_shape_index = 3;
+}