blob: 81b51d3132319868139226e2c029eed252ce4cb7 [file] [log] [blame]
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
// Compilation for distributed TPU (TPU_REPLICATED_CORE devices).
#include "tensorflow/core/tpu/graph_rewrite/distributed_tpu_rewrite_pass.h"
#include <algorithm>
#include <queue>
#include <utility>
#include <vector>
#include "absl/algorithm/container.h"
#include "absl/container/flat_hash_map.h"
#include "absl/strings/escaping.h"
#include "tensorflow/compiler/jit/encapsulate_util.h"
#include "tensorflow/compiler/tf2xla/resource_operation_table.h"
#include "tensorflow/compiler/tf2xla/sharding_util.h"
#include "tensorflow/compiler/tf2xla/side_effect_util.h"
#include "tensorflow/compiler/tf2xla/tf2xla_util.h"
#include "tensorflow/compiler/xla/array3d.h"
#include "tensorflow/compiler/xla/array4d.h"
#include "tensorflow/compiler/xla/client/sharding_builder.h"
#include "tensorflow/compiler/xla/service/computation_placer.h"
#include "tensorflow/compiler/xla/xla.pb.h"
#include "tensorflow/core/common_runtime/function.h"
#include "tensorflow/core/common_runtime/graph_constructor.h"
#include "tensorflow/core/common_runtime/lower_function_call_op.h"
#include "tensorflow/core/common_runtime/lower_functional_ops.h"
#include "tensorflow/core/common_runtime/lower_if_op.h"
#include "tensorflow/core/common_runtime/lower_while_op.h"
#include "tensorflow/core/common_runtime/optimization_registry.h"
#include "tensorflow/core/framework/function.h"
#include "tensorflow/core/framework/graph_to_functiondef.h"
#include "tensorflow/core/framework/node_def_builder.h"
#include "tensorflow/core/framework/node_def_util.h"
#include "tensorflow/core/framework/partial_tensor_shape.h"
#include "tensorflow/core/framework/tensor.pb.h"
#include "tensorflow/core/framework/types.pb.h"
#include "tensorflow/core/framework/versions.pb.h"
#include "tensorflow/core/graph/algorithm.h"
#include "tensorflow/core/graph/graph.h"
#include "tensorflow/core/lib/core/errors.h"
#include "tensorflow/core/lib/core/status.h"
#include "tensorflow/core/lib/gtl/cleanup.h"
#include "tensorflow/core/lib/strings/proto_serialization.h"
#include "tensorflow/core/lib/strings/str_util.h"
#include "tensorflow/core/platform/fingerprint.h"
#include "tensorflow/core/protobuf/tpu/compile_metadata.pb.h"
#include "tensorflow/core/protobuf/tpu/dynamic_padding.pb.h"
#include "tensorflow/core/protobuf/tpu/topology.pb.h"
#include "tensorflow/core/public/session_options.h"
#include "tensorflow/core/tpu/graph_rewrite/cond_builder.h"
#include "tensorflow/core/tpu/graph_rewrite/distributed_tpu_rewrite_helpers.h"
#include "tensorflow/core/tpu/graph_rewrite/distributed_tpu_rewrite_pass_internal.h"
#include "tensorflow/core/tpu/graph_rewrite/host_training_loop_optimization_util.h"
#include "tensorflow/core/tpu/graph_rewrite/incomplete_nodedef_builder.h"
#include "tensorflow/core/tpu/tpu_compile_interface.h"
#include "tensorflow/core/tpu/tpu_defs.h"
#include "tensorflow/core/tpu/tpu_fingerprint_utils.h"
#include "tensorflow/core/tpu/tpu_ops_c_api.h"
#include "tensorflow/core/util/device_name_utils.h"
#include "tensorflow/core/util/dump_graph.h"
#include "tensorflow/stream_executor/tpu/tpu_platform_interface.h"
namespace tensorflow {
namespace {
// Device coordinates are defined as (x, y, z, core), thus resulting in a rank 4
// topology.
constexpr int kTPUTopologyRank = 4;
// An upper bound on how many cores may be present in the topology.
static constexpr int kTPUMaxTopologySize = 4096;
// Attribute containing the serialized xla::OpSharding to be passed to the
// corresponding XLA HLO operation, which represents how a shape is distributed
// across logical cores, e.g., replication, single-device, or partitioning.
const char kShardingAttribute[] = "_XlaSharding";
const char kTPUPartitionedInput[] = "TPUPartitionedInput";
const char kTPUPartitionedOutput[] = "TPUPartitionedOutput";
const char kVarHandleOp[] = "VarHandleOp";
static const char* const kTPUCompilationResultAttr = "_tpu_compilation_status";
static const char* const kPostDeviceRewriteAttr = "_post_device_rewrite";
using NodeAndId = std::pair<const Node*, int>;
struct NodeAndPort {
explicit NodeAndPort(Node* node, int port) : node(node), port(port) {}
Node* node;
// Port of the node, e.g. this can be the `src_output` index of an Edge.
int port;
};
class IntrusiveHeapLink {
public:
using size_type = size_t;
static constexpr size_type kNotMember = -1;
IntrusiveHeapLink() = default;
// Only IntrusiveHeap and LinkAccess objects should make these objects.
explicit IntrusiveHeapLink(size_type pos) : pos_{pos} {}
// Only IntrusiveHeap and LinkAccess should get the value.
size_type get() const { return pos_; }
private:
size_type pos_{kNotMember};
};
template <typename T, IntrusiveHeapLink T::*M>
struct IntrusiveHeapDataMemberLinkAccess {
IntrusiveHeapLink Get(const T* elem) const { return elem->*M; }
void Set(T* elem, IntrusiveHeapLink link) const { elem->*M = link; }
};
template <typename T>
struct DefaultIntrusiveHeapLinkAccess {
IntrusiveHeapLink Get(const T* elem) const { return elem->heap; }
void Set(T* elem, IntrusiveHeapLink link) const { elem->heap = link; }
};
template <typename T, typename PtrCompare,
typename LinkAccess = DefaultIntrusiveHeapLinkAccess<T>,
typename Alloc = std::allocator<T*>>
class IntrusiveHeap {
public:
typedef typename IntrusiveHeapLink::size_type size_type;
typedef T value_type;
typedef T* pointer;
typedef const T* const_pointer;
typedef PtrCompare pointer_compare_type;
typedef LinkAccess link_access_type;
typedef Alloc allocator_type;
explicit IntrusiveHeap(
const pointer_compare_type& comp = pointer_compare_type(),
const link_access_type& link_access = link_access_type(),
const allocator_type& alloc = allocator_type())
: rep_(comp, link_access, alloc) {}
size_type size() const { return heap().size(); }
bool empty() const { return heap().empty(); }
// Return the top element, but don't remove it.
pointer top() const {
DCHECK(!empty());
return heap()[0];
}
// Remove the top() pointer from the heap and return it.
pointer Pop() {
pointer t = top();
Remove(t);
return t;
}
// Insert 't' into the heap.
void Push(pointer t) {
SetPositionOf(t, heap().size());
heap().push_back(t);
FixHeapUp(t);
}
// Adjust the heap to accommodate changes in '*t'.
void Adjust(pointer t) {
DCHECK(Contains(t));
size_type h = GetPositionOf(t);
if (h != 0 && compare()(t, heap()[(h - 1) >> 1])) {
FixHeapUp(t);
} else {
FixHeapDown(t);
}
}
// Remove the specified pointer from the heap.
void Remove(pointer t) {
DCHECK(Contains(t));
size_type h = GetPositionOf(t);
SetPositionOf(t, IntrusiveHeapLink::kNotMember);
if (h == heap().size() - 1) {
// Fast path for removing from back of heap.
heap().pop_back();
return;
}
// Move the element from the back of the heap to overwrite 't'.
pointer& elem = heap()[h];
elem = heap().back();
SetPositionOf(elem, h); // Element has moved, so update its link.
heap().pop_back();
Adjust(elem); // Restore the heap invariant.
}
void Clear() { heap().clear(); }
bool Contains(const_pointer t) const {
size_type h = GetPositionOf(t);
return (h != IntrusiveHeapLink::kNotMember) && (h < size()) &&
heap()[h] == t;
}
void reserve(size_type n) { heap().reserve(n); }
size_type capacity() const { return heap().capacity(); }
allocator_type get_allocator() const { return rep_.heap_.get_allocator(); }
private:
typedef std::vector<pointer, allocator_type> heap_type;
// Empty base class optimization for pointer_compare and link_access.
// The heap_ data member retains a copy of the allocator, so it is not
// stored explicitly.
struct Rep : pointer_compare_type, link_access_type {
explicit Rep(const pointer_compare_type& cmp,
const link_access_type& link_access,
const allocator_type& alloc)
: pointer_compare_type(cmp),
link_access_type(link_access),
heap_(alloc) {}
heap_type heap_; // NOLINT
};
const pointer_compare_type& compare() const { return rep_; }
const link_access_type& link_access() const { return rep_; }
const heap_type& heap() const { return rep_.heap_; }
heap_type& heap() { return rep_.heap_; }
size_type GetPositionOf(const_pointer t) const {
return link_access().Get(t).get();
}
void SetPositionOf(pointer t, size_type pos) const {
return link_access().Set(t, IntrusiveHeapLink(pos));
}
void FixHeapUp(pointer t) {
size_type h = GetPositionOf(t);
while (h != 0) {
size_type parent = (h - 1) >> 1;
if (compare()(heap()[parent], t)) {
break;
}
heap()[h] = heap()[parent];
SetPositionOf(heap()[h], h);
h = parent;
}
heap()[h] = t;
SetPositionOf(t, h);
}
void FixHeapDown(pointer t) {
size_type h = GetPositionOf(t);
for (;;) {
size_type kid = (h << 1) + 1;
if (kid >= heap().size()) {
break;
}
if (kid + 1 < heap().size() && compare()(heap()[kid + 1], heap()[kid])) {
++kid;
}
if (compare()(t, heap()[kid])) {
break;
}
heap()[h] = heap()[kid];
SetPositionOf(heap()[h], h);
h = kid;
}
heap()[h] = t;
SetPositionOf(t, h);
}
Rep rep_;
};
string CoreDeviceLabel(int core) {
return strings::StrCat("/device:", DEVICE_TPU_REPLICATED_CORE, ":", core);
}
// Creates a unique node name with a particular prefix.
string UniqueNodeName(const StringPiece prefix, Graph* graph) {
return graph->NewName(strings::StrCat(prefix, "/_", internal::GetNodeId()));
}
Status SetNodeDeviceForTPUCommunication(DeviceNameUtils::ParsedName device,
const string& target_device_type,
Node* node) {
TF_RET_CHECK(device.has_type && device.type == DEVICE_TPU_NODE);
TF_RET_CHECK(device.has_id);
TF_RET_CHECK(HasNodeAttr(node->def(), kXlaHasHostTransferAttrName));
// Store the device instance as an attr on the Node.
TF_RETURN_IF_ERROR(SetDeviceOrdinalAttributeForNode(node, device.id));
// Place the execute Op on the TPU_SYSTEM device so it can access the cache of
// compiled protos in the resource manager.
device.type = target_device_type;
device.id = 0;
node->set_assigned_device_name(DeviceNameUtils::ParsedNameToString(device));
return Status::OK();
}
// Iterate over the nodes in the original graph and find all the TPUReplicate
// nodes, and all the nodes that are part of outside_compilation clusters.
Status FindTaggedNodes(
Graph* graph, std::vector<Node*>* replicate_nodes,
std::map<string, DistributedTPURewritePass::OutsideCompilationNodeMap>*
outside_compilation_nodes,
std::map<string, std::vector<Node*>>* head_tail_outside_compilation_nodes) {
for (Node* node : graph->op_nodes()) {
if (node->type_string() == "_TPUReplicate") {
replicate_nodes->push_back(node);
const AttrValue* cluster_attr = node->attrs().Find(kTPUReplicateAttr);
if (cluster_attr == nullptr) {
return errors::Internal("TPUReplicate node ", node->name(), " has no ",
kTPUReplicateAttr, " attr.");
} else {
const string& cluster = cluster_attr->s();
if (cluster.empty()) {
return errors::Internal("Attr ", kTPUReplicateAttr, " on node ",
node->name(), " has no string value.");
}
if (outside_compilation_nodes->find(cluster) !=
outside_compilation_nodes->end()) {
return errors::Internal(
"TPUReplicate node ", node->name(), " has ", kTPUReplicateAttr,
" attr value '", cluster,
"' which is a duplicate of another TPUReplicate node in the "
"graph.");
}
(*outside_compilation_nodes)[cluster] =
DistributedTPURewritePass::OutsideCompilationNodeMap();
(*head_tail_outside_compilation_nodes)[cluster] = std::vector<Node*>();
}
}
}
for (Node* node : graph->op_nodes()) {
if (node->type_string() != "_TPUReplicate") {
const AttrValue* cluster_attr = node->attrs().Find(kTPUReplicateAttr);
const AttrValue* outside_compilation_attr =
node->attrs().Find(kOutsideCompilationAttr);
if (cluster_attr == nullptr) {
if (outside_compilation_attr != nullptr) {
return errors::Internal("Node ", node->name(), " has ",
kOutsideCompilationAttr, " attr but no ",
kTPUReplicateAttr, " attr.");
}
} else {
const string& cluster = cluster_attr->s();
if (cluster.empty()) {
return errors::Internal("Attr ", kTPUReplicateAttr, " on node ",
node->name(), " has no string value.");
}
const auto iter = outside_compilation_nodes->find(cluster);
if (iter == outside_compilation_nodes->end()) {
return errors::Internal(
"Attr ", kTPUReplicateAttr, " on node ", node->name(),
" does not correspond to a TPUReplicate node.");
}
if (outside_compilation_attr == nullptr) {
return errors::Internal("Node ", node->name(), " has ",
kTPUReplicateAttr, " attr but no ",
kOutsideCompilationAttr, " attr.");
}
const string& oc_cluster = outside_compilation_attr->s();
if (oc_cluster.empty()) {
return errors::Internal("Attr ", kOutsideCompilationAttr, " on node ",
node->name(), " has no string value.");
}
// Outside compilation cluster at head and tail of TPU computation has
// already been moved to host and is already replicated. As so, do not
// replicate outside compilation nodes with replica id attribute.
int replica_id;
if (TryGetNodeAttr(node->def(), kXlaReplicaIdAttrName, &replica_id)) {
const AttrValue* head_attr =
node->attrs().Find("_xla_only_arg_or_oc_input");
const AttrValue* tail_attr =
node->attrs().Find("_xla_only_ret_or_oc_output");
if (((head_attr != nullptr) && (head_attr->b())) ||
((tail_attr != nullptr) && (tail_attr->b()))) {
// This is safe as this has the same keys as
// outside_compilation_nodes which we already know has this key.
(*head_tail_outside_compilation_nodes)[cluster].push_back(node);
}
continue;
}
iter->second[oc_cluster].push_back(node);
}
}
}
return Status::OK();
}
// Helper class to spread TPU computation arguments and return values
// across cores.
// If all shapes are fully defined, balance by their size.
// If some of them are not fully defined, the undefined shapes size will
// be estimated with the average size of the fully defined ones.
// If none are defined, fall back to round-robin.
class TensorDevicePlacer {
public:
// Creates a TensorDevicePlacer object to distribute arguments or
// return values to a set of num_devices devices, where the types and
// the inferred shapes of the inputs (arguments or return values) are
// passed in types and shapes.
TensorDevicePlacer(int64_t num_devices, const DataTypeVector& types,
const std::vector<InferredShape>& shapes)
: index_nodes_(num_devices), sizes_(types.size()) {
int64_t total_size = 0;
int64_t num_defined = 0;
for (int64_t i = 0; i < types.size(); ++i) {
sizes_[i] = GetInferredShapeSize(shapes[i], types[i]);
if (sizes_[i] >= 0) {
total_size += sizes_[i];
++num_defined;
}
}
// If a shape is undefined, select a size for it which is the average
// of the defined shapes. If no shapes are defined, assign 1 so that we
// get round-robin behavior.
int64_t undefined_shape_size =
(num_defined > 0) ? total_size / num_defined : 1;
for (int64_t i = 0; i < sizes_.size(); ++i) {
if (sizes_[i] < 0) {
sizes_[i] = undefined_shape_size;
}
}
for (int64_t i = 0; i < num_devices; ++i) {
heap_.Push(&index_nodes_[i]);
}
}
// Reports that the argument/return-value at index has been assigned
// by the user to a given device.
void ReportDeviceAssigned(int64_t device, int64_t index) {
if (device >= index_nodes_.size()) {
LOG(DFATAL) << "Sharding assignment is out of bounds. Check that the "
"number of nodes is properly set.";
}
DeviceNode* node = &index_nodes_.at(device);
node->size += sizes_.at(index);
heap_.Adjust(node);
}
// Retrieves the device at which the argument/return-value at index
// should be assigned to.
int64 RetrieveAssignment(int64_t index) {
DeviceNode* node = heap_.top();
int64_t device = node - index_nodes_.data();
node->size += sizes_.at(index);
heap_.Adjust(node);
return device;
}
private:
struct DeviceNode {
struct Compare {
// Compare functor to implement a min heap using the ::gtl::IntrusiveHeap
// infrastructure.
bool operator()(const DeviceNode* lhs, const DeviceNode* rhs) const {
return lhs->size < rhs->size;
}
};
IntrusiveHeapLink heap;
int64 size = 0;
};
static int64 GetInferredShapeSize(const InferredShape& ishape,
DataType dtype) {
return ishape.shape.IsFullyDefined()
? ishape.shape.num_elements() * DataTypeSize(dtype)
: -1;
}
std::vector<DeviceNode> index_nodes_;
IntrusiveHeap<DeviceNode, typename DeviceNode::Compare> heap_;
std::vector<int64> sizes_;
};
Status ValidateCoreNumber(int64_t core, int64_t num_cores_per_replica) {
if (core < 0 || core >= num_cores_per_replica) {
return tensorflow::errors::InvalidArgument("Invalid core ID: ", core,
". The valid core IDs are [0..",
num_cores_per_replica, ")");
}
return Status::OK();
}
Status FindHostComputeKeyPlaceholderNodes(
const Graph* graph, const std::vector<Node*>& replicate_nodes,
std::unordered_map<string, Node*>* host_compute_key_placeholder_map) {
host_compute_key_placeholder_map->clear();
for (const auto node : replicate_nodes) {
(*host_compute_key_placeholder_map)[node->name()] = nullptr;
}
for (Node* node : graph->op_nodes()) {
if (node->type_string() == "Placeholder" &&
str_util::EndsWith(node->name(), "_key_placeholder")) {
const AttrValue* call_node_attr =
node->attrs().Find("_host_compute_call_node");
if (call_node_attr != nullptr) {
auto iter = host_compute_key_placeholder_map->find(call_node_attr->s());
if (iter == host_compute_key_placeholder_map->end()) {
return errors::InvalidArgument(
"Node ", node->name(), " has _host_compute_call_node attribute '",
call_node_attr->s(), "' that doesn't correspond to a call node");
}
if (iter->second != nullptr) {
return errors::InvalidArgument(
"Key placeholder node ", iter->second->name(), " for call node ",
call_node_attr->s(), " previously found as ",
iter->second->name());
}
iter->second = node;
}
}
}
return Status::OK();
}
Status ReplaceCompilationResultNodeWithIdentity(Graph* graph, Node** node) {
Node* old_node = *node;
// We want to replace the node with an identity node with the same name.
const string& node_name = old_node->name();
// Create identity node.
TF_ASSIGN_OR_RETURN(
Node * id_node,
BuildIdentityNode(graph, node_name, DT_STRING,
/*input=*/nullptr, /*requested_device=*/""));
// No incoming edges are copied as a new one will be added from compile node
// to id_node.
// Copy outgoing edges to the id node.
std::vector<const Edge*> out_edges(old_node->out_edges().begin(),
old_node->out_edges().end());
for (const Edge* edge : out_edges) {
Node* dst = edge->dst();
int src_output = edge->src_output();
int dst_input = edge->dst_input();
if (src_output == Graph::kControlSlot) {
graph->AddControlEdge(id_node, dst);
} else {
graph->AddEdge(id_node, src_output, dst, dst_input);
}
graph->RemoveEdge(edge);
}
graph->RemoveNode(old_node);
*node = id_node;
return Status::OK();
}
Status GetStepMarkerLocation(const Node& replicate_node,
xla::DebugOptions::StepMarkerLocation* location) {
string step_marker_location_attr;
TF_RETURN_IF_ERROR(GetNodeAttr(replicate_node.attrs(), "step_marker_location",
&step_marker_location_attr));
if (step_marker_location_attr.empty()) {
*location = xla::DebugOptions::STEP_MARK_AT_ENTRY;
} else {
if (!xla::DebugOptions::StepMarkerLocation_Parse(step_marker_location_attr,
location)) {
return errors::InvalidArgument("Malformed step_marker_location: ",
step_marker_location_attr);
}
}
return Status::OK();
}
// Extracts a map of dimension and number of splits for tiled input from xla
// sharding attribute.
Status GetDimensionIndicesAndNumSplitsFromSharding(
const xla::OpSharding& sharding, std::map<int, int>* split_dimension_map) {
int64_t tensor_tile_rank = sharding.tile_assignment_dimensions_size();
if (sharding.replicate_on_last_tile_dim()) {
tensor_tile_rank--;
}
for (int dim_index = 0; dim_index < tensor_tile_rank; dim_index++) {
if (sharding.tile_assignment_dimensions(dim_index) > 1) {
split_dimension_map->emplace(
dim_index, sharding.tile_assignment_dimensions(dim_index));
}
}
if (split_dimension_map->empty()) {
return errors::InvalidArgument("Arg has unnecessary tiled sharding: ",
sharding.DebugString());
}
return Status::OK();
}
// Updates contents of the function with `function_name` in function library
// definition `flib_def` to `new_graph`. This is required when graph
// transformation happens inside a function call body.
Status UpdateFunctionLibDefinition(const Graph& new_graph,
const std::string& function_name,
FunctionLibraryDefinition* flib_def) {
FunctionDef graph_fdef;
TF_RETURN_IF_ERROR(GraphToFunctionDef(new_graph, function_name, &graph_fdef));
TF_RETURN_IF_ERROR(flib_def->ReplaceFunction(function_name, graph_fdef));
return Status::OK();
}
struct NodeOut {
Node* node;
int index;
};
struct ShardedInputIndex {
int replica_id;
int argument_index;
bool operator<(const ShardedInputIndex& rhs) const {
return std::tie(replica_id, argument_index) <
std::tie(rhs.replica_id, rhs.argument_index);
}
};
struct ShardedInputInfo {
// Split node that would be connected to tiled input Node.
Node* split_node;
// List of splits nodes and output index of the split node from which sharded
// input will be connected to the TPUExecute node. The inputs are ordered by
// logical core ids.
std::vector<NodeOut> sharded_inputs;
};
// Adds pad node after split node to graph for uneven sharding tiled inputs.
// |graph| owns the returned Node* instance.
xla::StatusOr<Node*> CreatePadNode(const int padding, const int num_dims,
const int split_dim, DataType dtype,
Node* control_predecessor, Node* split_node,
const int split_index, Graph* graph) {
// Add paddings node.
Status s;
NodeDef paddings_def;
paddings_def.set_name(
graph->NewName(absl::StrCat(split_node->name(), "/paddings")));
paddings_def.set_op("Const");
AddNodeAttr("dtype", DT_INT32, &paddings_def);
paddings_def.set_device(split_node->assigned_device_name());
TensorProto sizes_tensor_proto;
sizes_tensor_proto.set_dtype(DT_INT32);
for (int i = 0; i < num_dims; ++i) {
sizes_tensor_proto.add_int_val(0);
if (i == split_dim) {
sizes_tensor_proto.add_int_val(padding);
} else {
sizes_tensor_proto.add_int_val(0);
}
}
TensorShape sizes_shape({num_dims, 2});
sizes_shape.AsProto(sizes_tensor_proto.mutable_tensor_shape());
AddNodeAttr("value", sizes_tensor_proto, &paddings_def);
Node* paddings_node = graph->AddNode(paddings_def, &s);
TF_RETURN_IF_ERROR(s);
// Add Pad node.
NodeDef pad_def;
pad_def.set_name(graph->NewName(
absl::StrCat(split_node->name(), "/pad_shard_", split_index)));
pad_def.set_op("Pad");
pad_def.set_device(split_node->assigned_device_name());
AddNodeAttr("T", dtype, &pad_def);
AddNodeAttr("Tpaddings", DT_INT32, &pad_def);
pad_def.add_input(absl::StrCat(split_node->name(), ":", split_index));
pad_def.add_input(absl::StrCat(paddings_node->name(), ":0"));
Node* pad_node = graph->AddNode(pad_def, &s);
pad_node->set_assigned_device_name(split_node->assigned_device_name());
TF_RETURN_IF_ERROR(s);
// Add edges for pad node.
graph->AddEdge(split_node, split_index, pad_node, 0);
graph->AddEdge(paddings_node, 0, pad_node, 1);
graph->AddControlEdge(control_predecessor, pad_node);
return pad_node;
}
// Adds split node and split dimension node to graph for sharding tiled inputs.
// |graph| owns the returned Node* instance.
xla::StatusOr<Node*> CreateSplitNode(const int num_splits, const int dim,
const int num_dims, const int64 padding,
const int orig_src_output, DataType dtype,
absl::string_view name_prefix,
Node* control_predecessor, Node* orig_src,
Graph* graph) {
const std::string input_assigned_device = orig_src->assigned_device_name();
Node* to_split_node = orig_src;
int to_split_index = orig_src_output;
if (padding > 0) {
TF_ASSIGN_OR_RETURN(
Node * pad_node,
CreatePadNode(padding, num_dims, dim, dtype, control_predecessor,
orig_src, orig_src_output, graph));
to_split_node = pad_node;
to_split_index = 0;
}
// Add a split dimension node.
NodeDef split_dim_def;
split_dim_def.set_name(
graph->NewName(absl::StrCat(name_prefix, "/split_dim")));
split_dim_def.set_op("Const");
split_dim_def.set_device(input_assigned_device);
AddNodeAttr("dtype", DT_INT32, &split_dim_def);
TensorProto tensor_proto;
tensor_proto.set_dtype(DT_INT32);
tensor_proto.add_int_val(dim);
TensorShape shape({});
shape.AsProto(tensor_proto.mutable_tensor_shape());
AddNodeAttr("value", tensor_proto, &split_dim_def);
Status s;
Node* split_dim_node = graph->AddNode(split_dim_def, &s);
TF_RETURN_IF_ERROR(s);
// Add a split node.
NodeDef split_def;
split_def.set_name(graph->NewName(absl::StrCat(name_prefix, "/split")));
split_def.set_op("Split");
split_def.set_device(input_assigned_device);
AddNodeAttr("num_split", num_splits, &split_def);
AddNodeAttr("T", dtype, &split_def);
split_def.add_input(absl::StrCat(split_dim_node->name(), ":0"));
split_def.add_input(absl::StrCat(to_split_node->name(), ":", to_split_index));
Node* split_node = graph->AddNode(split_def, &s);
TF_RETURN_IF_ERROR(s);
split_node->set_assigned_device_name(input_assigned_device);
// If colocate the newly created split op to source node of input to TPU
// computation.
split_node->AddAttr(kColocationAttrName,
std::vector<string>{absl::StrCat(kColocationGroupPrefix,
orig_src->name())});
graph->AddEdge(split_dim_node, 0, split_node, 0);
graph->AddEdge(to_split_node, to_split_index, split_node, 1);
// Add a control dependency from `control_predecessor` to newly created
// constant node. This ensures that newly added split/split dim
// nodes are placed inside correct while loop frames when TPUExecute
// node is inside a host training loop.
graph->AddControlEdge(control_predecessor, split_dim_node);
return split_node;
}
int64 GetPadding(const int split_dim, const int num_splits,
const PartialTensorShape& partial_tensor_shape) {
// If dim dimension is not defined, no uneven sharding support.
if (partial_tensor_shape.dim_size(split_dim) <= 0) {
return 0;
}
int64_t per_split_size = tensorflow::MathUtil::CeilOfRatio<int64>(
partial_tensor_shape.dim_size(split_dim), num_splits);
int64_t total_padding =
per_split_size * num_splits - partial_tensor_shape.dim_size(split_dim);
return total_padding;
}
// Creates a set of splits nodes that shards tiled input node in graph.
xla::StatusOr<ShardedInputInfo> CreateOrGetSplitNodesForInputSharding(
const xla::OpSharding& sharding, int orig_arg_num, DataType dtype,
const PartialTensorShape& partial_tensor_shape, int replica_id,
int orig_src_output, Node* orig_src, Node* control_predecessor,
Graph* graph,
std::map<ShardedInputIndex, ShardedInputInfo>*
arg_index_to_sharded_input_map) {
ShardedInputIndex input_index{replica_id, orig_arg_num};
auto iter = arg_index_to_sharded_input_map->find(input_index);
if (iter != arg_index_to_sharded_input_map->end()) {
return iter->second;
}
// Maps input dimension and number of splits with which the
// dimension sharded.
std::map<int, int> split_dimension_map;
TF_RETURN_IF_ERROR(GetDimensionIndicesAndNumSplitsFromSharding(
sharding, &split_dimension_map));
TF_RET_CHECK(!split_dimension_map.empty())
<< "Unnecessary sharding attribute found.";
// For v1 while loop, nodes inside the loop body must either
// 1) Have data edges from while loop input node.
// or
// 2) Have direct control dependency from while loop input control
// node.
//
// As so, if we are adding Split node inside, while loop body,
// we must manually add a control dependency to a node inside
// a while loop (i.e. `control_predecessor`) to constant nodes
// without data in-edges to make sure that added split nodes
// have correct frame name. Else, placer will complain when
// `BuildControlFlow()` is invoked.
auto sharding_it = split_dimension_map.begin();
std::queue<Node*> split_nodes_for_dimension;
absl::flat_hash_map<Node*, int> node_to_split_dim;
int split_dimension = sharding_it->first;
int num_split = sharding_it->second;
// Creates a tree of split nodes for sharding tiled inputs. Splits nodes
// are created such that input data is sharded in row major order.
// Split nodes at ith depth from the original input node represent nodes
// that split the input data at ith dimension.
TF_ASSIGN_OR_RETURN(
Node * root_split_node,
CreateSplitNode(
num_split, split_dimension, partial_tensor_shape.dims(),
GetPadding(split_dimension, num_split, partial_tensor_shape),
orig_src_output, dtype,
absl::StrCat("sharded_input/replica_", replica_id, "_dim_",
split_dimension),
control_predecessor, orig_src, graph));
sharding_it++;
split_nodes_for_dimension.emplace(root_split_node);
node_to_split_dim[root_split_node] = split_dimension;
while (sharding_it != split_dimension_map.end()) {
split_dimension = sharding_it->first;
num_split = sharding_it->second;
int num_split_nodes_in_dimension = split_nodes_for_dimension.size();
for (int i = 0; i < num_split_nodes_in_dimension; ++i) {
Node* input_split_node = split_nodes_for_dimension.front();
split_nodes_for_dimension.pop();
for (int src_output_index = 0;
src_output_index < input_split_node->num_outputs();
++src_output_index) {
TF_ASSIGN_OR_RETURN(
Node * split_node,
CreateSplitNode(
num_split, split_dimension, partial_tensor_shape.dims(),
GetPadding(split_dimension, num_split, partial_tensor_shape),
src_output_index, dtype,
absl::StrCat("sharded_input/replica_", replica_id, "_dim_",
split_dimension),
control_predecessor, input_split_node, graph));
split_nodes_for_dimension.emplace(split_node);
node_to_split_dim[split_node] = split_dimension;
}
}
sharding_it++;
}
// `split_nodes_for_dimension` now includes final split nodes
// from which sharded data will be fed into TPUExcute nodes -- sorted by
// row major order.
std::vector<NodeOut> sharded_inputs_list(
sharding.tile_assignment_devices_size());
int64_t next_core_tile_index = 0;
while (!split_nodes_for_dimension.empty()) {
Node* split_node = split_nodes_for_dimension.front();
split_nodes_for_dimension.pop();
int num_splits;
TF_RETURN_IF_ERROR(
GetNodeAttr(split_node->def(), "num_split", &num_splits));
for (int out_index = 0; out_index < num_splits; ++out_index) {
int64_t repeat_count =
sharding.replicate_on_last_tile_dim()
? *sharding.tile_assignment_dimensions().rbegin()
: 1;
for (int64_t i = 0; i < repeat_count; ++i) {
int64_t next_core =
sharding.tile_assignment_devices(next_core_tile_index++);
sharded_inputs_list[next_core] = NodeOut{split_node, out_index};
}
}
}
ShardedInputInfo sharded_input_info{root_split_node,
std::move(sharded_inputs_list)};
(*arg_index_to_sharded_input_map)[input_index] = sharded_input_info;
return sharded_input_info;
}
// Creates a xla split node to shard an input, and adds that new node to a
// Graph.
StatusOr<Node*> CreateXlaSplitOp(absl::string_view node_name,
const bool is_resource, const NodeOut& input,
const PartialTensorShape& partial_tensor_shape,
const std::vector<Node*>& control_inputs,
const std::vector<Node*>& control_outputs,
const DataType dtype, const int num_shards,
const xla::OpSharding& sharding,
Graph* graph) {
const std::string& input_assigned_device = input.node->assigned_device_name();
NodeDef xla_split_def;
xla_split_def.set_name(graph->NewName(node_name));
xla_split_def.set_op(is_resource ? "ReadVariableXlaSplitND" : "XlaSplitND");
xla_split_def.set_device(input_assigned_device);
AddNodeAttr("T", dtype, &xla_split_def);
AddNodeAttr("N", num_shards, &xla_split_def);
const std::vector<int64> num_splits(
sharding.tile_assignment_dimensions().begin(),
sharding.replicate_on_last_tile_dim()
? std::prev(sharding.tile_assignment_dimensions().end())
: sharding.tile_assignment_dimensions().end());
AddNodeAttr("num_splits", num_splits, &xla_split_def);
const int rank = sharding.replicate_on_last_tile_dim()
? sharding.tile_assignment_dimensions_size() - 1
: sharding.tile_assignment_dimensions_size();
std::vector<int32> paddings;
paddings.reserve(rank);
for (int dim = 0; dim < rank; ++dim) {
paddings.push_back(GetPadding(dim, sharding.tile_assignment_dimensions(dim),
partial_tensor_shape));
}
AddNodeAttr("paddings", paddings, &xla_split_def);
if (!is_resource) {
AddNodeAttr("_tpu_avoid_constant_fold", "not_used", &xla_split_def);
AddNodeAttr(kColocationAttrName,
std::vector<string>{
absl::StrCat(kColocationGroupPrefix, input.node->name())},
&xla_split_def);
}
Status s;
Node* xla_split = graph->AddNode(xla_split_def, &s);
TF_RETURN_IF_ERROR(s);
if (is_resource) {
xla_split->set_requested_device(input.node->requested_device());
}
xla_split->set_assigned_device_name(input_assigned_device);
graph->AddEdge(input.node, input.index, xla_split, 0);
for (Node* control_input : control_inputs) {
graph->AddControlEdge(control_input, xla_split);
}
for (Node* control_output : control_outputs) {
graph->AddControlEdge(xla_split, control_output);
}
return xla_split;
}
// Creates a sharded tensor list for all input shards of an input with sharding.
xla::StatusOr<std::vector<NodeOut>> ShardInputWithXlaSplitOp(
absl::string_view node_name, const bool is_resource, const NodeOut& input,
const PartialTensorShape& partial_tensor_shape,
const std::vector<Node*>& control_inputs,
const std::vector<Node*>& control_outputs, const DataType dtype,
const xla::OpSharding& sharding, Graph* graph) {
const int repeat = sharding.replicate_on_last_tile_dim()
? *sharding.tile_assignment_dimensions().rbegin()
: 1;
const int num_shards = sharding.tile_assignment_devices_size() / repeat;
TF_ASSIGN_OR_RETURN(
Node * xla_split,
CreateXlaSplitOp(node_name, is_resource, input, partial_tensor_shape,
control_inputs, control_outputs, dtype, num_shards,
sharding, graph));
std::vector<NodeOut> sharded_inputs_list(
sharding.tile_assignment_devices_size());
for (int i = 0; i < num_shards; ++i) {
for (int j = 0; j < repeat; ++j) {
const int index = i * repeat + j;
const int core = sharding.tile_assignment_devices(index);
sharded_inputs_list[core] = {xla_split, i};
}
}
return sharded_inputs_list;
}
// Creates an XlaSplitND op to shard a per-replica arg.
xla::StatusOr<ShardedInputInfo> CreateOrGetXlaSplitNodeForShardedPerReplicaArg(
const xla::OpSharding& sharding, const int replica_id,
const int orig_arg_num, DataType dtype,
const PartialTensorShape& partial_tensor_shape, Node* orig_src,
const int orig_src_output, Graph* graph,
std::map<ShardedInputIndex, ShardedInputInfo>*
arg_index_to_sharded_input_map) {
ShardedInputIndex input_index{replica_id, orig_arg_num};
auto iter = arg_index_to_sharded_input_map->find(input_index);
if (iter != arg_index_to_sharded_input_map->end()) {
return iter->second;
}
TF_ASSIGN_OR_RETURN(
std::vector<NodeOut> sharded_inputs_list,
ShardInputWithXlaSplitOp(
absl::StrCat(orig_src->name(), "/replica_", replica_id, "_split"),
/*is_resource=*/false, /*input=*/{orig_src, orig_src_output},
partial_tensor_shape, /*control_inputs=*/{}, /*control_outputs=*/{},
dtype, sharding, graph));
ShardedInputInfo sharded_input_info{nullptr, std::move(sharded_inputs_list)};
(*arg_index_to_sharded_input_map)[input_index] = sharded_input_info;
return sharded_input_info;
}
// Creates an XlaSplitND op to shard a distributed arg.
xla::StatusOr<ShardedInputInfo> CreateOrGetXlaSplitNodeForDistributedArg(
const xla::OpSharding& sharding, const int num_replicas,
const int replica_id, const int orig_arg_num, DataType dtype,
const PartialTensorShape& partial_tensor_shape, Node* orig_src,
const int orig_src_output, Graph* graph,
std::map<ShardedInputIndex, ShardedInputInfo>*
arg_index_to_sharded_input_map) {
ShardedInputIndex input_index{replica_id, orig_arg_num};
auto iter = arg_index_to_sharded_input_map->find(input_index);
if (iter != arg_index_to_sharded_input_map->end()) {
return iter->second;
}
TF_ASSIGN_OR_RETURN(
std::vector<NodeOut> sharded_inputs_list,
ShardInputWithXlaSplitOp(
absl::StrCat(orig_src->name(), "/distributed_split"),
/*is_resource=*/false, /*input=*/{orig_src, orig_src_output},
partial_tensor_shape, /*control_inputs=*/{}, /*control_outputs=*/{},
dtype, sharding, graph));
ShardedInputInfo sharded_input_info{nullptr, std::move(sharded_inputs_list)};
for (int replica = 0; replica < num_replicas; ++replica) {
(*arg_index_to_sharded_input_map)[{replica, orig_arg_num}] =
sharded_input_info;
}
return sharded_input_info;
}
// Creates an ReadVariableXlaSplitND op to shard a variable arg.
xla::StatusOr<ShardedInputInfo> CreateOrGetXlaSplitNodeForVariableArg(
const xla::OpSharding& sharding, const int num_replicas,
const int replica_id, const int orig_arg_num, DataType dtype,
const PartialTensorShape& partial_tensor_shape, Node* orig_src,
const int orig_src_output, Graph* graph,
std::vector<Node*>* to_be_removed_nodes,
std::map<ShardedInputIndex, ShardedInputInfo>*
arg_index_to_sharded_input_map) {
ShardedInputIndex input_index{replica_id, orig_arg_num};
auto iter = arg_index_to_sharded_input_map->find(input_index);
if (iter != arg_index_to_sharded_input_map->end()) {
return iter->second;
}
DCHECK_EQ(orig_src->type_string(), "ReadVariableOp");
std::vector<Node*> control_outputs;
std::vector<const Edge*> edges_to_remove;
for (const Edge* edge : orig_src->out_edges()) {
if (edge->IsControlEdge()) {
control_outputs.push_back(edge->dst());
}
edges_to_remove.push_back(edge);
}
to_be_removed_nodes->push_back(orig_src);
const Edge* resource = nullptr;
TF_RETURN_IF_ERROR(orig_src->input_edge(0, &resource));
std::vector<Node*> control_inputs;
for (const Edge* edge : orig_src->in_edges()) {
if (edge->IsControlEdge()) {
control_inputs.push_back(edge->src());
}
}
TF_ASSIGN_OR_RETURN(
std::vector<NodeOut> sharded_inputs_list,
ShardInputWithXlaSplitOp(
absl::StrCat(resource->src()->name(), "/read_variable_split"),
/*is_resource=*/true,
/*input=*/{resource->src(), resource->src_output()},
partial_tensor_shape, control_inputs, control_outputs, dtype,
sharding, graph));
for (const Edge* edge : edges_to_remove) {
graph->RemoveControlEdge(edge);
}
DCHECK(orig_src->out_edges().empty());
ShardedInputInfo sharded_input_info{nullptr, std::move(sharded_inputs_list)};
for (int replica = 0; replica < num_replicas; ++replica) {
(*arg_index_to_sharded_input_map)[{replica, orig_arg_num}] =
sharded_input_info;
}
return sharded_input_info;
}
// Creates a concat node to be used for aggregating sharded retvals across
// logical cores.
xla::StatusOr<Node*> CreateConcatNode(int dim, int num_splits, DataType dtype,
absl::string_view name_prefix,
const std::vector<NodeOut>& inputs,
Graph* graph, absl::string_view device) {
// Add a Concat dim node.
NodeDef concat_dim_def;
concat_dim_def.set_name(
graph->NewName(absl::StrCat(name_prefix, "/concat_dim")));
concat_dim_def.set_op("Const");
AddNodeAttr("dtype", DT_INT32, &concat_dim_def);
concat_dim_def.set_device(std::string(device));
TensorProto tensor_proto;
tensor_proto.set_dtype(DT_INT32);
tensor_proto.add_int_val(dim);
TensorShape shape({});
shape.AsProto(tensor_proto.mutable_tensor_shape());
AddNodeAttr("value", tensor_proto, &concat_dim_def);
Status s;
Node* concat_dim_node = graph->AddNode(concat_dim_def, &s);
TF_RETURN_IF_ERROR(s);
// Add a Concat node.
NodeDef concat_def;
concat_def.set_name(graph->NewName(absl::StrCat(name_prefix, "/concat")));
concat_def.set_op("Concat");
AddNodeAttr("N", num_splits, &concat_def);
AddNodeAttr("T", dtype, &concat_def);
concat_def.add_input(absl::StrCat(concat_dim_node->name(), ":0"));
concat_def.set_device(std::string(device));
for (const auto& i : inputs) {
concat_def.add_input(absl::StrCat(i.node->name(), ":", i.index));
}
Node* concat_node = graph->AddNode(concat_def, &s);
TF_RETURN_IF_ERROR(s);
graph->AddEdge(concat_dim_node, 0, concat_node, 0);
// 0th input to concat node is a concat dim node. So we start from 1st input
// and add all input edges.
int dst_input = 1;
for (const auto& i : inputs) {
graph->AddEdge(i.node, i.index, concat_node, dst_input);
++dst_input;
}
return concat_node;
}
// Adds slice node after concat node to graph for uneven sharding tiled inputs.
xla::StatusOr<Node*> CreateSliceNode(DataType dtype,
const PartialTensorShape& shape,
Node* concat_node,
const int concat_out_index, Graph* graph,
absl::string_view device) {
Status s;
// Add begin node for concat.
NodeDef begin_def;
begin_def.set_name(
graph->NewName(absl::StrCat(concat_node->name(), "/slice_begin")));
begin_def.set_op("Const");
AddNodeAttr("dtype", DT_INT32, &begin_def);
begin_def.set_device(std::string(device));
TensorProto begin_tensor_proto;
begin_tensor_proto.set_dtype(DT_INT32);
for (int i = 0; i < shape.dims(); ++i) {
begin_tensor_proto.add_int_val(0);
}
TensorShape begin_shape({shape.dims()});
begin_shape.AsProto(begin_tensor_proto.mutable_tensor_shape());
AddNodeAttr("value", begin_tensor_proto, &begin_def);
Node* begin_node = graph->AddNode(begin_def, &s);
TF_RETURN_IF_ERROR(s);
// Add size node.
NodeDef size_def;
size_def.set_name(
graph->NewName(absl::StrCat(concat_node->name(), "/slice_size")));
size_def.set_op("Const");
AddNodeAttr("dtype", DT_INT32, &size_def);
size_def.set_device(std::string(device));
TensorProto sizes_tensor_proto;
sizes_tensor_proto.set_dtype(DT_INT32);
for (int i = 0; i < shape.dims(); ++i) {
sizes_tensor_proto.add_int_val(shape.dim_size(i));
}
TensorShape sizes_shape({shape.dims()});
sizes_shape.AsProto(sizes_tensor_proto.mutable_tensor_shape());
AddNodeAttr("value", sizes_tensor_proto, &size_def);
Node* size_node = graph->AddNode(size_def, &s);
TF_RETURN_IF_ERROR(s);
// Add Slice node.
NodeDef slice_def;
slice_def.set_name(
graph->NewName(absl::StrCat(concat_node->name(), "/slice")));
slice_def.set_op("Slice");
slice_def.set_device(std::string(device));
AddNodeAttr("T", dtype, &slice_def);
AddNodeAttr("Index", DT_INT32, &slice_def);
slice_def.add_input(absl::StrCat(concat_node->name(), ":", concat_out_index));
slice_def.add_input(absl::StrCat(begin_node->name(), ":0"));
slice_def.add_input(absl::StrCat(size_node->name(), ":0"));
Node* slice_node = graph->AddNode(slice_def, &s);
TF_RETURN_IF_ERROR(s);
// Add edges for slice node.
graph->AddEdge(concat_node, concat_out_index, slice_node, 0);
graph->AddEdge(begin_node, 0, slice_node, 1);
graph->AddEdge(size_node, 0, slice_node, 2);
return slice_node;
}
// Creates a set of Concat nodes that aggregates sharded outputs from TPUExecute
// nodes into a single output. Sharded outputs are concatenated along row major
// order. That is, tiled output along 0th dimension will be concatenated last.
xla::StatusOr<Node*> CreateConcatNodesForRetval(
const xla::OpSharding& sharding, DataType dtype,
const PartialTensorShape& inferred_shape, int replica_id,
const std::vector<NodeOut>& orig_inputs, Graph* graph,
absl::string_view device) {
std::map<int, int> split_dimension_map;
TF_RETURN_IF_ERROR(GetDimensionIndicesAndNumSplitsFromSharding(
sharding, &split_dimension_map));
std::vector<NodeOut> inputs_to_sharded_retval = orig_inputs;
bool has_paddings = false;
for (auto it = split_dimension_map.rbegin(); it != split_dimension_map.rend();
it++) {
auto dim = it->first;
auto num_splits = it->second;
int num_concat_nodes = inputs_to_sharded_retval.size() / num_splits;
int input_index_to_concat_node = 0;
std::vector<NodeOut> new_concat_nodes;
for (int i = 0; i < num_concat_nodes; ++i) {
auto concat_input_it =
inputs_to_sharded_retval.begin() + input_index_to_concat_node;
std::vector<NodeOut> inputs(concat_input_it,
concat_input_it + num_splits);
input_index_to_concat_node += num_splits;
TF_ASSIGN_OR_RETURN(
Node * concat_node,
CreateConcatNode(
dim, num_splits, dtype,
absl::StrCat("sharded_output/replica_", replica_id, "_dim_", dim),
inputs, graph, device));
int64_t paddings = GetPadding(dim, num_splits, inferred_shape);
has_paddings |= paddings > 0;
new_concat_nodes.emplace_back(NodeOut{concat_node, 0});
}
inputs_to_sharded_retval = new_concat_nodes;
}
TF_RET_CHECK(inputs_to_sharded_retval.size() == 1);
if (has_paddings) {
TF_ASSIGN_OR_RETURN(Node * slice_node,
CreateSliceNode(dtype, inferred_shape,
inputs_to_sharded_retval.at(0).node,
/*concat_out_index*/ 0, graph, device));
return slice_node;
}
return inputs_to_sharded_retval.at(0).node;
}
xla::StatusOr<Node*> CreateXlaConcatNode(
const xla::OpSharding& sharding, const int replica_id, DataType dtype,
const PartialTensorShape& partial_tensor_shape,
const std::vector<NodeOut>& orig_inputs, absl::string_view device,
Graph* graph) {
NodeDef xla_concat_def;
xla_concat_def.set_name(graph->NewName(
absl::StrCat("sharded_output/replica_", replica_id, "_concat")));
xla_concat_def.set_op("XlaConcatND");
xla_concat_def.set_device(std::string(device));
AddNodeAttr("T", dtype, &xla_concat_def);
AddNodeAttr("N", static_cast<int64>(orig_inputs.size()), &xla_concat_def);
const std::vector<int64> num_concats(
sharding.tile_assignment_dimensions().begin(),
sharding.replicate_on_last_tile_dim()
? std::prev(sharding.tile_assignment_dimensions().end())
: sharding.tile_assignment_dimensions().end());
AddNodeAttr("num_concats", num_concats, &xla_concat_def);
const int rank = sharding.replicate_on_last_tile_dim()
? sharding.tile_assignment_dimensions_size() - 1
: sharding.tile_assignment_dimensions_size();
std::vector<int32> paddings;
paddings.reserve(rank);
for (int dim = 0; dim < rank; ++dim) {
paddings.push_back(GetPadding(dim, sharding.tile_assignment_dimensions(dim),
partial_tensor_shape));
}
AddNodeAttr("paddings", paddings, &xla_concat_def);
Status s;
Node* xla_concat = graph->AddNode(xla_concat_def, &s);
TF_RETURN_IF_ERROR(s);
for (int i = 0, e = orig_inputs.size(); i < e; ++i) {
const NodeOut& input = orig_inputs[i];
graph->AddEdge(input.node, input.index, xla_concat, i);
}
return xla_concat;
}
// Set the padding ops the same devices as the original inputs. If the original
// inputs are on TPUs, the padding ops will be placed on TPUs and XLA on demand
// mode will be triggered, so we don't need to copy the data back to the host
// to do the padding.
Status SetPaddingNodesDevices(Graph* graph) {
for (Node* n : graph->op_nodes()) {
bool tpu_padding_attr;
if (n->type_string() == "Pad" &&
GetNodeAttr(n->attrs(), kPostDeviceRewriteAttr, &tpu_padding_attr)
.ok()) {
Node* unpadded_input;
TF_RETURN_IF_ERROR(n->input_node(0, &unpadded_input));
const string& requested_device = unpadded_input->requested_device();
const string& assigned_device = unpadded_input->assigned_device_name();
if (!requested_device.empty() || !assigned_device.empty()) {
// The output nodes of the original unpadded inputs include the padded
// inputs and real shapes of inputs, we assign those to the same device
// as the original inputs.
for (Node* out : unpadded_input->out_nodes()) {
if (GetNodeAttr(out->attrs(), kPostDeviceRewriteAttr,
&tpu_padding_attr)
.ok()) {
out->set_requested_device(requested_device);
out->set_assigned_device_name(assigned_device);
}
}
// There might be a tf.shape node added before TPUCompileOp, we need to
// set its device as well.
for (Node* out : n->out_nodes()) {
if (n->type_string() == "Shape") {
out->set_requested_device(requested_device);
out->set_assigned_device_name(assigned_device);
}
}
}
}
}
return Status::OK();
}
const string& AssignedOrRequestedDevice(const Node* node) {
if (!node->assigned_device_name().empty()) {
return node->assigned_device_name();
}
return node->requested_device();
}
bool IsTpuDevice(const string& device_string) {
DeviceNameUtils::ParsedName device;
return DeviceNameUtils::ParseFullName(device_string, &device) &&
device.type == DEVICE_TPU_NODE;
}
// Returns a set of device ops can be placed on TPU. There is no strict rule of
// thumb to decide which ops should be in the list, but empirically they are
// mostly dummy ops like Identity-like ops or control flow related ops. However
// people can add also add other ops like Pad to allow data stay on TPU.
const absl::flat_hash_set<std::string>& PlaceOnTPUOpList() {
static const auto place_on_tpu_ops = new absl::flat_hash_set<std::string>(
{"Identity", "IdentityN", "Enter", "Exit", "Switch", "Merge",
"NextIteration", "Shape", "_Retval"});
return *place_on_tpu_ops;
}
// If an op satisfies the following conditions, it will be placed on the same
// TPU device as its inputs:
// (1) The op can be placed on TPU (in the PlaceOnTPUOpList)
// (2) The op itself has no requested or assigned devices.
// (3) All the data inputs of this op are placed on the same device on TPUs.
// There are exceptions like the NextIterations input of Switch node can
// be placed on CPU as it is just a boolean.
//
// Returns true if the node device has been changed, otherwise returns false.
bool PlaceOpsOnTPU(Node* node) {
if (!AssignedOrRequestedDevice(node).empty() ||
!PlaceOnTPUOpList().contains(node->type_string())) {
return false;
}
string src_tpu_device = "";
Node* src_node;
for (const Edge* e : node->in_edges()) {
if (e->IsControlEdge()) {
continue;
}
Node* src = e->src();
const string& src_device = AssignedOrRequestedDevice(src);
// Make exceptions that we don't force the some inputs to place on TPUs.
if (node->IsSwitch() && src->IsLoopCond()) {
continue;
}
if (!IsTpuDevice(src_device) ||
(!src_tpu_device.empty() && src_device != src_tpu_device)) {
return false;
}
if (src_tpu_device.empty()) {
src_tpu_device = src_device;
src_node = src;
}
}
node->set_assigned_device_name(src_node->assigned_device_name());
node->set_requested_device(src_node->requested_device());
return true;
}
xla::OpMetadata CreateOpMetadataFromNode(const Node& node) {
xla::OpMetadata metadata;
metadata.set_op_type(node.type_string());
metadata.set_op_name(node.name());
return metadata;
}
// Helper struct holding node (nullable) and associated sharding.
struct NodeAndSharding {
explicit NodeAndSharding(const Node* node, const xla::OpSharding& sharding)
: node(node), sharding(sharding) {}
const Node* node;
xla::OpSharding sharding;
};
// Validate sharding configuration derived from XlaSharding attribute.
// Infer the core id from the OpSharding, if necessary.
Status ParseAndValidateSharding(const NodeAndSharding& node_and_sharding,
const int num_cores_per_replica,
int64* inferred_core_id,
absl::optional<NodeAndSharding>* result) {
if (node_and_sharding.sharding.type() == xla::OpSharding::MAXIMAL) {
int64_t core_annotation =
node_and_sharding.sharding.tile_assignment_devices(0);
TF_RETURN_IF_ERROR(
ValidateCoreNumber(core_annotation, num_cores_per_replica));
if (*inferred_core_id == -1 || *inferred_core_id > core_annotation) {
*inferred_core_id = core_annotation;
result->emplace(node_and_sharding);
}
} else {
if (node_and_sharding.sharding.type() == xla::OpSharding::OTHER) {
for (int64_t core :
node_and_sharding.sharding.tile_assignment_devices()) {
TF_RETURN_IF_ERROR(ValidateCoreNumber(core, num_cores_per_replica));
}
}
if (!result->has_value()) {
*result = node_and_sharding;
} else {
std::string result_value_serialized;
xla::OpSharding result_value = result->value().sharding;
result_value.clear_metadata();
SerializeToStringDeterministic(result_value, &result_value_serialized);
std::string sharding_serialized;
xla::OpSharding sharding = node_and_sharding.sharding;
sharding.clear_metadata();
SerializeToStringDeterministic(sharding, &sharding_serialized);
// TODO(lyandy): Choose the more granular sharding instead of always
// assigning to core 0 (maximal).
if (result_value_serialized != sharding_serialized) {
// We see different shardings, assign to core 0.
auto core_zero_sharding = xla::sharding_builder::AssignDevice(0);
DCHECK_NE(node_and_sharding.node, nullptr);
*core_zero_sharding.add_metadata() =
CreateOpMetadataFromNode(*node_and_sharding.node);
result->emplace(
NodeAndSharding(node_and_sharding.node, core_zero_sharding));
}
}
}
return Status::OK();
}
// As XlaSharding node may be followed by Cast op or an Identity op,
// recursively walk the graph and aggregate nodes connectd to
// |input_node| or Cast/Identity op following the |input_node|.
void FindNodesMaybeContainingShardingInfo(const Node& input_node,
std::vector<const Node*>* nodes) {
if (input_node.IsIdentity() || input_node.type_string() == "Cast") {
for (const Node* connected_node : input_node.out_nodes())
FindNodesMaybeContainingShardingInfo(*connected_node, nodes);
}
nodes->emplace_back(&input_node);
}
// Parse sharding configuration from |node| or it's adjacent nodes.
// XlaSharding configuration may be derived from
// a) Connected Identity op node.
// b) Connected Cast op node.
xla::StatusOr<absl::optional<NodeAndSharding>>
ParseInputShardingFromAdjacentNode(const int num_cores_per_replica,
const Node& node) {
// If |node| has `device` attribute or is a XlaSharding op,
// return the parsed OpSharding.
TF_ASSIGN_OR_RETURN(absl::optional<xla::OpSharding> sharding,
ParseShardingFromDevice(node, num_cores_per_replica,
/*add_metadata=*/true));
if (sharding.has_value()) {
return absl::optional<NodeAndSharding>(NodeAndSharding(&node, *sharding));
}
// XlaShardingOp may be followed by an identity or followed by identity
// and a Cast op.
std::vector<const Node*> potential_nodes_with_input_sharding;
FindNodesMaybeContainingShardingInfo(node,
&potential_nodes_with_input_sharding);
for (const Node* maybe_node_with_sharding_info :
potential_nodes_with_input_sharding) {
if (maybe_node_with_sharding_info->type_string() != "XlaSharding") continue;
TF_ASSIGN_OR_RETURN(
absl::optional<xla::OpSharding> sharding_config,
ParseShardingFromDevice(*maybe_node_with_sharding_info,
num_cores_per_replica, /*add_metadata=*/true));
if (sharding_config.has_value()) {
return absl::optional<NodeAndSharding>(
NodeAndSharding(maybe_node_with_sharding_info, *sharding_config));
}
}
return absl::optional<NodeAndSharding>();
}
// Walk the graph from an argument node to find OpSharding configuration
// from its neighbor nodes. Sharding configuration may be inferred from
// 1) Parsing XlaSharding attribute from neighboring node.
// 2) If argument node is a resource, then by parsing adjacent nodes
// of the connected ReadVariable op.
Status ParseAndValidateShardingFromNeighbors(
const int num_cores_per_replica, const std::string& arg_node_name,
const Node& neighbor_node, int64* inferred_core_id, bool* is_fast_mem,
absl::optional<NodeAndSharding>* result) {
if (neighbor_node.attrs().Find(TPU_FAST_MEM_ATTR) != nullptr) {
*is_fast_mem = true;
VLOG(2) << "place " << neighbor_node.name() << " on fast memory because "
<< arg_node_name << " has " << TPU_FAST_MEM_ATTR << " attribute";
}
// XlaSharding information may be encoded on node directly connected to the
// argument node.
TF_ASSIGN_OR_RETURN(
absl::optional<NodeAndSharding> node_and_sharding,
ParseInputShardingFromAdjacentNode(num_cores_per_replica, neighbor_node));
if (node_and_sharding.has_value()) {
TF_RETURN_IF_ERROR(ParseAndValidateSharding(
*node_and_sharding, num_cores_per_replica, inferred_core_id, result));
return Status::OK();
}
// When we use variable in TPU computation, we always have a
// XlaSharding op followed by a ReadVariableOp. As so, correctly parse
// the users of ReadVariableOp for potential sharding configuration.
if (neighbor_node.type_string() == "ReadVariableOp") {
for (const Edge* e : neighbor_node.out_edges()) {
if (e->IsControlEdge()) continue;
if (e->dst()->attrs().Find(TPU_FAST_MEM_ATTR) != nullptr) {
*is_fast_mem = true;
VLOG(2) << "place " << arg_node_name << " on fast memory because "
<< e->dst()->name() << TPU_FAST_MEM_ATTR << " attribute";
}
TF_ASSIGN_OR_RETURN(
absl::optional<NodeAndSharding> node_and_sharding,
ParseInputShardingFromAdjacentNode(num_cores_per_replica, *e->dst()));
if (node_and_sharding.has_value()) {
TF_RETURN_IF_ERROR(ParseAndValidateSharding(*node_and_sharding,
num_cores_per_replica,
inferred_core_id, result));
return Status::OK();
}
}
}
return Status::OK();
}
} // namespace
// Inputs:
// replication_spec_string: the device to which the TPUReplicate node was
// assigned.
// device_set: the set of TF devices.
// Outputs:
// tpu_compilation_device: the name of the TPU compilation device.
// num_tpus_per_task: the number of TPUs in each task. Verifies that all tasks
// have the same number of TPU devices.
// tpu_devices: the TPU devices, indexed by [task][device].
static Status GetTPUDeviceNames(
const string& replication_spec_string, const DeviceSet& device_set,
string* tpu_compilation_device, int* num_tpus_per_task,
std::vector<std::vector<Device*>>* tpu_devices) {
// TODO(b/110910013) GetSystemDevice parses the spec and returns the name of
// the tpu_system device, which we replace by the cpu device. We do this
// replacement because we want to place the TPUCompileOp (and the compile
// assert op) explicitly on cpu devices on the same job as the tpu_system
// device.
DeviceNameUtils::ParsedName replication_spec;
Device* replication_device;
TF_RETURN_IF_ERROR(DistributedTPURewriteHelpers::GetSystemDevice(
replication_spec_string, device_set, &replication_spec,
&replication_device));
*tpu_compilation_device =
str_util::StringReplace(replication_device->name(), DEVICE_TPU_SYSTEM,
DEVICE_CPU, /*replace_all=*/true);
// Finds the set of TPU devices attached to the tasks in the job.
TF_RETURN_IF_ERROR(DistributedTPURewriteHelpers::GetTPUDevices(
replication_spec, device_set, num_tpus_per_task, tpu_devices));
return Status::OK();
}
// Parses the topology attribute of TPUReplicate, and populates *topology with
// a physical mesh coordinate to (task, device) mapping.
static Status ParseTopologyAttr(const string& topology_attr,
const tpu::TpuTopologyExternal& tpu_topology,
int num_tasks, int num_tpus_per_task,
xla::Array4D<std::pair<int, int>>* topology) {
static_assert(4 == kTPUTopologyRank, "Assumes the topology rank is 4");
tpu::TopologyProto proto;
proto.ParseFromString(topology_attr);
if (proto.mesh_shape_size() != kTPUTopologyRank) {
return errors::InvalidArgument("TPU topology must be rank ",
kTPUTopologyRank);
}
if (proto.num_tasks() != num_tasks) {
return errors::InvalidArgument("Mismatched number of TPU tasks");
}
if (proto.num_tpu_devices_per_task() != num_tpus_per_task) {
return errors::InvalidArgument("Mismatched number of TPUs per task (",
proto.num_tpu_devices_per_task(),
" != ", num_tpus_per_task, ").");
}
if (proto.device_coordinates_size() !=
num_tasks * num_tpus_per_task * kTPUTopologyRank) {
return errors::InvalidArgument(
"device coordinates should be ", num_tasks, "x", num_tpus_per_task, "x",
kTPUTopologyRank, "; got ", proto.device_coordinates_size());
}
int devices_per_chip = tpu_topology.LogicalDevicesPerChip(kTensorCore);
*topology = xla::Array4D<std::pair<int, int>>(
tpu_topology.chip_bounds().x, tpu_topology.chip_bounds().y,
tpu_topology.chip_bounds().z, devices_per_chip, {-1, -1});
int pos = 0;
for (int task = 0; task < num_tasks; ++task) {
for (int device = 0; device < num_tpus_per_task; ++device) {
int32 x = proto.device_coordinates(pos++);
int32 y = proto.device_coordinates(pos++);
int32 z = proto.device_coordinates(pos++);
int32 core = proto.device_coordinates(pos++);
if (!tpu_topology.HasChip(x, y, z) || core < 0 ||
core >= devices_per_chip) {
return errors::InvalidArgument(
"Mesh coordinates (", x, ",", y, ",", z, ",", core,
") are not valid for the current TPU topology");
}
if ((*topology)(x, y, z, core).first != -1) {
return errors::InvalidArgument("Duplicate coordinates (", x, ",", y,
",", z, ",", core, ") in TPU topology");
}
(*topology)(x, y, z, core) = {task, device};
}
}
return Status::OK();
}
// Parses the value of the device_assignment attribute to TPUReplicate.
// Populates *device_assignment; *device_assignment must be a 2D array with
// shape (num_replicas, num_cores_per_replica).
static Status ParseDeviceAssignmentAttr(
absl::Span<const int> device_assignment_attr,
const tpu::TpuTopologyExternal& tpu_topology, int num_replicas,
int num_cores_per_replica,
xla::Array2D<tpu::TpuCoreLocationExternal>* device_assignment) {
static_assert(4 == kTPUTopologyRank, "Assumes the topology rank is 4");
const int64 device_assignment_attr_size =
num_replicas * num_cores_per_replica * kTPUTopologyRank;
if (device_assignment_attr.size() != device_assignment_attr_size) {
return errors::InvalidArgument(
"Length of device_assignment attribute must be equal to num_replicas (",
num_replicas, ") * num_cores_per_replica (", num_cores_per_replica,
") * ", kTPUTopologyRank, " got ", device_assignment_attr.size());
}
for (int core : device_assignment_attr) {
if (core < 0 || core >= kTPUMaxTopologySize) {
return errors::InvalidArgument(
"Invalid core number in device assignment: ", core);
}
}
*device_assignment = xla::Array2D<tpu::TpuCoreLocationExternal>(
num_replicas, num_cores_per_replica);
int devices_per_chip = tpu_topology.LogicalDevicesPerChip(kTensorCore);
xla::Array4D<int> replica_assignment(
tpu_topology.chip_bounds().x, tpu_topology.chip_bounds().y,
tpu_topology.chip_bounds().z, devices_per_chip, -1);
int pos = 0;
for (int replica = 0; replica < num_replicas; ++replica) {
for (int logical_core = 0; logical_core < num_cores_per_replica;
++logical_core) {
int32 x = device_assignment_attr[pos++];
int32 y = device_assignment_attr[pos++];
int32 z = device_assignment_attr[pos++];
int32 core = device_assignment_attr[pos++];
if (!tpu_topology.HasChip(x, y, z) || core < 0 ||
core >= devices_per_chip) {
return errors::InvalidArgument(
"Mesh coordinates (", x, ",", y, ",", core,
") are not valid for the current TPU topology");
}
tpu::TpuCoreLocationExternal core_location =
tpu_topology.Core(kTensorCore, x, y, z, core);
if (replica_assignment(x, y, z, core) != -1) {
return errors::InvalidArgument("Duplicate coordinates (", x, ",", y,
",", z, ",", core,
") in TPU device assignment");
}
replica_assignment(x, y, z, core) = replica;
(*device_assignment)(replica, logical_core) = core_location;
}
}
return Status::OK();
}
// Builds TensorFlow device assignments for the special case of a single core
// computation that is replicated to every core in the mesh.
// LINT.IfChange
static Status BuildFullMeshDeviceAssignment(
int num_replicas, const std::vector<std::vector<Device*>>& tpu_devices,
int num_tasks, int num_tpus_per_task,
std::vector<std::vector<string>>* tf_device_assignment,
std::vector<int>* devices_to_lock) {
// Assign TensorFlow devices to replicas arbitrarily.
for (int i = 0; i < num_replicas; ++i) {
int task = i / num_tpus_per_task;
int device = i % num_tpus_per_task;
TF_RET_CHECK(task >= 0 && task < num_tasks);
TF_RET_CHECK(device >= 0 && device < num_tpus_per_task);
// We don't actually know which TF device corresponds to which physical
// device, but it doesn't matter—they're all identical.
(*tf_device_assignment)[i] = {tpu_devices[task][device]->name()};
devices_to_lock->push_back(i);
}
return Status::OK();
}
// LINT.ThenChange(//tensorflow/compiler/mlir/tensorflow/utils/tpu_rewrite_device_util.cc)
// Builds TensorFlow device assignments for a replicated computation and convert
// device_assignment into xla_device_assignment.
static Status BuildGeneralDeviceAssignment(
int num_replicas, int num_cores_per_replica,
const std::vector<std::vector<Device*>>& tpu_devices,
const xla::Array2D<tpu::TpuCoreLocationExternal>& device_assignment,
const xla::Array4D<std::pair<int, int>>& topology,
std::vector<std::vector<string>>* tf_device_assignment,
std::vector<int>* devices_to_lock,
std::unique_ptr<xla::DeviceAssignment>* xla_device_assignment) {
// Assign TensorFlow devices to each computation's replicas according to
// device_assignment and 'topology'.
*xla_device_assignment = absl::make_unique<xla::DeviceAssignment>(
num_replicas, num_cores_per_replica);
for (int replica = 0; replica < num_replicas; ++replica) {
for (int computation = 0; computation < num_cores_per_replica;
++computation) {
const tpu::TpuCoreLocationExternal& core_location =
device_assignment(replica, computation);
int task;
int device;
std::tie(task, device) =
topology(core_location.chip_coordinates().x,
core_location.chip_coordinates().y,
core_location.chip_coordinates().z, core_location.index());
CHECK_LT(computation, num_cores_per_replica);
(**xla_device_assignment)(replica, computation) = core_location.Id();
// The communication pattern between replicas will be determined later by
// BuildAllReduceRing.
TF_RET_CHECK(task >= 0 && task < tpu_devices.size());
TF_RET_CHECK(device >= 0 && device < tpu_devices[task].size());
(*tf_device_assignment)[replica].push_back(
tpu_devices[task][device]->name());
devices_to_lock->push_back((task * tpu_devices[task].size()) + device);
}
}
return Status::OK();
}
/*static*/ Status DistributedTPURewritePass::BuildDeviceAssignment(
const tpu::TpuTopologyExternal& tpu_topology, int num_tpus_per_task,
const std::vector<std::vector<Device*>>& tpu_devices, int num_replicas,
int num_cores_per_replica, const string& topology_attr,
absl::Span<const int> device_assignment_attr,
std::vector<std::vector<string>>* tf_device_assignment,
std::vector<int>* devices_to_lock,
std::unique_ptr<xla::DeviceAssignment>* xla_device_assignment) {
const int num_tasks = tpu_devices.size();
const int num_tpu_devices = num_tasks * num_tpus_per_task;
VLOG(2) << "num_tasks=" << num_tasks
<< " num_tpus_per_task=" << num_tpus_per_task;
// Checks num_replicas is sane first to avoid integer overflow.
if (num_replicas > num_tpu_devices) {
#ifdef PLATFORM_CLOUD_TPU
return errors::InvalidArgument("Requested num_replicas=", num_replicas,
" but there are only ", num_tpu_devices,
" cores in the TPU topology.");
#else
return errors::InvalidArgument("Requested num_replicas=", num_replicas,
" but there are only ", num_tpu_devices,
" cores in the TPU topology.");
#endif
}
if (num_replicas * num_cores_per_replica > num_tpu_devices) {
return errors::InvalidArgument(
"Requested num_replicas=", num_replicas, " with ",
num_cores_per_replica, " cores per replica, but there are only ",
num_tpu_devices, " cores in the TPU topology");
}
tf_device_assignment->clear();
tf_device_assignment->resize(num_replicas);
devices_to_lock->clear();
devices_to_lock->reserve(num_replicas * num_cores_per_replica);
// Special case: we allow the user to omit the topology and device assignment
// information in two cases:
// * there is only one replica and one core per replica. In this case, we
// don't need to know topology information because we don't communicate with
// other cores.
// * the number of replicas is equal to the number of cores in the slice. In
// this case, all cores are running the same program so we don't need to
// know which is which.
if (topology_attr.empty()) {
// LINT.IfChange
if (num_replicas != 1 && num_replicas != num_tpu_devices) {
return errors::InvalidArgument(
"TPUReplicate asked to create ", num_replicas,
" replicas, but the number of cores in the TPU topology is ",
num_tpu_devices,
" and no TPU device assignment was supplied. "
"A TPU device assignment is required if the number of replicas is "
"not 1 or the number of cores in the topology (",
num_tpu_devices, ")");
}
if (num_cores_per_replica != 1) {
return errors::InvalidArgument(
"A TPU topology must be provided if num_cores_per_replica != 1");
}
if (!device_assignment_attr.empty()) {
return errors::InvalidArgument(
"A TPU topology must be provided if device_assignment_attr is "
"non-empty");
}
// If there is only one replica, assign the Tensorflow computation to task 0
// device 0, and leave the XLA device assignment empty. We don't know which
// core this is in the TPU topology, but it doesn't matter—we don't need to
// communicate with any other cores.
if (num_replicas == 1) {
(*tf_device_assignment)[0] = {tpu_devices[0][0]->name()};
devices_to_lock->push_back(0);
return Status::OK();
}
// Otherwise, num_replicas is equal to the number of cores, and we build a
// device assignment that covers the entire mesh. We do not need to know
// the topology to do so because all cores are identical.
return BuildFullMeshDeviceAssignment(num_replicas, tpu_devices, num_tasks,
num_tpus_per_task,
tf_device_assignment, devices_to_lock);
// LINT.ThenChange(//tensorflow/compiler/mlir/tensorflow/utils/tpu_rewrite_device_util.cc)
}
// Array that maps mesh coordinates to {TF task, TF TPU device #} pairs.
xla::Array4D<std::pair<int, int>> topology;
TF_RETURN_IF_ERROR(ParseTopologyAttr(topology_attr, tpu_topology, num_tasks,
num_tpus_per_task, &topology));
// Array that maps logical (replica, core) pairs to physical mesh coordinates.
xla::Array2D<tpu::TpuCoreLocationExternal> device_assignment;
TF_RETURN_IF_ERROR(ParseDeviceAssignmentAttr(
device_assignment_attr, tpu_topology, num_replicas, num_cores_per_replica,
&device_assignment));
return BuildGeneralDeviceAssignment(
num_replicas, num_cores_per_replica, tpu_devices, device_assignment,
topology, tf_device_assignment, devices_to_lock, xla_device_assignment);
}
Status DistributedTPURewritePass::GetComputationForTPUReplicateOp(
const NameAttrList& function, FunctionLibraryRuntime* flr,
Graph* computation, DataTypeVector* arg_types,
DataTypeVector* retval_types) {
FunctionLibraryRuntime::Handle handle;
TF_RETURN_IF_ERROR(
flr->Instantiate(function.name(), AttrSlice(&function.attr()), &handle));
const FunctionBody* fbody = flr->GetFunctionBody(handle);
CopyGraph(*fbody->graph, computation);
*arg_types = fbody->arg_types;
*retval_types = fbody->ret_types;
return Status::OK();
}
// Grab the InferredShape corresponding to an edge input.
static Status GetEdgeShape(const GraphShapeInfo& shape_info, const Edge& edge,
const InferredShape** info) {
auto it = shape_info.find(edge.src()->name());
if (it == shape_info.end()) {
return errors::InvalidArgument(
"Input to replicated TPU computation is missing InferredShape: ",
edge.src()->name());
}
TF_RET_CHECK(it->second.size() > edge.src_output());
*info = &it->second[edge.src_output()];
return Status::OK();
}
Status DistributedTPURewritePass::GetArgAndRetvalShapes(
const GraphShapeInfo& shape_info, const Node& node,
const ParameterInfo& params_info, std::vector<InferredShape>* arg_shapes,
std::vector<InferredShape>* retval_shapes) {
std::vector<const Edge*> input_edges;
TF_RETURN_IF_ERROR(node.input_edges(&input_edges));
// If any replica's arg shape is unknown, we will mark the computation's arg
// shape as being unknown. If the shapes differ the TpuExecute Op will raise a
// runtime error.
std::vector<bool> any_replica_shape_unknown(
params_info.NumInputsToEachReplica());
arg_shapes->clear();
arg_shapes->resize(params_info.NumInputsToEachReplica());
TF_RET_CHECK(input_edges.size() == params_info.NumInputsFromHost());
// Determines the shapes of the per-replica arguments and checks that all
// replicas have identical shapes.
int64_t edge_pos = 0;
auto check_shape = [&](int input_index) -> Status {
const InferredShape* info;
TF_RETURN_IF_ERROR(GetEdgeShape(shape_info, *input_edges[edge_pos], &info));
++edge_pos;
if ((info->handle_type == DT_INVALID && !info->shape.IsFullyDefined()) ||
(info->handle_type != DT_INVALID &&
!info->handle_shape.IsFullyDefined())) {
any_replica_shape_unknown[input_index] = true;
}
xla::StatusOr<InferredShape> status =
MergeInferredShapes((*arg_shapes)[input_index], *info);
if (!status.ok()) {
return errors::InvalidArgument(
"Mismatched shapes for input ", input_index, ": ",
(*arg_shapes)[input_index].shape.DebugString(), " vs. ",
info->shape.DebugString());
}
(*arg_shapes)[input_index] = status.ValueOrDie();
return Status::OK();
};
for (int64_t i = 0; i < params_info.NumReplicas(); ++i) {
for (int64_t j = 0; j < params_info.NumPerReplicaArgs(); ++j) {
TF_RETURN_IF_ERROR(check_shape(j));
}
}
for (int64_t i = 0; i < params_info.NumDistributedArgs(); ++i) {
TF_RETURN_IF_ERROR(check_shape(params_info.NumPerReplicaArgs() + i));
}
for (int64_t i = 0;
i < params_info.NumPerReplicaArgs() + params_info.NumDistributedArgs();
++i) {
if (any_replica_shape_unknown[i]) {
(*arg_shapes)[i].shape = PartialTensorShape();
(*arg_shapes)[i].handle_shape = PartialTensorShape();
}
}
// Determines the shape of the broadcast arguments.
for (int64_t i = 0; i < params_info.NumBroadcastArgs(); ++i) {
TF_RET_CHECK(node.input_type(edge_pos) != DT_RESOURCE);
const InferredShape* info;
TF_RETURN_IF_ERROR(GetEdgeShape(shape_info, *input_edges[edge_pos], &info));
(*arg_shapes)[i + params_info.NumPerReplicaArgs() +
params_info.NumDistributedArgs()]
.shape = info->shape;
++edge_pos;
}
// Determines the handle shape and handle type of the resource variable
// arguments.
for (int64_t i = 0; i < params_info.NumVariables(); ++i) {
TF_RET_CHECK(node.input_type(edge_pos) == DT_RESOURCE);
const InferredShape* info;
TF_RETURN_IF_ERROR(GetEdgeShape(shape_info, *input_edges[edge_pos], &info));
InferredShape& arg_shape =
(*arg_shapes)[i + params_info.NumPerReplicaArgs() +
params_info.NumDistributedArgs() +
params_info.NumBroadcastArgs()];
arg_shape.shape = TensorShape(); // Variables are always scalars.
arg_shape.handle_shape = info->handle_shape;
arg_shape.handle_type = info->handle_type;
TF_RET_CHECK(arg_shape.handle_type != DT_INVALID)
<< " input edge: " << input_edges[edge_pos]->DebugString();
++edge_pos;
}
// Determines the shape of the guaranteed constants.
// TODO(vinuraja): Can be removed because they are not required for any
// calculations. Leaving them here for symmetry with other structures like
// arg_types, arg_sharding, etc.
for (int64_t i = 0; i < params_info.NumGuaranteedConstants(); ++i) {
TF_RET_CHECK(node.input_type(edge_pos) != DT_RESOURCE);
const InferredShape* info;
TF_RETURN_IF_ERROR(GetEdgeShape(shape_info, *input_edges[edge_pos], &info));
(*arg_shapes)[i + params_info.NumPerReplicaArgs() +
params_info.NumDistributedArgs() +
params_info.NumBroadcastArgs() + params_info.NumVariables()]
.shape = info->shape;
++edge_pos;
}
// Extract the return value shapes.
auto it = shape_info.find(node.name());
retval_shapes->clear();
if (it != shape_info.end()) {
TF_RET_CHECK(it->second.size() >= node.num_outputs());
retval_shapes->resize(node.num_outputs());
for (int i = 0; i < node.num_outputs(); ++i) {
(*retval_shapes)[i].shape = it->second[i].shape;
}
} else if (node.num_outputs() > 0) {
return errors::InvalidArgument(
"Replicated TPU computation is missing InferredShape: ",
FormatNodeForError(node));
}
return Status::OK();
}
// Verifies that all nodes have legal sharding.
static Status ValidateCoreNumbers(const Graph& graph,
int num_cores_per_replica) {
for (Node* n : graph.nodes()) {
TF_ASSIGN_OR_RETURN(absl::optional<xla::OpSharding> sharding,
ParseShardingFromDevice(*n, num_cores_per_replica,
/*add_metadata=*/true));
}
return Status::OK();
}
static Status InferXlaShardingFromNeighbors(
const Node& n, int num_cores_per_replica, FunctionLibraryRuntime* flr,
CachedFunctionHandles* cached_function_handles,
absl::optional<NodeAndSharding>* output_node_and_sharding,
bool* is_fast_mem) {
int64_t core = -1;
absl::optional<NodeAndSharding> result;
// We assume the variable has been allocated on fast memory if any consuming
// op has TPU_FAST_MEM_ATTR attribute. This is a protocol between runtime and
// compiler.
*is_fast_mem = false;
for (const Edge* edge : n.out_edges()) {
if (edge->IsControlEdge()) continue;
TF_RETURN_IF_ERROR(ParseAndValidateShardingFromNeighbors(
num_cores_per_replica, n.name(), *edge->dst(), &core, is_fast_mem,
&result));
if (!flr) continue;
// The nodes deciding this arg's device assignment might be in
// FunctionDef. Instantiate FunctionDefs associated with this node
// and check nodes using this arg.
std::function<Status(const Edge* call_edge)> parse_sharding_from_function =
[&](const Edge* call_edge) {
auto associated_functions = GetAssociatedFunctions(
*call_edge->dst(), flr->GetFunctionLibraryDefinition());
for (auto& associated_function : associated_functions) {
FunctionLibraryRuntime::Handle handle;
TF_RETURN_IF_ERROR(cached_function_handles->GetOrInstantiate(
associated_function.func_name(),
AttrSlice(&associated_function.attrs()), &handle));
const FunctionBody* body = flr->GetFunctionBody(handle);
Graph* g = body->graph;
for (Node* body_node : g->nodes()) {
if (!body_node->IsArg()) continue;
int index;
TF_RETURN_IF_ERROR(
GetNodeAttr(body_node->attrs(), "index", &index));
if (index != call_edge->dst_input()) continue;
for (const Edge* out_edge : body_node->out_edges()) {
if (out_edge->IsControlEdge()) continue;
TF_RETURN_IF_ERROR(ParseAndValidateShardingFromNeighbors(
num_cores_per_replica, n.name(), *out_edge->dst(), &core,
is_fast_mem, &result));
TF_RETURN_IF_ERROR(parse_sharding_from_function(out_edge));
}
}
}
return Status::OK();
};
TF_RETURN_IF_ERROR(parse_sharding_from_function(edge));
}
*output_node_and_sharding = result;
return Status::OK();
}
bool UseSpmdForXlaPartitioning(const Node* replicate_node) {
bool spmd_attr;
if (!replicate_node ||
!TryGetNodeAttr(replicate_node->attrs(), "use_spmd_for_xla_partitioning",
&spmd_attr)) {
spmd_attr = false;
}
return spmd_attr;
}
std::string FormatNodeAndShardingMsg(
const absl::optional<NodeAndSharding>& node_and_sharding) {
DCHECK(node_and_sharding.has_value());
xla::OpSharding sharding_no_metadata = node_and_sharding->sharding;
sharding_no_metadata.clear_metadata();
std::string escaped_sharding_str =
absl::CEscape(sharding_no_metadata.SerializeAsString());
if (node_and_sharding->node == nullptr) {
return absl::StrCat(" via default sharding '", escaped_sharding_str, "'");
}
return absl::StrCat(" via node ", node_and_sharding->node->DebugString(),
" sharding '", escaped_sharding_str, "'");
}
Status DistributedTPURewritePass::AssignArgsAndRetvalsToCores(
int num_cores_per_replica, const ParameterInfo& params_info,
const DataTypeVector& arg_types,
const std::vector<InferredShape>& arg_shapes,
const DataTypeVector& retval_types,
const std::vector<InferredShape>& retval_shapes, const Graph& graph,
const Node* replicate_node, FunctionLibraryRuntime* flr,
bool allow_parameter_replication_for_spmd,
std::vector<xla::OpSharding>* arg_sharding, std::vector<bool>* arg_fast_mem,
std::vector<xla::OpSharding>* retval_sharding,
std::vector<std::string>* arg_names) {
// Builds vectors of the argument and return nodes.
std::vector<Node*> args(arg_types.size());
std::vector<Node*> retvals(retval_types.size());
absl::flat_hash_map<int, Node*> partitioned_output_nodes;
for (Node* node : graph.op_nodes()) {
if (node->IsArg()) {
int index;
TF_RETURN_IF_ERROR(GetNodeAttr(node->attrs(), "index", &index));
TF_RET_CHECK(index >= 0 && index < args.size());
args[index] = node;
} else if (node->IsRetval()) {
int index;
TF_RETURN_IF_ERROR(GetNodeAttr(node->attrs(), "index", &index));
TF_RET_CHECK(index >= 0 && index < retvals.size());
retvals[index] = node;
}
}
for (const Edge* edge : replicate_node->out_edges()) {
int num_partitioned_outputs = 0;
for (const Edge* out_edge : edge->dst()->out_edges()) {
if (out_edge->dst()->type_string() == kTPUPartitionedOutput) {
partitioned_output_nodes[edge->src_output()] = out_edge->dst();
num_partitioned_outputs++;
}
}
if (num_partitioned_outputs > 1) {
return errors::InvalidArgument(
"More than one TPUPartitionedOutput per replciated output.");
}
}
// Verifies there are no missing arguments/return values.
for (int i = 0; i < args.size(); ++i) {
if (args[i] == nullptr) {
return errors::Internal("Missing function argument: ", i);
}
}
for (int i = 0; i < retvals.size(); ++i) {
if (retvals[i] == nullptr) {
return errors::Internal("Missing function return value: ", i);
}
}
// Assigns a core to each _Arg. Chooses the lowest-numbered core that
// consumes the argument. We choose the lowest-numbered core so the
// assignment is deterministic.
TensorDevicePlacer args_device_selector(num_cores_per_replica, arg_types,
arg_shapes);
arg_sharding->resize(args.size());
arg_names->resize(args.size());
arg_fast_mem->resize(args.size());
CachedFunctionHandles cached_function_handles(flr);
const bool use_spmd = (UseSpmdForXlaPartitioning(replicate_node) ||
replicate_inputs_outputs_by_default_for_xla_spmd_) &&
allow_parameter_replication_for_spmd;
// Offset _TPUReplicate non per replica argument indices by
// (num_replicas - 1) * num_per_replica_args as _TPUReplicate nodes are
// constructed with all per replica args across all replicas while the
// encapsulated function only has 1 replica's per replica args. Per replica
// args are ordered by replica first, so the index here does not require an
// offset and the first replica's input nodes is sufficient for determining
// argument sharding.
const int index_offset =
(params_info.NumReplicas() - 1) * params_info.NumPerReplicaArgs();
for (int i = 0; i < args.size(); ++i) {
const Node* n = args[i];
absl::optional<int64> assigned_core;
absl::optional<NodeAndSharding> node_and_sharding;
bool is_fast_mem;
TF_RETURN_IF_ERROR(InferXlaShardingFromNeighbors(
*n, num_cores_per_replica, flr, &cached_function_handles,
&node_and_sharding, &is_fast_mem));
const bool is_per_replica_arg = params_info.IsPerReplicaArg(i);
if (is_per_replica_arg || params_info.IsDistributedArg(i)) {
Node* input_node;
TF_RETURN_IF_ERROR(replicate_node->input_node(
i + (is_per_replica_arg ? 0 : index_offset), &input_node));
if (input_node->type_string() == kTPUPartitionedInput) {
TF_ASSIGN_OR_RETURN(
absl::optional<xla::OpSharding> parsed_sharding,
GetShardingFromNodeDef(input_node->def(), /*add_metadata=*/true));
if (!parsed_sharding.has_value())
return errors::InvalidArgument("Missing _XlaSharding attr from: ",
input_node->DebugString());
node_and_sharding = NodeAndSharding(input_node, *parsed_sharding);
VLOG(1) << "Arg " << i << " parsed sharding information from "
<< input_node->DebugString() << " : "
<< parsed_sharding->DebugString();
}
}
if (params_info.IsVariableArg(i)) {
Node* input_node;
TF_RETURN_IF_ERROR(
replicate_node->input_node(i + index_offset, &input_node));
if (input_node->type_string() == kVarHandleOp) {
TF_ASSIGN_OR_RETURN(
absl::optional<xla::OpSharding> parsed_sharding,
GetShardingFromNodeDef(input_node->def(), /*add_metadata=*/true));
if (parsed_sharding.has_value()) {
node_and_sharding = NodeAndSharding(input_node, *parsed_sharding);
VLOG(1) << "Arg " << i << " parsed sharding information from "
<< input_node->DebugString() << " : "
<< parsed_sharding->DebugString();
}
}
}
if (node_and_sharding.has_value() && enable_automatic_model_parallelism_) {
return tensorflow::errors::InvalidArgument(
"Specifying manual sharding is not allowed when automatic "
"model parallelism is enabled.",
node_and_sharding->sharding.DebugString());
}
if (!node_and_sharding.has_value()) {
if (use_spmd &&
(params_info.IsVariableArg(i) || params_info.IsBroadcastArg(i) ||
((params_info.IsPerReplicaArg(i) ||
params_info.IsDistributedArg(i)) &&
arg_types[i] != DT_RESOURCE))) {
// Use replication for host variables or non-variable per-replica
// inputs.
node_and_sharding = NodeAndSharding(/*node=*/nullptr,
xla::sharding_builder::Replicate());
} else {
// TODO(dlibenzi): Distributing variables to cores other than 0 makes
// learning/brain/research/babelfish/trainer:trainer_tpu_test fail.
// For now distribute only per replica arguments, unless
// tf_jf_distribute_vars is set, to allow debugging the issue.
if (((params_info.IsPerReplicaArg(i) ||
params_info.IsDistributedArg(i)) &&
arg_types[i] != DT_RESOURCE) ||
(distribute_vars_ && params_info.IsVariableArg(i))) {
assigned_core = args_device_selector.RetrieveAssignment(i);
} else {
assigned_core = 0;
}
node_and_sharding = NodeAndSharding(
/*node=*/nullptr,
xla::sharding_builder::AssignDevice(*assigned_core));
}
*node_and_sharding->sharding.add_metadata() =
CreateOpMetadataFromNode(*replicate_node);
} else if (node_and_sharding->sharding.type() == xla::OpSharding::MAXIMAL) {
assigned_core = node_and_sharding->sharding.tile_assignment_devices(0);
} else if (node_and_sharding->sharding.type() !=
xla::OpSharding::REPLICATED &&
node_and_sharding->sharding.type() != xla::OpSharding::OTHER) {
return tensorflow::errors::InvalidArgument(
"Unsupported argument sharding (for arg ", n->DebugString(),
"): ", node_and_sharding->sharding.DebugString());
}
if (assigned_core.has_value()) {
args_device_selector.ReportDeviceAssigned(*assigned_core, i);
VLOG(3) << "Assigning argument " << i << " (" << n->DebugString()
<< ") to core " << *assigned_core
<< FormatNodeAndShardingMsg(node_and_sharding);
args[i]->set_assigned_device_name(CoreDeviceLabel(*assigned_core));
} else if (node_and_sharding->sharding.type() == xla::OpSharding::OTHER) {
for (int64_t core :
node_and_sharding->sharding.tile_assignment_devices()) {
TF_RET_CHECK(core >= 0 && core < num_cores_per_replica)
<< "core " << core << " should be between [0, "
<< num_cores_per_replica << "). sharding is "
<< node_and_sharding->sharding.DebugString();
args_device_selector.ReportDeviceAssigned(core, i);
}
VLOG(3) << "Assigning argument " << i << " (" << n->DebugString()
<< ") with tiled sharding to cores "
<< absl::StrJoin(
node_and_sharding->sharding.tile_assignment_devices(), ",")
<< " " << FormatNodeAndShardingMsg(node_and_sharding);
} else {
DCHECK_EQ(node_and_sharding->sharding.type(),
xla::OpSharding::REPLICATED);
for (int64_t core = 0; core < num_cores_per_replica; ++core) {
args_device_selector.ReportDeviceAssigned(core, i);
}
VLOG(3) << "Assigning argument " << i << " (" << n->DebugString()
<< ") to all cores"
<< FormatNodeAndShardingMsg(node_and_sharding);
}
(*arg_sharding)[i] = node_and_sharding->sharding;
(*arg_fast_mem)[i] = is_fast_mem;
(*arg_names)[i] = n->name();
if (is_fast_mem) {
VLOG(3) << "Add " << TPU_FAST_MEM_ATTR << " attribute to "
<< args[i]->name();
}
args[i]->AddAttr(kShardingAttribute,
node_and_sharding->sharding.SerializeAsString());
}
TF_RETURN_IF_ERROR(cached_function_handles.ReleaseAllHandles());
// Assigns each _Retval node to the core that produces its value.
TensorDevicePlacer retvals_device_selector(num_cores_per_replica,
retval_types, retval_shapes);
retval_sharding->resize(retvals.size());
for (int i = 0; i < retvals.size(); ++i) {
const Edge* edge;
TF_RETURN_IF_ERROR(retvals[i]->input_edge(0, &edge));
TF_ASSIGN_OR_RETURN(
absl::optional<xla::OpSharding> edge_sharding,
ParseShardingFromEdgeSource(*edge, num_cores_per_replica,
/*add_metadata=*/true));
absl::optional<NodeAndSharding> node_and_sharding;
if (edge_sharding.has_value()) {
node_and_sharding.emplace(NodeAndSharding(edge->src(), *edge_sharding));
}
if (partitioned_output_nodes.contains(i)) {
Node* output_node = partitioned_output_nodes[i];
TF_ASSIGN_OR_RETURN(
absl::optional<xla::OpSharding> parsed_sharding,
GetShardingFromNodeDef(output_node->def(), /*add_metadata=*/true));
if (parsed_sharding.has_value()) {
node_and_sharding = NodeAndSharding(output_node, *parsed_sharding);
VLOG(1) << "Retval " << i << " parsed sharding information from "
<< output_node->DebugString() << " : "
<< parsed_sharding->DebugString();
}
}
absl::optional<int64> assigned_core;
if (node_and_sharding.has_value()) {
if (enable_automatic_model_parallelism_) {
return tensorflow::errors::InvalidArgument(
"Specifying manual sharding is not allowed when automatic "
"model parallelism is enabled.",
node_and_sharding->sharding.DebugString());
}
if (node_and_sharding->sharding.type() == xla::OpSharding::MAXIMAL) {
assigned_core = node_and_sharding->sharding.tile_assignment_devices(0);
TF_RETURN_IF_ERROR(
ValidateCoreNumber(*assigned_core, num_cores_per_replica));
} else if (node_and_sharding->sharding.type() !=
xla::OpSharding::REPLICATED &&
node_and_sharding->sharding.type() != xla::OpSharding::OTHER) {
return tensorflow::errors::InvalidArgument(
"Unsupported argument sharding for retval ",
retvals[i]->DebugString(), " edge=", edge->DebugString(), ": ",
node_and_sharding->sharding.DebugString());
}
} else {
if (use_spmd) {
node_and_sharding = NodeAndSharding(/*node=*/nullptr,
xla::sharding_builder::Replicate());
} else {
if (distribute_vars_) {
assigned_core = retvals_device_selector.RetrieveAssignment(i);
} else {
assigned_core = 0;
}
node_and_sharding = NodeAndSharding(
/*node=*/nullptr,
xla::sharding_builder::AssignDevice(*assigned_core));
}
*node_and_sharding->sharding.add_metadata() =
CreateOpMetadataFromNode(*replicate_node);
}
if (assigned_core.has_value()) {
retvals[i]->set_assigned_device_name(CoreDeviceLabel(*assigned_core));
retvals_device_selector.ReportDeviceAssigned(*assigned_core, i);
VLOG(3) << "Assigning return value " << i << " ("
<< retvals[i]->DebugString() << ") to core " << *assigned_core
<< FormatNodeAndShardingMsg(node_and_sharding);
} else if (node_and_sharding->sharding.type() == xla::OpSharding::OTHER) {
for (int64_t core :
node_and_sharding->sharding.tile_assignment_devices()) {
TF_RET_CHECK(core >= 0 && core < num_cores_per_replica)
<< "core " << core << " should be between [0, "
<< num_cores_per_replica << "). sharding is "
<< node_and_sharding->sharding.DebugString();
retvals_device_selector.ReportDeviceAssigned(core, i);
}
VLOG(3) << "Assigning return value " << i << " ("
<< retvals[i]->DebugString() << ") with tiled sharding to cores "
<< absl::StrJoin(
node_and_sharding->sharding.tile_assignment_devices(), ",")
<< " " << FormatNodeAndShardingMsg(node_and_sharding);
} else {
DCHECK_EQ(node_and_sharding->sharding.type(),
xla::OpSharding::REPLICATED);
for (int64_t core = 0; core < num_cores_per_replica; ++core) {
retvals_device_selector.ReportDeviceAssigned(core, i);
}
VLOG(3) << "Assigning return value " << i << " ("
<< retvals[i]->DebugString() << ") to all cores"
<< FormatNodeAndShardingMsg(node_and_sharding);
}
retvals[i]->AddAttr(kShardingAttribute,
node_and_sharding->sharding.SerializeAsString());
(*retval_sharding)[i] = node_and_sharding->sharding;
}
if (use_spmd &&
(absl::c_any_of(*arg_sharding,
[](const xla::OpSharding& s) {
return s.type() == xla::OpSharding::MAXIMAL;
}) ||
absl::c_any_of(*retval_sharding, [](const xla::OpSharding& s) {
return s.type() == xla::OpSharding::MAXIMAL;
}))) {
LOG(WARNING) << "XLA SPMD only supports cases where all inputs/outputs "
"exist on every partition (sharded or replicated). Fall "
"back to MPMD.";
return AssignArgsAndRetvalsToCores(
num_cores_per_replica, params_info, arg_types, arg_shapes, retval_types,
retval_shapes, graph, replicate_node, flr,
/*allow_parameter_replication_for_spmd=*/false, arg_sharding,
arg_fast_mem, retval_sharding, arg_names);
}
return Status::OK();
}
// Builds Shape nodes that compute the shapes of arguments whose shapes are not
// statically known.
/* static */ Status DistributedTPURewritePass::BuildDynamicShapeNodes(
const Node& replicate_node, const std::vector<InferredShape>& arg_shapes,
const ParameterInfo& params_info, const std::vector<Node*>& variable_reads,
Graph* graph, std::vector<Node*>* dynamic_shape_nodes) {
dynamic_shape_nodes->clear();
std::vector<const Edge*> replicate_input_edges;
TF_RETURN_IF_ERROR(replicate_node.input_edges(&replicate_input_edges));
// The compiler determines the shape of each constant by inspecting the value
// of its corresponding host-memory tensor; this happens when a step is run.
// As a result, the shapes of constants are not needed at graph rewrite time.
const int num_args = arg_shapes.size() - params_info.NumGuaranteedConstants();
TF_RET_CHECK(num_args == params_info.NumPerReplicaArgs() +
params_info.NumDistributedArgs() +
params_info.NumBroadcastArgs() +
params_info.NumVariables());
for (int i = 0; i < num_args; ++i) {
const PartialTensorShape* shape = arg_shapes[i].handle_type == DT_INVALID
? &arg_shapes[i].shape
: &arg_shapes[i].handle_shape;
if (!shape->IsFullyDefined()) {
NodeDef def;
Node* src;
int src_output;
std::vector<Node*> control_inputs;
if (params_info.IsVariableArg(i)) {
int64_t var_num = i - params_info.NumPerReplicaArgs() -
params_info.NumDistributedArgs() -
params_info.NumBroadcastArgs();
TF_RET_CHECK(0 <= var_num && var_num < variable_reads.size());
Node* read = variable_reads[var_num];
DCHECK_EQ(read->type_string(), "ReadVariableOp");
for (const Edge* edge : read->in_edges()) {
if (edge->IsControlEdge()) {
control_inputs.push_back(edge->src());
}
}
const Edge* variable_input = nullptr;
TF_RETURN_IF_ERROR(read->input_edge(/*idx=*/0, &variable_input));
src = variable_input->src();
src_output = variable_input->src_output();
def.set_name(
graph->NewName(strings::StrCat(src->name(), "/variable_shape")));
def.set_op("VariableShape");
} else {
if (params_info.IsPerReplicaArg(i)) {
TF_RET_CHECK(i < replicate_input_edges.size());
// All replicas must have the same input shapes. Uses the shape of the
// inputs from the first replica.
src = replicate_input_edges[i]->src();
src_output = replicate_input_edges[i]->src_output();
} else {
DCHECK(params_info.IsDistributedArg(i) ||
params_info.IsBroadcastArg(i));
int64_t input_num =
params_info.NumPerReplicaArgs() * params_info.NumReplicas() + i -
params_info.NumPerReplicaArgs();
TF_RET_CHECK(0 <= input_num &&
input_num < replicate_input_edges.size());
src = replicate_input_edges[input_num]->src();
src_output = replicate_input_edges[input_num]->src_output();
}
def.set_name(graph->NewName(strings::StrCat(src->name(), "/shape")));
def.set_op("Shape");
AddNodeAttr("T", src->output_type(src_output), &def);
}
def.set_device(src->assigned_device_name());
AddNodeAttr("out_type", DT_INT64, &def);
MergeDebugInfo(NodeDebugInfo(replicate_node.def()), &def);
Status status;
Node* shape_node = graph->AddNode(def, &status);
if (!status.ok()) return status;
dynamic_shape_nodes->push_back(shape_node);
shape_node->set_assigned_device_name(src->assigned_device_name());
graph->AddEdge(src, src_output, shape_node, 0);
for (Node* control_input : control_inputs) {
graph->AddControlEdge(control_input, shape_node);
}
}
}
return Status::OK();
}
namespace {
bool XlaBroadcastTypeSupported(const DataType dtype) {
return (dtype == DT_FLOAT || dtype == DT_BFLOAT16 || dtype == DT_INT32 ||
dtype == DT_BOOL);
}
bool XlaBroadcastKindSupported(
const DistributedTPURewritePass::ParameterInfo& params_info,
int param_num) {
// NOTE: This is intended to cover non-sharded data parallel variables, for
// training only. . Is it correct to just check if the arg_type is
// DT_RESOURCE?
return params_info.IsVariableArg(param_num) &&
!(params_info.IsPerReplicaArg(param_num) ||
params_info.IsDistributedArg(param_num) ||
params_info.IsBroadcastArg(param_num) ||
params_info.IsConstantArg(param_num));
}
bool EnableXlaParamBroadcast(
bool enable_xla_param_broadcast,
const DistributedTPURewritePass::ParameterInfo& params_info, int param_num,
DataType dtype, int num_cores_per_replica) {
// Conditions necessary to use XLA collectives for arg broadcast:
// 1. Globally enabled via enable_xla_param_broadcast.
// 2. DataType must be supported.
// 3. Parameter must be a variable, and not distributed or broadcasted.
// 4. Model parallelism is not currently supported.
return enable_xla_param_broadcast && XlaBroadcastTypeSupported(dtype) &&
XlaBroadcastKindSupported(params_info, param_num) &&
(num_cores_per_replica == 1);
}
} // namespace
// Builds a TPUCompile node that compiles the bodies of the function call
// `nodes`.
Status DistributedTPURewritePass::BuildCompileNode(
const Node* replicate_node, const NameAttrList& function,
uint64 library_fingerprint, const ParameterInfo& params_info,
const std::vector<InferredShape>& arg_shapes,
const DataTypeVector& arg_types,
const std::vector<Node*>& guaranteed_constant_nodes,
const string& session_handle,
const std::vector<xla::OpSharding>& arg_sharding,
const std::vector<bool>& arg_fast_mem,
const std::vector<std::string>& arg_names,
const std::vector<xla::OpSharding>& retval_sharding,
int num_cores_per_replica, const string& compile_device,
const xla::DeviceAssignment* xla_device_assignment,
const std::vector<Node*>& dynamic_shape_nodes, Graph* graph,
Node** compile_node, int64_t autotuner_thresh, int num_tasks) {
VLOG(1) << "BuildCompileNode";
tpu::TPUCompileMetadataProto proto;
proto.set_num_replicas(params_info.NumReplicas());
proto.set_num_cores_per_replica(num_cores_per_replica);
proto.set_function_library_fingerprint(library_fingerprint);
proto.set_enable_automatic_model_parallelism(
enable_cross_replica_sharding_mirrored_variables_);
const bool use_spmd =
UseSpmdForXlaPartitioning(replicate_node) && allow_xla_spmd_partition_ &&
!absl::c_any_of(arg_sharding,
[](const xla::OpSharding& s) {
return s.type() == xla::OpSharding::MAXIMAL;
}) &&
!absl::c_any_of(retval_sharding, [](const xla::OpSharding& s) {
return s.type() == xla::OpSharding::MAXIMAL;
});
proto.set_use_spmd_for_xla_partitioning(use_spmd);
// Get and fill padding map.
if (replicate_node != nullptr) {
xla::DebugOptions::StepMarkerLocation location;
TF_RETURN_IF_ERROR(GetStepMarkerLocation(*replicate_node, &location));
proto.set_step_marker_location(location);
}
if (xla_device_assignment != nullptr) {
TF_RETURN_IF_ERROR(
xla_device_assignment->Serialize(proto.mutable_device_assignment()));
}
const int num_args = arg_types.size();
const int num_guaranteed_constants = guaranteed_constant_nodes.size();
const int guaranteed_const_start_index = num_args - num_guaranteed_constants;
TF_RET_CHECK(num_args == arg_shapes.size());
TF_RET_CHECK(num_args == arg_sharding.size())
<< num_args << " != " << arg_sharding.size();
for (int i = 0; i < num_args; ++i) {
tpu::TPUCompileMetadataProto::Arg* arg = proto.add_args();
DataType type = arg_types[i];
const InferredShape& arg_shape = arg_shapes[i];
arg->set_name(arg_names[i]);
if (type == DT_RESOURCE) {
TF_RET_CHECK(arg_shape.handle_type != DT_INVALID) << i;
arg->set_dtype(arg_shape.handle_type);
arg_shape.handle_shape.AsProto(arg->mutable_shape());
arg->set_kind(tpu::TPUCompileMetadataProto::Arg::VARIABLE);
arg->set_fast_mem(arg_fast_mem[i]);
} else {
arg->set_dtype(type);
arg_shape.shape.AsProto(arg->mutable_shape());
if (i >= guaranteed_const_start_index) {
const DataType edge_type =
guaranteed_constant_nodes[i - guaranteed_const_start_index]
->output_type(0);
TF_RET_CHECK(type == edge_type)
<< "Arg type: " << type << " but edge type: " << edge_type;
arg->set_kind(tpu::TPUCompileMetadataProto::Arg::GUARANTEED_CONSTANT);
} else {
arg->set_kind(tpu::TPUCompileMetadataProto::Arg::PARAMETER);
}
}
// Use XLA collective primitives to distribute variables to all replicas,
// for multi-host systems.
arg->set_requires_xla_broadcast(
num_tasks > 1 &&
EnableXlaParamBroadcast(enable_xla_param_broadcast_, params_info, i,
arg_shape.handle_type /*arg.dtype?*/,
num_cores_per_replica));
// As long as the argument is not a per-replica one, it should have the same
// value for all replicas. For clarity, we keep the (redundant) checks for
// variable, broadcast and constant types, to prevent bugs in case new types
// with different semantics are introduced in the future.
arg->set_is_same_data_across_replicas(
!params_info.IsPerReplicaArg(i) && !params_info.IsDistributedArg(i) &&
(params_info.IsVariableArg(i) || params_info.IsBroadcastArg(i) ||
params_info.IsConstantArg(i)));
if (params_info.mirrored_variable_indices().count(i) > 0) {
CHECK_EQ(type, DT_RESOURCE);
arg->set_is_same_data_across_replicas(true);
// 64-bit type is not shardable by XLA:TPU yet.
bool sharding_enabled = (arg_shape.handle_type != DT_COMPLEX64 &&
arg_shape.handle_type != DT_INT64 &&
arg_shape.handle_type != DT_UINT64 &&
arg_shape.handle_type != DT_DOUBLE);
arg->set_enable_xla_sharding(
sharding_enabled ? tpu::TPUCompileMetadataProto::Arg::TENTATIVE
: tpu::TPUCompileMetadataProto::Arg::DISALLOWED);
}
*arg->mutable_sharding() = arg_sharding[i];
}
const int num_retvals = retval_sharding.size();
for (int i = 0; i < num_retvals; ++i) {
*proto.add_retvals()->mutable_sharding() = retval_sharding[i];
}
proto.set_session_handle(session_handle);
DataTypeVector constant_arg_types;
constant_arg_types.reserve(num_guaranteed_constants);
for (int i = 0; i < num_guaranteed_constants; ++i) {
constant_arg_types.push_back(arg_types[guaranteed_const_start_index + i]);
}
proto.set_xla_fusion_autotuner_thresh(autotuner_thresh);
string metadata;
proto.SerializeToString(&metadata);
NodeDef def;
def.set_name(UniqueNodeName("TPUReplicate/_compile", graph));
def.set_op("TPUCompile");
def.set_device(compile_device);
if (replicate_node) {
MergeDebugInfo(NodeDebugInfo(replicate_node->def()), &def);
}
AddNodeAttr("function", function, &def);
AddNodeAttr("num_computations", num_cores_per_replica, &def);
AddNodeAttr("NumDynamicShapes", static_cast<int>(dynamic_shape_nodes.size()),
&def);
AddNodeAttr("metadata", metadata, &def);
AddNodeAttr("Tguaranteed_constants", constant_arg_types, &def);
Status status;
*compile_node = graph->AddNode(def, &status);
TF_RETURN_IF_ERROR(status);
(*compile_node)->set_assigned_device_name(compile_device);
for (int i = 0; i < dynamic_shape_nodes.size(); ++i) {
graph->AddEdge(dynamic_shape_nodes[i], 0, *compile_node, i);
}
for (int i = 0; i < num_guaranteed_constants; ++i) {
graph->AddEdge(guaranteed_constant_nodes[i], 0, *compile_node,
dynamic_shape_nodes.size() + i);
}
VLOG(1) << "BuildCompileNode(): " << status;
return status;
}
Status DistributedTPURewritePass::FindGuaranteedConstantInputs(
const Node& node, const NameRangeMap& input_range_map,
std::vector<Node*>* guaranteed_constants) {
std::vector<const Edge*> input_edges;
TF_RETURN_IF_ERROR(node.input_edges(&input_edges));
std::pair<int, int> variables_limits =
input_range_map.at("guaranteed_constants");
for (int i = variables_limits.first; i < variables_limits.second; ++i) {
guaranteed_constants->push_back(input_edges[i]->src());
}
return Status::OK();
}
Status DistributedTPURewritePass::FindVariableInputs(
const Node& node, const NameRangeMap& input_range_map,
std::vector<VariableInput>* variables) {
std::vector<const Edge*> input_edges;
TF_RETURN_IF_ERROR(node.input_edges(&input_edges));
std::pair<int, int> variables_limits = input_range_map.at("variables");
for (int i = variables_limits.first; i < variables_limits.second; ++i) {
Node* node = input_edges[i]->src();
// Find the type of the VarHandleOp that feeds this node, looking through
// any wrapping Enter or Switch nodes.
while (node->IsEnter() || node->IsSwitch()) {
TF_RETURN_IF_ERROR(node->input_node(0, &node));
}
// Fix the variable device assignment if it is requested with a full name.
if (!node->has_assigned_device_name() &&
!node->requested_device().empty()) {
DeviceNameUtils::ParsedName var_device;
TF_RET_CHECK(DeviceNameUtils::ParseFullName(node->requested_device(),
&var_device));
if (var_device.has_job && var_device.has_replica && var_device.has_task &&
var_device.has_type && var_device.has_id) {
node->set_assigned_device_name(node->requested_device());
if (node != input_edges[i]->src() &&
!input_edges[i]->src()->has_assigned_device_name()) {
input_edges[i]->src()->set_assigned_device_name(
node->requested_device());
}
}
}
if (node->type_string() == kVarHandleOp) {
DataType dtype;
TF_RETURN_IF_ERROR(GetNodeAttr(node->attrs(), "dtype", &dtype));
variables->push_back(VariableInput{input_edges[i]->src(),
input_edges[i]->src_output(), dtype});
} else if (node->type_string() == "_Arg") {
std::vector<DataType> dtypes;
TF_RETURN_IF_ERROR(GetNodeAttr(node->attrs(), "_handle_dtypes", &dtypes));
if (dtypes.empty()) {
return errors::Internal(
"_Arg node with resource output must have non-empty _handle_dtypes "
"attribute: ",
node->DebugString());
}
variables->push_back(VariableInput{
input_edges[i]->src(), input_edges[i]->src_output(), dtypes[0]});
} else {
return errors::Internal(
"Cannot handle variable input with node type other than VarHandleOp "
"and _Arg: ",
node->DebugString());
}
}
return Status::OK();
}
// Builds a NoOp node, used for building control dependencies.
static Status BuildNoopNode(const Node& source, StringPiece name,
const string& device, Graph* graph, Node** node) {
NodeDefBuilder builder(name, "NoOp", NodeDebugInfo(source));
if (!device.empty()) {
builder.Device(device);
}
NodeDef def;
TF_RETURN_IF_ERROR(builder.Finalize(&def));
Status status;
*node = graph->AddNode(def, &status);
if (!device.empty()) {
(*node)->set_assigned_device_name(device);
}
return status;
}
Status DistributedTPURewritePass::ConnectHostComputeNodes(
Node* compile_node, Node* key_placeholder_node, Graph* graph) {
// First find all the downstream nodes of the key placeholder node, since we
// want to delete the connecting edges from key_placeholder_node which would
// invalidate the out_nodes iterator.
std::vector<Node*> host_transfer_nodes;
for (Node* node : key_placeholder_node->out_nodes()) {
host_transfer_nodes.push_back(node);
}
for (Node* node : host_transfer_nodes) {
int input_index = -1;
for (int i = 0; i < node->num_inputs(); i++) {
const Edge* e;
TF_RETURN_IF_ERROR(node->input_edge(i, &e));
if (e->src() == key_placeholder_node) {
if (input_index != -1) {
return errors::Internal(
"Node ", node->name(),
" has multiple input edges from key placeholder node");
}
input_index = e->dst_input();
}
}
if (input_index == -1) {
return errors::Internal("Node ", node->name(),
" has no input edge from key placeholder node");
}
const Edge* key_edge;
TF_RETURN_IF_ERROR(node->input_edge(input_index, &key_edge));
graph->RemoveEdge(key_edge);
graph->AddEdge(compile_node, 1, node, input_index);
}
graph->RemoveNode(key_placeholder_node);
return Status::OK();
}
Status DistributedTPURewritePass::BuildVariableReads(
absl::Span<const VariableInput> variables, Node* control_predecessor,
Graph* graph, std::vector<Node*>* variable_reads) {
variable_reads->resize(variables.size());
for (int i = 0; i < variables.size(); ++i) {
string name =
graph->NewName(strings::StrCat(variables[i].node->name(), "/read"));
NodeDefBuilder builder(name, "ReadVariableOp",
NodeDebugInfo(*variables[i].node));
builder.Attr("dtype", variables[i].dtype);
builder.Device(variables[i].node->assigned_device_name());
builder.Input(variables[i].node->name(), 0, DT_RESOURCE);
NodeDef def;
TF_RETURN_IF_ERROR(builder.Finalize(&def));
Status status;
Node* read_node;
(*variable_reads)[i] = read_node = graph->AddNode(def, &status);
if (!status.ok()) return status;
read_node->set_requested_device(variables[i].node->requested_device());
read_node->set_assigned_device_name(
variables[i].node->assigned_device_name());
graph->AddEdge(variables[i].node, variables[i].index, read_node, 0);
graph->AddControlEdge(control_predecessor, read_node);
}
return Status::OK();
}
bool DistributedTPURewritePass::ContainsResourceWriteOp(
const Graph& graph, const FunctionLibraryDefinition& fld) {
for (const Node* n : graph.nodes()) {
const XlaResourceOpInfo* op_info = GetResourceOpInfoForOp(n->type_string());
if (op_info && op_info->kind() != XlaResourceOpKind::kRead) {
VLOG(2) << "Found write resource op inside computation";
return true;
}
}
for (const string& func_name : fld.ListFunctionNames()) {
const FunctionDef* func_def = fld.Find(func_name);
for (const NodeDef& n : func_def->node_def()) {
const XlaResourceOpInfo* op_info = GetResourceOpInfoForOp(n.op());
if (op_info && op_info->kind() != XlaResourceOpKind::kRead) {
VLOG(2) << "Found write resource op inside " << func_name;
return true;
}
}
}
return false;
}
Status DistributedTPURewritePass::BuildVariableWrites(
absl::Span<const VariableInput> variables, Node* control_successor,
absl::Span<const VariableWrite> variable_writes, Graph* graph) {
CHECK_EQ(variables.size(), variable_writes.size());
for (int i = 0; i < variables.size(); ++i) {
const VariableWrite& write = variable_writes[i];
NodeDebugInfo debug_info(*variables[i].node);
auto name = [&](string suffix) {
return graph->NewName(
strings::StrCat(variables[i].node->name(), "/", suffix));
};
Node* write_node;
TF_RETURN_IF_ERROR(
IncompleteNodeDefBuilder(name("assign"), "AssignVariableOp", debug_info)
.AddAttr("dtype", variables[i].dtype)
.Device(variables[i].node->assigned_device_name())
.Build(graph, &write_node));
// Colocate the control flow with the variable.
CondBuilder cb(variables[i].node->name(),
variables[i].node->assigned_device_name(), debug_info,
graph);
// Inputs to conditional.
Node* switch_val;
TF_RETURN_IF_ERROR(
cb.AddInput("switch_val", variables[i].dtype,
/*device=*/write.value->assigned_device_name(), debug_info,
&switch_val));
Node* switch_var;
TF_RETURN_IF_ERROR(
cb.AddInput("switch_var", DT_RESOURCE,
/*device=*/variables[i].node->assigned_device_name(),
debug_info, &switch_var));
// Conditionally write the value back.
graph->AddEdge(variables[i].node, variables[i].index, switch_var, 0);
graph->AddEdge(switch_var, CondBuilder::kThenBranch, write_node, 0);
graph->AddEdge(switch_val, CondBuilder::kThenBranch, write_node, 1);
// Add control edge from the write to value that will be merged. There is no
// output from the write so this control edge ensures the write completes.
graph->AddControlEdge(write_node, cb.switch_t());
graph->AddControlEdge(cb.control_successor(), control_successor);
graph->AddEdge(write.predicate, write.predicate_output, cb.pred(), 0);
graph->AddEdge(write.value, write.value_output, switch_val, 0);
}
return Status::OK();
}
namespace {
// Creates nodes for zero-initialized dummy arguments for TPUExecute nodes.
xla::StatusOr<Node*> MaybeCreatePerHostDummyArgs(
const std::vector<InferredShape>& arg_shapes, const string& host_cpu_device,
const DistributedTPURewritePass::ParameterInfo& params_info, Node* var_read,
int var_num, int num_cores_per_replica, Graph* graph) {
Status status;
if (num_cores_per_replica > 1) {
LOG_FIRST_N(WARNING, 1) << "XLA parameter broadcast is not supported for "
"model-partitioned parameters. Falling back to "
"non-broadcast mode for all parameters.";
return var_read;
}
DataType dtype;
TF_RETURN_IF_ERROR(GetNodeAttr(var_read->def(), "dtype", &dtype));
DeviceNameUtils::ParsedName parsed_device;
TF_RET_CHECK(DeviceNameUtils::ParseFullName(host_cpu_device, &parsed_device));
TF_RET_CHECK(parsed_device.has_task);
// Task 0 behaves as the primary task, where variables are assigned. Use the
// variable reads as arguments to TPUExecute.
// For other tasks, create dummies if the graph meets preconditions.
int64_t orig_arg_num = var_num + params_info.NumPerReplicaArgs() +
params_info.NumDistributedArgs() +
params_info.NumBroadcastArgs();
if (parsed_device.task == 0 ||
!EnableXlaParamBroadcast(/*enable_xla_param_broadcast=*/true, params_info,
orig_arg_num, dtype, num_cores_per_replica)) {
return var_read;
}
auto raw_var_shape = arg_shapes[orig_arg_num];
TensorShape var_shape;
if (!raw_var_shape.handle_shape.AsTensorShape(&var_shape) &&
!raw_var_shape.shape.AsTensorShape(&var_shape)) {
return Status(error::FAILED_PRECONDITION, "Failed to read arg shape.");
}
// Const - shape_as_tensor
const std::string name_prefix = strings::StrCat(
var_read->name(), absl::StrFormat("/dummy_%d", parsed_device.task));
NodeDef shape_tensor_def;
shape_tensor_def.set_op("Const");
shape_tensor_def.set_name(graph->NewName(
strings::StrCat(name_prefix, "/Initializer/zeros/shape_as_tensor")));
AddNodeAttr("dtype", DT_INT32, &shape_tensor_def);
TensorProto tensorshape_proto;
tensorshape_proto.set_dtype(DT_INT32);
for (int i = 0; i < var_shape.dims(); ++i) {
tensorshape_proto.add_int_val(var_shape.dim_size(i));
}
TensorShape shape_shape({var_shape.dims()});
shape_shape.AsProto(tensorshape_proto.mutable_tensor_shape());
AddNodeAttr("value", tensorshape_proto, &shape_tensor_def);
Node* shape_as_tensor_node = graph->AddNode(shape_tensor_def, &status);
TF_RETURN_IF_ERROR(status);
// Const - initializer value
NodeDef init_val_def;
init_val_def.set_op("Const");
init_val_def.set_name(graph->NewName(
strings::StrCat(name_prefix, "/Initializer/zeros/const_val")));
TensorProto tensor_proto;
tensor_proto.set_dtype(dtype);
if (dtype == DT_FLOAT) {
tensor_proto.add_float_val(0.0f);
} else if (dtype == DT_BFLOAT16) {
tensor_proto.add_half_val(0);
} else if (dtype == DT_INT32) {
tensor_proto.add_int_val(0);
} else if (dtype == DT_BOOL) {
tensor_proto.add_bool_val(false);
} else {
return errors::Internal(
"Unable to create zero-init dummy arg tensor for type ", dtype);
}
TensorShape scalar_shape({});
scalar_shape.AsProto(tensor_proto.mutable_tensor_shape());
AddNodeAttr("value", tensor_proto, &init_val_def);
AddNodeAttr("dtype", dtype, &init_val_def);
Node* init_val_node = graph->AddNode(init_val_def, &status);
TF_RETURN_IF_ERROR(status);
// Fill node
NodeDef fill_def;
fill_def.set_op("Fill");
fill_def.set_device(host_cpu_device);
fill_def.set_name(
graph->NewName(strings::StrCat(name_prefix, "/Initializer/zeros")));
AddNodeAttr("T", dtype, &fill_def);
AddNodeAttr("index_type", DT_INT32, &fill_def);
Node* fill_node = graph->AddNode(fill_def, &status);
TF_RETURN_IF_ERROR(status);
graph->AddEdge(shape_as_tensor_node, 0, fill_node, 0);
graph->AddEdge(init_val_node, 0, fill_node, 1);
return fill_node;
}
// Helper that creates an IdentityN node containing all of the variables
// values on CPU device 'device', except for those that will be split across
// cores. (For split variables, this may cause additional cross-host data
// transfers if more than 1 devices share the same variable partition on a
// remote host.)
//
// A previous iteration of this code built one Identity node per TPU core per
// variable, but this can rapidly become hundreds of thousands of nodes. This
// formulation creates a single IdentityN node containing all of the variables
// on each host. This may cause some unnecessary variable copies if only a
// subset of hosts consume a given variable, but has the virtue of being
// simple, and most models use pure replication where all cores want all the
// variables.
//
// If enable_xla_param_broadcast is set to true, then per-host dummy
// tensor args are created on all hosts except for the primary host. In this
// scheme, the dummy args feed the IdentityN node on their local host. All
// are zero-initialized.
//
// Returns the node and its output index to be consumed by TPUExecute for the
// requested variable index.
xla::StatusOr<NodeOut> CreateOrGetPerHostVariableCopy(
const string& host_cpu_device, int64_t var_index,
const std::vector<Node*>& variable_reads,
const DistributedTPURewritePass::ParameterInfo& params_info,
const std::vector<xla::OpSharding>& arg_shardings,
const Node& replicate_node, const bool enable_xla_param_broadcast,
const int num_cores_per_replica,
const std::vector<InferredShape>& arg_shapes,
absl::flat_hash_map<string, std::vector<NodeOut>>* per_host_var_copies,
Graph* graph) {
auto it = per_host_var_copies->find(host_cpu_device);
if (it != per_host_var_copies->end()) {
return it->second[var_index];
}
DataTypeVector dtypes;
// Per-variable data source for TPUExecute.
std::vector<NodeOut> index_mapping;
index_mapping.reserve(variable_reads.size());
dtypes.reserve(variable_reads.size());
for (int64_t i = 0; i < variable_reads.size(); ++i) {
Node* read = variable_reads[i];
int64_t orig_arg_num = i + params_info.NumPerReplicaArgs() +
params_info.NumDistributedArgs() +
params_info.NumBroadcastArgs();
if (arg_shardings[orig_arg_num].type() != xla::OpSharding::OTHER) {
// We haven't built the IdentityN node yet, so temporarily use nullptr.
index_mapping.push_back(
NodeOut{nullptr, static_cast<int>(dtypes.size())});
dtypes.push_back(read->output_type(0));
} else {
// Do not copy the full tensor of partitioned variables.
index_mapping.push_back(NodeOut{read, 0});
}
}
NodeDef ndef;
ndef.set_name(graph->NewName(
absl::StrCat(replicate_node.name(), "/", kTpuExecuteStagingNodeName)));
ndef.set_op(kTpuExecuteStagingOp);
ndef.set_device(host_cpu_device);
AddNodeAttr("T", dtypes, &ndef);
// TF meta-optimizer should skip this node for constant folding.
AddNodeAttr("_tpu_avoid_constant_fold", "not_used", &ndef);
Status s;
Node* id_node = graph->AddNode(ndef, &s);
TF_RETURN_IF_ERROR(s);
id_node->set_assigned_device_name(host_cpu_device);
for (int64_t i = 0; i < variable_reads.size(); ++i) {
if (index_mapping[i].node == nullptr) {
// Fill index_mapping with the actual IdentityN node.
index_mapping[i].node = id_node;
if (!enable_xla_param_broadcast) {
// Add the variable read edge to id_node.
graph->AddEdge(variable_reads[i], 0, id_node, index_mapping[i].index);
} else {
// XLA param broadcast mode is enabled. Create zero-valued dummy
// tensors to use as variable args in the TPUExecuteOp, instead of
// original variable reads.
TF_ASSIGN_OR_RETURN(
Node * var_read,
MaybeCreatePerHostDummyArgs(arg_shapes, host_cpu_device,
params_info, variable_reads[i], i,
num_cores_per_replica, graph));
graph->AddEdge(var_read, 0, id_node, index_mapping[i].index);
}
}
}
auto result = index_mapping[var_index];
(*per_host_var_copies)[host_cpu_device] = std::move(index_mapping);
return result;
}
} // namespace
Status DistributedTPURewritePass::BuildExecuteNodes(
const ParameterInfo& params_info, int num_tasks, int num_cores_per_replica,
const Node& replicate_node, const std::vector<std::string>& arg_names,
const DataTypeVector& arg_types,
const std::vector<InferredShape>& arg_shapes,
const DataTypeVector& retval_types,
const std::vector<xla::OpSharding>& arg_shardings,
const std::vector<xla::OpSharding>& retval_shardings,
const std::vector<std::vector<string>>& tpu_device_names,
Node* compile_node, const std::vector<Node*>& variable_reads,
Node* control_predecessor, Node* control_successor, Node* multilock_acquire,
std::vector<VariableWrite>* variable_writes, Graph* graph) {
VLOG(1) << "BuildExecuteNodes " << replicate_node.DebugString();
TF_RET_CHECK(params_info.NumReplicas() == tpu_device_names.size());
const int num_variables = variable_reads.size();
const int num_retvals_per_replica = retval_types.size();
variable_writes->resize(num_variables);
std::vector<const Edge*> replicate_input_edges;
TF_RETURN_IF_ERROR(replicate_node.input_edges(&replicate_input_edges));
// Map from replicate input index to the fan_in node;
absl::flat_hash_map<int, std::vector<NodeAndPort>>
replicate_input_fan_in_nodes;
absl::flat_hash_map<int, std::vector<Node*>> replicate_output_fan_out_nodes;
absl::flat_hash_map<int, std::vector<int>>
replicate_output_fan_out_dst_inputs;
std::vector<Node*> to_be_removed_nodes;
for (const Edge* e : replicate_input_edges) {
if (e->src()->type_string() == kTPUPartitionedInput) {
int num_users = 0;
for (const auto& ue : e->src()->out_edges()) {
if (!ue->IsControlEdge()) ++num_users;
}
if (num_users != 1) {
return tensorflow::errors::InvalidArgument(
e->src()->name(), " must only have one user. Found ", num_users);
}
to_be_removed_nodes.push_back(e->src());
std::vector<NodeAndPort>& nodes =
replicate_input_fan_in_nodes[e->dst_input()];
nodes.resize(num_cores_per_replica, NodeAndPort(nullptr, 0));
VLOG(2) << "allocate " << num_cores_per_replica
<< " for replicate_input_fan_in_nodes[" << e->dst_input() << "]";
std::vector<const Edge*> fan_in_edges;
TF_RETURN_IF_ERROR(e->src()->input_edges(&fan_in_edges));
TF_RET_CHECK(fan_in_edges.size() == num_cores_per_replica);
for (const Edge* fe : fan_in_edges) {
nodes[fe->dst_input()].node = fe->src();
nodes[fe->dst_input()].port = fe->src_output();
VLOG(2) << "replicate_input_fan_in_nodes[" << e->dst_input() << "]["
<< fe->dst_input() << "] = " << fe->src()->name();
}
}
}
// Replicate output edges are sorted by replica id and then by outputs for
// each replica. For example, if TPU Computation has outputs (output_1,
// output_2, and output_3) and number of replicas is 2, then
// replicate_output_edges order would be:
// output_1_replica_1, output_2_replica_1, output_3_replica_1,
// output_1_replica_2, output_2_replica_2, output_3_replica_2.
std::vector<const Edge*> replicate_output_edges(replicate_node.num_outputs(),
nullptr);
for (const Edge* edge : replicate_node.out_edges()) {
if (edge->IsControlEdge()) continue;
int num_partitioned_outputs = 0;
for (const Edge* out_edge : edge->dst()->out_edges()) {
if (out_edge->dst()->type_string() == kTPUPartitionedOutput) {
num_partitioned_outputs++;
// Paths between replicate_node and replicate_output_fan_out_nodes:
// ReplicateNode->TpuOutIdenity->kTPUPartitionedOutput->fan-out-nodes
TF_RET_CHECK(edge->dst()->out_edges().size() == 1);
to_be_removed_nodes.push_back(edge->dst());
to_be_removed_nodes.push_back(out_edge->dst());
// Get the right replicated id from the replicate_output_edge.
std::vector<Node*>& nodes =
replicate_output_fan_out_nodes[edge->src_output()];
std::vector<int>& dst_inputs =
replicate_output_fan_out_dst_inputs[edge->src_output()];
nodes.resize(num_cores_per_replica, nullptr);
dst_inputs.resize(num_cores_per_replica, 0);
TF_RET_CHECK(out_edge->dst()->out_edges().size() ==
num_cores_per_replica);
for (const Edge* fe : out_edge->dst()->out_edges()) {
nodes[fe->src_output()] = fe->dst();
dst_inputs[fe->src_output()] = fe->dst_input();
VLOG(2) << "replicate_output_fan_out_nodes[" << out_edge->src_output()
<< "][" << fe->src_output()
<< "] = " << fe->dst()->DebugString() << " with dst_input "
<< fe->dst_input();
}
}
}
replicate_output_edges[edge->src_output()] = edge;
if (num_partitioned_outputs > 1) {
return errors::InvalidArgument(
"More than one TPUPartitionedOutput per replicated output.");
}
}
const int num_execute_args =
arg_shardings.size() - params_info.NumGuaranteedConstants();
// Inverts the arg_shardings and retval_shardings mappings to
// form core -> {argument number} maps.
std::vector<std::vector<int>> core_arg_nums(num_cores_per_replica);
for (int i = 0; i < num_execute_args; ++i) {
const auto& sharding = arg_shardings[i];
if (sharding.type() == xla::OpSharding::MAXIMAL) {
int core = sharding.tile_assignment_devices(0);
TF_RETURN_IF_ERROR(ValidateCoreNumber(core, num_cores_per_replica));
core_arg_nums[core].push_back(i);
} else if (sharding.type() == xla::OpSharding::OTHER) {
for (int64_t core : sharding.tile_assignment_devices()) {
core_arg_nums[core].push_back(i);
}
} else if (sharding.type() == xla::OpSharding::REPLICATED) {
for (int core = 0; core < num_cores_per_replica; ++core) {
core_arg_nums[core].push_back(i);
}
} else {
return tensorflow::errors::InvalidArgument(
"Unsupported argument sharding for arg=", arg_names[i],
" shape=", arg_shapes[i].shape.DebugString(), ": ",
sharding.DebugString());
}
}
std::vector<std::vector<int>> core_retval_nums(num_cores_per_replica);
for (int i = 0; i < retval_shardings.size(); ++i) {
const auto& sharding = retval_shardings[i];
if (sharding.type() == xla::OpSharding::MAXIMAL) {
int core = sharding.tile_assignment_devices(0);
TF_RETURN_IF_ERROR(ValidateCoreNumber(core, num_cores_per_replica));
core_retval_nums[core].push_back(i);
} else if (sharding.type() == xla::OpSharding::REPLICATED) {
for (int core = 0; core < num_cores_per_replica; ++core) {
core_retval_nums[core].push_back(i);
}
} else if (sharding.type() == xla::OpSharding::OTHER) {
for (int64_t core : sharding.tile_assignment_devices()) {
core_retval_nums[core].push_back(i);
}
} else {
return tensorflow::errors::InvalidArgument(
"Unsupported argument sharding: ", sharding.DebugString());
}
}
// Maps host device name to a list of per-variable pairs (variable_copy_node,
// output_index_of_copy_node).
absl::flat_hash_map<string, std::vector<NodeOut>> per_host_var_copies;
Node* execute_successor = control_successor;
int num_total_cores = params_info.NumReplicas() * num_cores_per_replica;
if (enable_multicore_locking_ && num_total_cores > 1) {
// Add a node to release exclusive access once all the cores have finished
// execution.
NodeDef lock_def;
lock_def.set_name(graph->NewName(
strings::StrCat(compile_node->name(), "/", "tpu_release_multilock")));
lock_def.set_op("ConsumeTpuMultilock");
MergeDebugInfo(NodeDebugInfo(replicate_node.def()), &lock_def);
Status status;
Node* multilock_release = graph->AddNode(lock_def, &status);
TF_RETURN_IF_ERROR(status);
multilock_release->set_assigned_device_name(
compile_node->assigned_device_name());
TF_RET_CHECK(multilock_acquire != nullptr);
graph->AddEdge(multilock_acquire, 0, multilock_release, 0);
graph->AddControlEdge(multilock_release, control_successor);
// Make sure all execute Ops happen before the multilock_release.
execute_successor = multilock_release;
}
// Mapping from original resource arg number to a second level map. Second
// level map is from core id to output index of updated variable value.
absl::flat_hash_map<int, absl::flat_hash_map<int, int>>
orig_arg_num_to_output_index_mapping;
// Mapping from retval index to a second level map. Second level map is from
// core id to output index of sharded output value.
std::unordered_map<int, std::unordered_map<int, int>>
retval_index_to_output_index_mapping;
// Represents mapping of argument index of sharded input to each
// TPUExecute node to its corresponding Split node and its output index
// from which sharded input will be fed into TPUExecute node.
std::map<ShardedInputIndex, ShardedInputInfo> input_index_to_sharded_inputs;
// Builds one TPUExecute node per core per replica.
std::vector<std::vector<Node*>> execute_nodes(params_info.NumReplicas());
for (int core = 0; core < num_cores_per_replica; ++core) {
DataTypeVector core_retval_types;
for (int output : core_retval_nums[core]) {
core_retval_types.push_back(retval_types[output]);
}
DataTypeVector core_arg_types;
std::vector<int> core_variable_writes;
for (int input : core_arg_nums[core]) {
// Resource variables can be passed either by reference (as a DT_RESOURCE)
// tensor or by value (as the variable's current value). Per-replica or
// distributed resource arguments are always passed by reference and
// broadcast variables are always passed by value.
if (arg_types[input] == DT_RESOURCE &&
!params_info.IsPerReplicaArg(input) &&
!params_info.IsDistributedArg(input)) {
DataType handle_type = arg_shapes[input].handle_type;
TF_RET_CHECK(handle_type != DT_INVALID) << DataTypeString(handle_type);
core_arg_types.push_back(handle_type);
int base = input - params_info.NumPerReplicaArgs() -
params_info.NumDistributedArgs() -
params_info.NumBroadcastArgs();
// Variables passed by value will have a corresponding additional output
// containing an updated value for the variable.
core_variable_writes.push_back(base);
core_retval_types.push_back(handle_type);
} else {
core_arg_types.push_back(arg_types[input]);
}
}
NodeDef def;
def.set_op("TPUExecute");
MergeDebugInfo(NodeDebugInfo(replicate_node.def()), &def);
AddNodeAttr("Targs", core_arg_types, &def);
AddNodeAttr("Tresults", core_retval_types, &def);
for (int64_t replica = 0; replica < params_info.NumReplicas(); ++replica) {
def.set_name(strings::StrCat(replicate_node.name(), "/_execute_", replica,
"_", core));
Status status;
Node* node = graph->AddNode(def, &status);
if (!status.ok()) return status;
execute_nodes[replica].push_back(node);
node->set_assigned_device_name(tpu_device_names[replica][core]);
// Add control edges to ensure that execution happens after
// `control_predecessor`, happens before `execute_successor`, and is
// triggered by evaluating any operator that depends on the original
// TPUReplicate operator. See the comment at the top of the header file
// for more details.
graph->AddControlEdge(control_predecessor, node);
graph->AddControlEdge(node, execute_successor);
// Add data input edges.
for (int64_t i = 0; i < core_arg_nums[core].size(); ++i) {
int64_t orig_arg_num = core_arg_nums[core][i];
VLOG(2) << " replica " << replica << " core " << core << " i " << i
<< " orig_arg_num " << orig_arg_num;
const bool is_per_replica_arg =
params_info.IsPerReplicaArg(orig_arg_num);
if (is_per_replica_arg || params_info.IsDistributedArg(orig_arg_num)) {
// Per-replica input and distributed input
const int64_t input_num =
is_per_replica_arg ? replica * params_info.NumPerReplicaArgs() +
core_arg_nums[core][i]
: params_info.NumReplicas() *
params_info.NumPerReplicaArgs() +
core_arg_nums[core][i] -
params_info.NumPerReplicaArgs();
const Edge* edge = replicate_input_edges[input_num];
VLOG(2) << "replicate_input_edges[" << input_num << "]";
DataType dtype = edge->src()->output_type(edge->src_output());
if (dtype == DT_RESOURCE) {
DataType handle_dtype = arg_shapes[orig_arg_num].handle_type;
if (std::find(kTpuAllTypes.begin(), kTpuAllTypes.end(),
handle_dtype) == kTpuAllTypes.end()) {
return errors::InvalidArgument(
"Unsupported resource variable data type for TPU: ",
DataTypeString(handle_dtype), ", caused by output ",
edge->src()->name(), ":", edge->src_output());
}
} else {
if (std::find(kTpuAllTypes.begin(), kTpuAllTypes.end(), dtype) ==
kTpuAllTypes.end()) {
return errors::InvalidArgument(
"Unsupported data type for TPU: ", DataTypeString(dtype),
", caused by output ", edge->src()->name(), ":",
edge->src_output());
}
}
if (arg_shardings[orig_arg_num].type() == xla::OpSharding::OTHER) {
// Don't automatically add a split node when input node is
// kTPUPartitionedInput
if (edge->src()->type_string() == kTPUPartitionedInput) {
VLOG(2)
<< "Connect "
<< replicate_input_fan_in_nodes[input_num][core].node->name()
<< " to " << node->name() << " at " << i;
graph->AddEdge(replicate_input_fan_in_nodes[input_num][core].node,
replicate_input_fan_in_nodes[input_num][core].port,
node, i);
} else {
if (dtype == DT_RESOURCE) {
return errors::InvalidArgument(
"Tiled sharding for per-replica DT_RESOURCE input must",
"be TPUPartitionedInput. Here got ",
edge->src()->type_string());
}
const xla::OpSharding& sharding = arg_shardings[orig_arg_num];
ShardedInputInfo sharded_input_info;
if (use_nd_sharding_ops_ && is_per_replica_arg) {
TF_ASSIGN_OR_RETURN(
sharded_input_info,
CreateOrGetXlaSplitNodeForShardedPerReplicaArg(
sharding, replica, orig_arg_num, dtype,
PartialTensorShape(), edge->src(), edge->src_output(),
graph, &input_index_to_sharded_inputs));
} else if (use_nd_sharding_ops_) {
TF_ASSIGN_OR_RETURN(
sharded_input_info,
CreateOrGetXlaSplitNodeForDistributedArg(
sharding, params_info.NumReplicas(), replica,
orig_arg_num, dtype, PartialTensorShape(), edge->src(),
edge->src_output(), graph,
&input_index_to_sharded_inputs));
} else {
TF_ASSIGN_OR_RETURN(
sharded_input_info,
CreateOrGetSplitNodesForInputSharding(
sharding, orig_arg_num, dtype, PartialTensorShape(),
replica, edge->src_output(), edge->src(),
control_predecessor, graph,
&input_index_to_sharded_inputs));
}
NodeOut split_node_and_index =
sharded_input_info.sharded_inputs.at(core);
// Connect with Split node output.
graph->AddEdge(split_node_and_index.node,
split_node_and_index.index, node, i);
}
} else if (edge->src()->type_string() == kTPUPartitionedInput &&
arg_shardings[orig_arg_num].type() ==
xla::OpSharding::REPLICATED) {
graph->AddEdge(replicate_input_fan_in_nodes[input_num][core].node,
replicate_input_fan_in_nodes[input_num][core].port,
node, i);
} else {
graph->AddEdge(edge->src(), edge->src_output(), node, i);
}
} else if (params_info.IsBroadcastArg(orig_arg_num)) {
// Broadcast input.
int64_t input_num = params_info.FirstBroadcastArgFromHost() +
core_arg_nums[core][i] -
params_info.NumPerReplicaArgs() -
params_info.NumDistributedArgs();
const Edge* edge = replicate_input_edges[input_num];
DataType dtype = edge->src()->output_type(edge->src_output());
if (std::find(kTpuAllTypes.begin(), kTpuAllTypes.end(), dtype) ==
kTpuAllTypes.end()) {
return errors::InvalidArgument(
"Unsupported data type for TPU: ", DataTypeString(dtype),
", caused by output ", edge->src()->name(), ":",
edge->src_output());
}
graph->AddEdge(edge->src(), edge->src_output(), node, i);
} else {
// Variable input.
int64_t variable_num =
orig_arg_num - params_info.NumPerReplicaArgs() -
params_info.NumDistributedArgs() - params_info.NumBroadcastArgs();
TF_RET_CHECK(variable_num < num_variables);
Node* variable_read = variable_reads[variable_num];
DataType dtype = variable_read->output_type(0);
if (std::find(kTpuAllTypes.begin(), kTpuAllTypes.end(), dtype) ==
kTpuAllTypes.end()) {
return errors::InvalidArgument(
"Unsupported resource variable data type for TPU: ",
DataTypeString(dtype), ", caused by ReadVariableOp ",
variable_read->DebugString());
}
DeviceNameUtils::ParsedName requested_device;
string requested = variable_read->requested_device();
TF_RET_CHECK(
DeviceNameUtils::ParseFullName(requested, &requested_device));
if (requested_device.type != "TPU") {
// Stage the value via the CPU device on the remote host. The graph
// partitioner will introduce an intermediate copy rather than
// copying the same tensor multiple times across the network, and we
// would prefer that intermediate copy to be in host memory to avoid
// running out of memory if the TPUExecute op on the staging device
// starts running before the _Send ops to the other TPU devices on
// the same host complete. We don't do this if the variables are
// already placed on TPU, otherwise it will cause an unnecessary
// round trip copy.
// TODO(b/79580121): give each replica its own on-device variable
// replica and then delete this code.
string device;
TF_RETURN_IF_ERROR(DeviceNameUtils::DeviceNameToCpuDeviceName(
tpu_device_names[replica][core], &device));
TF_ASSIGN_OR_RETURN(
auto var_data,
CreateOrGetPerHostVariableCopy(
device, variable_num, variable_reads, params_info,
arg_shardings, replicate_node, enable_xla_param_broadcast_,
num_cores_per_replica, arg_shapes, &per_host_var_copies,
graph));
if (arg_shardings[orig_arg_num].type() == xla::OpSharding::OTHER) {
ShardedInputInfo sharded_input_info;
if (use_nd_sharding_ops_) {
TF_ASSIGN_OR_RETURN(
sharded_input_info,
CreateOrGetXlaSplitNodeForVariableArg(
arg_shardings[orig_arg_num], params_info.NumReplicas(),
replica, orig_arg_num,
arg_shapes[orig_arg_num].handle_type,
arg_shapes[orig_arg_num].handle_shape, var_data.node,
var_data.index, graph, &to_be_removed_nodes,
&input_index_to_sharded_inputs));
} else {
TF_ASSIGN_OR_RETURN(
sharded_input_info,
CreateOrGetSplitNodesForInputSharding(
arg_shardings[orig_arg_num], orig_arg_num,
arg_shapes[orig_arg_num].handle_type,
arg_shapes[orig_arg_num].handle_shape, replica,
var_data.index, var_data.node, control_predecessor,
graph, &input_index_to_sharded_inputs));
}
NodeOut split_node_and_index =
sharded_input_info.sharded_inputs[core];
// Connect with Split node output.
graph->AddEdge(split_node_and_index.node,
split_node_and_index.index, node, i);
} else {
graph->AddEdge(var_data.node, var_data.index, node, i);
}
} else {
graph->AddEdge(variable_reads[variable_num], 0, node, i);
}
}
}
// Adds a program input edge from the compiler.
graph->AddEdge(compile_node, core + 1, node, node->num_inputs() - 1);
// Add data output edges.
int num_outputs = core_retval_nums[core].size();
for (int i = 0; i < num_outputs; ++i) {
int output_num =
replica * num_retvals_per_replica + core_retval_nums[core][i];
const auto& sharding = retval_shardings[core_retval_nums[core][i]];
if (sharding.type() == xla::OpSharding::OTHER) {
int retval_index = core_retval_nums[core][i];
retval_index_to_output_index_mapping[retval_index][core] = i;
bool is_last_core =
core ==
*std::max_element(sharding.tile_assignment_devices().begin(),
sharding.tile_assignment_devices().end());
bool isPartitionOutNode = false;
const Edge* e = replicate_output_edges[output_num];
const Edge* e_out;
for (const Edge* out_edge : e->dst()->out_edges()) {
if (out_edge->dst()->type_string() == kTPUPartitionedOutput) {
isPartitionOutNode = true;
e_out = out_edge;
}
}
if (isPartitionOutNode) {
graph->AddEdge(
node, i, replicate_output_fan_out_nodes[output_num][core],
replicate_output_fan_out_dst_inputs[output_num][core]);
VLOG(2) << "Connect " << node->name() << " at " << i << " to "
<< replicate_output_fan_out_nodes[output_num][core]->name()
<< " at "
<< replicate_output_fan_out_dst_inputs[output_num][core];
if (is_last_core) {
graph->RemoveEdge(e);
graph->RemoveEdge(e_out);
}
continue;
}
// Do this in the iteration of last core in tile assignment, so all
// TPUExecute nodes have been created.
if (!is_last_core) {
continue;
}
// Add a Concat node.
std::vector<NodeOut> orig_inputs;
for (int64_t tile_index = 0;
tile_index < sharding.tile_assignment_devices_size();
++tile_index) {
int64_t last_tile_dim_size =
*sharding.tile_assignment_dimensions().rbegin();
if (sharding.replicate_on_last_tile_dim() &&
tile_index % last_tile_dim_size != 0) {
continue;
}
int64_t core_id = sharding.tile_assignment_devices(tile_index);
int core_retval_index =
retval_index_to_output_index_mapping[retval_index][core_id];
orig_inputs.push_back(
NodeOut{execute_nodes[replica][core_id],
static_cast<int>(
core_retval_nums[core_id][core_retval_index])});
}
DataType dtype = e->src()->output_type(e->src_output());
Node* concat_node = nullptr;
if (use_nd_sharding_ops_) {
TF_ASSIGN_OR_RETURN(
concat_node, CreateXlaConcatNode(
sharding, replica, dtype,
/*partial_tensor_shape=*/PartialTensorShape(),
orig_inputs, /*device=*/"", graph));
} else {
TF_ASSIGN_OR_RETURN(
concat_node,
CreateConcatNodesForRetval(
sharding, dtype, /*inferred_shape=*/PartialTensorShape(),
replica, orig_inputs, graph, /*device=*/""));
}
const Edge* edge = replicate_output_edges[output_num];
Node* dst = edge->dst();
int dst_input = edge->dst_input();
graph->RemoveEdge(edge);
graph->AddEdge(concat_node, 0, dst, dst_input);
continue;
}
// If this is a replicated output, outputs on all cores will be the
// same, and we only take the output from core 0.
if (sharding.type() == xla::OpSharding::REPLICATED && core != 0) {
continue;
}
// If output has maximal sharding, make sure we only use output from
// TPUExecute node with logical core id equal to core id defined by the
// xla sharding.
if (sharding.type() == xla::OpSharding::MAXIMAL &&
core != sharding.tile_assignment_devices(0)) {
continue;
}
const Edge* replicate_edge_to_replace =
replicate_output_edges[output_num];
Node* dst = replicate_edge_to_replace->dst();
int dst_input = replicate_edge_to_replace->dst_input();
graph->RemoveEdge(replicate_edge_to_replace);
graph->AddEdge(node, i, dst, dst_input);
}
// Feed the updated variable values from the first replica to the
// variable write nodes.
if (replica == 0) {
for (int i = 0; i < core_variable_writes.size(); ++i) {
int orig_arg_num =
core_variable_writes[i] + params_info.NumPerReplicaArgs() +
params_info.NumDistributedArgs() + params_info.NumBroadcastArgs();
const auto& sharding = arg_shardings[orig_arg_num];
// If this is a tiling sharded variable, concat variable updates from
// all cores.
if (sharding.type() == xla::OpSharding::OTHER) {
orig_arg_num_to_output_index_mapping[orig_arg_num][core] = i;
// Do this in the iteration of last core in tile assignment, so all
// TPUExecute nodes have been created.
if (core !=
*std::max_element(sharding.tile_assignment_devices().begin(),
sharding.tile_assignment_devices().end())) {
continue;
}
// Add a Concat node.
std::vector<NodeOut> orig_inputs;
for (int64_t tile_index = 0;
tile_index < sharding.tile_assignment_devices_size();
++tile_index) {
int64_t last_tile_dim_size =
*sharding.tile_assignment_dimensions().rbegin();
if (sharding.replicate_on_last_tile_dim() &&
tile_index % last_tile_dim_size != 0) {
continue;
}
int64_t core_id = sharding.tile_assignment_devices(tile_index);
int core_retval_num =
orig_arg_num_to_output_index_mapping[orig_arg_num][core_id];
orig_inputs.push_back(
NodeOut{execute_nodes[0][core_id],
static_cast<int>(core_retval_nums[core_id].size() +
core_retval_num)});
}
// Use the variable read's device for the concat. They should both
// be collocated with the variable.
absl::string_view device =
variable_reads[core_variable_writes[i]]->assigned_device_name();
Node* concat_node = nullptr;
if (use_nd_sharding_ops_) {
TF_ASSIGN_OR_RETURN(
concat_node,
CreateXlaConcatNode(sharding, replica,
arg_shapes[orig_arg_num].handle_type,
arg_shapes[orig_arg_num].handle_shape,
orig_inputs, device, graph));
} else {
TF_ASSIGN_OR_RETURN(
concat_node,
CreateConcatNodesForRetval(
sharding, arg_shapes[orig_arg_num].handle_type,
arg_shapes[orig_arg_num].handle_shape, replica,
orig_inputs, graph, device));
}
// Populate VariableWrite.
VariableWrite& write = variable_writes->at(core_variable_writes[i]);
write.value = concat_node;
write.value_output = 0;
write.predicate = compile_node;
write.predicate_output = num_cores_per_replica + core + 1;
continue;
}
// If this is a replicated variable, outputs on all cores will be the
// same, and we only take the output from core 0 for the variable
// update.
if (sharding.type() == xla::OpSharding::REPLICATED && core != 0) {
continue;
}
VariableWrite& write = variable_writes->at(core_variable_writes[i]);
write.value = node;
write.value_output = num_outputs + i;
write.predicate = compile_node;
write.predicate_output = num_cores_per_replica + core + 1;
}
}
}
}
for (Node* node : to_be_removed_nodes) {
graph->RemoveNode(node);
}
return Status::OK();
} // NOLINT(readability/fn_size)
/* static */ Status DistributedTPURewritePass::CopyOutsideCompilationNodes(
int replica_index, const std::vector<Node*>& outside_compilation_nodes,
const DeviceNameUtils::ParsedName& tpu_device,
const DeviceNameUtils::ParsedName& partial_device,
NodeToNodeReplicasMap* node_images, Graph* graph) {
for (Node* node : outside_compilation_nodes) {
NodeDef image_def = node->def();
MergeDebugInfo(NodeDebugInfo(node->def()), &image_def);
const string suffix = strings::StrCat("/R", replica_index);
// In addition to node name, make the frame name unique to avoid multiple
// LoopCond nodes in one frame.
TF_RETURN_IF_ERROR(
AddPrefixAndSuffixToNode("" /* prefix */, suffix, &image_def));
Status status;
Node* image = graph->AddNode(image_def, &status);
image->AddAttr(kXlaReplicaIdAttrName, replica_index);
TF_RETURN_IF_ERROR(status);
if (HasNodeAttr(image->def(), kXlaHasHostTransferAttrName)) {
TF_RETURN_IF_ERROR(
SetNodeDeviceForTPUCommunication(tpu_device, DEVICE_CPU, image));
} else {
const string& original_device_string =
node->assigned_device_name().empty() ? node->requested_device()
: node->assigned_device_name();
DeviceNameUtils::ParsedName device;
TF_RET_CHECK(
DeviceNameUtils::ParseFullName(original_device_string, &device));
// If the requested device can be merged with the replica's host device,
// then do so. For example, if the requested device is "/CPU:0" or
// "/GPU:0" then it will be placed on the CPU/GPU of the host where this
// replica is running. But if the requested device is
// "/task:3/replica:2/CPU:0" then it will be placed on that task/replica.
if (DeviceNameUtils::IsSpecification(device, partial_device)) {
TF_RETURN_IF_ERROR(
DeviceNameUtils::MergeDevNames(&device, partial_device));
}
image->set_requested_device(DeviceNameUtils::ParsedNameToString(device));
}
std::vector<Node*>& node_image_vector = (*node_images)[node];
node_image_vector.resize(replica_index + 1);
node_image_vector[replica_index] = image;
}
return Status::OK();
}
/* static */ Status DistributedTPURewritePass::ReplicateOutsideCompilationNodes(
const std::vector<std::vector<string>>& tf_device_assignment,
const HostComputeCoreMap& host_compute_core,
const OutsideCompilationNodeMap& outside_compilation_nodes,
NodeToNodeReplicasMap* node_images, Graph* graph) {
// Iterate over replicas.
for (int i = 0; i < tf_device_assignment.size(); ++i) {
const auto& core_devices = tf_device_assignment[i];
for (const auto& oc_cluster_iter : outside_compilation_nodes) {
const string& oc_cluster_name = oc_cluster_iter.first;
const auto& oc_cluster_nodes = oc_cluster_iter.second;
// We previously validated that host_compute_core contains an entry for
// each cluster.
int core = host_compute_core.at(oc_cluster_name);
TF_RET_CHECK(core >= 0 && core < core_devices.size());
// tpu_device is the device the HostCompute XLA Op for this cluster runs
// on.
DeviceNameUtils::ParsedName tpu_device;
TF_RET_CHECK(
DeviceNameUtils::ParseFullName(core_devices[core], &tpu_device));
// partial_device contains the replica and task but not the type.
DeviceNameUtils::ParsedName partial_device = tpu_device;
partial_device.has_type = false;
partial_device.has_id = false;
if (tf_device_assignment.size() == 1) {
// With a single replica don't copy any nodes just put the original
// nodes into the image map. We leave the device placement alone, except
// that we have to fill in the correct core for the host send and
// receive nodes.
for (Node* node : oc_cluster_nodes) {
(*node_images)[node] = {node};
node->AddAttr(kXlaReplicaIdAttrName, 0);
if (HasNodeAttr(node->def(), kXlaHasHostTransferAttrName)) {
TF_RETURN_IF_ERROR(
SetNodeDeviceForTPUCommunication(tpu_device, DEVICE_CPU, node));
}
}
} else {
// Iterate over outside_compilation clusters in this computation, adding
// all the nodes with appropriate device assignments.
TF_RETURN_IF_ERROR(
CopyOutsideCompilationNodes(i, oc_cluster_nodes, tpu_device,
partial_device, node_images, graph));
}
}
}
return Status::OK();
}
/* static */ Status DistributedTPURewritePass::CopyOutsideCompilationEdges(
const std::vector<Node*>& outside_compilation_nodes,
const NodeToNodeReplicasMap& node_images,
const std::unordered_map<string, Node*> outside_compilation_inputs,
Graph* graph) {
for (Node* node : outside_compilation_nodes) {
const auto& images = node_images.at(node);
// Make a copy of all edges and iterate on "in_edges", because we might
// remove edges when iteratating through them.
std::vector<const Edge*> in_edges(node->in_edges().begin(),
node->in_edges().end());
for (const Edge* edge : in_edges) {
Node* src = edge->src();
const auto iter = node_images.find(src);
if (iter == node_images.end()) {
if (images.size() > 1) {
// The source node is a 'normal' node not part of any
// rewrite. Broadcast the value to all replicas. (If images.size() ==
// 1 the cluster is not replicated and we can leave the original edge
// in place.)
for (Node* dst : images) {
graph->AddEdge(src, edge->src_output(), dst, edge->dst_input());
}
}
continue;
}
// The source node is a replicated outside_compilation node.
const auto& src_images = iter->second;
if (src_images.size() != images.size()) {
return errors::InvalidArgument(
"Graph contains an edge from node ", src->name(),
" in an outside_compilation block replicated ", src_images.size(),
" ways to node ", node->name(),
" in an outside_compilation block replicated ", images.size(),
" ways. Replication factors must match. Leave a comment on "
"tracking bug b/76419636 if you need this to be supported.");
}
bool is_lifted_arg;
string outside_compilation_cluster;
if (GetNodeAttr(src->def(), kXlaIsLiftedArgAttrName, &is_lifted_arg)
.ok() &&
GetNodeAttr(src->def(), kOutsideCompilationAttr,
&outside_compilation_cluster)
.ok()) {
const auto input_iter =
outside_compilation_inputs.find(outside_compilation_cluster);
TF_RET_CHECK(input_iter != outside_compilation_inputs.end());
TF_RET_CHECK(input_iter->second->type_string() == "IdentityN");
int dst_input = edge->dst_input();
if (src_images.size() == 1) {
graph->RemoveEdge(edge);
}
for (int i = 0; i < src_images.size(); ++i) {
graph->AddEdge(input_iter->second, i, images[i], dst_input);
}
continue;
}
bool is_placeholder_for_arg;
string outside_compilation_input_attr;
if (GetNodeAttr(src->def(), kXlaIsPlaceholderForArg,
&is_placeholder_for_arg)
.ok() &&
GetNodeAttr(src->def(), kXlaOutsideCompilationInputsAttrName,
&outside_compilation_input_attr)
.ok()) {
const auto input_iter =
outside_compilation_inputs.find(outside_compilation_input_attr);
TF_RET_CHECK(input_iter != outside_compilation_inputs.end());
TF_RET_CHECK(input_iter->second->type_string() == "IdentityN");
int dst_input = edge->dst_input();
if (src_images.size() == 1) {
graph->RemoveEdge(edge);
}
for (int i = 0; i < src_images.size(); ++i) {
graph->AddEdge(input_iter->second, i, images[i], dst_input);
}
continue;
}
if (images.size() > 1) {
// If images.size() == 1 neither cluster is replicated and we can
// leave the original edges in place.
for (int i = 0; i < src_images.size(); ++i) {
graph->AddEdge(src_images[i], edge->src_output(), images[i],
edge->dst_input());
}
}
}
for (const Edge* edge : node->out_edges()) {
Node* dst = edge->dst();
const auto iter = node_images.find(dst);
if (iter == node_images.end()) {
// The source node is a 'normal' node not part of any rewrite.
if (edge->IsControlEdge()) {
// Make the dst node have a control dependency on every replica.
if (images.size() > 1) {
for (int i = 0; i < images.size(); ++i) {
graph->AddControlEdge(images[i], dst);
}
}
// else the cluster is not replicated so we can leave the original
// edge in place.
} else {
// The edge
// is only valid if the outside_compilation block is not replicated.
if (images.size() > 1) {
return errors::InvalidArgument(
"Graph contains an edge from node ", node->name(),
" in an outside_compilation block replicated ", images.size(),
" ways to node ", dst->name(),
" that is not part of an outside_compilation block. Edges from "
"outside_compilation to regular graph nodes are only supported "
"for replication factors of 1. Leave a comment on tracking bug "
"b/76419636 if you need this to be supported.");
}
// else the cluster is not replicated so we can leave the original
// edge in place.
}
}
// The case where src and dst are both in node_images is covered elsewhere
// when iterating over in_edges of dst.
}
}
return Status::OK();
}
/* static */ Status DistributedTPURewritePass::ReplicateOutsideCompilationEdges(
const OutsideCompilationNodeMap& outside_compilation_nodes,
const NodeToNodeReplicasMap& node_images,
const std::unordered_map<string, Node*> outside_compilation_inputs,
Graph* graph) {
for (const auto& oc_cluster_iter : outside_compilation_nodes) {
TF_RETURN_IF_ERROR(
CopyOutsideCompilationEdges(oc_cluster_iter.second, node_images,
outside_compilation_inputs, graph));
}
return Status::OK();
}
/* static */ Status DistributedTPURewritePass::RemoveOutsideCompilationNodes(
const NodeToNodeReplicasMap& node_images, Graph* graph) {
for (const auto& iter : node_images) {
if (iter.second.size() > 1) {
// The cluster was replicated so remove the original node.
Node* node = iter.first;
graph->RemoveNode(node);
}
}
return Status::OK();
}
/* static */ Status
DistributedTPURewritePass::LowerOutsideCompilationFunctionalNodes(
Graph* g, FunctionLibraryDefinition& flib_def,
const TPUReplicateDeviceNamesMapping& tpu_replicate_device_names_mapping) {
bool modified = false;
do {
std::vector<Node*> nodes_to_lower;
for (Node* n : g->op_nodes()) {
if (!HasNodeAttr(n->def(), kOutsideCompilationAttr)) {
continue;
}
if (n->IsWhileNode() || n->IsIfNode() || IsFunctionCall(flib_def, *n)) {
// Only lower functional ops with DT_RESOURCE input, because otherwise
// placer will complain. For normal cases, lowering will cause slowdown
// when related functions are huge (b/139037679).
bool has_resource_input = false;
for (const Edge* e : n->in_edges()) {
if (!e->IsControlEdge() &&
e->src()->output_type(e->src_output()) == DT_RESOURCE) {
has_resource_input = true;
break;
}
}
if (has_resource_input) {
nodes_to_lower.push_back(n);
}
}
}
modified = !nodes_to_lower.empty();
auto lower_functional_node = [&flib_def, &g](Node* n) -> Status {
// Clear device assignment. Otherwise all lowered nodes will have
// device assignment, which is not what we want.
n->set_requested_device("");
int replica_id;
TF_RETURN_IF_ERROR(
GetNodeAttr(n->def(), kXlaReplicaIdAttrName, &replica_id));
string outside_compilation_attr;
TF_RETURN_IF_ERROR(GetNodeAttr(n->def(), kOutsideCompilationAttr,
&outside_compilation_attr));
// There are two different kinds of functional outside compilation nodes:
// 1. Nodes that are in outside compilation blocks already. They are
// generated by FunctionalizeControlFlowForXlaPass, and only have
// attribute kOutsideCompilationAttr.
// 2. Mirrored control flow built for outside compilation in functional
// nodes. They are generated by ExtractOutsideCompilationPass, and have
// both kOutsideCompilationAttr and kXlaHasHostTransferAttrName.
// When lowering them, they need to be treated differently.
// For 1), their body functions are always V1 functions written by users,
// and their "control outputs" are control inputs of _Retval nodes. They
// should be lowered as V1 functions.
// For 2), we always add necessary "control outputs"
// (_XlaRecvAtHost/_XlaSendAtHost nodes) to "control_ret" field in their
// FunctionDef's. They should be lowered as V2 functions.
bool is_host_side_mirrored_control_flow =
HasNodeAttr(n->def(), kXlaHasHostTransferAttrName);
int num_node_ids = g->num_node_ids();
bool is_call_node = IsFunctionCall(flib_def, *n);
if (n->IsWhileNode()) {
TF_RETURN_IF_ERROR(RewriteWhileNode(n, g, &flib_def,
/*keep_node_fetchable=*/false));
} else if (n->IsIfNode()) {
TF_RETURN_IF_ERROR(RewriteIfNode(n, g, /*keep_node_fetchable=*/false));
} else {
TF_RET_CHECK(is_call_node);
// See comments for "is_host_side_mirrored_control_flow" above.
// If this is a node that's in outside compilation block, lower it as
// V1 function. This is controlled by removing
// kLowerAsMultiDeviceFunctionAttr from the node.
if (!is_host_side_mirrored_control_flow) {
n->ClearAttr(LowerFunctionalOpsPass::kLowerAsMultiDeviceFunctionAttr);
} else {
n->ClearAttr(LowerFunctionalOpsPass::kLowerAsMultiDeviceFunctionAttr);
n->AddAttr(LowerFunctionalOpsPass::kLowerAsMultiDeviceFunctionAttr,
true);
}
TF_RETURN_IF_ERROR(
RewriteFunctionCallNode(n, g, flib_def,
/*keep_caller_fetchable=*/false));
}
for (int i = num_node_ids; i < g->num_node_ids(); i++) {
Node* node = g->FindNodeId(i);
if (!node) {
continue;
}
if (!is_call_node && is_host_side_mirrored_control_flow &&
IsFunctionCall(flib_def, *node)) {
// For If/While nodes, if they are host side mirrored control flow,
// mark their body function calls with kXlaHasHostTransferAttrName
// attribute to make sure we lower them as V2 function.
node->AddAttr(kXlaHasHostTransferAttrName, true);
}
if (IsFunctionCall(flib_def, *node) || node->IsWhileNode() ||
node->IsIfNode()) {
// Set kOutsideCompilationAttr attribute so we lower these
// nested function call nodes later.
node->AddAttr(kOutsideCompilationAttr, outside_compilation_attr);
// Set kXlaReplicaIdAttrName attribute so we know replica id when we
// lower this function call node.
node->AddAttr(kXlaReplicaIdAttrName, replica_id);
} else if (node->type_string() == "_XlaRecvAtHost" ||
node->type_string() == "_XlaSendFromHost") {
// For "_XlaRecvAtHost" and "_XlaSendFromHost" nodes, make sure they
// have kXlaReplicaIdAttrName attribute so later we know which host
// device to assign.
node->AddAttr(kXlaReplicaIdAttrName, replica_id);
}
}
return Status::OK();
};
for (Node* n : nodes_to_lower) {
TF_RETURN_IF_ERROR(lower_functional_node(n));
}
} while (modified);
// Set device for all _XlaRecvAtHost and _XlaSendFromHost nodes.
for (Node* n : g->op_nodes()) {
if (n->type_string() != "_XlaRecvAtHost" &&
n->type_string() != "_XlaSendFromHost") {
continue;
}
string replicate;
TF_RETURN_IF_ERROR(GetNodeAttr(n->def(), kTPUReplicateAttr, &replicate));
auto iter = tpu_replicate_device_names_mapping.find(replicate);
TF_RET_CHECK(iter != tpu_replicate_device_names_mapping.end());
const auto& tpu_device_names = iter->second;
int replica_id;
TF_RETURN_IF_ERROR(
GetNodeAttr(n->def(), kXlaReplicaIdAttrName, &replica_id));
TF_RET_CHECK(replica_id < tpu_device_names.size());
const string& tpu_device_name = tpu_device_names[replica_id][0];
string host_device_name;
TF_RETURN_IF_ERROR(DeviceNameUtils::DeviceNameToCpuDeviceName(
tpu_device_name, &host_device_name));
n->set_assigned_device_name(host_device_name);
// We may run TPU rewrite passes again on the subgraphs of the resulting
// graph. Clear kTPUReplicateAttr and kOutsideCompilationAttr for
// "_XlaRecvAtHost" nodes and "_XlaSendFromHost" nodes, in order to make
// sure that TPU rewrite passes take no effect on host-side subgraphs for
// outside compilation.
n->ClearAttr(kTPUReplicateAttr);
n->ClearAttr(kOutsideCompilationAttr);
}
// Remove IdentityN nodes generated for outside compilation. IdentityN is
// exempt from resource edge colocation, but here we do need input and output
// for these IdentityN nodes to be colocated.
std::vector<Node*> identityn_nodes;
for (Node* n : g->op_nodes()) {
if (n->type_string() == "IdentityN" &&
HasNodeAttr(n->def(), kXlaOutsideCompilationInputsAttrName)) {
identityn_nodes.push_back(n);
}
}
for (Node* n : identityn_nodes) {
std::vector<const Edge*> out_edges(n->out_edges().begin(),
n->out_edges().end());
for (const Edge* e : out_edges) {
if (e->IsControlEdge()) {
continue;
}
int src_output = e->src_output();
const Edge* input_edge;
TF_RETURN_IF_ERROR(n->input_edge(src_output, &input_edge));
Node* dst = e->dst();
int dst_input = e->dst_input();
g->RemoveEdge(e);
g->AddEdge(input_edge->src(), input_edge->src_output(), dst, dst_input);
}
g->RemoveNode(n);
}
return Status::OK();
}
/* static */ Status DistributedTPURewritePass::ParseHostComputeCores(
const Node& replicate_node,
const OutsideCompilationNodeMap& outside_compilation_nodes,
HostComputeCoreMap* host_compute_core) {
std::vector<string> hc_core_string;
TF_RETURN_IF_ERROR(GetNodeAttr(replicate_node.attrs(), "host_compute_core",
&hc_core_string));
TF_RETURN_IF_ERROR(
ParseHostComputeCoreList(hc_core_string, host_compute_core));
for (const auto& iter : outside_compilation_nodes) {
const string& oc_cluster_name = iter.first;
if (host_compute_core->find(oc_cluster_name) == host_compute_core->end()) {
// By default put host compute Ops on replicated core 0.
(*host_compute_core)[oc_cluster_name] = 0;
}
}
return Status::OK();
}
/* static */ Status DistributedTPURewritePass::GetDeviceTopology(
const DeviceSet& device_set, const Node& replicate_node, int* num_replicas,
int* num_cores_per_replica, int* num_tasks,
std::vector<std::vector<string>>* tf_device_assignment,
std::vector<int>* devices_to_lock,
std::unique_ptr<xla::DeviceAssignment>* xla_device_assignment,
string* tpu_compilation_device) {
TF_RETURN_IF_ERROR(
GetNodeAttr(replicate_node.attrs(), "num_replicas", num_replicas));
if (*num_replicas < 1) {
return errors::InvalidArgument("num_replicas must be >= 1, got ",
*num_replicas);
}
// Find the set of TPU devices in the TF job.
// Indexed by [task number][tpu device number].
std::vector<std::vector<Device*>> tpu_devices;
int num_tpus_per_task;
TF_RETURN_IF_ERROR(GetTPUDeviceNames(replicate_node.requested_device(),
device_set, tpu_compilation_device,
&num_tpus_per_task, &tpu_devices));
*num_tasks = tpu_devices.size();
string topology;
TF_RETURN_IF_ERROR(
GetNodeAttr(replicate_node.attrs(), "topology", &topology));
TF_RETURN_IF_ERROR(GetNodeAttr(
replicate_node.attrs(), "num_cores_per_replica", num_cores_per_replica));
std::vector<int> device_assignment;
TF_RETURN_IF_ERROR(GetNodeAttr(replicate_node.attrs(), "device_assignment",
&device_assignment));
// TODO(cwhipkey): since we can control multiple pods of different shapes
// from a single worker, it may be desirable to propagate the remote device
// information around (e.g., in DeviceAttributes). This can lead to the mesh
// topology proto being leaked to cloud TPU users (e.g. through GetStatus
// calls); this may be okay, but to be conservative, just assume that the
// master session has the proper flags set.
// We do not initialize platform right now, but we can still retrieve the
// TPU topology even with an uninitialized platform.
auto* tpu_platform = tpu::TpuPlatformInterface::GetRegisteredPlatform(
/*initialize_platform=*/false);
TF_RET_CHECK(tpu_platform);
tpu::TpuTopologyExternal tpu_topology(tpu_platform->GetTopologyPtr());
TF_RET_CHECK(num_tpus_per_task ==
tpu_topology.LogicalDevicesPerHost(kTensorCore));
TF_RETURN_IF_ERROR(BuildDeviceAssignment(
tpu_topology, num_tpus_per_task, tpu_devices, *num_replicas,
*num_cores_per_replica, topology, device_assignment, tf_device_assignment,
devices_to_lock, xla_device_assignment));
return Status::OK();
}
/* static */ Status DistributedTPURewritePass::GetIOTypes(
int num_replicas, const Node& replicate_node, FunctionLibraryRuntime* flr,
Graph* graph, NameRangeMap* input_name_map, const NameAttrList** function,
std::unique_ptr<Graph>* computation, DataTypeVector* arg_types,
DataTypeVector* retval_types, ParameterInfo* params_info) {
DataTypeVector input_types, broadcast_input_types, guaranteed_constant_types;
TF_RETURN_IF_ERROR(
GetNodeAttr(replicate_node.attrs(), "Tinputs", &input_types));
TF_RETURN_IF_ERROR(GetNodeAttr(replicate_node.attrs(), "Tbroadcast_inputs",
&broadcast_input_types));
TF_RETURN_IF_ERROR(GetNodeAttr(replicate_node.attrs(),
"Tguaranteed_constants",
&guaranteed_constant_types));
int num_distributed_vars;
TF_RETURN_IF_ERROR(GetNodeAttr(replicate_node.attrs(),
"num_distributed_variables",
&num_distributed_vars));
const int num_per_replica_inputs = input_types.size() - num_distributed_vars;
if (num_per_replica_inputs % num_replicas != 0) {
return errors::InvalidArgument(
"Number of inputs to TPUReplicate (", num_per_replica_inputs,
") is not divisible by the number of replicas (", num_replicas, ").");
}
int num_variables;
TF_RETURN_IF_ERROR(
GetNodeAttr(replicate_node.attrs(), "NumVariables", &num_variables));
NameRangeMap output_name_map;
TF_RETURN_IF_ERROR(NameRangesForNode(replicate_node, replicate_node.op_def(),
input_name_map, &output_name_map));
TF_RETURN_IF_ERROR(
GetNodeAttr(replicate_node.attrs(), "computation", function));
*computation = absl::make_unique<Graph>(graph->op_registry());
TF_RETURN_IF_ERROR(GetComputationForTPUReplicateOp(
**function, flr, computation->get(), arg_types, retval_types));
*params_info = ParameterInfo(
num_replicas, num_per_replica_inputs / num_replicas, num_distributed_vars,
broadcast_input_types.size(), num_variables,
guaranteed_constant_types.size(), retval_types->size());
if (arg_types->size() != params_info->NumInputsToEachReplica()) {
return errors::InvalidArgument(
"Computation argument to TPUReplicate has wrong number of "
"arguments. Expected ",
params_info->NumInputsToEachReplica(), " inputs, got ",
arg_types->size());
}
if (replicate_node.num_outputs() != params_info->NumOutputsToHost()) {
return errors::InvalidArgument(
"Wrong number of outputs from TPUReplicate. Expected ",
params_info->NumOutputsToHost(), " outputs, got ",
replicate_node.num_outputs());
}
if (enable_cross_replica_sharding_mirrored_variables_) {
std::vector<int> mirrored_variable_indices;
TF_RETURN_IF_ERROR(GetNodeAttr(replicate_node.attrs(),
TPUREPLICATE_MIRRORED_VAR_INDICES_ATTR,
&mirrored_variable_indices));
for (int index : mirrored_variable_indices) {
TF_RET_CHECK(params_info->IsPerReplicaArg(index) ||
params_info->IsDistributedArg(index))
<< "Mirrored variables not categorized as per-replica arguments, "
"index: "
<< index;
params_info->mutable_mirrored_variable_indices()->insert(index);
}
}
return Status::OK();
}
/* static */ Status DistributedTPURewritePass::BuildSequencingNodes(
const string& tpu_compilation_device, const Node& replicate_node,
Graph* graph, Node** host_transfer_sequencer, Node** control_before,
Node** control_after) {
*host_transfer_sequencer = nullptr;
TF_RETURN_IF_ERROR(
BuildNoopNode(replicate_node,
graph->NewName(strings::StrCat(replicate_node.name(), "/",
"control_before")),
/*device=*/"", graph, control_before));
for (const Edge* e : replicate_node.in_edges()) {
if (!e->IsControlEdge()) {
continue;
}
Node* predecessor = e->src();
if (predecessor->IsSource()) continue;
if (predecessor->type_string() == "NoOp" &&
predecessor->attrs().Find("_xla_host_transfer_sequencer") != nullptr) {
// The node is the sequencer for host transfer operations. Its control
// dependency needs to be placed after the execute node, not before.
if (*host_transfer_sequencer != nullptr) {
return errors::Internal("Replicate node ", replicate_node.name(),
" has two transfer sequencer nodes: ",
(*host_transfer_sequencer)->name(), " and ",
predecessor->name());
}
// Set the correct device to match the other sequencing nodes.
predecessor->set_assigned_device_name(tpu_compilation_device);
*host_transfer_sequencer = predecessor;
} else {
graph->AddControlEdge(predecessor, *control_before);
}
}
TF_RETURN_IF_ERROR(
BuildNoopNode(replicate_node,
graph->NewName(strings::StrCat(replicate_node.name(), "/",
"control_after")),
/*device=*/tpu_compilation_device, graph, control_after));
for (Node* successor : replicate_node.out_nodes()) {
if (successor->attrs().Find("_xla_tail_outside_compilation") != nullptr) {
graph->AddControlEdge(successor, *control_after);
} else {
graph->AddControlEdge(*control_after, successor);
}
}
return Status::OK();
}
/* static */ Status DistributedTPURewritePass::DealWithConstantsAndVariables(
const Node& replicate_node, const NameRangeMap& input_name_map,
Graph* graph, Node* host_transfer_sequencer, Node* control_before,
Node* control_after, absl::Span<const VariableInput> variable_nodes,
std::vector<Node*>* guaranteed_constant_nodes,
std::vector<Node*>* variable_reads) {
TF_RETURN_IF_ERROR(FindGuaranteedConstantInputs(
replicate_node, input_name_map, guaranteed_constant_nodes));
TF_RETURN_IF_ERROR(BuildVariableReads(variable_nodes, control_before, graph,
variable_reads));
// Add the control dependency from host transfer nodes.
if (host_transfer_sequencer != nullptr) {
graph->AddControlEdge(host_transfer_sequencer, control_after);
}
return Status::OK();
}
/* static */ Status
DistributedTPURewritePass::BuildCompilationStatusReturnNodes(
Node* replicate_node, Node* compile_node,
absl::Span<const int> devices_to_lock, Node** control_after_compilation,
Node** multilock_acquire, Graph* graph) {
const Edge* compilation_edge = nullptr;
for (const auto* e : replicate_node->out_edges()) {
if (e->IsControlEdge() &&
e->dst()->type_string() == "TPUCompilationResult") {
TF_RET_CHECK(compilation_edge == nullptr)
<< "Multiple compilation result nodes attached to the same replicate "
"cluster.";
compilation_edge = e;
}
}
// TODO(jpienaar): This should be checked by default, current tests not using
// this are ones that use the "abort upon successful compile flag" which will
// be removed. Leaving this in until then.
if (compilation_edge != nullptr) {
Node* compilation_status = compilation_edge->dst();
const AttrValue* compile_status_cluster_attr =
compilation_status->attrs().Find(kTPUCompilationResultAttr);
TF_RET_CHECK(compile_status_cluster_attr != nullptr);
const string& compile_status_cluster = compile_status_cluster_attr->s();
TF_RET_CHECK(!compile_status_cluster.empty());
const AttrValue* replicate_cluster_attr =
replicate_node->attrs().Find(kTPUReplicateAttr);
TF_RET_CHECK(replicate_cluster_attr != nullptr);
const string& replicate_cluster = replicate_cluster_attr->s();
TF_RET_CHECK(!replicate_cluster.empty());
TF_RET_CHECK(compile_status_cluster == replicate_cluster);
TF_RETURN_IF_ERROR(
ReplaceCompilationResultNodeWithIdentity(graph, &compilation_status));
graph->AddEdge(compile_node, 0, compilation_status, 0);
}
NodeDef def;
def.set_name(UniqueNodeName("tpu_compile_succeeded_assert", graph));
// Create an op to assert that compilation succeeded. The alternative would
// have been to have each execute op check and return an error.
def.set_op("TPUCompileSucceededAssert");
MergeDebugInfo(NodeDebugInfo(replicate_node->def()), &def);
Status status;
Node* compile_succeeded = graph->AddNode(def, &status);
compile_succeeded->set_assigned_device_name(
compile_node->assigned_device_name());
TF_RETURN_IF_ERROR(status);
graph->AddEdge(compile_node, 0, compile_succeeded, 0);
Node* last_node_before_sequencer = compile_succeeded;
if (enable_multicore_locking_ && devices_to_lock.size() > 1) {
// Add a lock node to acquire exclusive access to all the cores that will
// execute this program. The lock is required to prevent deadlock or
// incorrect results when running concurrent multi-core programs in the
// same distributed runtime when there is no direct graph dependency
// between the programs (either because they are run from different sessions
// or because they are in the same graph, but have no control or data
// dependencies to sequence them). Consider the case of two multi-core
// computations A and B whose cores overlap and include cores X and Y. With
// no locking and no graph dependencies it is possible that A's program
// gets enqueued before B's on core X, while B's program gets enqueued
// before A's on core Y. This will lead either to deadlock or to
// incorrect results, since the runtime has no mechanism to re-sequence
// the programs on the cores. By adding a multi-lock acquisition for all the
// before any TPUExecute ops are run, and releasing it after they complete,
// we ensure that the programs are enqueued on the cores in a consistent
// order.
//
// There is a risk when computations are in the same graph, and include a
// data dependency, that the lock acquisition could provoke deadlock.
// Suppose that A must happen before B because B's input depends on A's
// output. Then it is obviously necessary that A's lock acquisition must
// happen before B's lock acquisition, and so we must ensure that there is
// a graph dependency causing B's lock acquisition to be sequenced after A's
// lock acquisition. Right now that dependency is satisfied because the
// shape inference code cannot determine the shape of A's outputs, and so
// B's compilation, which precedes B's lock acquisition, is always sequenced
// after A's execution. If the shape inference is improved it will be
// necessary to add an explicit control edge between dependent lock
// acquisition ops.
NodeDef lock_def;
lock_def.set_name(graph->NewName(
strings::StrCat(compile_node->name(), "/", "tpu_acquire_multilock")));
lock_def.set_op("TpuMultilock");
AddNodeAttr("lock_list", devices_to_lock, &lock_def);
MergeDebugInfo(NodeDebugInfo(replicate_node->def()), &lock_def);
Status status;
*multilock_acquire = graph->AddNode(lock_def, &status);
TF_RETURN_IF_ERROR(status);
(*multilock_acquire)
->set_assigned_device_name(compile_node->assigned_device_name());
graph->AddControlEdge(compile_succeeded, *multilock_acquire);
last_node_before_sequencer = *multilock_acquire;
} else {
*multilock_acquire = nullptr;
}
// Build a sequencing node for when compilation has completed.
TF_RETURN_IF_ERROR(
BuildNoopNode(*replicate_node,
graph->NewName(strings::StrCat(compile_node->name(), "/",
"after_compilation")),
/*device=*/"", graph, control_after_compilation));
graph->AddControlEdge(last_node_before_sequencer, *control_after_compilation);
return Status::OK();
}
// Updates the head and tail outside compiled nodes so that nodes have the
// correct device and removes the replication and outside compilation attributes
// so that these nodes do not trigger further graph optimization passes.
/* static */ Status DistributedTPURewritePass::UpdateHeadTailOutsideCompilation(
const std::vector<std::vector<string>>& tf_device_assignment,
const std::vector<Node*>& head_tail_outside_compilation_nodes) {
for (Node* node : head_tail_outside_compilation_nodes) {
int replica_id;
TF_RETURN_IF_ERROR(
GetNodeAttr(node->def(), kXlaReplicaIdAttrName, &replica_id));
// Since we set the device, this will now run on a task other than 0. We
// clear the two following attributes so that we don't trigger encapsulation
// again on the remote host (which will fail due to a missing
// _TPUReplicateMetadata node for the cluster).
for (const Edge* e : node->in_edges()) {
// Resource consuming ops should colocate with its resource input.
if (e->src()->IsArg() &&
e->src()->output_type(e->src_output()) == DT_RESOURCE) {
node->set_requested_device(tf_device_assignment[replica_id][0]);
}
}
if (node->requested_device().empty()) {
string cpu_device;
TF_RETURN_IF_ERROR(DeviceNameUtils::DeviceNameToCpuDeviceName(
tf_device_assignment[replica_id][0], &cpu_device));
node->set_requested_device(cpu_device);
}
node->ClearAttr(kTPUReplicateAttr);
node->ClearAttr(kOutsideCompilationAttr);
}
return Status::OK();
}
// Performs the rewrite on a single TPUReplicate node.
/* static */ Status DistributedTPURewritePass::RewriteTPUReplicateNode(
const string& session_handle, const DeviceSet& device_set,
Node* replicate_node, FunctionLibraryDefinition* flib_def,
FunctionLibraryRuntime* flr, Node* host_compute_key_placeholder_node,
const OutsideCompilationNodeMap& outside_compilation_nodes,
const std::vector<Node*>& head_tail_outside_compilation_nodes,
NodeToNodeReplicasMap* outside_compilation_node_images, Graph* graph,
const GraphShapeInfo& shape_info,
TPUReplicateDeviceNamesMapping* tpu_replicate_device_names_mapping,
int64_t autotuner_thresh) {
VLOG(2) << "Rewriting node " << replicate_node->name();
// num_replicas and num_cores_per_replica are the 'virtual' replicas (copies
// of the computation) and cores (virtual cores within computations) specified
// by the user. They will be mapped to physical TPU cores below.
int num_replicas;
int num_cores_per_replica;
int num_tasks;
std::vector<std::vector<string>> tf_device_assignment;
std::vector<int> devices_to_lock;
std::unique_ptr<xla::DeviceAssignment> xla_device_assignment;
string tpu_compilation_device;
TF_RETURN_IF_ERROR(GetDeviceTopology(
device_set, *replicate_node, &num_replicas, &num_cores_per_replica,
&num_tasks, &tf_device_assignment, &devices_to_lock,
&xla_device_assignment, &tpu_compilation_device));
TF_RETURN_IF_ERROR(UpdateHeadTailOutsideCompilation(
tf_device_assignment, head_tail_outside_compilation_nodes));
string replicate;
TF_RETURN_IF_ERROR(
GetNodeAttr(replicate_node->def(), kTPUReplicateAttr, &replicate));
tpu_replicate_device_names_mapping->emplace(replicate, tf_device_assignment);
NameRangeMap input_name_map;
const NameAttrList* function;
std::unique_ptr<Graph> computation;
DataTypeVector arg_types, retval_types;
ParameterInfo params_info;
TF_RETURN_IF_ERROR(GetIOTypes(num_replicas, *replicate_node, flr, graph,
&input_name_map, &function, &computation,
&arg_types, &retval_types, &params_info));
std::vector<InferredShape> arg_shapes, retval_shapes;
TF_RETURN_IF_ERROR(GetArgAndRetvalShapes(
shape_info, *replicate_node, params_info, &arg_shapes, &retval_shapes));
TF_RETURN_IF_ERROR(ValidateCoreNumbers(*computation, num_cores_per_replica));
std::vector<xla::OpSharding> arg_sharding;
std::vector<bool> arg_fast_mem;
std::vector<std::string> arg_names;
std::vector<xla::OpSharding> retval_sharding;
TF_RETURN_IF_ERROR(AssignArgsAndRetvalsToCores(
num_cores_per_replica, params_info, arg_types, arg_shapes, retval_types,
retval_shapes, *computation, replicate_node, flr,
allow_xla_spmd_partition_, &arg_sharding, &arg_fast_mem, &retval_sharding,
&arg_names));
VLOG(1) << DumpGraphToFile("distributed_tpu_graph_to_replicate", *computation,
flib_def);
GraphDef graph_def;
graph->ToGraphDef(&graph_def);
FunctionLibraryDefinition reachable_functions =
flib_def->ReachableDefinitions(graph_def);
uint64 library_fingerprint;
TF_RETURN_IF_ERROR(
FingerprintFunctionLibrary(reachable_functions, &library_fingerprint));
VLOG(1) << "Fingerprint functions: "
<< absl::StrJoin(reachable_functions.ListFunctionNames(), ", ");
VLOG(1) << "library_fingerprint: " << library_fingerprint;
// Builds trigger nodes that put barriers around the expansion of
// TPUReplicate. In particular, we must guarantee:
// a) variable reads happen after all predecessors of the original
// TPUReplicate.
// b) variable writes happen before all successors of the original
// TPUReplicate.
// c) all replicas execute, even if output tensors are only requested from
// a subset of replicas. This is necessary both to ensure that variable
// updates happen, but also Send/Recv will deadlock if only one half of
// the communicating pair runs.
Node* host_transfer_sequencer;
Node* control_before;
Node* control_after;
TF_RETURN_IF_ERROR(BuildSequencingNodes(
tpu_compilation_device, *replicate_node, graph, &host_transfer_sequencer,
&control_before, &control_after));
// Build a vector of variable nodes that are inputs.
std::vector<VariableInput> variable_inputs;
TF_RETURN_IF_ERROR(
FindVariableInputs(*replicate_node, input_name_map, &variable_inputs));
std::vector<Node*> guaranteed_constant_nodes;
std::vector<Node*> variable_reads;
TF_RETURN_IF_ERROR(DealWithConstantsAndVariables(
*replicate_node, input_name_map, graph, host_transfer_sequencer,
control_before, control_after, variable_inputs,
&guaranteed_constant_nodes, &variable_reads));
// Builds Shape nodes that compute the dynamic shapes of arguments whose
// shapes are not statically known.
std::vector<Node*> dynamic_shape_nodes;
TF_RETURN_IF_ERROR(BuildDynamicShapeNodes(*replicate_node, arg_shapes,
params_info, variable_reads, graph,
&dynamic_shape_nodes));
// Builds a TPUCompile node that compiles `clusters` on `compile_device`.
Node* compile_node;
TF_RETURN_IF_ERROR(BuildCompileNode(
replicate_node, *function, library_fingerprint, params_info, arg_shapes,
arg_types, guaranteed_constant_nodes, session_handle, arg_sharding,
arg_fast_mem, arg_names, retval_sharding, num_cores_per_replica,
/*compile_device=*/tpu_compilation_device, xla_device_assignment.get(),
dynamic_shape_nodes, graph, &compile_node, autotuner_thresh, num_tasks));
// Compilation must be sequenced after the control node if the TPU computation
// in a control-flow construct, such as a loop.
graph->AddControlEdge(control_before, compile_node);
Node* control_after_compilation;
Node* multilock_acquire;
TF_RETURN_IF_ERROR(BuildCompilationStatusReturnNodes(
replicate_node, compile_node, devices_to_lock, &control_after_compilation,
&multilock_acquire, graph));
std::vector<VariableWrite> variable_writes;
TF_RETURN_IF_ERROR(BuildExecuteNodes(
params_info, num_tasks, num_cores_per_replica, *replicate_node, arg_names,
arg_types, arg_shapes, retval_types, arg_sharding, retval_sharding,
tf_device_assignment, compile_node, variable_reads,
control_after_compilation, control_after, multilock_acquire,
&variable_writes, graph));
bool contains_resource_write_op =
ContainsResourceWriteOp(*graph, reachable_functions);
VLOG(2) << "contains_resource_write_op: " << contains_resource_write_op;
// Skip conditional write if there is no resource writing op inside TPU
// computation.
if (contains_resource_write_op) {
TF_RETURN_IF_ERROR(BuildVariableWrites(variable_inputs, control_after,
variable_writes, graph));
}
if (host_compute_key_placeholder_node != nullptr) {
TF_RETURN_IF_ERROR(ConnectHostComputeNodes(
compile_node, host_compute_key_placeholder_node, graph));
}
HostComputeCoreMap host_compute_core;
TF_RETURN_IF_ERROR(ParseHostComputeCores(
*replicate_node, outside_compilation_nodes, &host_compute_core));
TF_RETURN_IF_ERROR(ReplicateOutsideCompilationNodes(
tf_device_assignment, host_compute_core, outside_compilation_nodes,
outside_compilation_node_images, graph));
graph->RemoveNode(replicate_node);
return Status::OK();
}
// Adds sharded weight update optimization for each host training loop.
//
// For any host training loop found in the graph, TPUVariableReshard ops
// are inserted to match the best layout chosen by the XLA.
/* static */ Status
DistributedTPURewritePass::PerformHostTrainingLoopOptimization(
Graph* graph, FunctionLibraryDefinition* flib_def,
FunctionLibraryRuntime* flr) {
std::vector<tpu::HostTrainingLoopInfo> host_training_loops_info;
Status s = tpu::DetectHostTrainingLoop(
/*current_function_name=*/nullptr,
/*current_function_attr=*/nullptr, flib_def, graph, flr,
&host_training_loops_info);
if (!s.ok()) {
VLOG(2) << "No valid host training loop found. Skipping sharded weight "
<< "update optimization.";
return Status::OK();
}
for (const auto& host_loop : host_training_loops_info) {
const auto& function_name = host_loop.encapsulating_function_name;
// `function_name` has value when host training loop is inside a
// function call node. When host training loop is found inside a function
// call node, then, in addition to adding TPUVariableReshard ops, function
// library definition needs to be updated as well.
if (function_name.has_value()) {
const auto& function_attr = host_loop.encapsulating_function_attrs;
TF_RET_CHECK(function_attr.has_value())
<< "Unable to find function attribute for function: "
<< *function_name;
const FunctionDef* function_def = flib_def->Find(*function_name);
TF_RET_CHECK(function_def)
<< "Unable to find function : " << *function_name;
std::unique_ptr<FunctionBody> fbody;
TF_RETURN_IF_ERROR(FunctionDefToBodyHelper(
*function_def, AttrSlice(&function_attr.value()), flib_def, &fbody));
Graph* function_graph = fbody->graph;
TF_RETURN_IF_ERROR(tpu::AddReshardOp(function_graph, host_loop));
TF_RETURN_IF_ERROR(UpdateFunctionLibDefinition(*function_graph,
*function_name, flib_def));
} else {
TF_RETURN_IF_ERROR(tpu::AddReshardOp(graph, host_loop));
}
}
return Status::OK();
}
Status DistributedTPURewritePass::PlaceUnassignedDeviceNodesOnTPUIfPossible(
Graph* graph) {
ReverseDFS(*graph, {}, PlaceOpsOnTPU);
return Status::OK();
}
Status DistributedTPURewritePass::Run(
const GraphOptimizationPassOptions& options) {
VLOG(1) << "DistributedTPURewritePass::Run";
Graph* graph = options.graph->get();
VLOG(1) << DumpGraphToFile("distributed_tpu_compilation_before", *graph,
options.flib_def);
const auto* config = &options.session_options->config;
std::unique_ptr<ProcessFunctionLibraryRuntime> pflr(
new ProcessFunctionLibraryRuntime(
nullptr, options.session_options->env, config,
graph->versions().producer(), options.flib_def,
config ? config->graph_options().optimizer_options()
: OptimizerOptions()));
FunctionLibraryRuntime* flr =
pflr->GetFLR(ProcessFunctionLibraryRuntime::kDefaultFLRDevice);
// This pass can only run in the session master, which should fill
// in the device_set field to the options.
TF_RET_CHECK(options.device_set != nullptr);
// Find all the replicate nodes before mutating the graph.
std::vector<Node*> replicate_nodes;
// Map from compiled subgraph cluster name to the outside_compilation nodes in
// that cluster.
std::map<string, OutsideCompilationNodeMap> outside_compilation_nodes;
std::map<string, std::vector<Node*>> head_tail_outside_compilation_nodes;
TF_RETURN_IF_ERROR(FindTaggedNodes(graph, &replicate_nodes,
&outside_compilation_nodes,
&head_tail_outside_compilation_nodes));
if (replicate_nodes.empty()) {
// Remove unused TPUPartitionedInput nodes.
for (Node* n : graph->nodes()) {
if (n->type_string() == kTPUPartitionedInput) graph->RemoveNode(n);
}
VLOG(1) << DumpGraphToFile("distributed_tpu_compilation_after", *graph,
options.flib_def);
VLOG(1) << "Replicate nodes are empty. DistributedTPURewritePass::Run() "
"finished";
return Status::OK();
}
std::unordered_map<string, Node*> host_compute_key_placeholder_map;
TF_RETURN_IF_ERROR(FindHostComputeKeyPlaceholderNodes(
graph, replicate_nodes, &host_compute_key_placeholder_map));
// This shape inference pass does not compute the shapes of outputs of
// TPU computations. The concurrent multi-core locking implementation
// *relies* on this behavior because it ensures that, if TPU computation B's
// inputs depend on TPU computation A's outputs, then computation B's
// compilation will be sequenced after A's execution, and this ensures that
// locks are acquired in the correct order. If the shape inference is improved
// to compute shapes of TPU computation outputs, it will be necessary to add
// an explicit control edge between lock acquisitions for dependent
// computations in order to avoid deadlock.
GraphShapeInfo shape_info;
TF_RETURN_IF_ERROR(InferShapes(graph, /*arg_shapes=*/{},
flr->GetFunctionLibraryDefinition(),
&shape_info));
int64_t autotuner_thresh = options.session_options->config.experimental()
.xla_fusion_autotuner_thresh();
NodeToNodeReplicasMap outside_compilation_node_images;
TPUReplicateDeviceNamesMapping tpu_replicate_device_names_mapping;
for (Node* node : replicate_nodes) {
TF_RETURN_IF_ERROR(RewriteTPUReplicateNode(
options.session_handle, *options.device_set, node, options.flib_def,
flr, host_compute_key_placeholder_map[node->name()],
outside_compilation_nodes[node->name()],
head_tail_outside_compilation_nodes[node->name()],
&outside_compilation_node_images, graph, shape_info,
&tpu_replicate_device_names_mapping, autotuner_thresh));
}
// Place the padding nodes generated by dynamic padder on the correct devices.
// TODO(rxsang): Place padding ops on TPUs in
// PlaceUnassignedDeviceNodesOnTPUIfPossible function.
TF_RETURN_IF_ERROR(SetPaddingNodesDevices(graph));
std::unordered_map<string, Node*> outside_compilation_inputs;
for (Node* n : graph->op_nodes()) {
string lifted_arg_inputs_attr;
if (n->type_string() == "IdentityN" &&
GetNodeAttr(n->def(), kXlaOutsideCompilationInputsAttrName,
&lifted_arg_inputs_attr)
.ok()) {
outside_compilation_inputs[lifted_arg_inputs_attr] = n;
}
}
for (const auto& iter : outside_compilation_nodes) {
TF_RETURN_IF_ERROR(ReplicateOutsideCompilationEdges(
iter.second, outside_compilation_node_images,
outside_compilation_inputs, graph));
}
TF_RETURN_IF_ERROR(
RemoveOutsideCompilationNodes(outside_compilation_node_images, graph));
TF_RETURN_IF_ERROR(LowerOutsideCompilationFunctionalNodes(
graph, *options.flib_def, tpu_replicate_device_names_mapping));
TF_RETURN_IF_ERROR(PlaceUnassignedDeviceNodesOnTPUIfPossible(graph));
VLOG(1) << DumpGraphToFile("distributed_tpu_compilation_after", *graph,
options.flib_def);
VLOG(1) << "DistributedTPURewritePass::Run() finished";
if (enable_cross_replica_sharding_mirrored_variables_) {
VLOG(1) << "Starting host training loop optimization.";
VLOG(1) << DumpGraphToFile("host_loop_optimization_before", *graph,
options.flib_def);
TF_RETURN_IF_ERROR(
PerformHostTrainingLoopOptimization(graph, options.flib_def, flr));
VLOG(1) << DumpGraphToFile("host_loop_optimization_after", *graph,
options.flib_def);
VLOG(1) << "Host training loop optimization finished.";
}
return Status::OK();
}
bool DistributedTPURewritePass::distribute_vars_ = false;
bool DistributedTPURewritePass::allow_xla_spmd_partition_ = true;
bool DistributedTPURewritePass::
replicate_inputs_outputs_by_default_for_xla_spmd_ = false;
bool DistributedTPURewritePass::
enable_cross_replica_sharding_mirrored_variables_ = true;
bool DistributedTPURewritePass::enable_automatic_model_parallelism_ = false;
bool DistributedTPURewritePass::enable_xla_param_broadcast_ = false;
bool DistributedTPURewritePass::enable_multicore_locking_ = false;
bool DistributedTPURewritePass::use_nd_sharding_ops_ = false;
/*static*/ void DistributedTPURewritePass::SetDistributedTpuRewritePassOptions(
bool distribute_vars, bool allow_xla_spmd_partition,
bool replicate_inputs_outputs_by_default_for_xla_spmd,
bool enable_cross_replica_sharding_mirrored_variables,
bool enable_automatic_model_parallelism, bool enable_xla_param_broadcast,
bool enable_multicore_locking, bool use_nd_sharding_ops) {
distribute_vars_ = distribute_vars;
allow_xla_spmd_partition_ = allow_xla_spmd_partition;
replicate_inputs_outputs_by_default_for_xla_spmd_ =
replicate_inputs_outputs_by_default_for_xla_spmd;
enable_cross_replica_sharding_mirrored_variables_ =
enable_cross_replica_sharding_mirrored_variables;
enable_automatic_model_parallelism_ = enable_automatic_model_parallelism;
enable_xla_param_broadcast_ = enable_xla_param_broadcast;
enable_multicore_locking_ = enable_multicore_locking;
use_nd_sharding_ops_ = use_nd_sharding_ops;
}
} // namespace tensorflow