Refactor xla_sharding to be more useful.

PiperOrigin-RevId: 295838039
Change-Id: Ia138c41a9e2739379ecf3e2222686a195b0fe56d
diff --git a/tensorflow/compiler/tf2xla/sharding_util.cc b/tensorflow/compiler/tf2xla/sharding_util.cc
index 4d5bf08..366e8d4 100644
--- a/tensorflow/compiler/tf2xla/sharding_util.cc
+++ b/tensorflow/compiler/tf2xla/sharding_util.cc
@@ -26,22 +26,6 @@
 }  // namespace
 
 namespace {
-xla::StatusOr<absl::optional<xla::OpSharding>> GetShardingFromNodeDef(
-    const NodeDef& node_def) {
-  if (!HasNodeAttr(node_def, kShardingAttribute)) {
-    return absl::optional<xla::OpSharding>();
-  }
-  string value;
-  xla::OpSharding sharding;
-  TF_RETURN_IF_ERROR(GetNodeAttr(node_def, kShardingAttribute, &value));
-  if (!sharding.ParseFromString(value)) {
-    return xla::InvalidArgument(
-        "Experimental _XlaSharding attribute was not a valid encoded "
-        "xla::OpSharding proto.");
-  }
-  return absl::optional<xla::OpSharding>(sharding);
-}
-
 Status CoreOutOfRangeError(int core, int num_cores_per_replica) {
   return errors::InvalidArgument(
       "Invalid replicated core id: ", core,
@@ -107,4 +91,19 @@
   }
 }
 
+xla::StatusOr<absl::optional<xla::OpSharding>> GetShardingFromNodeDef(
+    const NodeDef& node_def) {
+  if (!HasNodeAttr(node_def, kShardingAttribute)) {
+    return absl::optional<xla::OpSharding>();
+  }
+  string value;
+  xla::OpSharding sharding;
+  TF_RETURN_IF_ERROR(GetNodeAttr(node_def, kShardingAttribute, &value));
+  if (!sharding.ParseFromString(value)) {
+    return xla::InvalidArgument(
+        "Experimental _XlaSharding attribute was not a valid encoded "
+        "xla::OpSharding proto.");
+  }
+  return absl::optional<xla::OpSharding>(sharding);
+}
 }  // namespace tensorflow
diff --git a/tensorflow/compiler/tf2xla/sharding_util.h b/tensorflow/compiler/tf2xla/sharding_util.h
index ab67d4f..1964348 100644
--- a/tensorflow/compiler/tf2xla/sharding_util.h
+++ b/tensorflow/compiler/tf2xla/sharding_util.h
@@ -45,6 +45,10 @@
 
 void SetShardingDeviceAssignmentFromNode(const Node& src, Node* dst);
 
+// Get sharding inforamtion from node.
+xla::StatusOr<absl::optional<xla::OpSharding>> GetShardingFromNodeDef(
+    const NodeDef& node_def);
+
 }  // namespace tensorflow
 
 #endif  // TENSORFLOW_COMPILER_TF2XLA_SHARDING_UTIL_H_