Addressed sanjoy's comments
diff --git a/tensorflow/compiler/tf2xla/frontend_attributes_util.cc b/tensorflow/compiler/tf2xla/frontend_attributes_util.cc
index 96e6187..1dd0d3a 100644
--- a/tensorflow/compiler/tf2xla/frontend_attributes_util.cc
+++ b/tensorflow/compiler/tf2xla/frontend_attributes_util.cc
@@ -24,15 +24,14 @@
} // namespace
xla::StatusOr<absl::optional<xla::FrontendAttributes>>
-GetFrontendAttributesFromNodeDef(const NodeDef& node_def) {
- if (!HasNodeAttr(node_def, kFrontendAttributesAttribute)) {
- return absl::optional<xla::FrontendAttributes>();
+GetFrontendAttributesFromNodeDef(const AttrSlice& attrs) {
+ auto attr = attrs.Find(kFrontendAttributesAttribute);
+ if (attr == nullptr) {
+ return xla::StatusOr<absl::optional<xla::FrontendAttributes>>(
+ absl::nullopt);
}
- string value;
xla::FrontendAttributes attributes;
- TF_RETURN_IF_ERROR(
- GetNodeAttr(node_def, kFrontendAttributesAttribute, &value));
- if (!attributes.ParseFromString(value)) {
+ if (!attributes.ParseFromString(attr->s())) {
return errors::InvalidArgument(
"Experimental _XlaFrontendAttributes attribute was not a valid encoded "
"xla::FrontendAttributes proto.");
diff --git a/tensorflow/compiler/tf2xla/frontend_attributes_util.h b/tensorflow/compiler/tf2xla/frontend_attributes_util.h
index fc9df12..2beaa2f 100644
--- a/tensorflow/compiler/tf2xla/frontend_attributes_util.h
+++ b/tensorflow/compiler/tf2xla/frontend_attributes_util.h
@@ -20,12 +20,12 @@
#include "absl/types/optional.h"
#include "tensorflow/compiler/xla/statusor.h"
#include "tensorflow/compiler/xla/xla_data.pb.h"
-#include "tensorflow/core/framework/node_def.pb.h"
+#include "tensorflow/core/framework/node_def_util.h"
namespace tensorflow {
xla::StatusOr<absl::optional<xla::FrontendAttributes>>
-GetFrontendAttributesFromNodeDef(const NodeDef& node_def);
+GetFrontendAttributesFromNodeDef(const AttrSlice& attrs);
} // namespace tensorflow
diff --git a/tensorflow/compiler/tf2xla/xla_compilation_device.cc b/tensorflow/compiler/tf2xla/xla_compilation_device.cc
index 86e3f99..35a2e63 100644
--- a/tensorflow/compiler/tf2xla/xla_compilation_device.cc
+++ b/tensorflow/compiler/tf2xla/xla_compilation_device.cc
@@ -100,7 +100,7 @@
sharding_parse_result.ValueOrDie();
auto frontend_attributes_result =
- GetFrontendAttributesFromNodeDef(op_kernel->def());
+ GetFrontendAttributesFromNodeDef(AttrSlice(op_kernel->def()));
OP_REQUIRES_OK(context, frontend_attributes_result.status());
absl::optional<xla::FrontendAttributes> frontend_attributes =
frontend_attributes_result.ValueOrDie();
diff --git a/tensorflow/compiler/xla/client/xla_builder.cc b/tensorflow/compiler/xla/client/xla_builder.cc
index 5e33984..b2d375b 100644
--- a/tensorflow/compiler/xla/client/xla_builder.cc
+++ b/tensorflow/compiler/xla/client/xla_builder.cc
@@ -289,11 +289,12 @@
return Status::OK();
}
-Status XlaBuilder::AddFrontendAttribute(const XlaOp& op, std::string attribute,
- std::string value) {
+Status XlaBuilder::SetInstructionFrontendAttribute(const XlaOp& op,
+ std::string attribute,
+ std::string value) {
TF_ASSIGN_OR_RETURN(auto instr_proto, LookUpMutableInstruction(op));
auto* frontend_attributes = instr_proto->mutable_frontend_attributes();
- (*frontend_attributes->mutable_map())[attribute] = value;
+ (*frontend_attributes->mutable_map())[attribute] = std::move(value);
return Status::OK();
}
diff --git a/tensorflow/compiler/xla/client/xla_builder.h b/tensorflow/compiler/xla/client/xla_builder.h
index cdb31c6..8c013da 100644
--- a/tensorflow/compiler/xla/client/xla_builder.h
+++ b/tensorflow/compiler/xla/client/xla_builder.h
@@ -158,14 +158,31 @@
// Sets an OpSharding that will be attached to all instructions until cleared.
void SetSharding(const OpSharding& sharding) { sharding_ = sharding; }
+ // Sets the FrontendAttributes that will be added to all instructions until
+ // cleared.
+ //
+ // FrontendAttributes are often applied to a serie of XLA HLO instructions.
+ // As a result they are set on the Computation Builder and all the
+ // instructions generated via the builder will have the same frontend
+ // attributes attached to them.
void SetFrontendAttributes(const FrontendAttributes& frontend_attributes) {
frontend_attributes_ = frontend_attributes;
}
+ // Merge the passed FrontendAttributes with the ones already set.
+ //
+ // In case of duplicates the new attributes take precedence.
+ void MergeFrontendAttributes(const FrontendAttributes& frontend_attributes) {
+ frontend_attributes_.mutable_map()->insert(
+ frontend_attributes.map().begin(), frontend_attributes.map().end());
+ }
+
+ // Returns the FrontendAttributes that will be attached to all instructions.
const FrontendAttributes& frontend_attributes() const {
return frontend_attributes_;
}
+ // Clears all the frontend attributes.
void ClearFrontendAttributes() { frontend_attributes_.Clear(); }
// Clears the sharding. Ops will be sharded according to the default placement
@@ -326,7 +343,13 @@
// Looks up the HloInstruction and sets the frontend attribute "attribute" to
// "value".
- Status AddFrontendAttribute(const XlaOp& op, string attribute, string value);
+ //
+ // If the attribute already existed then its value is updated.
+ //
+ // Note: the attribute is only added to the HloInstruction, not to the
+ // builder.
+ Status SetInstructionFrontendAttribute(const XlaOp& op, string attribute,
+ string value);
private:
// Build helper which takes the id of the root operation..
@@ -610,8 +633,8 @@
StatusOr<const HloInstructionProto*> LookUpInstruction(const XlaOp& op) const;
StatusOr<const HloInstructionProto*> LookUpInstructionByHandle(
int64 handle) const;
- StatusOr<HloInstructionProto*> LookUpMutableInstruction(const XlaOp& op);
- StatusOr<HloInstructionProto*> LookUpMutableInstructionByHandle(int64 handle);
+ StatusOr<HloInstructionProto*> LookUpMutableInstruction(const XlaOp& op);
+ StatusOr<HloInstructionProto*> LookUpMutableInstructionByHandle(int64 handle);
// Internal helper method that does the building for an arbitrary unary op.
XlaOp UnaryOp(HloOpcode unop, const XlaOp& operand);
@@ -1056,8 +1079,9 @@
absl::optional<OpSharding> prev_sharding_;
};
-// RAII-style object: sets the current frontend attributes in builder on
-// construction, and clears it on destruction.
+// RAII-style object: save the current builder's frontend attributes, and merge
+// them with the new ones on construction.
+// Restore the original attributes on destruction.
class XlaScopedFrontendAttributesAssignment {
public:
XlaScopedFrontendAttributesAssignment(
@@ -1079,13 +1103,18 @@
void SetFrontendAttributes(
const absl::optional<FrontendAttributes>& attributes) {
if (attributes.has_value()) {
- builder_->SetFrontendAttributes(attributes.value());
+ // Save the existing attributes:
+ saved_ = builder_->frontend_attributes();
+ // Merge the existring attributes with the new ones.
+ builder_->MergeFrontendAttributes(attributes.value());
} else {
- builder_->ClearFrontendAttributes();
+ builder_->SetFrontendAttributes(saved_);
+ saved_.Clear();
}
}
xla::XlaBuilder* const builder_;
+ FrontendAttributes saved_;
};
// Free functions for building XlaOps. The intention is that these will
// become the public API for building XlaOps rather than calling methods on
diff --git a/tensorflow/compiler/xla/service/hlo_instruction.cc b/tensorflow/compiler/xla/service/hlo_instruction.cc
index 236ac143..9cd4116 100644
--- a/tensorflow/compiler/xla/service/hlo_instruction.cc
+++ b/tensorflow/compiler/xla/service/hlo_instruction.cc
@@ -1196,6 +1196,7 @@
if (operand->has_sharding()) {
broadcast->set_sharding(operand->sharding());
}
+ broadcast->set_frontend_attributes(operand->frontend_attributes());
return broadcast;
}
// Do explicit broadcast for degenerate broadcast.
@@ -1221,6 +1222,7 @@
if (operand->has_sharding()) {
reshaped_operand->set_sharding(operand->sharding());
}
+ reshaped_operand->set_frontend_attributes(operand->frontend_attributes());
// Broadcast 'reshape' up to the larger size.
auto broadcast = HloInstruction::CreateBroadcast(
broadcast_shape, reshaped_operand, broadcast_dimensions);
@@ -1228,6 +1230,7 @@
if (operand->has_sharding()) {
broadcast->set_sharding(operand->sharding());
}
+ broadcast->set_frontend_attributes(operand->frontend_attributes());
return broadcast;
}
@@ -1298,6 +1301,7 @@
derived_instruction->clear_sharding();
}
derived_instruction->set_metadata(metadata_);
+ derived_instruction->set_frontend_attributes(frontend_attributes_);
}
bool HloInstruction::HasSideEffectNoRecurse() const {
@@ -2483,6 +2487,12 @@
if (has_sharding()) {
extra.push_back(StrCat("sharding=", sharding().ToString()));
}
+ if (!frontend_attributes_.map().empty()) {
+ extra.push_back(
+ absl::StrFormat("frontend_attributes={%s}",
+ absl::StrJoin(frontend_attributes_.map(), ",",
+ absl::PairFormatter("="))));
+ }
if (!outer_dimension_partitions_.empty()) {
extra.push_back(absl::StrFormat("outer_dimension_partitions={%s}",
StrJoin(outer_dimension_partitions_, ",")));
@@ -2542,6 +2552,9 @@
proto.mutable_outer_dimension_partitions()->Add(idx);
}
}
+ if (!frontend_attributes_.map().empty()) {
+ proto.mutable_frontend_attributes()->CopyFrom(frontend_attributes_);
+ }
return proto;
}
diff --git a/tensorflow/compiler/xla/service/hlo_instruction.h b/tensorflow/compiler/xla/service/hlo_instruction.h
index 467dd29..cf17502 100644
--- a/tensorflow/compiler/xla/service/hlo_instruction.h
+++ b/tensorflow/compiler/xla/service/hlo_instruction.h
@@ -1888,6 +1888,14 @@
// Attributes passed from the frontend to give hints to the backend about
// how to compile this HLO.
+ // HLO -> HLO transforms are expected to preserve these attributes on a
+ // "best effort" basis only.
+ // For example:
+ // x = const(10, frontend_attributes={x}
+ // y = const(10, frontend_attributes={y}
+ // z = add(x,y), frontend_attributes={y}
+ // Could be simplified to:
+ // z' = const(20), frontend_attributes={?}
FrontendAttributes frontend_attributes_;
// This field is assigned to true when backend_config_ is assigned to