Add option to bridge pass to use sharding util ops (XlaSplitND, XlaConcatND) in place of tree of Split/Concat nodes.
Currently in the TPU bridge, a tree of Concat ops (with Slice ops potentially) is created to merge shard inputs. This can create temporary memory allocations at each level of Concat ops. Fusing these operations into 1 op removes these temporary allocations, reducing some memory overhead.
PiperOrigin-RevId: 385003313
Change-Id: I9b6ba961bcdb27137189450d67bd0baf1173ed41
diff --git a/tensorflow/core/tpu/graph_rewrite/distributed_tpu_rewrite_pass.cc b/tensorflow/core/tpu/graph_rewrite/distributed_tpu_rewrite_pass.cc
index ab7d564..ee82f19 100644
--- a/tensorflow/core/tpu/graph_rewrite/distributed_tpu_rewrite_pass.cc
+++ b/tensorflow/core/tpu/graph_rewrite/distributed_tpu_rewrite_pass.cc
@@ -17,7 +17,9 @@
#include "tensorflow/core/tpu/graph_rewrite/distributed_tpu_rewrite_pass.h"
+#include <algorithm>
#include <queue>
+#include <utility>
#include <vector>
#include "absl/algorithm/container.h"
@@ -895,6 +897,214 @@
return sharded_input_info;
}
+// Creates a xla split node to shard an input, and adds that new node to a
+// Graph.
+StatusOr<Node*> CreateXlaSplitOp(absl::string_view node_name,
+ const bool is_resource, const NodeOut& input,
+ const PartialTensorShape& partial_tensor_shape,
+ const std::vector<Node*>& control_inputs,
+ const std::vector<Node*>& control_outputs,
+ const DataType dtype, const int num_shards,
+ const xla::OpSharding& sharding,
+ Graph* graph) {
+ const std::string& input_assigned_device = input.node->assigned_device_name();
+ NodeDef xla_split_def;
+ xla_split_def.set_name(graph->NewName(node_name));
+ xla_split_def.set_op(is_resource ? "ReadVariableXlaSplitND" : "XlaSplitND");
+ xla_split_def.set_device(input_assigned_device);
+ AddNodeAttr("T", dtype, &xla_split_def);
+ AddNodeAttr("N", num_shards, &xla_split_def);
+ const std::vector<int64> num_splits(
+ sharding.tile_assignment_dimensions().begin(),
+ sharding.replicate_on_last_tile_dim()
+ ? std::prev(sharding.tile_assignment_dimensions().end())
+ : sharding.tile_assignment_dimensions().end());
+ AddNodeAttr("num_splits", num_splits, &xla_split_def);
+ const int rank = sharding.replicate_on_last_tile_dim()
+ ? sharding.tile_assignment_dimensions_size() - 1
+ : sharding.tile_assignment_dimensions_size();
+ std::vector<int32> paddings;
+ paddings.reserve(rank);
+ for (int dim = 0; dim < rank; ++dim) {
+ paddings.push_back(GetPadding(dim, sharding.tile_assignment_dimensions(dim),
+ partial_tensor_shape));
+ }
+ AddNodeAttr("paddings", paddings, &xla_split_def);
+
+ if (!is_resource) {
+ AddNodeAttr("_tpu_avoid_constant_fold", "not_used", &xla_split_def);
+ AddNodeAttr(kColocationAttrName,
+ std::vector<string>{
+ absl::StrCat(kColocationGroupPrefix, input.node->name())},
+ &xla_split_def);
+ }
+
+ Status s;
+ Node* xla_split = graph->AddNode(xla_split_def, &s);
+ TF_RETURN_IF_ERROR(s);
+ if (is_resource) {
+ xla_split->set_requested_device(input.node->requested_device());
+ }
+ xla_split->set_assigned_device_name(input_assigned_device);
+ graph->AddEdge(input.node, input.index, xla_split, 0);
+ for (Node* control_input : control_inputs) {
+ graph->AddControlEdge(control_input, xla_split);
+ }
+ for (Node* control_output : control_outputs) {
+ graph->AddControlEdge(xla_split, control_output);
+ }
+ return xla_split;
+}
+
+// Creates a sharded tensor list for all input shards of an input with sharding.
+xla::StatusOr<std::vector<NodeOut>> ShardInputWithXlaSplitOp(
+ absl::string_view node_name, const bool is_resource, const NodeOut& input,
+ const PartialTensorShape& partial_tensor_shape,
+ const std::vector<Node*>& control_inputs,
+ const std::vector<Node*>& control_outputs, const DataType dtype,
+ const xla::OpSharding& sharding, Graph* graph) {
+ const int repeat = sharding.replicate_on_last_tile_dim()
+ ? *sharding.tile_assignment_dimensions().rbegin()
+ : 1;
+ const int num_shards = sharding.tile_assignment_devices_size() / repeat;
+
+ TF_ASSIGN_OR_RETURN(
+ Node * xla_split,
+ CreateXlaSplitOp(node_name, is_resource, input, partial_tensor_shape,
+ control_inputs, control_outputs, dtype, num_shards,
+ sharding, graph));
+
+ std::vector<NodeOut> sharded_inputs_list(
+ sharding.tile_assignment_devices_size());
+
+ for (int i = 0; i < num_shards; ++i) {
+ for (int j = 0; j < repeat; ++j) {
+ const int index = i * repeat + j;
+ const int core = sharding.tile_assignment_devices(index);
+ sharded_inputs_list[core] = {xla_split, i};
+ }
+ }
+
+ return sharded_inputs_list;
+}
+
+// Creates an XlaSplitND op to shard a per-replica arg.
+xla::StatusOr<ShardedInputInfo> CreateOrGetXlaSplitNodeForShardedPerReplicaArg(
+ const xla::OpSharding& sharding, const int replica_id,
+ const int orig_arg_num, DataType dtype,
+ const PartialTensorShape& partial_tensor_shape, Node* orig_src,
+ const int orig_src_output, Graph* graph,
+ std::map<ShardedInputIndex, ShardedInputInfo>*
+ arg_index_to_sharded_input_map) {
+ ShardedInputIndex input_index{replica_id, orig_arg_num};
+ auto iter = arg_index_to_sharded_input_map->find(input_index);
+ if (iter != arg_index_to_sharded_input_map->end()) {
+ return iter->second;
+ }
+
+ TF_ASSIGN_OR_RETURN(
+ std::vector<NodeOut> sharded_inputs_list,
+ ShardInputWithXlaSplitOp(
+ absl::StrCat(orig_src->name(), "/replica_", replica_id, "_split"),
+ /*is_resource=*/false, /*input=*/{orig_src, orig_src_output},
+ partial_tensor_shape, /*control_inputs=*/{}, /*control_outputs=*/{},
+ dtype, sharding, graph));
+
+ ShardedInputInfo sharded_input_info{nullptr, std::move(sharded_inputs_list)};
+ (*arg_index_to_sharded_input_map)[input_index] = sharded_input_info;
+ return sharded_input_info;
+}
+
+// Creates an XlaSplitND op to shard a distributed arg.
+xla::StatusOr<ShardedInputInfo> CreateOrGetXlaSplitNodeForDistributedArg(
+ const xla::OpSharding& sharding, const int num_replicas,
+ const int replica_id, const int orig_arg_num, DataType dtype,
+ const PartialTensorShape& partial_tensor_shape, Node* orig_src,
+ const int orig_src_output, Graph* graph,
+ std::map<ShardedInputIndex, ShardedInputInfo>*
+ arg_index_to_sharded_input_map) {
+ ShardedInputIndex input_index{replica_id, orig_arg_num};
+ auto iter = arg_index_to_sharded_input_map->find(input_index);
+ if (iter != arg_index_to_sharded_input_map->end()) {
+ return iter->second;
+ }
+
+ TF_ASSIGN_OR_RETURN(
+ std::vector<NodeOut> sharded_inputs_list,
+ ShardInputWithXlaSplitOp(
+ absl::StrCat(orig_src->name(), "/distributed_split"),
+ /*is_resource=*/false, /*input=*/{orig_src, orig_src_output},
+ partial_tensor_shape, /*control_inputs=*/{}, /*control_outputs=*/{},
+ dtype, sharding, graph));
+
+ ShardedInputInfo sharded_input_info{nullptr, std::move(sharded_inputs_list)};
+ for (int replica = 0; replica < num_replicas; ++replica) {
+ (*arg_index_to_sharded_input_map)[{replica, orig_arg_num}] =
+ sharded_input_info;
+ }
+ return sharded_input_info;
+}
+
+// Creates an ReadVariableXlaSplitND op to shard a variable arg.
+xla::StatusOr<ShardedInputInfo> CreateOrGetXlaSplitNodeForVariableArg(
+ const xla::OpSharding& sharding, const int num_replicas,
+ const int replica_id, const int orig_arg_num, DataType dtype,
+ const PartialTensorShape& partial_tensor_shape, Node* orig_src,
+ const int orig_src_output, Graph* graph,
+ std::vector<Node*>* to_be_removed_nodes,
+ std::map<ShardedInputIndex, ShardedInputInfo>*
+ arg_index_to_sharded_input_map) {
+ ShardedInputIndex input_index{replica_id, orig_arg_num};
+ auto iter = arg_index_to_sharded_input_map->find(input_index);
+ if (iter != arg_index_to_sharded_input_map->end()) {
+ return iter->second;
+ }
+
+ DCHECK_EQ(orig_src->type_string(), "ReadVariableOp");
+ std::vector<Node*> control_outputs;
+ std::vector<const Edge*> edges_to_remove;
+ for (const Edge* edge : orig_src->out_edges()) {
+ if (edge->IsControlEdge()) {
+ control_outputs.push_back(edge->dst());
+ }
+ edges_to_remove.push_back(edge);
+ }
+
+ to_be_removed_nodes->push_back(orig_src);
+
+ const Edge* resource = nullptr;
+ TF_RETURN_IF_ERROR(orig_src->input_edge(0, &resource));
+
+ std::vector<Node*> control_inputs;
+ for (const Edge* edge : orig_src->in_edges()) {
+ if (edge->IsControlEdge()) {
+ control_inputs.push_back(edge->src());
+ }
+ }
+
+ TF_ASSIGN_OR_RETURN(
+ std::vector<NodeOut> sharded_inputs_list,
+ ShardInputWithXlaSplitOp(
+ absl::StrCat(resource->src()->name(), "/read_variable_split"),
+ /*is_resource=*/true,
+ /*input=*/{resource->src(), resource->src_output()},
+ partial_tensor_shape, control_inputs, control_outputs, dtype,
+ sharding, graph));
+
+ for (const Edge* edge : edges_to_remove) {
+ graph->RemoveControlEdge(edge);
+ }
+
+ DCHECK(orig_src->out_edges().empty());
+
+ ShardedInputInfo sharded_input_info{nullptr, std::move(sharded_inputs_list)};
+ for (int replica = 0; replica < num_replicas; ++replica) {
+ (*arg_index_to_sharded_input_map)[{replica, orig_arg_num}] =
+ sharded_input_info;
+ }
+ return sharded_input_info;
+}
+
// Creates a concat node to be used for aggregating sharded retvals across
// logical cores.
xla::StatusOr<Node*> CreateConcatNode(int dim, int num_splits, DataType dtype,
@@ -1061,6 +1271,45 @@
return inputs_to_sharded_retval.at(0).node;
}
+xla::StatusOr<Node*> CreateXlaConcatNode(
+ const xla::OpSharding& sharding, const int replica_id, DataType dtype,
+ const PartialTensorShape& partial_tensor_shape,
+ const std::vector<NodeOut>& orig_inputs, absl::string_view device,
+ Graph* graph) {
+ NodeDef xla_concat_def;
+ xla_concat_def.set_name(graph->NewName(
+ absl::StrCat("sharded_output/replica_", replica_id, "_concat")));
+ xla_concat_def.set_op("XlaConcatND");
+ xla_concat_def.set_device(std::string(device));
+ AddNodeAttr("T", dtype, &xla_concat_def);
+ AddNodeAttr("N", static_cast<int64>(orig_inputs.size()), &xla_concat_def);
+ const std::vector<int64> num_concats(
+ sharding.tile_assignment_dimensions().begin(),
+ sharding.replicate_on_last_tile_dim()
+ ? std::prev(sharding.tile_assignment_dimensions().end())
+ : sharding.tile_assignment_dimensions().end());
+ AddNodeAttr("num_concats", num_concats, &xla_concat_def);
+ const int rank = sharding.replicate_on_last_tile_dim()
+ ? sharding.tile_assignment_dimensions_size() - 1
+ : sharding.tile_assignment_dimensions_size();
+ std::vector<int32> paddings;
+ paddings.reserve(rank);
+ for (int dim = 0; dim < rank; ++dim) {
+ paddings.push_back(GetPadding(dim, sharding.tile_assignment_dimensions(dim),
+ partial_tensor_shape));
+ }
+ AddNodeAttr("paddings", paddings, &xla_concat_def);
+
+ Status s;
+ Node* xla_concat = graph->AddNode(xla_concat_def, &s);
+ TF_RETURN_IF_ERROR(s);
+ for (int i = 0, e = orig_inputs.size(); i < e; ++i) {
+ const NodeOut& input = orig_inputs[i];
+ graph->AddEdge(input.node, input.index, xla_concat, i);
+ }
+ return xla_concat;
+}
+
// Set the padding ops the same devices as the original inputs. If the original
// inputs are on TPUs, the padding ops will be placed on TPUs and XLA on demand
// mode will be triggered, so we don't need to copy the data back to the host
@@ -3059,7 +3308,7 @@
replicate_output_edges[edge->src_output()] = edge;
if (num_partitioned_outputs > 1) {
return errors::InvalidArgument(
- "More than one TPUPartitionedOutput per replciated output.");
+ "More than one TPUPartitionedOutput per replicated output.");
}
}
@@ -3213,16 +3462,17 @@
int64_t orig_arg_num = core_arg_nums[core][i];
VLOG(2) << " replica " << replica << " core " << core << " i " << i
<< " orig_arg_num " << orig_arg_num;
- if (params_info.IsPerReplicaArg(orig_arg_num) ||
- params_info.IsDistributedArg(orig_arg_num)) {
+ const bool is_per_replica_arg =
+ params_info.IsPerReplicaArg(orig_arg_num);
+ if (is_per_replica_arg || params_info.IsDistributedArg(orig_arg_num)) {
// Per-replica input and distributed input
- int64_t input_num = params_info.IsPerReplicaArg(orig_arg_num)
- ? replica * params_info.NumPerReplicaArgs() +
- core_arg_nums[core][i]
- : params_info.NumReplicas() *
- params_info.NumPerReplicaArgs() +
- core_arg_nums[core][i] -
- params_info.NumPerReplicaArgs();
+ const int64_t input_num =
+ is_per_replica_arg ? replica * params_info.NumPerReplicaArgs() +
+ core_arg_nums[core][i]
+ : params_info.NumReplicas() *
+ params_info.NumPerReplicaArgs() +
+ core_arg_nums[core][i] -
+ params_info.NumPerReplicaArgs();
const Edge* edge = replicate_input_edges[input_num];
VLOG(2) << "replicate_input_edges[" << input_num << "]";
@@ -3265,14 +3515,32 @@
}
const xla::OpSharding& sharding = arg_shardings[orig_arg_num];
- // Create or get the Split node.
- TF_ASSIGN_OR_RETURN(
- ShardedInputInfo sharded_input_info,
- CreateOrGetSplitNodesForInputSharding(
- sharding, orig_arg_num, dtype,
- arg_shapes[orig_arg_num].handle_shape, replica,
- edge->src_output(), edge->src(), control_predecessor,
- graph, &input_index_to_sharded_inputs));
+ ShardedInputInfo sharded_input_info;
+ if (use_nd_sharding_ops_ && is_per_replica_arg) {
+ TF_ASSIGN_OR_RETURN(
+ sharded_input_info,
+ CreateOrGetXlaSplitNodeForShardedPerReplicaArg(
+ sharding, replica, orig_arg_num, dtype,
+ PartialTensorShape(), edge->src(), edge->src_output(),
+ graph, &input_index_to_sharded_inputs));
+ } else if (use_nd_sharding_ops_) {
+ TF_ASSIGN_OR_RETURN(
+ sharded_input_info,
+ CreateOrGetXlaSplitNodeForDistributedArg(
+ sharding, params_info.NumReplicas(), replica,
+ orig_arg_num, dtype, PartialTensorShape(), edge->src(),
+ edge->src_output(), graph,
+ &input_index_to_sharded_inputs));
+ } else {
+ TF_ASSIGN_OR_RETURN(
+ sharded_input_info,
+ CreateOrGetSplitNodesForInputSharding(
+ sharding, orig_arg_num, dtype, PartialTensorShape(),
+ replica, edge->src_output(), edge->src(),
+ control_predecessor, graph,
+ &input_index_to_sharded_inputs));
+ }
+
NodeOut split_node_and_index =
sharded_input_info.sharded_inputs.at(core);
// Connect with Split node output.
@@ -3348,16 +3616,28 @@
graph));
if (arg_shardings[orig_arg_num].type() == xla::OpSharding::OTHER) {
- const xla::OpSharding& sharding = arg_shardings[orig_arg_num];
- // Create or get the Split node.
- TF_ASSIGN_OR_RETURN(
- ShardedInputInfo sharded_input_info,
- CreateOrGetSplitNodesForInputSharding(
- sharding, orig_arg_num,
- arg_shapes[orig_arg_num].handle_type,
- arg_shapes[orig_arg_num].handle_shape, replica,
- var_data.index, var_data.node, control_predecessor, graph,
- &input_index_to_sharded_inputs));
+ ShardedInputInfo sharded_input_info;
+ if (use_nd_sharding_ops_) {
+ TF_ASSIGN_OR_RETURN(
+ sharded_input_info,
+ CreateOrGetXlaSplitNodeForVariableArg(
+ arg_shardings[orig_arg_num], params_info.NumReplicas(),
+ replica, orig_arg_num,
+ arg_shapes[orig_arg_num].handle_type,
+ arg_shapes[orig_arg_num].handle_shape, var_data.node,
+ var_data.index, graph, &to_be_removed_nodes,
+ &input_index_to_sharded_inputs));
+ } else {
+ TF_ASSIGN_OR_RETURN(
+ sharded_input_info,
+ CreateOrGetSplitNodesForInputSharding(
+ arg_shardings[orig_arg_num], orig_arg_num,
+ arg_shapes[orig_arg_num].handle_type,
+ arg_shapes[orig_arg_num].handle_shape, replica,
+ var_data.index, var_data.node, control_predecessor,
+ graph, &input_index_to_sharded_inputs));
+ }
+
NodeOut split_node_and_index =
sharded_input_info.sharded_inputs[core];
// Connect with Split node output.
@@ -3440,11 +3720,20 @@
core_retval_nums[core_id][core_retval_index])});
}
DataType dtype = e->src()->output_type(e->src_output());
- TF_ASSIGN_OR_RETURN(
- Node * concat_node,
- CreateConcatNodesForRetval(
- sharding, dtype, /*inferred_shape*/ PartialTensorShape(),
- replica, orig_inputs, graph, /*device=*/""));
+ Node* concat_node = nullptr;
+ if (use_nd_sharding_ops_) {
+ TF_ASSIGN_OR_RETURN(
+ concat_node, CreateXlaConcatNode(
+ sharding, replica, dtype,
+ /*partial_tensor_shape=*/PartialTensorShape(),
+ orig_inputs, /*device=*/"", graph));
+ } else {
+ TF_ASSIGN_OR_RETURN(
+ concat_node,
+ CreateConcatNodesForRetval(
+ sharding, dtype, /*inferred_shape=*/PartialTensorShape(),
+ replica, orig_inputs, graph, /*device=*/""));
+ }
const Edge* edge = replicate_output_edges[output_num];
Node* dst = edge->dst();
@@ -3522,12 +3811,22 @@
// be collocated with the variable.
absl::string_view device =
variable_reads[core_variable_writes[i]]->assigned_device_name();
- TF_ASSIGN_OR_RETURN(
- Node * concat_node,
- CreateConcatNodesForRetval(
- sharding, arg_shapes[orig_arg_num].handle_type,
- arg_shapes[orig_arg_num].handle_shape, replica, orig_inputs,
- graph, device));
+ Node* concat_node = nullptr;
+ if (use_nd_sharding_ops_) {
+ TF_ASSIGN_OR_RETURN(
+ concat_node,
+ CreateXlaConcatNode(sharding, replica,
+ arg_shapes[orig_arg_num].handle_type,
+ arg_shapes[orig_arg_num].handle_shape,
+ orig_inputs, device, graph));
+ } else {
+ TF_ASSIGN_OR_RETURN(
+ concat_node,
+ CreateConcatNodesForRetval(
+ sharding, arg_shapes[orig_arg_num].handle_type,
+ arg_shapes[orig_arg_num].handle_shape, replica,
+ orig_inputs, graph, device));
+ }
// Populate VariableWrite.
VariableWrite& write = variable_writes->at(core_variable_writes[i]);
write.value = concat_node;
@@ -3558,7 +3857,7 @@
graph->RemoveNode(node);
}
return Status::OK();
-}
+} // NOLINT(readability/fn_size)
/* static */ Status DistributedTPURewritePass::CopyOutsideCompilationNodes(
int replica_index, const std::vector<Node*>& outside_compilation_nodes,
@@ -4375,7 +4674,7 @@
// by the user. They will be mapped to physical TPU cores below.
int num_replicas;
int num_cores_per_replica;
- int num_tasks; // Number of tasks.
+ int num_tasks;
std::vector<std::vector<string>> tf_device_assignment;
std::vector<int> devices_to_lock;
std::unique_ptr<xla::DeviceAssignment> xla_device_assignment;
@@ -4708,13 +5007,14 @@
bool DistributedTPURewritePass::enable_automatic_model_parallelism_ = false;
bool DistributedTPURewritePass::enable_xla_param_broadcast_ = false;
bool DistributedTPURewritePass::enable_multicore_locking_ = false;
+bool DistributedTPURewritePass::use_nd_sharding_ops_ = false;
/*static*/ void DistributedTPURewritePass::SetDistributedTpuRewritePassOptions(
bool distribute_vars, bool allow_xla_spmd_partition,
bool replicate_inputs_outputs_by_default_for_xla_spmd,
bool enable_cross_replica_sharding_mirrored_variables,
bool enable_automatic_model_parallelism, bool enable_xla_param_broadcast,
- bool enable_multicore_locking) {
+ bool enable_multicore_locking, bool use_nd_sharding_ops) {
distribute_vars_ = distribute_vars;
allow_xla_spmd_partition_ = allow_xla_spmd_partition;
replicate_inputs_outputs_by_default_for_xla_spmd_ =
@@ -4724,6 +5024,7 @@
enable_automatic_model_parallelism_ = enable_automatic_model_parallelism;
enable_xla_param_broadcast_ = enable_xla_param_broadcast;
enable_multicore_locking_ = enable_multicore_locking;
+ use_nd_sharding_ops_ = use_nd_sharding_ops;
}
} // namespace tensorflow
diff --git a/tensorflow/core/tpu/graph_rewrite/distributed_tpu_rewrite_pass.h b/tensorflow/core/tpu/graph_rewrite/distributed_tpu_rewrite_pass.h
index 470dd92..bdb1df3 100644
--- a/tensorflow/core/tpu/graph_rewrite/distributed_tpu_rewrite_pass.h
+++ b/tensorflow/core/tpu/graph_rewrite/distributed_tpu_rewrite_pass.h
@@ -111,6 +111,7 @@
#include <string>
#include <vector>
+#include "absl/container/flat_hash_map.h"
#include "absl/container/node_hash_map.h"
#include "absl/types/span.h"
#include "tensorflow/compiler/jit/shape_inference.h"
@@ -133,7 +134,7 @@
bool replicate_inputs_outputs_by_default_for_xla_spmd,
bool enable_cross_replica_sharding_mirrored_variables,
bool enable_automatic_model_parallelism, bool enable_xla_param_broadcast,
- bool enable_multicore_locking);
+ bool enable_multicore_locking, bool use_nd_sharding_ops);
Status Run(const GraphOptimizationPassOptions& options) override;
@@ -598,6 +599,7 @@
static bool enable_automatic_model_parallelism_;
static bool enable_xla_param_broadcast_;
static bool enable_multicore_locking_;
+ static bool use_nd_sharding_ops_;
};
} // namespace tensorflow