Refactor xla_sharding to be more useful.
PiperOrigin-RevId: 295838039
Change-Id: Ia138c41a9e2739379ecf3e2222686a195b0fe56d
diff --git a/tensorflow/compiler/tf2xla/sharding_util.cc b/tensorflow/compiler/tf2xla/sharding_util.cc
index 4d5bf08..366e8d4 100644
--- a/tensorflow/compiler/tf2xla/sharding_util.cc
+++ b/tensorflow/compiler/tf2xla/sharding_util.cc
@@ -26,22 +26,6 @@
} // namespace
namespace {
-xla::StatusOr<absl::optional<xla::OpSharding>> GetShardingFromNodeDef(
- const NodeDef& node_def) {
- if (!HasNodeAttr(node_def, kShardingAttribute)) {
- return absl::optional<xla::OpSharding>();
- }
- string value;
- xla::OpSharding sharding;
- TF_RETURN_IF_ERROR(GetNodeAttr(node_def, kShardingAttribute, &value));
- if (!sharding.ParseFromString(value)) {
- return xla::InvalidArgument(
- "Experimental _XlaSharding attribute was not a valid encoded "
- "xla::OpSharding proto.");
- }
- return absl::optional<xla::OpSharding>(sharding);
-}
-
Status CoreOutOfRangeError(int core, int num_cores_per_replica) {
return errors::InvalidArgument(
"Invalid replicated core id: ", core,
@@ -107,4 +91,19 @@
}
}
+xla::StatusOr<absl::optional<xla::OpSharding>> GetShardingFromNodeDef(
+ const NodeDef& node_def) {
+ if (!HasNodeAttr(node_def, kShardingAttribute)) {
+ return absl::optional<xla::OpSharding>();
+ }
+ string value;
+ xla::OpSharding sharding;
+ TF_RETURN_IF_ERROR(GetNodeAttr(node_def, kShardingAttribute, &value));
+ if (!sharding.ParseFromString(value)) {
+ return xla::InvalidArgument(
+ "Experimental _XlaSharding attribute was not a valid encoded "
+ "xla::OpSharding proto.");
+ }
+ return absl::optional<xla::OpSharding>(sharding);
+}
} // namespace tensorflow
diff --git a/tensorflow/compiler/tf2xla/sharding_util.h b/tensorflow/compiler/tf2xla/sharding_util.h
index ab67d4f..1964348 100644
--- a/tensorflow/compiler/tf2xla/sharding_util.h
+++ b/tensorflow/compiler/tf2xla/sharding_util.h
@@ -45,6 +45,10 @@
void SetShardingDeviceAssignmentFromNode(const Node& src, Node* dst);
+// Get sharding inforamtion from node.
+xla::StatusOr<absl::optional<xla::OpSharding>> GetShardingFromNodeDef(
+ const NodeDef& node_def);
+
} // namespace tensorflow
#endif // TENSORFLOW_COMPILER_TF2XLA_SHARDING_UTIL_H_