[XLA] Allow CustomCall to specify aliasing buffers
PiperOrigin-RevId: 333445659
Change-Id: I444cc5f33793de365ad66cfd828f6e4eeb489065
diff --git a/tensorflow/compiler/mlir/xla/ir/mlir_hlo_builder.cc b/tensorflow/compiler/mlir/xla/ir/mlir_hlo_builder.cc
index ac5e01a..3da8935 100644
--- a/tensorflow/compiler/mlir/xla/ir/mlir_hlo_builder.cc
+++ b/tensorflow/compiler/mlir/xla/ir/mlir_hlo_builder.cc
@@ -135,12 +135,16 @@
const string& call_target_name, absl::Span<const XlaOp> operands,
const Shape& shape, const string& opaque,
absl::optional<absl::Span<const Shape>> operand_shapes_with_layout,
- bool has_side_effect) {
+ bool has_side_effect,
+ absl::Span<const std::pair<ShapeIndex, std::pair<int64, ShapeIndex>>>
+ output_operand_aliasing) {
if (operand_shapes_with_layout.has_value())
return Unimplemented(
"CustomCall doesn't support operands shapes with layout");
TF_ASSIGN_OR_RETURN(mlir::Type ty, ConvertShapeToType<mlir::RankedTensorType>(
shape, builder_));
+ TF_RET_CHECK(output_operand_aliasing.empty())
+ << "MLIR CustomCallOp does not support output_operand_aliasing yet";
auto op = builder_.create<mlir::mhlo::CustomCallOp>(
loc_, ty, GetValues(operands), builder_.getStringAttr(call_target_name),
/*has_side_effect=*/builder_.getBoolAttr(has_side_effect),
diff --git a/tensorflow/compiler/mlir/xla/ir/mlir_hlo_builder.h b/tensorflow/compiler/mlir/xla/ir/mlir_hlo_builder.h
index 00b7aa4..59b4bc7 100644
--- a/tensorflow/compiler/mlir/xla/ir/mlir_hlo_builder.h
+++ b/tensorflow/compiler/mlir/xla/ir/mlir_hlo_builder.h
@@ -135,7 +135,9 @@
const string& call_target_name, absl::Span<const XlaOp> operands,
const Shape& shape, const string& opaque,
absl::optional<absl::Span<const Shape>> operand_shapes_with_layout,
- bool has_side_effect) override;
+ bool has_side_effect,
+ absl::Span<const std::pair<ShapeIndex, std::pair<int64, ShapeIndex>>>
+ output_operand_aliasing) override;
StatusOr<XlaOp> ReduceInternal(
const Shape& shape, absl::Span<const XlaOp> all_operands,
diff --git a/tensorflow/compiler/xla/client/xla_builder.cc b/tensorflow/compiler/xla/client/xla_builder.cc
index 3e2a4eb..c7bbf9f 100644
--- a/tensorflow/compiler/xla/client/xla_builder.cc
+++ b/tensorflow/compiler/xla/client/xla_builder.cc
@@ -1707,7 +1707,9 @@
const string& call_target_name, absl::Span<const XlaOp> operands,
const Shape& shape, const string& opaque,
absl::optional<absl::Span<const Shape>> operand_shapes_with_layout,
- bool has_side_effect) {
+ bool has_side_effect,
+ absl::Span<const std::pair<ShapeIndex, std::pair<int64, ShapeIndex>>>
+ output_operand_aliasing) {
return ReportErrorOrReturn([&]() -> StatusOr<XlaOp> {
if (absl::StartsWith(call_target_name, "$")) {
return InvalidArgument(
@@ -1739,7 +1741,8 @@
}
}
return CustomCallInternal(call_target_name, operands, shape, opaque,
- operand_shapes_with_layout, has_side_effect);
+ operand_shapes_with_layout, has_side_effect,
+ output_operand_aliasing);
});
}
@@ -1747,7 +1750,9 @@
const string& call_target_name, absl::Span<const XlaOp> operands,
const Shape& shape, const string& opaque,
absl::optional<absl::Span<const Shape>> operand_shapes_with_layout,
- bool has_side_effect) {
+ bool has_side_effect,
+ absl::Span<const std::pair<ShapeIndex, std::pair<int64, ShapeIndex>>>
+ output_operand_aliasing) {
HloInstructionProto instr;
*instr.mutable_shape() = shape.ToProto();
instr.set_custom_call_target(call_target_name);
@@ -1759,6 +1764,16 @@
}
}
instr.set_custom_call_has_side_effect(has_side_effect);
+ for (const auto& pair : output_operand_aliasing) {
+ auto aliasing = instr.add_custom_call_output_operand_aliasing();
+ aliasing->set_operand_index(pair.second.first);
+ for (int64 index : pair.second.second) {
+ aliasing->add_operand_shape_index(index);
+ }
+ for (int64 index : pair.first) {
+ aliasing->add_output_shape_index(index);
+ }
+ }
return AddInstruction(std::move(instr), HloOpcode::kCustomCall, operands);
}
@@ -1766,7 +1781,9 @@
const string& call_target_name, absl::Span<const XlaOp> operands,
const XlaComputation& computation, const Shape& shape, const string& opaque,
absl::optional<absl::Span<const Shape>> operand_shapes_with_layout,
- bool has_side_effect) {
+ bool has_side_effect,
+ absl::Span<const std::pair<ShapeIndex, std::pair<int64, ShapeIndex>>>
+ output_operand_aliasing) {
return ReportErrorOrReturn([&]() -> StatusOr<XlaOp> {
HloInstructionProto instr;
if (absl::StartsWith(call_target_name, "$")) {
@@ -1804,6 +1821,16 @@
}
}
AddCalledComputation(computation, &instr);
+ for (const auto& pair : output_operand_aliasing) {
+ auto aliasing = instr.add_custom_call_output_operand_aliasing();
+ aliasing->set_operand_index(pair.second.first);
+ for (int64 index : pair.second.second) {
+ aliasing->add_operand_shape_index(index);
+ }
+ for (int64 index : pair.first) {
+ aliasing->add_output_shape_index(index);
+ }
+ }
return AddInstruction(std::move(instr), HloOpcode::kCustomCall, operands);
});
}
@@ -3861,31 +3888,39 @@
return builder->Call(computation, operands);
}
-XlaOp CustomCall(XlaBuilder* builder, const string& call_target_name,
- absl::Span<const XlaOp> operands, const Shape& shape,
- const string& opaque, bool has_side_effect) {
+XlaOp CustomCall(
+ XlaBuilder* builder, const string& call_target_name,
+ absl::Span<const XlaOp> operands, const Shape& shape, const string& opaque,
+ bool has_side_effect,
+ absl::Span<const std::pair<ShapeIndex, std::pair<int64, ShapeIndex>>>
+ output_operand_aliasing) {
return builder->CustomCall(call_target_name, operands, shape, opaque,
/*operand_shapes_with_layout=*/absl::nullopt,
- has_side_effect);
+ has_side_effect, output_operand_aliasing);
}
-XlaOp CustomCallWithComputation(XlaBuilder* builder,
- const string& call_target_name,
- absl::Span<const XlaOp> operands,
- const XlaComputation& computation,
- const Shape& shape, const string& opaque,
- bool has_side_effect) {
- return builder->CustomCall(
- call_target_name, operands, computation, shape, opaque,
- /*operand_shapes_with_layout=*/absl::nullopt, has_side_effect);
+XlaOp CustomCallWithComputation(
+ XlaBuilder* builder, const string& call_target_name,
+ absl::Span<const XlaOp> operands, const XlaComputation& computation,
+ const Shape& shape, const string& opaque, bool has_side_effect,
+ absl::Span<const std::pair<ShapeIndex, std::pair<int64, ShapeIndex>>>
+ output_operand_aliasing) {
+ return builder->CustomCall(call_target_name, operands, computation, shape,
+ opaque,
+ /*operand_shapes_with_layout=*/absl::nullopt,
+ has_side_effect, output_operand_aliasing);
}
-XlaOp CustomCallWithLayout(XlaBuilder* builder, const string& call_target_name,
- absl::Span<const XlaOp> operands, const Shape& shape,
- absl::Span<const Shape> operand_shapes_with_layout,
- const string& opaque, bool has_side_effect) {
+XlaOp CustomCallWithLayout(
+ XlaBuilder* builder, const string& call_target_name,
+ absl::Span<const XlaOp> operands, const Shape& shape,
+ absl::Span<const Shape> operand_shapes_with_layout, const string& opaque,
+ bool has_side_effect,
+ absl::Span<const std::pair<ShapeIndex, std::pair<int64, ShapeIndex>>>
+ output_operand_aliasing) {
return builder->CustomCall(call_target_name, operands, shape, opaque,
- operand_shapes_with_layout, has_side_effect);
+ operand_shapes_with_layout, has_side_effect,
+ output_operand_aliasing);
}
XlaOp Complex(const XlaOp lhs, const XlaOp rhs,
diff --git a/tensorflow/compiler/xla/client/xla_builder.h b/tensorflow/compiler/xla/client/xla_builder.h
index cd9809c..55bcd86 100644
--- a/tensorflow/compiler/xla/client/xla_builder.h
+++ b/tensorflow/compiler/xla/client/xla_builder.h
@@ -593,7 +593,9 @@
const string& call_target_name, absl::Span<const XlaOp> operands,
const Shape& shape_with_layout, const string& opaque,
absl::optional<absl::Span<const Shape>> operand_shapes_with_layout,
- bool has_side_effect);
+ bool has_side_effect,
+ absl::Span<const std::pair<ShapeIndex, std::pair<int64, ShapeIndex>>>
+ output_operand_aliasing);
// Internal version of CustomCall without computation that doesn't do op
// specific error handling and expects arguments to be legal. CustomCall
@@ -602,14 +604,18 @@
const string& call_target_name, absl::Span<const XlaOp> operands,
const Shape& shape_with_layout, const string& opaque,
absl::optional<absl::Span<const Shape>> operand_shapes_with_layout,
- bool has_side_effect);
+ bool has_side_effect,
+ absl::Span<const std::pair<ShapeIndex, std::pair<int64, ShapeIndex>>>
+ output_operand_aliasing);
XlaOp CustomCall(
const string& call_target_name, absl::Span<const XlaOp> operands,
const XlaComputation& computation, const Shape& shape_with_layout,
const string& opaque,
absl::optional<absl::Span<const Shape>> operand_shapes_with_layout,
- bool has_side_effect);
+ bool has_side_effect,
+ absl::Span<const std::pair<ShapeIndex, std::pair<int64, ShapeIndex>>>
+ output_operand_aliasing);
XlaOp Reduce(XlaOp operand, XlaOp init_value,
const XlaComputation& computation,
@@ -1058,18 +1064,25 @@
const string& outfeed_config);
friend XlaOp Call(XlaBuilder* builder, const XlaComputation& computation,
absl::Span<const XlaOp> operands);
- friend XlaOp CustomCall(XlaBuilder* builder, const string& call_target_name,
- absl::Span<const XlaOp> operands, const Shape& shape,
- const string& opaque, bool has_side_effect);
+ friend XlaOp CustomCall(
+ XlaBuilder* builder, const string& call_target_name,
+ absl::Span<const XlaOp> operands, const Shape& shape,
+ const string& opaque, bool has_side_effect,
+ absl::Span<const std::pair<ShapeIndex, std::pair<int64, ShapeIndex>>>
+ output_operand_aliasing);
friend XlaOp CustomCallWithComputation(
XlaBuilder* builder, const string& call_target_name,
absl::Span<const XlaOp> operands, const XlaComputation& computation,
- const Shape& shape, const string& opaque, bool has_side_effect);
+ const Shape& shape, const string& opaque, bool has_side_effect,
+ absl::Span<const std::pair<ShapeIndex, std::pair<int64, ShapeIndex>>>
+ output_operand_aliasing);
friend XlaOp CustomCallWithLayout(
XlaBuilder* builder, const string& call_target_name,
absl::Span<const XlaOp> operands, const Shape& shape_with_layout,
absl::Span<const Shape> operand_shapes_with_layout, const string& opaque,
- bool has_side_effect);
+ bool has_side_effect,
+ absl::Span<const std::pair<ShapeIndex, std::pair<int64, ShapeIndex>>>
+ output_operand_aliasing);
friend XlaOp Complex(XlaOp real, XlaOp imag,
absl::Span<const int64> broadcast_dimensions);
friend XlaOp Conj(XlaOp operand);
@@ -1805,30 +1818,39 @@
// backend, a call instruction is emitted which targets a symbol with the name
// |call_target_name|. |call_target_name| and |opaque| can arbitrary strings,
// but |call_target_name| should be short as it may be used in labels. |opaque|
-// can encode arbitrarily large amounts of information.
-XlaOp CustomCall(XlaBuilder* builder, const string& call_target_name,
- absl::Span<const XlaOp> operands, const Shape& shape,
- const string& opaque = "", bool has_side_effect = false);
+// can encode arbitrarily large amounts of information. |has_side_effect|
+// specifies whether the instruction can have side effects.
+// |output_operand_aliasing| specifies a list of output/operand buffer pairs
+// that alias each other, where the output buffer is represented as a
+// ShapeIndex, and the operand buffer is represented as the operand index and
+// the ShapeIndex.
+XlaOp CustomCall(
+ XlaBuilder* builder, const string& call_target_name,
+ absl::Span<const XlaOp> operands, const Shape& shape,
+ const string& opaque = "", bool has_side_effect = false,
+ absl::Span<const std::pair<ShapeIndex, std::pair<int64, ShapeIndex>>>
+ output_operand_aliasing = {});
// Overload which constructs a custom call that applies an Xla computation.
-XlaOp CustomCallWithComputation(XlaBuilder* builder,
- const string& call_target_name,
- absl::Span<const XlaOp> operands,
- const XlaComputation& computation,
- const Shape& shape, const string& opaque = "",
- bool has_side_effect = false);
+XlaOp CustomCallWithComputation(
+ XlaBuilder* builder, const string& call_target_name,
+ absl::Span<const XlaOp> operands, const XlaComputation& computation,
+ const Shape& shape, const string& opaque = "", bool has_side_effect = false,
+ absl::Span<const std::pair<ShapeIndex, std::pair<int64, ShapeIndex>>>
+ output_operand_aliasing = {});
// Overload which constructs a custom call with fixed layouts. The operands will
// have the layouts specified by |operand_shapes_with_layout| when provided to
// external code, and the external code is expected to produce a result with the
// layout specified by |shape_with_layout|. All shapes in |shape_with_layout|
// and |operand_shapes_with_layout| must have layouts.
-XlaOp CustomCallWithLayout(XlaBuilder* builder, const string& call_target_name,
- absl::Span<const XlaOp> operands,
- const Shape& shape_with_layout,
- absl::Span<const Shape> operand_shapes_with_layout,
- const string& opaque = "",
- bool has_side_effect = false);
+XlaOp CustomCallWithLayout(
+ XlaBuilder* builder, const string& call_target_name,
+ absl::Span<const XlaOp> operands, const Shape& shape_with_layout,
+ absl::Span<const Shape> operand_shapes_with_layout,
+ const string& opaque = "", bool has_side_effect = false,
+ absl::Span<const std::pair<ShapeIndex, std::pair<int64, ShapeIndex>>>
+ output_operand_aliasing = {});
// The following methods enqueue element-wise binary arithmetic operations
// onto the computation. The shapes of the operands have to match unless one
diff --git a/tensorflow/compiler/xla/service/BUILD b/tensorflow/compiler/xla/service/BUILD
index 170f774..b9aff29 100644
--- a/tensorflow/compiler/xla/service/BUILD
+++ b/tensorflow/compiler/xla/service/BUILD
@@ -3492,6 +3492,8 @@
hdrs = ["memory_space_assignment_utils.h"],
deps = [
":heap_simulator",
+ ":hlo",
+ ":hlo_casting_utils",
],
)
diff --git a/tensorflow/compiler/xla/service/hlo.proto b/tensorflow/compiler/xla/service/hlo.proto
index c3a7b3a..ac94b2e 100644
--- a/tensorflow/compiler/xla/service/hlo.proto
+++ b/tensorflow/compiler/xla/service/hlo.proto
@@ -35,7 +35,7 @@
option cc_enable_arenas = true;
// Serialization of HloInstruction.
-// Next ID: 74
+// Next ID: 75
message HloInstructionProto {
reserved 10;
reserved "parameter_name";
@@ -232,6 +232,11 @@
// kCustomCall.
bool custom_call_has_side_effect = 65;
+ // A list of CustomCallOutputOperandAliasing pairs that specifies aliasing
+ // buffers between output and operands for kCustomCall.
+ repeated xla.CustomCallOutputOperandAliasing
+ custom_call_output_operand_aliasing = 74;
+
// The delta value for kRngGetAndUpdateState.
int64 delta = 66;
diff --git a/tensorflow/compiler/xla/service/hlo_dataflow_analysis.cc b/tensorflow/compiler/xla/service/hlo_dataflow_analysis.cc
index 72899ff..bc1063f 100644
--- a/tensorflow/compiler/xla/service/hlo_dataflow_analysis.cc
+++ b/tensorflow/compiler/xla/service/hlo_dataflow_analysis.cc
@@ -432,6 +432,23 @@
return changed;
}
+bool HloDataflowAnalysis::UpdateCustomCallValueSet(
+ HloInstruction* custom_call) {
+ CHECK_EQ(custom_call->opcode(), HloOpcode::kCustomCall);
+ bool changed = false;
+ for (const auto& aliasing : Cast<HloCustomCallInstruction>(custom_call)
+ ->output_to_operand_aliasing()) {
+ const HloValueSet& operand_value_set = GetValueSet(
+ custom_call->operand(aliasing.second.first), aliasing.second.second);
+ HloValueSet& value_set = GetValueSet(custom_call, aliasing.first);
+ if (value_set != operand_value_set) {
+ value_set = operand_value_set;
+ changed = true;
+ }
+ }
+ return changed;
+}
+
bool HloDataflowAnalysis::UpdateCopyStartValueSet(HloInstruction* copy_start) {
CHECK_EQ(copy_start->opcode(), HloOpcode::kCopyStart);
bool changed = false;
@@ -757,6 +774,8 @@
return UpdateAddDependencyValueSet(instruction);
case HloOpcode::kBitcast:
return UpdateBitcastValueSet(instruction);
+ case HloOpcode::kCustomCall:
+ return UpdateCustomCallValueSet(instruction);
case HloOpcode::kSetDimensionSize:
return UpdateSetDimensionSizeValueSet(instruction);
case HloOpcode::kDomain:
@@ -1018,6 +1037,22 @@
define_value_at(/*index=*/{1});
define_value_at(/*index=*/{2});
break;
+ case HloOpcode::kCustomCall: {
+ absl::flat_hash_set<ShapeIndex> aliasing_indices;
+ for (const auto& aliasing :
+ Cast<HloCustomCallInstruction>(instruction)
+ ->output_to_operand_aliasing()) {
+ aliasing_indices.insert(aliasing.first);
+ }
+ ShapeUtil::ForEachSubshape(
+ instruction->shape(),
+ [&](const Shape& /*subshape*/, const ShapeIndex& index) {
+ if (!aliasing_indices.contains(index)) {
+ define_value_at(index);
+ }
+ });
+ break;
+ }
default:
define_all_values();
break;
diff --git a/tensorflow/compiler/xla/service/hlo_dataflow_analysis.h b/tensorflow/compiler/xla/service/hlo_dataflow_analysis.h
index ffa307d..c3aad04 100644
--- a/tensorflow/compiler/xla/service/hlo_dataflow_analysis.h
+++ b/tensorflow/compiler/xla/service/hlo_dataflow_analysis.h
@@ -216,6 +216,7 @@
bool UpdateCallValueSet(HloInstruction* call);
bool UpdateConditionalValueSet(HloInstruction* conditional);
bool UpdateCopyValueSet(HloInstruction* copy);
+ bool UpdateCustomCallValueSet(HloInstruction* custom_call);
bool UpdateDomainValueSet(HloInstruction* domain);
bool UpdateGetTupleElementValueSet(HloInstruction* gte);
bool UpdateParameterValueSet(HloInstruction* parameter);
diff --git a/tensorflow/compiler/xla/service/hlo_instruction.cc b/tensorflow/compiler/xla/service/hlo_instruction.cc
index 251261a..41488dc 100644
--- a/tensorflow/compiler/xla/service/hlo_instruction.cc
+++ b/tensorflow/compiler/xla/service/hlo_instruction.cc
@@ -568,6 +568,19 @@
std::max(static_cast<int64>(proto.batch_group_count()), int64{1}));
custom_call_instr->set_custom_call_has_side_effect(
proto.custom_call_has_side_effect());
+ std::vector<std::pair<ShapeIndex, std::pair<int64, ShapeIndex>>>
+ output_to_operand_aliasing;
+ for (const auto& aliasing : proto.custom_call_output_operand_aliasing()) {
+ output_to_operand_aliasing.emplace_back(
+ ShapeIndex(aliasing.output_shape_index().begin(),
+ aliasing.output_shape_index().end()),
+ std::pair<int64, ShapeIndex>{
+ aliasing.operand_index(),
+ ShapeIndex(aliasing.operand_shape_index().begin(),
+ aliasing.operand_shape_index().end())});
+ }
+ custom_call_instr->set_output_to_operand_aliasing(
+ std::move(output_to_operand_aliasing));
break;
}
case HloOpcode::kPad:
diff --git a/tensorflow/compiler/xla/service/hlo_instructions.cc b/tensorflow/compiler/xla/service/hlo_instructions.cc
index c4c31db..45b2d88 100644
--- a/tensorflow/compiler/xla/service/hlo_instructions.cc
+++ b/tensorflow/compiler/xla/service/hlo_instructions.cc
@@ -2395,6 +2395,16 @@
}
}
proto.set_custom_call_has_side_effect(custom_call_has_side_effect_);
+ for (const auto& pair : output_to_operand_aliasing_) {
+ auto aliasing = proto.add_custom_call_output_operand_aliasing();
+ aliasing->set_operand_index(pair.second.first);
+ for (int64 index : pair.first) {
+ aliasing->add_output_shape_index(index);
+ }
+ for (int64 index : pair.second.second) {
+ aliasing->add_operand_shape_index(index);
+ }
+ }
return proto;
}
@@ -2432,6 +2442,16 @@
if (custom_call_has_side_effect_) {
extra.push_back("custom_call_has_side_effect=true");
}
+ if (!output_to_operand_aliasing_.empty()) {
+ std::vector<string> pair_strings;
+ for (const auto& pair : output_to_operand_aliasing_) {
+ pair_strings.push_back(StrCat(pair.first.ToString(), ": (",
+ pair.second.first, ", ",
+ pair.second.second.ToString(), ")"));
+ }
+ extra.push_back(StrCat("output_to_operand_aliasing={",
+ StrJoin(pair_strings, ", "), "}"));
+ }
return extra;
}
@@ -2475,6 +2495,10 @@
casted_other.custom_call_has_side_effect()) {
return false;
}
+ if (output_to_operand_aliasing_ !=
+ casted_other.output_to_operand_aliasing()) {
+ return false;
+ }
// Note: backend_config comparison is done in Identical, which is the
// intended/exposed way to compare computations, and so not repeated here.
return custom_call_target_ == casted_other.custom_call_target_;
@@ -2499,6 +2523,7 @@
cloned->set_feature_group_count(feature_group_count_);
cloned->set_batch_group_count(batch_group_count_);
cloned->set_custom_call_has_side_effect(custom_call_has_side_effect_);
+ cloned->set_output_to_operand_aliasing(output_to_operand_aliasing_);
return std::move(cloned);
}
diff --git a/tensorflow/compiler/xla/service/hlo_instructions.h b/tensorflow/compiler/xla/service/hlo_instructions.h
index 821849b..88e8743 100644
--- a/tensorflow/compiler/xla/service/hlo_instructions.h
+++ b/tensorflow/compiler/xla/service/hlo_instructions.h
@@ -1430,6 +1430,20 @@
CHECK(layout_constrained());
return operand_shapes_with_layout_;
}
+ // Gets a list of output/operand buffer pairs that alias each other, where the
+ // output buffer is represented as a ShapeIndex, and the operand buffer is
+ // represented as the operand index and the ShapeIndex. By default this list
+ // is empty.
+ const std::vector<std::pair<ShapeIndex, std::pair<int64, ShapeIndex>>>&
+ output_to_operand_aliasing() const {
+ return output_to_operand_aliasing_;
+ }
+ // Sets the list of output/operand buffer pairs that alias each other.
+ void set_output_to_operand_aliasing(
+ std::vector<std::pair<ShapeIndex, std::pair<int64, ShapeIndex>>>
+ aliasing) {
+ output_to_operand_aliasing_ = std::move(aliasing);
+ }
private:
std::vector<string> ExtraAttributesToStringImpl(
@@ -1458,6 +1472,10 @@
std::vector<Shape> operand_shapes_with_layout_;
// Whether this custom call has a side-effect.
bool custom_call_has_side_effect_;
+ // A list of output/operand buffer pairs that alias each other. See comment of
+ // output_to_operand_aliasing().
+ std::vector<std::pair<ShapeIndex, std::pair<int64, ShapeIndex>>>
+ output_to_operand_aliasing_;
};
class HloPadInstruction : public HloInstruction {
diff --git a/tensorflow/compiler/xla/service/hlo_parser.cc b/tensorflow/compiler/xla/service/hlo_parser.cc
index e2bbda3..37bdeaa 100644
--- a/tensorflow/compiler/xla/service/hlo_parser.cc
+++ b/tensorflow/compiler/xla/service/hlo_parser.cc
@@ -212,6 +212,7 @@
kEnum,
kRandomAlgorithm,
kAliasing,
+ kInstructionAliasing,
};
struct AttrConfig {
@@ -346,6 +347,12 @@
// fails.
bool ParseAliasing(AliasingData* data);
+ // Parses the per-instruction aliasing information from string `s`, returns
+ // `false` if it fails.
+ bool ParseInstructionOutputOperandAliasing(
+ std::vector<std::pair<ShapeIndex, std::pair<int64, ShapeIndex>>>*
+ aliasing_output_operand_pairs);
+
bool ParseShapeIndex(ShapeIndex* out);
// Returns true if the current token is the beginning of a shape.
@@ -598,6 +605,58 @@
return true;
}
+bool HloParserImpl::ParseInstructionOutputOperandAliasing(
+ std::vector<std::pair<ShapeIndex, std::pair<int64, ShapeIndex>>>*
+ aliasing_output_operand_pairs) {
+ if (!ParseToken(
+ TokKind::kLbrace,
+ "Expects '{' at the start of instruction aliasing description")) {
+ return false;
+ }
+
+ while (lexer_.GetKind() != TokKind::kRbrace) {
+ ShapeIndex out;
+ if (!ParseShapeIndex(&out)) {
+ return false;
+ }
+ std::string errmsg =
+ "Expected format: <output_shape_index>: (<operand_index>, "
+ "<operand_shape_index>)";
+ if (!ParseToken(TokKind::kColon, errmsg)) {
+ return false;
+ }
+
+ if (!ParseToken(TokKind::kLparen, errmsg)) {
+ return false;
+ }
+ int64 operand_index;
+ ParseInt64(&operand_index);
+ if (!ParseToken(TokKind::kComma, errmsg)) {
+ return false;
+ }
+ ShapeIndex operand_shape_index;
+ if (!ParseShapeIndex(&operand_shape_index)) {
+ return false;
+ }
+
+ aliasing_output_operand_pairs->emplace_back(
+ out, std::pair<int64, ShapeIndex>{operand_index, operand_shape_index});
+ if (!ParseToken(TokKind::kRparen, errmsg)) {
+ return false;
+ }
+
+ if (!EatIfPresent(TokKind::kComma)) {
+ break;
+ }
+ }
+ if (!ParseToken(
+ TokKind::kRbrace,
+ "Expects '}' at the end of instruction aliasing description")) {
+ return false;
+ }
+ return true;
+}
+
// ::= 'HloModule' name computations
bool HloParserImpl::ParseHloModule(HloModule* module) {
if (lexer_.GetKind() != TokKind::kw_HloModule) {
@@ -1777,6 +1836,8 @@
optional<std::vector<Shape>> operand_layout_constraints;
optional<bool> custom_call_has_side_effect;
optional<HloComputation*> to_apply;
+ optional<std::vector<std::pair<ShapeIndex, std::pair<int64, ShapeIndex>>>>
+ output_to_operand_aliasing;
attrs["custom_call_target"] = {/*required=*/true, AttrTy::kString,
&custom_call_target};
attrs["window"] = {/*required=*/false, AttrTy::kWindow, &window};
@@ -1792,6 +1853,9 @@
&custom_call_has_side_effect};
attrs["to_apply"] = {/*required=*/false, AttrTy::kHloComputation,
&to_apply};
+ attrs["output_to_operand_aliasing"] = {/*required=*/false,
+ AttrTy::kInstructionAliasing,
+ &output_to_operand_aliasing};
if (!ParseOperands(&operands) || !ParseAttributes(attrs)) {
return false;
}
@@ -1861,6 +1925,10 @@
custom_call_instr->set_custom_call_has_side_effect(
*custom_call_has_side_effect);
}
+ if (output_to_operand_aliasing.has_value()) {
+ custom_call_instr->set_output_to_operand_aliasing(
+ std::move(*output_to_operand_aliasing));
+ }
break;
}
case HloOpcode::kDot: {
@@ -3223,6 +3291,19 @@
->emplace(aliasing_data);
return true;
}
+ case AttrTy::kInstructionAliasing: {
+ std::vector<std::pair<ShapeIndex, std::pair<int64, ShapeIndex>>>
+ aliasing_output_operand_pairs;
+ if (!ParseInstructionOutputOperandAliasing(
+ &aliasing_output_operand_pairs)) {
+ return false;
+ }
+ static_cast<optional<
+ std::vector<std::pair<ShapeIndex, std::pair<int64, ShapeIndex>>>>*>(
+ attr_out_ptr)
+ ->emplace(std::move(aliasing_output_operand_pairs));
+ return true;
+ }
}
}();
if (!success) {
diff --git a/tensorflow/compiler/xla/service/hlo_parser_test.cc b/tensorflow/compiler/xla/service/hlo_parser_test.cc
index 620e67c..d220d73 100644
--- a/tensorflow/compiler/xla/service/hlo_parser_test.cc
+++ b/tensorflow/compiler/xla/service/hlo_parser_test.cc
@@ -993,6 +993,19 @@
)"
},
+// CustomCallWithAliasing
+{
+"CustomCallWithAliasing",
+R"(HloModule CustomCallWithAliasing
+
+ENTRY %CustomCallWithAliasing (p0: (f32[2,2], f32[42,2,3]), p1: f32[123,4]) -> (f32[123,4], f32[2,2], f32[1,2,3]) {
+ %p0 = (f32[2,2]{0,1}, f32[42,2,3]{0,1,2}) parameter(0)
+ %p1 = f32[123,4]{0,1} parameter(1)
+ ROOT %custom-call = (f32[123,4]{0,1}, f32[2,2]{0,1}, f32[1,2,3]{0,1,2}) custom-call((f32[2,2]{0,1}, f32[42,2,3]{0,1,2}) %p0, f32[123,4]{0,1} %p1), custom_call_target="baz", output_to_operand_aliasing={{0}: (1, {}), {1}: (0, {0})}
+}
+
+)"
+},
// Parse c64 literal
{
"ParseC64Literal",
diff --git a/tensorflow/compiler/xla/service/hlo_verifier.cc b/tensorflow/compiler/xla/service/hlo_verifier.cc
index b3603e4..4be0c52 100644
--- a/tensorflow/compiler/xla/service/hlo_verifier.cc
+++ b/tensorflow/compiler/xla/service/hlo_verifier.cc
@@ -801,6 +801,28 @@
TF_RET_CHECK(LayoutUtil::HasLayout(operand_shape_with_layout));
}
}
+ for (const auto& pair : custom_call->output_to_operand_aliasing()) {
+ TF_RET_CHECK(pair.second.first < custom_call->operand_count())
+ << "Invalid aliasing operand index.";
+ TF_RET_CHECK(ShapeUtil::IndexIsValid(
+ custom_call->operand(pair.second.first)->shape(), pair.second.second))
+ << "Invalid aliasing operand shape index.";
+ TF_RET_CHECK(ShapeUtil::IndexIsValid(custom_call->shape(), pair.first))
+ << "Invalid aliasing output shape index.";
+ const Shape& output_subshape =
+ ShapeUtil::GetSubshape(custom_call->shape(), pair.first);
+ const Shape& operand_subshape = ShapeUtil::GetSubshape(
+ custom_call->operand(pair.second.first)->shape(), pair.second.second);
+ if (layout_sensitive_) {
+ TF_RET_CHECK(operand_subshape == output_subshape)
+ << "Different aliasing shapes: " << operand_subshape.ToString()
+ << " vs " << output_subshape.ToString();
+ } else {
+ TF_RET_CHECK(ShapeUtil::Compatible(output_subshape, operand_subshape))
+ << "Different aliasing shapes: " << operand_subshape.ToString()
+ << " vs " << output_subshape.ToString();
+ }
+ }
return Status::OK();
}
diff --git a/tensorflow/compiler/xla/service/memory_space_assignment_utils.cc b/tensorflow/compiler/xla/service/memory_space_assignment_utils.cc
index 0c44ae0..aad943aa 100644
--- a/tensorflow/compiler/xla/service/memory_space_assignment_utils.cc
+++ b/tensorflow/compiler/xla/service/memory_space_assignment_utils.cc
@@ -15,6 +15,9 @@
#include "tensorflow/compiler/xla/service/memory_space_assignment_utils.h"
+#include "tensorflow/compiler/xla/service/hlo_casting_utils.h"
+#include "tensorflow/compiler/xla/service/hlo_instructions.h"
+
namespace xla {
bool MemorySpaceAssignmentUtils::IsValueAllowedInAlternateMemory(
@@ -87,6 +90,17 @@
return false;
}
}
+ if (auto* custom_call =
+ DynCast<HloCustomCallInstruction>(position.instruction)) {
+ for (const auto& pair : custom_call->output_to_operand_aliasing()) {
+ if (position.index == pair.first) {
+ VLOG(4) << "Keeping value " << value->ToShortString()
+ << " in default mem because it is a custom-call output that "
+ "aliases an operand buffer.";
+ return false;
+ }
+ }
+ }
}
return true;
diff --git a/tensorflow/compiler/xla/xla_data.proto b/tensorflow/compiler/xla/xla_data.proto
index d334f87..2d311dd 100644
--- a/tensorflow/compiler/xla/xla_data.proto
+++ b/tensorflow/compiler/xla/xla_data.proto
@@ -691,3 +691,11 @@
// unknown-trip-count.
KnownTripCount known_trip_count = 1;
}
+
+// Specifies a pair of output/operand buffers for kCustomCall that alias each
+// other.
+message CustomCallOutputOperandAliasing {
+ repeated int64 output_shape_index = 1;
+ int64 operand_index = 2;
+ repeated int64 operand_shape_index = 3;
+}