Add option to bridge pass to use sharding util ops (XlaSplitND, XlaConcatND) in place of tree of Split/Concat nodes.

Currently in the TPU bridge, a tree of Concat ops (with Slice ops potentially) is created to merge shard inputs. This can create temporary memory allocations at each level of Concat ops. Fusing these operations into 1 op removes these temporary allocations, reducing some memory overhead.

PiperOrigin-RevId: 385003313
Change-Id: I9b6ba961bcdb27137189450d67bd0baf1173ed41
diff --git a/tensorflow/core/tpu/graph_rewrite/distributed_tpu_rewrite_pass.cc b/tensorflow/core/tpu/graph_rewrite/distributed_tpu_rewrite_pass.cc
index ab7d564..ee82f19 100644
--- a/tensorflow/core/tpu/graph_rewrite/distributed_tpu_rewrite_pass.cc
+++ b/tensorflow/core/tpu/graph_rewrite/distributed_tpu_rewrite_pass.cc
@@ -17,7 +17,9 @@
 
 #include "tensorflow/core/tpu/graph_rewrite/distributed_tpu_rewrite_pass.h"
 
+#include <algorithm>
 #include <queue>
+#include <utility>
 #include <vector>
 
 #include "absl/algorithm/container.h"
@@ -895,6 +897,214 @@
   return sharded_input_info;
 }
 
+// Creates a xla split node to shard an input, and adds that new node to a
+// Graph.
+StatusOr<Node*> CreateXlaSplitOp(absl::string_view node_name,
+                                 const bool is_resource, const NodeOut& input,
+                                 const PartialTensorShape& partial_tensor_shape,
+                                 const std::vector<Node*>& control_inputs,
+                                 const std::vector<Node*>& control_outputs,
+                                 const DataType dtype, const int num_shards,
+                                 const xla::OpSharding& sharding,
+                                 Graph* graph) {
+  const std::string& input_assigned_device = input.node->assigned_device_name();
+  NodeDef xla_split_def;
+  xla_split_def.set_name(graph->NewName(node_name));
+  xla_split_def.set_op(is_resource ? "ReadVariableXlaSplitND" : "XlaSplitND");
+  xla_split_def.set_device(input_assigned_device);
+  AddNodeAttr("T", dtype, &xla_split_def);
+  AddNodeAttr("N", num_shards, &xla_split_def);
+  const std::vector<int64> num_splits(
+      sharding.tile_assignment_dimensions().begin(),
+      sharding.replicate_on_last_tile_dim()
+          ? std::prev(sharding.tile_assignment_dimensions().end())
+          : sharding.tile_assignment_dimensions().end());
+  AddNodeAttr("num_splits", num_splits, &xla_split_def);
+  const int rank = sharding.replicate_on_last_tile_dim()
+                       ? sharding.tile_assignment_dimensions_size() - 1
+                       : sharding.tile_assignment_dimensions_size();
+  std::vector<int32> paddings;
+  paddings.reserve(rank);
+  for (int dim = 0; dim < rank; ++dim) {
+    paddings.push_back(GetPadding(dim, sharding.tile_assignment_dimensions(dim),
+                                  partial_tensor_shape));
+  }
+  AddNodeAttr("paddings", paddings, &xla_split_def);
+
+  if (!is_resource) {
+    AddNodeAttr("_tpu_avoid_constant_fold", "not_used", &xla_split_def);
+    AddNodeAttr(kColocationAttrName,
+                std::vector<string>{
+                    absl::StrCat(kColocationGroupPrefix, input.node->name())},
+                &xla_split_def);
+  }
+
+  Status s;
+  Node* xla_split = graph->AddNode(xla_split_def, &s);
+  TF_RETURN_IF_ERROR(s);
+  if (is_resource) {
+    xla_split->set_requested_device(input.node->requested_device());
+  }
+  xla_split->set_assigned_device_name(input_assigned_device);
+  graph->AddEdge(input.node, input.index, xla_split, 0);
+  for (Node* control_input : control_inputs) {
+    graph->AddControlEdge(control_input, xla_split);
+  }
+  for (Node* control_output : control_outputs) {
+    graph->AddControlEdge(xla_split, control_output);
+  }
+  return xla_split;
+}
+
+// Creates a sharded tensor list for all input shards of an input with sharding.
+xla::StatusOr<std::vector<NodeOut>> ShardInputWithXlaSplitOp(
+    absl::string_view node_name, const bool is_resource, const NodeOut& input,
+    const PartialTensorShape& partial_tensor_shape,
+    const std::vector<Node*>& control_inputs,
+    const std::vector<Node*>& control_outputs, const DataType dtype,
+    const xla::OpSharding& sharding, Graph* graph) {
+  const int repeat = sharding.replicate_on_last_tile_dim()
+                         ? *sharding.tile_assignment_dimensions().rbegin()
+                         : 1;
+  const int num_shards = sharding.tile_assignment_devices_size() / repeat;
+
+  TF_ASSIGN_OR_RETURN(
+      Node * xla_split,
+      CreateXlaSplitOp(node_name, is_resource, input, partial_tensor_shape,
+                       control_inputs, control_outputs, dtype, num_shards,
+                       sharding, graph));
+
+  std::vector<NodeOut> sharded_inputs_list(
+      sharding.tile_assignment_devices_size());
+
+  for (int i = 0; i < num_shards; ++i) {
+    for (int j = 0; j < repeat; ++j) {
+      const int index = i * repeat + j;
+      const int core = sharding.tile_assignment_devices(index);
+      sharded_inputs_list[core] = {xla_split, i};
+    }
+  }
+
+  return sharded_inputs_list;
+}
+
+// Creates an XlaSplitND op to shard a per-replica arg.
+xla::StatusOr<ShardedInputInfo> CreateOrGetXlaSplitNodeForShardedPerReplicaArg(
+    const xla::OpSharding& sharding, const int replica_id,
+    const int orig_arg_num, DataType dtype,
+    const PartialTensorShape& partial_tensor_shape, Node* orig_src,
+    const int orig_src_output, Graph* graph,
+    std::map<ShardedInputIndex, ShardedInputInfo>*
+        arg_index_to_sharded_input_map) {
+  ShardedInputIndex input_index{replica_id, orig_arg_num};
+  auto iter = arg_index_to_sharded_input_map->find(input_index);
+  if (iter != arg_index_to_sharded_input_map->end()) {
+    return iter->second;
+  }
+
+  TF_ASSIGN_OR_RETURN(
+      std::vector<NodeOut> sharded_inputs_list,
+      ShardInputWithXlaSplitOp(
+          absl::StrCat(orig_src->name(), "/replica_", replica_id, "_split"),
+          /*is_resource=*/false, /*input=*/{orig_src, orig_src_output},
+          partial_tensor_shape, /*control_inputs=*/{}, /*control_outputs=*/{},
+          dtype, sharding, graph));
+
+  ShardedInputInfo sharded_input_info{nullptr, std::move(sharded_inputs_list)};
+  (*arg_index_to_sharded_input_map)[input_index] = sharded_input_info;
+  return sharded_input_info;
+}
+
+// Creates an XlaSplitND op to shard a distributed arg.
+xla::StatusOr<ShardedInputInfo> CreateOrGetXlaSplitNodeForDistributedArg(
+    const xla::OpSharding& sharding, const int num_replicas,
+    const int replica_id, const int orig_arg_num, DataType dtype,
+    const PartialTensorShape& partial_tensor_shape, Node* orig_src,
+    const int orig_src_output, Graph* graph,
+    std::map<ShardedInputIndex, ShardedInputInfo>*
+        arg_index_to_sharded_input_map) {
+  ShardedInputIndex input_index{replica_id, orig_arg_num};
+  auto iter = arg_index_to_sharded_input_map->find(input_index);
+  if (iter != arg_index_to_sharded_input_map->end()) {
+    return iter->second;
+  }
+
+  TF_ASSIGN_OR_RETURN(
+      std::vector<NodeOut> sharded_inputs_list,
+      ShardInputWithXlaSplitOp(
+          absl::StrCat(orig_src->name(), "/distributed_split"),
+          /*is_resource=*/false, /*input=*/{orig_src, orig_src_output},
+          partial_tensor_shape, /*control_inputs=*/{}, /*control_outputs=*/{},
+          dtype, sharding, graph));
+
+  ShardedInputInfo sharded_input_info{nullptr, std::move(sharded_inputs_list)};
+  for (int replica = 0; replica < num_replicas; ++replica) {
+    (*arg_index_to_sharded_input_map)[{replica, orig_arg_num}] =
+        sharded_input_info;
+  }
+  return sharded_input_info;
+}
+
+// Creates an ReadVariableXlaSplitND op to shard a variable arg.
+xla::StatusOr<ShardedInputInfo> CreateOrGetXlaSplitNodeForVariableArg(
+    const xla::OpSharding& sharding, const int num_replicas,
+    const int replica_id, const int orig_arg_num, DataType dtype,
+    const PartialTensorShape& partial_tensor_shape, Node* orig_src,
+    const int orig_src_output, Graph* graph,
+    std::vector<Node*>* to_be_removed_nodes,
+    std::map<ShardedInputIndex, ShardedInputInfo>*
+        arg_index_to_sharded_input_map) {
+  ShardedInputIndex input_index{replica_id, orig_arg_num};
+  auto iter = arg_index_to_sharded_input_map->find(input_index);
+  if (iter != arg_index_to_sharded_input_map->end()) {
+    return iter->second;
+  }
+
+  DCHECK_EQ(orig_src->type_string(), "ReadVariableOp");
+  std::vector<Node*> control_outputs;
+  std::vector<const Edge*> edges_to_remove;
+  for (const Edge* edge : orig_src->out_edges()) {
+    if (edge->IsControlEdge()) {
+      control_outputs.push_back(edge->dst());
+    }
+    edges_to_remove.push_back(edge);
+  }
+
+  to_be_removed_nodes->push_back(orig_src);
+
+  const Edge* resource = nullptr;
+  TF_RETURN_IF_ERROR(orig_src->input_edge(0, &resource));
+
+  std::vector<Node*> control_inputs;
+  for (const Edge* edge : orig_src->in_edges()) {
+    if (edge->IsControlEdge()) {
+      control_inputs.push_back(edge->src());
+    }
+  }
+
+  TF_ASSIGN_OR_RETURN(
+      std::vector<NodeOut> sharded_inputs_list,
+      ShardInputWithXlaSplitOp(
+          absl::StrCat(resource->src()->name(), "/read_variable_split"),
+          /*is_resource=*/true,
+          /*input=*/{resource->src(), resource->src_output()},
+          partial_tensor_shape, control_inputs, control_outputs, dtype,
+          sharding, graph));
+
+  for (const Edge* edge : edges_to_remove) {
+    graph->RemoveControlEdge(edge);
+  }
+
+  DCHECK(orig_src->out_edges().empty());
+
+  ShardedInputInfo sharded_input_info{nullptr, std::move(sharded_inputs_list)};
+  for (int replica = 0; replica < num_replicas; ++replica) {
+    (*arg_index_to_sharded_input_map)[{replica, orig_arg_num}] =
+        sharded_input_info;
+  }
+  return sharded_input_info;
+}
+
 // Creates a concat node to be used for aggregating sharded retvals across
 // logical cores.
 xla::StatusOr<Node*> CreateConcatNode(int dim, int num_splits, DataType dtype,
@@ -1061,6 +1271,45 @@
   return inputs_to_sharded_retval.at(0).node;
 }
 
+xla::StatusOr<Node*> CreateXlaConcatNode(
+    const xla::OpSharding& sharding, const int replica_id, DataType dtype,
+    const PartialTensorShape& partial_tensor_shape,
+    const std::vector<NodeOut>& orig_inputs, absl::string_view device,
+    Graph* graph) {
+  NodeDef xla_concat_def;
+  xla_concat_def.set_name(graph->NewName(
+      absl::StrCat("sharded_output/replica_", replica_id, "_concat")));
+  xla_concat_def.set_op("XlaConcatND");
+  xla_concat_def.set_device(std::string(device));
+  AddNodeAttr("T", dtype, &xla_concat_def);
+  AddNodeAttr("N", static_cast<int64>(orig_inputs.size()), &xla_concat_def);
+  const std::vector<int64> num_concats(
+      sharding.tile_assignment_dimensions().begin(),
+      sharding.replicate_on_last_tile_dim()
+          ? std::prev(sharding.tile_assignment_dimensions().end())
+          : sharding.tile_assignment_dimensions().end());
+  AddNodeAttr("num_concats", num_concats, &xla_concat_def);
+  const int rank = sharding.replicate_on_last_tile_dim()
+                       ? sharding.tile_assignment_dimensions_size() - 1
+                       : sharding.tile_assignment_dimensions_size();
+  std::vector<int32> paddings;
+  paddings.reserve(rank);
+  for (int dim = 0; dim < rank; ++dim) {
+    paddings.push_back(GetPadding(dim, sharding.tile_assignment_dimensions(dim),
+                                  partial_tensor_shape));
+  }
+  AddNodeAttr("paddings", paddings, &xla_concat_def);
+
+  Status s;
+  Node* xla_concat = graph->AddNode(xla_concat_def, &s);
+  TF_RETURN_IF_ERROR(s);
+  for (int i = 0, e = orig_inputs.size(); i < e; ++i) {
+    const NodeOut& input = orig_inputs[i];
+    graph->AddEdge(input.node, input.index, xla_concat, i);
+  }
+  return xla_concat;
+}
+
 // Set the padding ops the same devices as the original inputs. If the original
 // inputs are on TPUs, the padding ops will be placed on TPUs and XLA on demand
 // mode will be triggered, so we don't need to copy the data back to the host
@@ -3059,7 +3308,7 @@
     replicate_output_edges[edge->src_output()] = edge;
     if (num_partitioned_outputs > 1) {
       return errors::InvalidArgument(
-          "More than one TPUPartitionedOutput per replciated output.");
+          "More than one TPUPartitionedOutput per replicated output.");
     }
   }
 
@@ -3213,16 +3462,17 @@
         int64_t orig_arg_num = core_arg_nums[core][i];
         VLOG(2) << " replica " << replica << " core " << core << " i " << i
                 << " orig_arg_num " << orig_arg_num;
-        if (params_info.IsPerReplicaArg(orig_arg_num) ||
-            params_info.IsDistributedArg(orig_arg_num)) {
+        const bool is_per_replica_arg =
+            params_info.IsPerReplicaArg(orig_arg_num);
+        if (is_per_replica_arg || params_info.IsDistributedArg(orig_arg_num)) {
           // Per-replica input and distributed input
-          int64_t input_num = params_info.IsPerReplicaArg(orig_arg_num)
-                                  ? replica * params_info.NumPerReplicaArgs() +
-                                        core_arg_nums[core][i]
-                                  : params_info.NumReplicas() *
-                                            params_info.NumPerReplicaArgs() +
-                                        core_arg_nums[core][i] -
-                                        params_info.NumPerReplicaArgs();
+          const int64_t input_num =
+              is_per_replica_arg ? replica * params_info.NumPerReplicaArgs() +
+                                       core_arg_nums[core][i]
+                                 : params_info.NumReplicas() *
+                                           params_info.NumPerReplicaArgs() +
+                                       core_arg_nums[core][i] -
+                                       params_info.NumPerReplicaArgs();
 
           const Edge* edge = replicate_input_edges[input_num];
           VLOG(2) << "replicate_input_edges[" << input_num << "]";
@@ -3265,14 +3515,32 @@
               }
               const xla::OpSharding& sharding = arg_shardings[orig_arg_num];
 
-              // Create or get the Split node.
-              TF_ASSIGN_OR_RETURN(
-                  ShardedInputInfo sharded_input_info,
-                  CreateOrGetSplitNodesForInputSharding(
-                      sharding, orig_arg_num, dtype,
-                      arg_shapes[orig_arg_num].handle_shape, replica,
-                      edge->src_output(), edge->src(), control_predecessor,
-                      graph, &input_index_to_sharded_inputs));
+              ShardedInputInfo sharded_input_info;
+              if (use_nd_sharding_ops_ && is_per_replica_arg) {
+                TF_ASSIGN_OR_RETURN(
+                    sharded_input_info,
+                    CreateOrGetXlaSplitNodeForShardedPerReplicaArg(
+                        sharding, replica, orig_arg_num, dtype,
+                        PartialTensorShape(), edge->src(), edge->src_output(),
+                        graph, &input_index_to_sharded_inputs));
+              } else if (use_nd_sharding_ops_) {
+                TF_ASSIGN_OR_RETURN(
+                    sharded_input_info,
+                    CreateOrGetXlaSplitNodeForDistributedArg(
+                        sharding, params_info.NumReplicas(), replica,
+                        orig_arg_num, dtype, PartialTensorShape(), edge->src(),
+                        edge->src_output(), graph,
+                        &input_index_to_sharded_inputs));
+              } else {
+                TF_ASSIGN_OR_RETURN(
+                    sharded_input_info,
+                    CreateOrGetSplitNodesForInputSharding(
+                        sharding, orig_arg_num, dtype, PartialTensorShape(),
+                        replica, edge->src_output(), edge->src(),
+                        control_predecessor, graph,
+                        &input_index_to_sharded_inputs));
+              }
+
               NodeOut split_node_and_index =
                   sharded_input_info.sharded_inputs.at(core);
               // Connect with Split node output.
@@ -3348,16 +3616,28 @@
                     graph));
 
             if (arg_shardings[orig_arg_num].type() == xla::OpSharding::OTHER) {
-              const xla::OpSharding& sharding = arg_shardings[orig_arg_num];
-              // Create or get the Split node.
-              TF_ASSIGN_OR_RETURN(
-                  ShardedInputInfo sharded_input_info,
-                  CreateOrGetSplitNodesForInputSharding(
-                      sharding, orig_arg_num,
-                      arg_shapes[orig_arg_num].handle_type,
-                      arg_shapes[orig_arg_num].handle_shape, replica,
-                      var_data.index, var_data.node, control_predecessor, graph,
-                      &input_index_to_sharded_inputs));
+              ShardedInputInfo sharded_input_info;
+              if (use_nd_sharding_ops_) {
+                TF_ASSIGN_OR_RETURN(
+                    sharded_input_info,
+                    CreateOrGetXlaSplitNodeForVariableArg(
+                        arg_shardings[orig_arg_num], params_info.NumReplicas(),
+                        replica, orig_arg_num,
+                        arg_shapes[orig_arg_num].handle_type,
+                        arg_shapes[orig_arg_num].handle_shape, var_data.node,
+                        var_data.index, graph, &to_be_removed_nodes,
+                        &input_index_to_sharded_inputs));
+              } else {
+                TF_ASSIGN_OR_RETURN(
+                    sharded_input_info,
+                    CreateOrGetSplitNodesForInputSharding(
+                        arg_shardings[orig_arg_num], orig_arg_num,
+                        arg_shapes[orig_arg_num].handle_type,
+                        arg_shapes[orig_arg_num].handle_shape, replica,
+                        var_data.index, var_data.node, control_predecessor,
+                        graph, &input_index_to_sharded_inputs));
+              }
+
               NodeOut split_node_and_index =
                   sharded_input_info.sharded_inputs[core];
               // Connect with Split node output.
@@ -3440,11 +3720,20 @@
                             core_retval_nums[core_id][core_retval_index])});
           }
           DataType dtype = e->src()->output_type(e->src_output());
-          TF_ASSIGN_OR_RETURN(
-              Node * concat_node,
-              CreateConcatNodesForRetval(
-                  sharding, dtype, /*inferred_shape*/ PartialTensorShape(),
-                  replica, orig_inputs, graph, /*device=*/""));
+          Node* concat_node = nullptr;
+          if (use_nd_sharding_ops_) {
+            TF_ASSIGN_OR_RETURN(
+                concat_node, CreateXlaConcatNode(
+                                 sharding, replica, dtype,
+                                 /*partial_tensor_shape=*/PartialTensorShape(),
+                                 orig_inputs, /*device=*/"", graph));
+          } else {
+            TF_ASSIGN_OR_RETURN(
+                concat_node,
+                CreateConcatNodesForRetval(
+                    sharding, dtype, /*inferred_shape=*/PartialTensorShape(),
+                    replica, orig_inputs, graph, /*device=*/""));
+          }
 
           const Edge* edge = replicate_output_edges[output_num];
           Node* dst = edge->dst();
@@ -3522,12 +3811,22 @@
             // be collocated with the variable.
             absl::string_view device =
                 variable_reads[core_variable_writes[i]]->assigned_device_name();
-            TF_ASSIGN_OR_RETURN(
-                Node * concat_node,
-                CreateConcatNodesForRetval(
-                    sharding, arg_shapes[orig_arg_num].handle_type,
-                    arg_shapes[orig_arg_num].handle_shape, replica, orig_inputs,
-                    graph, device));
+            Node* concat_node = nullptr;
+            if (use_nd_sharding_ops_) {
+              TF_ASSIGN_OR_RETURN(
+                  concat_node,
+                  CreateXlaConcatNode(sharding, replica,
+                                      arg_shapes[orig_arg_num].handle_type,
+                                      arg_shapes[orig_arg_num].handle_shape,
+                                      orig_inputs, device, graph));
+            } else {
+              TF_ASSIGN_OR_RETURN(
+                  concat_node,
+                  CreateConcatNodesForRetval(
+                      sharding, arg_shapes[orig_arg_num].handle_type,
+                      arg_shapes[orig_arg_num].handle_shape, replica,
+                      orig_inputs, graph, device));
+            }
             // Populate VariableWrite.
             VariableWrite& write = variable_writes->at(core_variable_writes[i]);
             write.value = concat_node;
@@ -3558,7 +3857,7 @@
     graph->RemoveNode(node);
   }
   return Status::OK();
-}
+}  // NOLINT(readability/fn_size)
 
 /* static */ Status DistributedTPURewritePass::CopyOutsideCompilationNodes(
     int replica_index, const std::vector<Node*>& outside_compilation_nodes,
@@ -4375,7 +4674,7 @@
   // by the user. They will be mapped to physical TPU cores below.
   int num_replicas;
   int num_cores_per_replica;
-  int num_tasks;  // Number of tasks.
+  int num_tasks;
   std::vector<std::vector<string>> tf_device_assignment;
   std::vector<int> devices_to_lock;
   std::unique_ptr<xla::DeviceAssignment> xla_device_assignment;
@@ -4708,13 +5007,14 @@
 bool DistributedTPURewritePass::enable_automatic_model_parallelism_ = false;
 bool DistributedTPURewritePass::enable_xla_param_broadcast_ = false;
 bool DistributedTPURewritePass::enable_multicore_locking_ = false;
+bool DistributedTPURewritePass::use_nd_sharding_ops_ = false;
 
 /*static*/ void DistributedTPURewritePass::SetDistributedTpuRewritePassOptions(
     bool distribute_vars, bool allow_xla_spmd_partition,
     bool replicate_inputs_outputs_by_default_for_xla_spmd,
     bool enable_cross_replica_sharding_mirrored_variables,
     bool enable_automatic_model_parallelism, bool enable_xla_param_broadcast,
-    bool enable_multicore_locking) {
+    bool enable_multicore_locking, bool use_nd_sharding_ops) {
   distribute_vars_ = distribute_vars;
   allow_xla_spmd_partition_ = allow_xla_spmd_partition;
   replicate_inputs_outputs_by_default_for_xla_spmd_ =
@@ -4724,6 +5024,7 @@
   enable_automatic_model_parallelism_ = enable_automatic_model_parallelism;
   enable_xla_param_broadcast_ = enable_xla_param_broadcast;
   enable_multicore_locking_ = enable_multicore_locking;
+  use_nd_sharding_ops_ = use_nd_sharding_ops;
 }
 
 }  // namespace tensorflow
diff --git a/tensorflow/core/tpu/graph_rewrite/distributed_tpu_rewrite_pass.h b/tensorflow/core/tpu/graph_rewrite/distributed_tpu_rewrite_pass.h
index 470dd92..bdb1df3 100644
--- a/tensorflow/core/tpu/graph_rewrite/distributed_tpu_rewrite_pass.h
+++ b/tensorflow/core/tpu/graph_rewrite/distributed_tpu_rewrite_pass.h
@@ -111,6 +111,7 @@
 #include <string>
 #include <vector>
 
+#include "absl/container/flat_hash_map.h"
 #include "absl/container/node_hash_map.h"
 #include "absl/types/span.h"
 #include "tensorflow/compiler/jit/shape_inference.h"
@@ -133,7 +134,7 @@
       bool replicate_inputs_outputs_by_default_for_xla_spmd,
       bool enable_cross_replica_sharding_mirrored_variables,
       bool enable_automatic_model_parallelism, bool enable_xla_param_broadcast,
-      bool enable_multicore_locking);
+      bool enable_multicore_locking, bool use_nd_sharding_ops);
 
   Status Run(const GraphOptimizationPassOptions& options) override;
 
@@ -598,6 +599,7 @@
   static bool enable_automatic_model_parallelism_;
   static bool enable_xla_param_broadcast_;
   static bool enable_multicore_locking_;
+  static bool use_nd_sharding_ops_;
 };
 
 }  // namespace tensorflow