| /* Copyright 2021 The TensorFlow Authors. All Rights Reserved. |
| |
| Licensed under the Apache License, Version 2.0 (the "License"); |
| you may not use this file except in compliance with the License. |
| You may obtain a copy of the License at |
| |
| http://www.apache.org/licenses/LICENSE-2.0 |
| |
| Unless required by applicable law or agreed to in writing, software |
| distributed under the License is distributed on an "AS IS" BASIS, |
| WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
| See the License for the specific language governing permissions and |
| limitations under the License. |
| ==============================================================================*/ |
| |
| #include "tensorflow/core/tpu/kernels/tpu_functional_ops.h" |
| |
| #include <memory> |
| |
| #include "tensorflow/core/framework/op_kernel.h" |
| |
| #define EIGEN_USE_THREADS |
| |
| #include "absl/base/call_once.h" |
| #include "absl/strings/str_cat.h" |
| #include "absl/synchronization/mutex.h" |
| #include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor" |
| #include "tensorflow/compiler/tf2xla/sharding_util.h" |
| #include "tensorflow/compiler/tf2xla/side_effect_util.h" |
| #include "tensorflow/compiler/xla/xla_data.pb.h" |
| #include "tensorflow/core/common_runtime/function_body.h" |
| #include "tensorflow/core/common_runtime/graph_constructor.h" |
| #include "tensorflow/core/common_runtime/placer.h" |
| #include "tensorflow/core/common_runtime/rendezvous_mgr.h" |
| #include "tensorflow/core/framework/graph_to_functiondef.h" |
| #include "tensorflow/core/framework/metrics.h" |
| #include "tensorflow/core/framework/node_def.pb.h" |
| #include "tensorflow/core/framework/node_def_util.h" |
| #include "tensorflow/core/framework/resource_mgr.h" |
| #include "tensorflow/core/framework/resource_var.h" |
| #include "tensorflow/core/framework/tensor.h" |
| #include "tensorflow/core/framework/tensor.pb.h" |
| #include "tensorflow/core/framework/tensor_shape.h" |
| #include "tensorflow/core/graph/graph_partition.h" |
| #include "tensorflow/core/graph/node_builder.h" |
| #include "tensorflow/core/lib/core/errors.h" |
| #include "tensorflow/core/lib/hash/hash.h" |
| #include "tensorflow/core/lib/strings/str_util.h" |
| #include "tensorflow/core/platform/blocking_counter.h" |
| #include "tensorflow/core/platform/errors.h" |
| #include "tensorflow/core/platform/fingerprint.h" |
| #include "tensorflow/core/platform/refcount.h" |
| #include "tensorflow/core/profiler/lib/traceme.h" |
| #include "tensorflow/core/protobuf/tpu/topology.pb.h" |
| #include "tensorflow/core/tpu/kernels/tpu_compile_op_support.h" |
| #include "tensorflow/core/tpu/kernels/tpu_fingerprint_lookup.h" |
| #include "tensorflow/core/tpu/kernels/tpu_op_consts.h" |
| #include "tensorflow/core/tpu/kernels/tpu_op_util.h" |
| #include "tensorflow/core/tpu/kernels/tpu_util.h" |
| #include "tensorflow/core/tpu/tpu_configuration.h" |
| #include "tensorflow/core/tpu/tpu_defs.h" |
| #include "tensorflow/core/util/dump_graph.h" |
| |
| namespace tensorflow { |
| namespace { |
| |
| constexpr char kTpuReplicateAttr[] = "_tpu_replicate"; |
| constexpr int kLastDimOfTpuInputFastPath = 128; |
| constexpr int kOtherDimOfTpuInputFastPath = 8; |
| |
| constexpr char kXLAShardingAttrName[] = "sharding"; |
| constexpr char kXLAShardingAttrAltName[] = "_XlaSharding"; |
| |
| Status GenerateDeviceNaturalOrder(int x_num_cores, int y_num_cores, |
| int z_num_cores, int num_cores_per_chip, |
| std::vector<int>* natural_order) { |
| for (int y = 0; y < y_num_cores; ++y) { |
| for (int x = 0; x < x_num_cores; ++x) { |
| for (int z = 0; z < z_num_cores; ++z) { |
| for (int c = 0; c < num_cores_per_chip; ++c) { |
| natural_order->push_back(x); |
| natural_order->push_back(y); |
| natural_order->push_back(z); |
| natural_order->push_back(c); |
| } |
| } |
| } |
| } |
| |
| return Status::OK(); |
| } |
| |
| struct TPUVariableInfo { |
| TPUVariableInfo(int device_ordinal_id, bool use_fast_mem) |
| : device_ordinal(device_ordinal_id), fast_mem(use_fast_mem) {} |
| // The TPU core which the variable will be placed on. |
| int device_ordinal; |
| // If true, try to place the variable on fast memory space if hardware |
| // support. For example, PuffyLite has CMEM space. |
| bool fast_mem; |
| }; |
| |
| // Check the descendants to parse the placement information for the input node. |
| // num_cores_per_replica descriables how many cores the single model uses. |
| Status ParseTPUVariableInfor(const Node* node, const int num_cores_per_replica, |
| TPUVariableInfo* var_info) { |
| int core = 0; |
| bool use_fast_mem = false; |
| VLOG(3) << "Parse tpu variable information for " << node->name(); |
| for (const Edge* edge : node->out_edges()) { |
| if (edge->IsControlEdge()) continue; |
| Node* next = edge->dst(); |
| VLOG(3) << "Neighbor node " << next->name(); |
| // Looking through Enter/Switch/ReadVariableOp nodes. |
| while (next->IsEnter() || next->IsSwitch() || |
| next->type_string() == "ReadVariableOp") { |
| Node* new_node = nullptr; |
| for (const Edge* e : next->out_edges()) { |
| if (!e->IsControlEdge()) { |
| new_node = e->dst(); |
| break; |
| } |
| } |
| if (new_node == nullptr) break; |
| next = new_node; |
| } |
| if (next != edge->dst()) { |
| VLOG(3) << "Looked through Enter/Switch node " << next->DebugString(); |
| } |
| TF_ASSIGN_OR_RETURN(absl::optional<xla::OpSharding> sharding, |
| ParseShardingFromDevice(*next, num_cores_per_replica, |
| /*add_metadata=*/false)); |
| if (sharding.has_value() && sharding->tile_assignment_devices_size() > 0) { |
| core = sharding->tile_assignment_devices(0); |
| VLOG(3) << next->name() << " is placed on core " << core; |
| } |
| if (next->attrs().Find(TPU_FAST_MEM_ATTR) != nullptr) { |
| use_fast_mem = true; |
| VLOG(3) << next->name() << " has " << TPU_FAST_MEM_ATTR << " attribute"; |
| } |
| } |
| VLOG(1) << "Place " << node->name() << " to core: " << core |
| << " fast_mem: " << use_fast_mem; |
| var_info->device_ordinal = core; |
| var_info->fast_mem = use_fast_mem; |
| |
| return Status::OK(); |
| } |
| |
| // Helper to instantiate function "func" in the library "lib". |
| Status Instantiate(FunctionLibraryRuntime* lib, const NameAttrList& func, |
| FunctionLibraryRuntime::Handle* handle) { |
| return lib->Instantiate(func.name(), AttrSlice(&func.attr()), handle); |
| } |
| |
| static constexpr const char* const kDeviceOrdinalAttr = "device_ordinal"; |
| |
| static constexpr const char* const kTPUExecuteOp = "TPUExecute"; |
| static constexpr const char* const kInfeedEnqueueOp = "InfeedEnqueue"; |
| static constexpr const char* const kInfeedEnqueueTupleOp = "InfeedEnqueueTuple"; |
| static constexpr const char* const kOutfeedDequeueOp = "OutfeedDequeue"; |
| static constexpr const char* const kOutfeedDequeueTupleOp = |
| "OutfeedDequeueTuple"; |
| static constexpr const char* const kOutfeedDequeueV2Op = "OutfeedDequeueV2"; |
| static constexpr const char* const kOutfeedDequeueTupleV2Op = |
| "OutfeedDequeueTupleV2"; |
| static constexpr const char* const kVarHandleOp = "VarHandleOp"; |
| |
| static constexpr const char* const kTPUDeviceNamePrefix = "/device:TPU:"; |
| static constexpr const int kTPUDefaultDeviceOrdinal = 0; |
| |
| bool IsSupportedTPUOp(const string& op_name) { |
| return op_name == kTPUExecuteOp || op_name == kInfeedEnqueueOp || |
| op_name == kInfeedEnqueueTupleOp || op_name == kOutfeedDequeueOp || |
| op_name == kOutfeedDequeueTupleOp || op_name == kOutfeedDequeueV2Op || |
| op_name == kOutfeedDequeueTupleV2Op; |
| } |
| |
| // Sets the sharding attributes for an XlaSharding node. |
| void SetXlaShardingNodeAttr(Node* xla_sharding_node, int num_cores_per_replica, |
| int rank, int shard_dim) { |
| auto sharding = absl::make_optional<xla::OpSharding>(); |
| sharding->set_type(xla::OpSharding::OTHER); |
| |
| std::vector<int64> dims(rank, 1LL); |
| dims[shard_dim] = num_cores_per_replica; |
| for (auto dim : dims) { |
| sharding->add_tile_assignment_dimensions(dim); |
| } |
| |
| // Sets up tile_assignment_devices. |
| for (int d = 0; d < num_cores_per_replica; ++d) { |
| sharding->add_tile_assignment_devices(d); |
| } |
| |
| xla_sharding_node->ClearAttr(kXLAShardingAttrName); |
| xla_sharding_node->ClearAttr(kXLAShardingAttrAltName); |
| xla_sharding_node->AddAttr(kXLAShardingAttrName, |
| sharding->SerializeAsString()); |
| xla_sharding_node->AddAttr(kXLAShardingAttrAltName, |
| sharding->SerializeAsString()); |
| } |
| |
| // If 'device_name' is a TPU device, set its device_ordinal to 'device_ordinal' |
| // and set '*rewritten' to true. Otherwise, do nothing. |
| Status UpdateTPUDeviceOrdinal(int device_ordinal, string* device_name, |
| bool* rewritten) { |
| DeviceNameUtils::ParsedName device; |
| if (!DeviceNameUtils::ParseFullName(*device_name, &device)) { |
| return errors::InvalidArgument("Unable to parse device name ", |
| *device_name); |
| } |
| if (device.type == DEVICE_TPU_NODE) { |
| device.id = device_ordinal; |
| *rewritten = true; |
| } |
| *device_name = DeviceNameUtils::ParsedNameToString(device); |
| return Status::OK(); |
| } |
| |
| const Edge* FindHostToDeviceEdge(Node* arg_node) { |
| const Edge* candidate_edge = nullptr; |
| for (const Edge* edge : arg_node->out_edges()) |
| if (!edge->IsControlEdge()) { |
| // Find CPU -> TPU input edge. |
| const Edge* original_edge; |
| while (edge->src()->attrs().Find(kTpuReplicateAttr) != nullptr || |
| edge->dst()->attrs().Find(kTpuReplicateAttr) == nullptr) { |
| const Node* new_src = edge->dst(); |
| original_edge = edge; |
| for (const Edge* new_edge : new_src->out_edges()) |
| if (!new_edge->IsControlEdge()) { |
| original_edge = edge; |
| edge = new_edge; |
| break; |
| } |
| if (original_edge == edge) break; |
| } |
| // TPU input edge: src is on CPU and dest is on TPU. |
| if (edge->src()->attrs().Find(kTpuReplicateAttr) != nullptr || |
| edge->dst()->attrs().Find(kTpuReplicateAttr) == nullptr) |
| continue; |
| // Won't work with GuaranteeConst. |
| if (edge->src()->type_string() == "GuaranteeConst") break; |
| candidate_edge = edge; |
| } |
| return candidate_edge; |
| } |
| |
| Status CreateInputProxy(Graph* graph, const Edge* candidate_edge, |
| const Edge** tpu_input_edge) { |
| std::vector<const Edge*> edges_to_replace; |
| for (const Edge* input_edge : candidate_edge->src()->out_edges()) { |
| if (!input_edge->IsControlEdge() && |
| input_edge->dst()->attrs().Find(kTpuReplicateAttr) != nullptr) |
| edges_to_replace.push_back(input_edge); |
| } |
| // Build an Identity node as the proxy of the original edge source. |
| Node* input_identity_node = nullptr; |
| TF_RETURN_IF_ERROR( |
| NodeBuilder(strings::StrCat(candidate_edge->src()->name(), "/proxy"), |
| "Identity") |
| .Input(candidate_edge->src()) |
| .Attr("T", candidate_edge->src()->output_type(0)) |
| .Attr(kTpuReplicateAttr, |
| candidate_edge->dst()->attrs().Find(kTpuReplicateAttr)->s()) |
| .Finalize(graph, &input_identity_node)); |
| // Find the tpu input edge from original source to proxy identity. |
| for (const Edge* input_edge : input_identity_node->in_edges()) |
| if (input_edge->src() == candidate_edge->src()) { |
| *tpu_input_edge = input_edge; |
| break; |
| } |
| // Replace original input edges with proxy's output. |
| for (const Edge* input_edge : edges_to_replace) { |
| graph->RemoveEdge(input_edge); |
| graph->AddEdge(input_identity_node, 0, input_edge->dst(), |
| input_edge->dst_input()); |
| } |
| return Status::OK(); |
| } |
| |
| Status GetClusterName(Graph* graph, string* cluster_name) { |
| *cluster_name = ""; |
| for (const Node* node : graph->nodes()) { |
| if (node->attrs().Find(kTpuReplicateAttr) == nullptr) continue; |
| if (cluster_name->empty()) |
| *cluster_name = node->attrs().Find(kTpuReplicateAttr)->s(); |
| // When optimization is turned on, the graph should only have one TPU |
| // cluster. |
| if (*cluster_name != node->attrs().Find(kTpuReplicateAttr)->s()) |
| return errors::FailedPrecondition( |
| "Only one cluster is allowed when optimization is turned on for " |
| "TPUPartitionedCall. Found ", |
| node->attrs().Find(kTpuReplicateAttr)->s(), " and ", *cluster_name); |
| } |
| return Status::OK(); |
| } |
| |
| // Removes nodes that has no effect that directly descends from _Arg node. |
| // |
| // This is currently used for removing TPUReplicatedInput and XlaSharding node |
| // are always descendants of _Arg node. During optimization, we try to insert |
| // nodes in between _Arg and _Arg's children, where some of the nodes inserted |
| // are TPU nodes. We will add the TPUReplicatedInput and XlaSharding op nodes |
| // back where necessary. |
| // |
| // Returns the number of nodes that were removed. |
| int64 RemoveDescendantNodeOfArg(Graph* graph, |
| const std::string& node_type_to_remove, |
| const std::set<std::string>& must_be_child_of) { |
| int64 nodes_removed = 0; |
| std::vector<std::pair<const Edge*, std::vector<const Edge*>>> edges_to_remove; |
| |
| for (Node* node : graph->nodes()) { |
| if (node_type_to_remove != node->type_string()) continue; |
| if (!must_be_child_of.empty()) { |
| bool has_arg_parent = false; |
| for (const Edge* edge : node->in_edges()) { |
| if (must_be_child_of.count(edge->src()->type_string()) > 0) { |
| has_arg_parent = true; |
| } |
| } |
| if (!has_arg_parent) continue; |
| } |
| nodes_removed++; |
| |
| const Edge* input_edge = nullptr; |
| std::vector<const Edge*> output_edges; |
| for (const Edge* edge : node->in_edges()) |
| if (!edge->IsControlEdge()) { |
| input_edge = edge; |
| break; |
| } |
| for (const Edge* edge : node->out_edges()) |
| if (!edge->IsControlEdge()) { |
| output_edges.push_back(edge); |
| } |
| if (input_edge != nullptr && !output_edges.empty()) |
| edges_to_remove.push_back(std::make_pair(input_edge, output_edges)); |
| } |
| for (const auto& it : edges_to_remove) { |
| for (const Edge* output_edge : it.second) { |
| graph->RemoveEdge(output_edge); |
| graph->AddEdge(it.first->src(), it.first->src_output(), |
| output_edge->dst(), output_edge->dst_input()); |
| } |
| graph->RemoveNode(it.first->dst()); |
| } |
| return nodes_removed; |
| } |
| |
| uint64 GetInputHash(OpKernelContext* ctx) { |
| uint64 input_hash = 0; // initialization for determinism. |
| // Use the number of elements to compute hash. |
| // TODO(chiachenc): use fhe full shape to compute the hash. |
| for (int i = 0; i < ctx->num_inputs(); ++i) { |
| VLOG(4) << "InputHash, combine input " << i |
| << ", NumElements: " << ctx->input(i).NumElements(); |
| input_hash = Hash64Combine(input_hash, ctx->input(i).NumElements()); |
| } |
| return input_hash; |
| } |
| |
| string HashShapeAndType(const string prefix, const std::vector<int>& input_dims, |
| const DataType& dtype, const bool input_shape_opt) { |
| string hash = strings::StrCat(prefix, dtype, "_dims"); |
| // We will concat at the last dimension. |
| for (int d = 0; d < input_dims.size() - 1; ++d) { |
| strings::StrAppend(&hash, "_", input_dims.at(d)); |
| } |
| |
| if (input_shape_opt) { |
| if (input_dims.back() % kLastDimOfTpuInputFastPath == 0) { |
| strings::StrAppend(&hash, "_last_", kLastDimOfTpuInputFastPath, "n"); |
| } else { |
| strings::StrAppend(&hash, "_last_other"); |
| } |
| } |
| return hash; |
| } |
| |
| // Get the information for input and output tensors (shapes, dtypes, etc). |
| Status GetInputOutputInfo( |
| Graph* graph, GraphShapeInfo& tpu_inferred_info, |
| std::map<int, InferredShape>& arg_shapes, EdgeShapes& tpu_input_shapes, |
| absl::flat_hash_map<const Edge*, DataType>& tpu_input_dtypes, |
| OpKernelContext* ctx) { |
| // Search for the device-to-host or tpu-to-cpu edges. |
| for (Node* node : graph->op_nodes()) { |
| if (!node->IsArg()) continue; |
| const DataType dtype = node->attrs().Find("T")->type(); |
| const int arg_index = node->attrs().Find("index")->i(); |
| if (dtype != DT_INT32 && dtype != DT_BFLOAT16 && dtype != DT_FLOAT && |
| dtype != DT_BOOL && dtype != DT_QINT8 && dtype != DT_QUINT8) |
| continue; |
| VLOG(3) << "Argnode: " << node->DebugString(); |
| const Tensor& tensor = ctx->input(arg_index); |
| |
| // Search for the cross-device edge from arg node. |
| const Edge* candidate_edge = FindHostToDeviceEdge(node); |
| if (candidate_edge == nullptr) continue; |
| |
| // Make proxy and get the sole tpu_input_edge for transfer the input tensor |
| // corresponding to the current _Arg node. |
| const Edge* tpu_input_edge = nullptr; |
| TF_RETURN_IF_ERROR( |
| CreateInputProxy(graph, candidate_edge, &tpu_input_edge)); |
| if (tpu_input_edge == nullptr) |
| return errors::NotFound("Couldn't find TPU input edge for", node->name()); |
| |
| // Optimize edge: original source to proxy identity. |
| VLOG(3) << "Input: " << tpu_input_edge->src()->name(); |
| std::vector<int>& input_shapes = tpu_input_shapes[tpu_input_edge]; |
| input_shapes.clear(); |
| for (int d = 0; d < tensor.dims(); ++d) { |
| input_shapes.push_back(tensor.dim_size(d)); |
| VLOG(3) << "Input Tensor: Dim[" << d << "] = " << tensor.dim_size(d); |
| } |
| tpu_input_dtypes[tpu_input_edge] = tensor.dtype(); |
| |
| // Collect shapes for non-resource-variable args. |
| PartialTensorShape partial_tensor_shape; |
| auto partial_shape = PartialTensorShape::MakePartialShape( |
| input_shapes.data(), input_shapes.size(), &partial_tensor_shape); |
| InferredShape inferred_shape = {partial_tensor_shape}; |
| arg_shapes[arg_index] = inferred_shape; |
| } |
| return Status::OK(); |
| } |
| |
| // Converts a integer vector that represents the shapes to a Tensorshape. |
| Status ConvertEdgeShapesToTensorShapes( |
| const std::map<std::string, std::vector<int>>& named_input_shapes, |
| std::vector<TensorShape>* shapes) { |
| shapes->resize(named_input_shapes.size()); |
| int32_t i = 0; |
| // keys in tpu_input_shapes may be stale. |
| for (const auto& iter : named_input_shapes) { |
| VLOG(2) << iter.first << ", rank: " << iter.second.size(); |
| const int64 rank = iter.second.size(); |
| std::vector<int64> dims(rank); |
| for (int64 d = 0; d < rank; ++d) { |
| VLOG(2) << " dim[" << d << "]: " << iter.second.at(d); |
| dims[d] = iter.second.at(d); |
| } |
| TF_RETURN_IF_ERROR(TensorShapeUtils::MakeShape(dims, &(*shapes)[i])); |
| i++; |
| } |
| return Status::OK(); |
| } |
| |
| // Get the TF fingerprint with the information from the TPUCompileOp or |
| // _TPUCompileMlirOp. |
| Status MaybeRegisterFingerprint( |
| Graph* graph, |
| const std::map<std::string, std::vector<int>>& named_input_shapes, |
| uint64 input_hash) { |
| // Find the compiler metadata. |
| tpu::TPUCompileMetadataProto metadata_proto; |
| std::map<std::string, std::vector<int>> inputs_to_keep; |
| int num_dynamic_shapes = -1; |
| tensorflow::uint64 fingerprint = 0; |
| |
| for (Node* node : graph->op_nodes()) { |
| if (node->type_string() == "TPUCompile" || |
| node->type_string() == "_TPUCompileMlir") { |
| num_dynamic_shapes = node->attrs().Find("NumDynamicShapes")->i(); |
| if (num_dynamic_shapes <= 0) { |
| break; |
| } |
| int visited = 0; |
| // TPUCompileOp/_TPUCompileMlirOp take Shape nodes as inputs. |
| // The number of Shape nodes matches the number of dynamic shaped inputs. |
| // The Shape nodes come from the input nodes: |
| // [TPU Input] --> [Input Shape] --> [TPUCompileOp] |
| for (auto in_node : node->in_nodes()) { |
| if (in_node->type_string() != "Shape") { |
| continue; |
| } |
| for (auto input_node : in_node->in_nodes()) { |
| auto iter = named_input_shapes.find(input_node->name()); |
| if (iter != named_input_shapes.end()) { |
| inputs_to_keep[iter->first] = iter->second; |
| } |
| } |
| visited++; |
| if (visited == num_dynamic_shapes) { |
| break; |
| } |
| } |
| std::string metadata = node->attrs().Find("metadata")->s(); |
| metadata_proto.ParseFromString(metadata); |
| |
| if (node->type_string() == "_TPUCompileMlir") { |
| std::string mlir_module = node->attrs().Find("mlir_module")->s(); |
| fingerprint = tensorflow::Fingerprint64(mlir_module); |
| } else { |
| fingerprint = metadata_proto.function_library_fingerprint(); |
| } |
| |
| break; |
| } |
| } |
| VLOG(2) << "inputs_to_keep size: " << inputs_to_keep.size(); |
| if (inputs_to_keep.size() != num_dynamic_shapes) { |
| VLOG(2) << "Cannot match all inputs shapes. Skip fingerprint registration."; |
| return Status::OK(); |
| } |
| |
| std::vector<TensorShape> input_shapes; |
| TF_RETURN_IF_ERROR( |
| ConvertEdgeShapesToTensorShapes(inputs_to_keep, &input_shapes)); |
| |
| std::vector<TensorShape> arg_shapes; |
| auto status = |
| tpu::ComputeArgumentShapes(metadata_proto, input_shapes, &arg_shapes); |
| if (!status.ok()) { |
| VLOG(2) << status.error_message(); |
| return Status::OK(); |
| } |
| uint64 tf_fingerprint = |
| tpu::CreateFingerprintWithNameAndShapes(fingerprint, arg_shapes); |
| VLOG(2) << "fingerprint: " << fingerprint; |
| VLOG(2) << "TF fingerprint: " << tf_fingerprint; |
| |
| ResourceMgr* rm = GetTPUConfigResourceMgr(); |
| tpu::TpuFingerprintLookup* fingerprint_lookup; |
| TF_RETURN_IF_ERROR(rm->Lookup<tpu::TpuFingerprintLookup>( |
| rm->default_container(), tpu::kFingerprintLookupResourceName, |
| &fingerprint_lookup)); |
| fingerprint_lookup->RegisterKeyAndIntermediatePair(input_hash, |
| tf_fingerprint); |
| return Status::OK(); |
| } |
| |
| bool FindTpuReplicatedInputAndXlaSharding( |
| const Graph* graph, XlaShardingInfoMap& xla_sharding_ops, |
| TpuReplicatedInputInfoMap& tpu_replicated_input_ops) { |
| bool xla_spmd_input_sharded = false; |
| // Detect whether there are XLA Sharding on the inputs, if there are, then |
| // we cannot remove the replicated inputs or the xla sharding ops. |
| for (Node* xla_sharding_node : graph->nodes()) { |
| if (xla_sharding_node->type_string() == "XlaSharding") { |
| for (const Edge* edge : xla_sharding_node->in_edges()) { |
| if (edge->src()->type_string() == "TPUReplicatedInput") { |
| Node* tpu_replicated_input_node = edge->src(); |
| Node* tpu_replicated_metadata_node = nullptr; |
| for (const Edge* input_edge : tpu_replicated_input_node->in_edges()) { |
| if (input_edge->IsControlEdge()) { |
| tpu_replicated_metadata_node = input_edge->src(); |
| } |
| } |
| |
| for (const Edge* input_edge : tpu_replicated_input_node->in_edges()) { |
| if (input_edge->src()->type_string() == "_Arg") { |
| Node* arg_node = input_edge->src(); |
| |
| xla_sharding_ops[arg_node->name()] = std::make_tuple( |
| xla_sharding_node->attrs().Find("T")->type(), |
| xla_sharding_node->attrs().Find("sharding")->s(), |
| xla_sharding_node->attrs().Find("_tpu_replicate")->s()); |
| |
| tpu_replicated_input_ops[arg_node->name()] = std::make_tuple( |
| tpu_replicated_input_node->attrs().Find("T")->type(), |
| tpu_replicated_metadata_node); |
| |
| VLOG(2) << "Detected input is sharded. XlaSharding node: " |
| << xla_sharding_node->DebugString() |
| << ", TPUReplicatedInput node: " |
| << edge->src()->DebugString() |
| << ", _Arg node: " << arg_node->DebugString(); |
| xla_spmd_input_sharded = true; |
| break; |
| } |
| } |
| } |
| } |
| } |
| } |
| return xla_spmd_input_sharded; |
| } |
| |
| } // end namespace |
| |
| namespace tpu_functional_internal { |
| |
| // An optimization pass that separates tensors to leverage the fast path in |
| // TPU input preparation. The algorithm is as follows: |
| // (1) Group all tensors that have same dimensions except the last dimension. A |
| // group of tensors will be concatenated by the last dimension in a later pass. |
| // (2) Check all groups of tensors and find groups whose dimensions after concat |
| // cannot leverage the fast path. |
| // (3) For groups of tensors that don't leverage the fast path, split tensors |
| // into two sub-groups such that one sub-group of tensors can leverage the fast |
| // path. |
| // Exception in (2) is that concated tensors are small, which means separating |
| // tensors would introduce overheads of data transfer to device. |
| // This optimization takes effect when both --input_shape_opt and |
| // --group_tensors_for_packing are true. |
| GroupedEdges GroupTensorsForInputPacking( |
| const EdgeShapes& tpu_input_shapes, |
| const absl::flat_hash_map<const Edge*, DataType>& tpu_input_dtypes, |
| bool input_shape_opt, bool group_tensors_for_packing) { |
| GroupedEdges grouped_input_edges; |
| for (const auto& iter : tpu_input_shapes) { |
| if (iter.second.empty()) continue; |
| DataType dtype = tpu_input_dtypes.find(iter.first)->second; |
| string hash_key = HashShapeAndType("input_tensors_dtype_", iter.second, |
| dtype, /*input_shape_opt*/ false); |
| grouped_input_edges[hash_key].push_back(iter.first); |
| } |
| // Apply grouping when both are true. |
| if (!input_shape_opt || !group_tensors_for_packing) |
| return grouped_input_edges; |
| |
| GroupedEdges grouped_input_edges_opt; |
| for (const auto& iter : grouped_input_edges) { |
| int sum_last_dim = 0; |
| int product_other_dims = 0; |
| VLOG(3) << "group name: " << iter.first; |
| for (const auto& edge : iter.second) { |
| const std::vector<int>& input_shapes = |
| tpu_input_shapes.find(edge)->second; |
| sum_last_dim += input_shapes.back(); |
| if (product_other_dims == 0) { |
| product_other_dims = 1; |
| for (int d = 0; d < input_shapes.size() - 1; ++d) |
| product_other_dims *= input_shapes.at(d); |
| } |
| } |
| VLOG(3) << "sum_last_dim: " << sum_last_dim; |
| VLOG(3) << "product_other_dims: " << product_other_dims; |
| // Already uses fast path, skip further grouping. |
| if ((sum_last_dim % kLastDimOfTpuInputFastPath) == 0 && |
| (product_other_dims % kOtherDimOfTpuInputFastPath) == 0) { |
| grouped_input_edges_opt[iter.first] = iter.second; |
| continue; |
| } |
| // Tensors are small, skip further grouping. |
| if ((sum_last_dim * product_other_dims) < |
| (kLastDimOfTpuInputFastPath * kOtherDimOfTpuInputFastPath)) { |
| grouped_input_edges_opt[iter.first] = iter.second; |
| continue; |
| } |
| VLOG(3) << "Splitting tensors."; |
| for (const auto& edge : iter.second) { |
| auto tpu_input_shape = tpu_input_shapes.find(edge)->second; |
| string hash_key = |
| HashShapeAndType("input_tensors_dtype_", tpu_input_shape, |
| tpu_input_dtypes.find(edge)->second, |
| /*input_shape_opt*/ true); |
| grouped_input_edges_opt[hash_key].push_back(edge); |
| } |
| } |
| return grouped_input_edges_opt; |
| } |
| |
| GroupedEdges GroupTensorsForOutputPacking(Graph* graph, |
| EdgeShapes& tpu_output_shapes, |
| GraphShapeInfo* shape_info) { |
| GroupedEdges shape_to_output; |
| for (const Edge* edge : graph->edges()) { |
| if (edge->IsControlEdge()) continue; |
| |
| // TPU input edge: src is on TPU and dest is on CPU. |
| if (edge->dst()->type_string() != "TPUReplicatedOutput") continue; |
| if (!shape_info->count(edge->src()->name())) continue; |
| |
| // output shapes for hashing |
| std::vector<int>& output_shapes = tpu_output_shapes[edge]; |
| output_shapes.clear(); |
| |
| int output_id = edge->src_output(); |
| auto inferred_shape_vec = shape_info->at(edge->src()->name()); |
| |
| for (int d : inferred_shape_vec.at(output_id).shape.dim_sizes()) { |
| output_shapes.push_back(d); |
| } |
| |
| // Hash Shape and Types. |
| DataType dtype = edge->src()->input_type(output_id); |
| string hash_key = |
| HashShapeAndType("output_tensors_dtype_", output_shapes, dtype, false); |
| |
| shape_to_output[hash_key].push_back(edge); |
| } |
| return shape_to_output; |
| } |
| |
| // Concatenates input tensors on CPU along the last dimension if all other |
| // dimensions are the same, and split them on TPU to reduce input overhead. |
| // `tpu_input_shapes` maps an edge to the shape of its output tensor. |
| // `grouped_input_edges` maps tensor name to all edges output from this tensor. |
| Status CreateConcatAndSplitNodesForInputTensor( |
| Graph* graph, const string& cluster_name, EdgeShapes* tpu_input_shapes, |
| const absl::flat_hash_map<std::string, std::vector<const Edge*>>& |
| grouped_input_edges, |
| int32_t minimum_input_tensors_packing, bool xla_spmd_input_sharded, |
| const XlaShardingInfoMap& xla_sharding_info, |
| const TpuReplicatedInputInfoMap& tpu_replicated_input_info) { |
| for (const auto& iter : grouped_input_edges) { |
| std::vector<int> last_dim_vec; |
| std::vector<NodeBuilder::NodeOut> concat_nodeouts; |
| absl::flat_hash_map<std::string, int> tensor_to_split_output; |
| int rank; |
| DataType dtype = DT_INVALID; |
| std::string src_name; |
| for (const Edge* edge : iter.second) { |
| src_name = edge->src()->name(); |
| string tensor_name = |
| absl::StrCat(edge->src()->name(), ":", edge->src_output()); |
| // Create Concat / Split pair for a tensor if not exist yet. |
| if (tensor_to_split_output.contains(tensor_name)) continue; |
| tensor_to_split_output[tensor_name] = concat_nodeouts.size(); |
| concat_nodeouts.push_back( |
| NodeBuilder::NodeOut(edge->src(), edge->src_output())); |
| dtype = edge->src()->output_type(edge->src_output()); |
| rank = tpu_input_shapes->at(edge).size(); |
| last_dim_vec.push_back(tpu_input_shapes->at(edge).back()); |
| } |
| |
| const int num_tensors = tensor_to_split_output.size(); |
| VLOG(3) << iter.first << " num_tensors: " << num_tensors; |
| if (num_tensors < minimum_input_tensors_packing) { |
| VLOG(3) << "skip concat/split " << iter.first; |
| continue; |
| } |
| |
| Node* concat_axis_node = nullptr; |
| TensorShape t_shape; |
| Tensor dim_tensor(DT_INT32, t_shape); |
| // Concat and Split at the last dim. |
| dim_tensor.flat<int>()(0) = rank - 1; |
| TF_RETURN_IF_ERROR( |
| NodeBuilder(strings::StrCat(iter.first, "/concat/axis"), "Const") |
| .Attr("dtype", DT_INT32) |
| .Attr("value", dim_tensor) |
| .Finalize(graph, &concat_axis_node)); |
| |
| Node* concat_node = nullptr; |
| TF_RETURN_IF_ERROR( |
| NodeBuilder(strings::StrCat(iter.first, "/concat"), "ConcatV2") |
| .Input(concat_nodeouts) |
| .Input(concat_axis_node) |
| .Attr("T", dtype) |
| .Attr("Tidx", DT_INT32) |
| .Attr("N", num_tensors) |
| .Finalize(graph, &concat_node)); |
| |
| Node* split_dim_node = nullptr; |
| TF_RETURN_IF_ERROR( |
| NodeBuilder(strings::StrCat(iter.first, "/split/split_dim"), "Const") |
| .Attr("dtype", DT_INT32) |
| .Attr("value", dim_tensor) |
| .Attr(kTpuReplicateAttr, cluster_name) |
| .Finalize(graph, &split_dim_node)); |
| |
| Node* split_vec_node = nullptr; |
| TensorShape split_vec_shape; |
| split_vec_shape.AddDim(1); |
| split_vec_shape.set_dim(0, last_dim_vec.size()); |
| |
| Tensor split_vec_tensor(DT_INT32, split_vec_shape); |
| for (int i = 0; i < last_dim_vec.size(); ++i) { |
| split_vec_tensor.flat<int>()(i) = last_dim_vec[i]; |
| } |
| VLOG(3) << "split_vec_tensor: " << split_vec_tensor.DebugString(); |
| |
| TF_RETURN_IF_ERROR( |
| NodeBuilder(strings::StrCat(iter.first, "/split/vec"), "Const") |
| .Attr("dtype", DT_INT32) |
| .Attr("value", split_vec_tensor) |
| .Attr(kTpuReplicateAttr, cluster_name) |
| .Finalize(graph, &split_vec_node)); |
| |
| Node* split_node = nullptr; |
| Node* input_to_split_node = concat_node; |
| Node* output_from_concat_node = nullptr; |
| if (xla_spmd_input_sharded && |
| tpu_replicated_input_info.count(src_name) > 0 && |
| xla_sharding_info.count(src_name) > 0) { |
| // Create new TPUReplicatedInput and XLAShardingOp nodes |
| // |
| // Rewrite the graph from: |
| // Concat -> Split |
| // to |
| // Concat -> TPUReplicatedInput -> XlaSharding -> Split |
| Node* tpu_replicated_input = nullptr; |
| Node* xla_sharding_op = nullptr; |
| |
| std::vector<NodeBuilder::NodeOut> replicated_input; |
| replicated_input.push_back(NodeBuilder::NodeOut(concat_node)); |
| |
| // TODO(b/183060455): Add TPUReplicatedInput to all graphs. |
| TF_RETURN_IF_ERROR( |
| NodeBuilder(strings::StrCat(iter.first, "/TPUReplicatedInput"), |
| "TPUReplicatedInput") |
| .Input(replicated_input) |
| .ControlInput(std::get<1>(tpu_replicated_input_info.at(src_name))) |
| .Attr("N", 1) |
| .Attr("T", std::get<0>(tpu_replicated_input_info.at(src_name))) |
| .Attr("index", -1) |
| .Attr("is_mirrored_variable", false) |
| .Attr("is_packed", false) |
| .Finalize(graph, &tpu_replicated_input)); |
| VLOG(2) << "Created new TPUReplicatedInput node " |
| << tpu_replicated_input->DebugString(); |
| |
| TF_RETURN_IF_ERROR( |
| NodeBuilder(strings::StrCat(iter.first, "/XlaSharding"), |
| "XlaSharding") |
| .Input(tpu_replicated_input) |
| .Attr("T", std::get<0>(xla_sharding_info.at(src_name))) |
| .Attr("sharding", std::get<1>(xla_sharding_info.at(src_name))) |
| .Attr("_XlaSharding", std::get<1>(xla_sharding_info.at(src_name))) |
| .Attr("_tpu_replicate", |
| std::get<2>(xla_sharding_info.at(src_name))) |
| .Finalize(graph, &xla_sharding_op)); |
| VLOG(2) << "Created new XLA sharding node " |
| << xla_sharding_op->DebugString(); |
| |
| input_to_split_node = xla_sharding_op; |
| output_from_concat_node = tpu_replicated_input; |
| } |
| // Update the `tpu_input_shapes` mapping: Add the new edge |
| // from concat to split. |
| TF_RETURN_IF_ERROR( |
| NodeBuilder(strings::StrCat(iter.first, "/split"), "SplitV") |
| .Input(input_to_split_node) |
| .Input(split_vec_node) |
| .Input(split_dim_node) |
| .Attr("T", dtype) |
| .Attr("num_split", num_tensors) |
| .Attr(kTpuReplicateAttr, cluster_name) |
| .Finalize(graph, &split_node)); |
| |
| if (output_from_concat_node == nullptr) |
| output_from_concat_node = split_node; |
| |
| const Edge* concat_to_split; |
| for (const Edge* edge : concat_node->out_edges()) |
| if (edge->dst() == output_from_concat_node) { |
| concat_to_split = edge; |
| break; |
| } |
| if (rank > 1) { |
| for (int d = 0; d < rank - 1; ++d) |
| (*tpu_input_shapes)[concat_to_split].push_back( |
| tpu_input_shapes->at(iter.second.back()).at(d)); |
| } |
| (*tpu_input_shapes)[concat_to_split].push_back( |
| std::accumulate(last_dim_vec.begin(), last_dim_vec.end(), 0)); |
| |
| // Connect split node to original tensor output. |
| for (const Edge* edge : iter.second) { |
| string tensor_name = |
| absl::StrCat(edge->src()->name(), ":", edge->src_output()); |
| int output_index = tensor_to_split_output.at(tensor_name); |
| graph->RemoveEdge(edge); |
| graph->AddEdge(split_node, output_index, edge->dst(), edge->dst_input()); |
| // Update the `tpu_input_shapes` mapping: Remove old edges. |
| tpu_input_shapes->erase(edge); |
| } |
| VLOG(3) << "Concat node: " << concat_node->DebugString(); |
| } |
| return Status::OK(); |
| } |
| |
| // Concatenates input tensors on TPU along the last dimension if all other |
| // dimensions are the same, and split them on CPU to reduce outfeed overhead. |
| // `tpu_inferred_info` maps an edge to the inferred shape of its output tensor. |
| // `shape_to_output` maps tensor name to all edges output from this tensor. |
| Status CreateConcatAndSplitNodesForOutputTensor( |
| Graph* graph, const string& cluster_name, EdgeShapes* tpu_output_shapes, |
| GraphShapeInfo* tpu_inferred_info, GroupedEdges shape_to_output, |
| int32_t minimum_output_tensors_packing) { |
| for (const auto& iter : shape_to_output) { |
| std::vector<int> last_dim_vec; |
| std::vector<NodeBuilder::NodeOut> concat_nodeouts; |
| absl::flat_hash_map<std::string, int> tensor_to_split_output; |
| int rank; |
| DataType dtype = DT_INVALID; |
| for (const Edge* edge : iter.second) { |
| string tensor_name = |
| absl::StrCat(edge->src()->name(), ":", edge->src_output()); |
| |
| // Create Concat / Split pair for a tensor if not exist yet. |
| if (tensor_to_split_output.contains(tensor_name)) continue; |
| tensor_to_split_output[tensor_name] = concat_nodeouts.size(); |
| |
| concat_nodeouts.push_back( |
| NodeBuilder::NodeOut(edge->src(), edge->src_output())); |
| dtype = edge->src()->output_type(edge->src_output()); |
| rank = tpu_output_shapes->at(edge).size(); |
| last_dim_vec.push_back(tpu_output_shapes->at(edge).back()); |
| } |
| |
| const int num_tensors = tensor_to_split_output.size(); |
| if (num_tensors < minimum_output_tensors_packing) { |
| VLOG(3) << "skip concat/split " << iter.first; |
| continue; |
| } |
| |
| Node* concat_axis_node = nullptr; |
| TensorShape t_shape; |
| Tensor dim_tensor(DT_INT32, t_shape); |
| // Concat and Split at the last dim. |
| dim_tensor.flat<int>()(0) = rank - 1; |
| TF_RETURN_IF_ERROR( |
| NodeBuilder(strings::StrCat(iter.first, "/concat/axis"), "Const") |
| .Attr("dtype", DT_INT32) |
| .Attr("value", dim_tensor) |
| .Attr(kTpuReplicateAttr, cluster_name) |
| .Finalize(graph, &concat_axis_node)); |
| |
| Node* concat_node = nullptr; |
| TF_RETURN_IF_ERROR( |
| NodeBuilder(strings::StrCat(iter.first, "/concat"), "ConcatV2") |
| .Input(concat_nodeouts) |
| .Input(concat_axis_node) |
| .Attr("T", dtype) |
| .Attr("Tidx", DT_INT32) |
| .Attr("N", num_tensors) |
| .Attr(kTpuReplicateAttr, cluster_name) |
| .Finalize(graph, &concat_node)); |
| |
| Node* tpu_replicated_output_node = nullptr; |
| TF_RETURN_IF_ERROR( |
| NodeBuilder(strings::StrCat(iter.first, "/tpu_replicated_output"), |
| "TPUReplicatedOutput") |
| .Input(concat_node) |
| .Attr("T", dtype) |
| .Attr("num_replicas", 1) |
| .Finalize(graph, &tpu_replicated_output_node)); |
| |
| Node* split_dim_node = nullptr; |
| TF_RETURN_IF_ERROR( |
| NodeBuilder(strings::StrCat(iter.first, "/split/split_dim"), "Const") |
| .Attr("dtype", DT_INT32) |
| .Attr("value", dim_tensor) |
| .Finalize(graph, &split_dim_node)); |
| |
| Node* split_vec_node = nullptr; |
| TensorShape split_vec_shape; |
| split_vec_shape.AddDim(1); |
| split_vec_shape.set_dim(0, last_dim_vec.size()); |
| |
| Tensor split_vec_tensor(DT_INT32, split_vec_shape); |
| for (int i = 0; i < last_dim_vec.size(); ++i) { |
| split_vec_tensor.flat<int>()(i) = last_dim_vec[i]; |
| } |
| VLOG(3) << "split_vec_tensor: " << split_vec_tensor.DebugString(); |
| |
| TF_RETURN_IF_ERROR( |
| NodeBuilder(strings::StrCat(iter.first, "/split/vec"), "Const") |
| .Attr("dtype", DT_INT32) |
| .Attr("value", split_vec_tensor) |
| .Finalize(graph, &split_vec_node)); |
| |
| Node* split_node = nullptr; |
| TF_RETURN_IF_ERROR( |
| NodeBuilder(strings::StrCat(iter.first, "/split"), "SplitV") |
| .Input(tpu_replicated_output_node) |
| .Input(split_vec_node) |
| .Input(split_dim_node) |
| .Attr("T", dtype) |
| .Attr("num_split", num_tensors) |
| .Finalize(graph, &split_node)); |
| |
| // Update the `tpu_out_shapes` mapping: Add the new edge |
| // from concat to split. |
| const Edge* concat_to_split; |
| for (const Edge* edge : concat_node->out_edges()) |
| if (edge->dst() == split_node) { |
| concat_to_split = edge; |
| break; |
| } |
| |
| if (rank > 1) (*tpu_output_shapes)[concat_to_split].push_back(-1); |
| for (int d = 1; d < rank - 1; ++d) |
| (*tpu_output_shapes)[concat_to_split].push_back( |
| tpu_output_shapes->at(iter.second.back()).at(d)); |
| (*tpu_output_shapes)[concat_to_split].push_back( |
| std::accumulate(last_dim_vec.begin(), last_dim_vec.end(), 0)); |
| |
| for (const Edge* edge : iter.second) { |
| // 1. Find old TPURelicatedOutput output edges |
| std::vector<const Edge*> output_edge_vec; |
| for (const Edge* output_edge : edge->dst()->out_edges()) |
| output_edge_vec.push_back(output_edge); |
| |
| string tensor_name = |
| absl::StrCat(edge->src()->name(), ":", edge->src_output()); |
| int output_index = tensor_to_split_output.at(tensor_name); |
| VLOG(3) << "output_index: " << output_index; |
| |
| // Connect split node to original tensor output. |
| for (const Edge* output_edge : output_edge_vec) { |
| VLOG(3) << "output_edge" << output_edge->DebugString(); |
| graph->RemoveEdge(output_edge); |
| graph->AddEdge(split_node, output_index, output_edge->dst(), |
| output_edge->dst_input()); |
| // Update the `tpu_output_shapes` mapping: Remove old edges. |
| tpu_output_shapes->erase(output_edge); |
| } |
| graph->RemoveNode(edge->dst()); |
| } |
| VLOG(3) << "Concat node: " << concat_node->DebugString(); |
| } |
| return Status::OK(); |
| } |
| |
| Status InsertReshapeNodePairs(Graph* graph, const string& cluster_name, |
| EdgeShapes* tpu_input_shapes, |
| int num_cores_per_replica) { |
| std::vector<const Edge*> tpu_input_edges_original; |
| for (const auto& it : *tpu_input_shapes) |
| if (!it.second.empty()) tpu_input_edges_original.push_back(it.first); |
| for (const Edge* edge : tpu_input_edges_original) { |
| VLOG(3) << "Reshape input: " << edge->DebugString(); |
| |
| // Check if there is a TPUReplicatedInput and XlaSharding in the middle |
| bool xla_sharded_input = false; |
| Node* xla_sharding_node = nullptr; |
| if (edge->dst()->type_string() == "TPUReplicatedInput" && |
| edge->dst()->out_nodes().begin()->type_string() == "XlaSharding") { |
| VLOG(3) << "Detected TPUReplicatedInput " << edge->dst()->DebugString() |
| << " and XlaSharding " |
| << edge->dst()->out_nodes().begin()->DebugString() |
| << ", setting xla_sharded_input = true"; |
| xla_sharded_input = true; |
| xla_sharding_node = *(edge->dst()->out_nodes().begin()); |
| } |
| |
| // 1. Build Reshape node for flatten. |
| |
| // 1.1 Build Const node for shape |
| Node* flatten_reshape_shape_node = nullptr; |
| Tensor flattened_input_shape_tensor; |
| flattened_input_shape_tensor = |
| Tensor(DT_INT32, TensorShape({static_cast<int64>(1)})); |
| flattened_input_shape_tensor.flat<int>()(0) = -1; |
| TF_RETURN_IF_ERROR( |
| NodeBuilder(absl::StrCat(edge->src()->name(), "/flatten/Reshape/shape"), |
| "Const") |
| .Attr("dtype", DT_INT32) |
| .Attr("value", flattened_input_shape_tensor) |
| .Finalize(graph, &flatten_reshape_shape_node)); |
| |
| // 1.2 Build Reshape node for flatten. |
| Node* flatten_reshape_node = nullptr; |
| TF_RETURN_IF_ERROR( |
| NodeBuilder(absl::StrCat(edge->src()->name(), "/flatten/Reshape"), |
| "Reshape") |
| .Input(edge->src(), edge->src_output()) |
| .Input(flatten_reshape_shape_node) |
| .Attr("T", edge->src()->output_type(edge->src_output())) |
| .Attr("Tshape", DT_INT32) |
| .Finalize(graph, &flatten_reshape_node)); |
| |
| // 2. Build Reshape node for recover. |
| |
| // 2.1 Build Const node for shape. |
| Node* recover_reshape_shape_node = nullptr; |
| Tensor original_input_shape_tensor( |
| DT_INT32, |
| TensorShape({static_cast<int64>(tpu_input_shapes->at(edge).size())})); |
| original_input_shape_tensor.flat<int>()(0) = -1; |
| for (int d = 1; d < tpu_input_shapes->at(edge).size(); ++d) |
| original_input_shape_tensor.flat<int>()(d) = |
| tpu_input_shapes->at(edge).at(d); |
| TF_RETURN_IF_ERROR( |
| NodeBuilder(absl::StrCat(edge->src()->name(), "/recover/Reshape/shape"), |
| "Const") |
| .Attr("dtype", DT_INT32) |
| .Attr("value", original_input_shape_tensor) |
| .Attr(kTpuReplicateAttr, cluster_name) // This node is on TPU. |
| .Finalize(graph, &recover_reshape_shape_node)); |
| |
| // 2.2 Build Reshape node for recover. |
| Node* recover_reshape_input_node = flatten_reshape_node; |
| const Edge* original_recover_reshape_input_edge = nullptr; |
| if (xla_sharded_input) { |
| // We want to find the node after the XlaSharding node |
| original_recover_reshape_input_edge = |
| *(edge->dst()->out_nodes().begin()->out_edges().begin()); |
| recover_reshape_input_node = *(edge->dst()->out_nodes().begin()); |
| VLOG(3) << "Recover reshape input node: " |
| << recover_reshape_input_node->DebugString() |
| << ", recover reshape input edge: " |
| << original_recover_reshape_input_edge->DebugString(); |
| } |
| |
| Node* recover_reshape_node = nullptr; |
| TF_RETURN_IF_ERROR( |
| NodeBuilder(absl::StrCat(edge->src()->name(), "/recover/Reshape"), |
| "Reshape") |
| .Input(recover_reshape_input_node) |
| .Input(recover_reshape_shape_node) |
| .Attr("T", edge->src()->output_type(edge->src_output())) |
| .Attr("Tshape", DT_INT32) |
| .Attr(kTpuReplicateAttr, cluster_name) // This node is on TPU. |
| .Finalize(graph, &recover_reshape_node)); |
| |
| // 3. Rewrite XlaSharding attribute if necessary |
| if (xla_sharding_node != nullptr) { |
| // The flattened tensor always has rank = 1 and we want to shard the only |
| // dimension (0). |
| SetXlaShardingNodeAttr(xla_sharding_node, num_cores_per_replica, 1, 0); |
| } |
| |
| // 4. Connect / disconnect nodes. |
| if (xla_sharded_input) { |
| graph->AddEdge(flatten_reshape_node, 0, edge->dst(), edge->dst_input()); |
| } |
| |
| if (original_recover_reshape_input_edge != nullptr) { |
| graph->AddEdge(recover_reshape_node, 0, |
| original_recover_reshape_input_edge->dst(), |
| original_recover_reshape_input_edge->dst_input()); |
| } else { |
| graph->AddEdge(recover_reshape_node, 0, edge->dst(), edge->dst_input()); |
| } |
| |
| graph->RemoveEdge(edge); |
| if (original_recover_reshape_input_edge != nullptr) { |
| graph->RemoveEdge(original_recover_reshape_input_edge); |
| } |
| |
| // 4. Update EdgeShapes. |
| int dimension = 1; |
| for (auto& it : (*tpu_input_shapes)[edge]) { |
| dimension *= it; |
| } |
| VLOG(3) << "Dimension after reshape: " << dimension; |
| for (const Edge* out_edge : flatten_reshape_node->out_edges()) { |
| if (out_edge->dst() == recover_reshape_node) { |
| (*tpu_input_shapes)[out_edge].push_back(dimension); |
| tpu_input_shapes->erase(edge); |
| break; |
| } |
| } |
| VLOG(3) << "Reshape optimization done for " << edge->src()->name(); |
| } |
| return Status::OK(); |
| } |
| } // namespace tpu_functional_internal |
| |
| void TPUPartitionedCallOp::ComputeAsync(OpKernelContext* ctx, |
| DoneCallback done) { |
| Status init_status; |
| absl::call_once(once_, [&]() { |
| library_runtime_ = ctx->function_library(); |
| if (library_runtime_ == nullptr) { |
| init_status = errors::Internal("No function library is provided."); |
| return; |
| } |
| flib_def_ = std::make_unique<FunctionLibraryDefinition>( |
| *library_runtime_->GetFunctionLibraryDefinition()); |
| device_mgr_ = library_runtime_->device_mgr(); |
| for (auto d : device_mgr_->ListDevices()) { |
| device_set_.AddDevice(d); |
| } |
| |
| DeviceNameUtils::ParsedName tpu_device_name; |
| tpu_device_name.has_type = true; |
| tpu_device_name.type = "TPU"; |
| std::vector<Device*> tpu_devices; |
| device_set_.FindMatchingDevices(tpu_device_name, &tpu_devices_); |
| }); |
| OP_REQUIRES_OK_ASYNC(ctx, init_status, done); |
| |
| // Initialize the ordinal selector with information from the graph if it is |
| // the first time we are running this op. |
| absl::call_once(ordinal_selector_once_, [&]() { |
| std::unique_ptr<Graph> graph(new Graph(flib_def_.get())); |
| int num_cores_per_replica = 1; |
| bool enable_spmd_xla_partitioning = false; |
| { |
| absl::MutexLock l(&mu_); |
| OP_REQUIRES_OK_ASYNC( |
| ctx, |
| GetGraphFromFunction(graph.get(), /*device_ordinal=*/0, |
| &num_cores_per_replica, |
| &enable_spmd_xla_partitioning), |
| done); |
| } |
| if (enable_spmd_xla_partitioning) { |
| ordinal_selector_ = |
| std::make_shared<tpu::TPUOrdinalSelector>(num_cores_per_replica); |
| } else { |
| ordinal_selector_ = std::make_shared<tpu::TPUOrdinalSelector>(); |
| } |
| |
| metrics::RecordTPUXlaSpmdCoresPerReplica(num_cores_per_replica); |
| }); |
| uint64 input_hash = GetInputHash(ctx); |
| int64_t ordinal_selector_req_id = -1; |
| // Select a TPU core. |
| absl::ReleasableMutexLock lock(&mu_); |
| int32_t device_ordinal = 0; |
| OP_REQUIRES_OK_ASYNC( |
| ctx, |
| GetTpuCoreOrdinal(ctx, input_hash, &ordinal_selector_req_id, |
| &device_ordinal), |
| done); |
| uint64 cache_hash = Hash64Combine(input_hash, device_ordinal); |
| |
| const std::vector<DeviceAndFHandle>* functions; |
| |
| bool cache_miss = !partition_cache_.count(cache_hash); |
| if (cache_miss) { |
| VLOG(3) << "Cache Miss: partitioning function " << func_.name() |
| << " cache_hash: " << cache_hash |
| << " device_ordinal: " << device_ordinal; |
| std::unique_ptr<Graph> graph(new Graph(flib_def_.get())); |
| int num_cores_per_replica = 1; |
| bool enable_spmd_xla_partitioning = false; |
| OP_REQUIRES_OK_ASYNC(ctx, |
| GetGraphFromFunction(graph.get(), device_ordinal, |
| &num_cores_per_replica, |
| &enable_spmd_xla_partitioning), |
| done); |
| |
| VLOG(1) << DumpGraphToFile("before_input_output_optimizations", *graph, |
| flib_def_.get()); |
| |
| std::map<std::string, std::vector<int>> named_input_shapes; |
| OP_REQUIRES_OK_ASYNC(ctx, |
| OptimizeTpuInputOutputTensors( |
| graph.get(), enable_spmd_xla_partitioning, |
| num_cores_per_replica, named_input_shapes, ctx), |
| done); |
| |
| VLOG(1) << DumpGraphToFile( |
| "before_replace_resource_args_with_var_handle_ops", *graph, |
| flib_def_.get()); |
| OP_REQUIRES_OK_ASYNC( |
| ctx, |
| ReplaceResourceArgsWithVarHandleOps(graph.get(), ctx, device_ordinal, |
| num_cores_per_replica, |
| enable_spmd_xla_partitioning), |
| done); |
| |
| VLOG(1) << DumpGraphToFile( |
| "after_replace_resource_args_with_var_handle_ops", *graph, |
| flib_def_.get()); |
| |
| // Graph rewrite passes. |
| GraphOptimizationPassOptions optimization_options; |
| // TODO(akshayka): Thread the SessionOptions into this kernel, or make |
| // it possible to specify the relevant options via attributes. |
| SessionOptions session_options; |
| session_options.config.mutable_experimental() |
| ->set_xla_fusion_autotuner_thresh(autotuner_thresh_); |
| |
| session_options.env = ctx->env(); |
| optimization_options.session_handle = ctx->session_handle(); |
| optimization_options.session_options = &session_options; |
| optimization_options.graph = &graph; |
| optimization_options.flib_def = flib_def_.get(); |
| optimization_options.device_set = &device_set_; |
| OP_REQUIRES_OK_ASYNC( |
| ctx, PlacementHelper(device_set_, optimization_options, func_.name()), |
| done); |
| |
| if (!enable_spmd_xla_partitioning || num_cores_per_replica == 1) { |
| OP_REQUIRES_OK_ASYNC( |
| ctx, |
| MaybeRegisterFingerprint(graph.get(), named_input_shapes, input_hash), |
| done); |
| } |
| // `subgraphs` maps from device names to functions. |
| std::unordered_map<std::string, std::unique_ptr<Graph>> subgraphs; |
| optimization_options.graph = nullptr; |
| optimization_options.device_set = nullptr; |
| optimization_options.partition_graphs = &subgraphs; |
| VLOG(1) << DumpGraphToFile("before_partition_helper.pbtxt", *graph, |
| flib_def_.get()); |
| OP_REQUIRES_OK_ASYNC(ctx, |
| PartitionHelper(device_set_, optimization_options, |
| graph.get(), &subgraphs), |
| done); |
| OP_REQUIRES_OK_ASYNC(ctx, |
| InstantiateFunctionsFromSubgraphs( |
| device_set_, device_ordinal, cache_hash, |
| num_cores_per_replica, std::move(subgraphs)), |
| done); |
| } |
| functions = &partition_cache_[cache_hash]; |
| lock.Release(); |
| |
| ExecuteFunctions(*functions, ctx, device_ordinal, ordinal_selector_req_id, |
| std::move(done)); |
| } |
| |
| Status TPUPartitionedCallOp::GetTpuCoreOrdinal(OpKernelContext* ctx, |
| uint64 input_hash, |
| int64_t* ordinal_selector_req_id, |
| int32_t* core_ordinal) { |
| profiler::TraceMe trace_me("TPUPartitionedCallOp-GetTpuCoreOrdinal"); |
| const Tensor* device_ordinal_t; |
| TF_RETURN_IF_ERROR(ctx->input(kDeviceOrdinalAttr, &device_ordinal_t)); |
| int device_ordinal = device_ordinal_t->scalar<int>()(); |
| if (device_ordinal == tpu::kDeferredCoreSelectionReserved) { |
| device_ordinal = |
| ordinal_selector_->GetOrdinal(input_hash, ordinal_selector_req_id); |
| } |
| *core_ordinal = device_ordinal; |
| return Status::OK(); |
| } |
| |
| Status TPUPartitionedCallOp::InitializeVarOnTPU( |
| OpKernelContext* ctx, const core::RefCountPtr<Var>& var, NodeDef* ndef, |
| int device_ordinal, bool fast_mem) { |
| const string device = strings::StrCat(kTPUDeviceNamePrefix, device_ordinal); |
| Status status; |
| std::unique_ptr<Graph> init_graph(new Graph(OpRegistry::Global())); |
| Node* init_handle = init_graph->AddNode(*ndef, &status); |
| TF_RETURN_IF_ERROR(status); |
| init_handle->set_assigned_device_name(device); |
| |
| NodeDef init_const_ndef; |
| init_const_ndef.set_name("initial_value"); |
| if (fast_mem) { |
| init_const_ndef.set_op("_TPUConst"); |
| AddNodeAttr("memory_space", "FastMem", &init_const_ndef); |
| } else { |
| init_const_ndef.set_op("Const"); |
| } |
| init_const_ndef.set_device(device); |
| AddNodeAttr("dtype", var->tensor()->dtype(), &init_const_ndef); |
| AddNodeAttr("value", *var->tensor(), &init_const_ndef); |
| |
| Node* init_const = init_graph->AddNode(init_const_ndef, &status); |
| TF_RETURN_IF_ERROR(status); |
| |
| NodeDef assign_node_def; |
| assign_node_def.set_name("Assign"); |
| assign_node_def.set_op("AssignVariableOp"); |
| assign_node_def.set_device(device); |
| AddNodeAttr("dtype", var->tensor()->dtype(), &assign_node_def); |
| Node* init_assign = init_graph->AddNode(assign_node_def, &status); |
| TF_RETURN_IF_ERROR(status); |
| |
| init_graph->AddEdge(init_handle, 0, init_assign, 0); |
| init_graph->AddEdge(init_const, 0, init_assign, 1); |
| FHandle fhandle; |
| const string fname = |
| strings::StrCat(ndef->name(), "_init_ord_", device_ordinal); |
| |
| TF_RETURN_IF_ERROR( |
| InstantiatePartition(*init_graph, fname, device, &fhandle, nullptr)); |
| |
| FunctionLibraryRuntime::Options opts; |
| opts.step_container = ctx->step_container(); |
| opts.cancellation_manager = ctx->cancellation_manager(); |
| opts.stats_collector = ctx->stats_collector(); |
| |
| // Blocking on threads in the same thread pool is disallowed because |
| // concurrent warm-up requests can exhaust the default thread pool. |
| // Create a new thread pool to initialize variables on TPU. |
| std::function<void(std::function<void()>)> runner = |
| [this](std::function<void()> fn) { pool_.Schedule(fn); }; |
| opts.runner = &runner; |
| |
| opts.source_device = local_device_name_; |
| PrivateIntraProcessRendezvous rendez(device_mgr_); |
| opts.rendezvous = &rendez; |
| opts.remote_execution = true; |
| |
| std::vector<Tensor> dummy_args; |
| std::vector<Tensor>* dummy_rets = new std::vector<Tensor>; |
| Notification done; |
| profiler::TraceMe trace_me("TPUPartitionedCallOp-InitializeVarOnTPU"); |
| library_runtime_->Run(opts, fhandle, dummy_args, dummy_rets, |
| [dummy_rets, &done, ctx](const Status& status) { |
| if (!status.ok()) { |
| ctx->SetStatus(status); |
| } |
| delete dummy_rets; |
| done.Notify(); |
| }); |
| done.WaitForNotification(); |
| // We don't actually want the variable initialization functions |
| // in the function library definition and the function library |
| // runtime, because flib_def_ is used for the graph rewrite passes. |
| // The TPU distributed rewrite pass computes a fingerprint for |
| // flib_def_, which will throw an length error if there are |
| // many variables whose initialization functions are added |
| // to the library definition. |
| TF_RETURN_IF_ERROR(flib_def_->RemoveFunction(fname)); |
| TF_RETURN_IF_ERROR(library_runtime_->ReleaseHandle(fhandle)); |
| return Status::OK(); |
| } |
| |
| Status TPUPartitionedCallOp::InitializeShardedVarOnTPU( |
| OpKernelContext* ctx, const core::RefCountPtr<Var>& var, |
| std::vector<NodeDef>& ndefs, int split_dim, int device_ordinal) { |
| std::unique_ptr<Graph> init_graph(new Graph(OpRegistry::Global())); |
| int num_cores = ndefs.size(); |
| string cpu_device = "/device:CPU:0"; |
| |
| Status status; |
| std::vector<std::string> devices; |
| std::vector<Node*> init_handles; |
| for (int i = 0; i < num_cores; i++) { |
| Node* init_handle = init_graph->AddNode(ndefs[i], &status); |
| TF_RETURN_IF_ERROR(status); |
| string device = strings::StrCat(kTPUDeviceNamePrefix, device_ordinal + i); |
| init_handle->set_assigned_device_name(device); |
| devices.push_back(device); |
| init_handles.push_back(init_handle); |
| } |
| |
| NodeDef init_const_ndef; |
| init_const_ndef.set_name("initial_value"); |
| init_const_ndef.set_op("Const"); |
| init_const_ndef.set_device(cpu_device); |
| AddNodeAttr("dtype", var->tensor()->dtype(), &init_const_ndef); |
| AddNodeAttr("value", *var->tensor(), &init_const_ndef); |
| Node* init_const = init_graph->AddNode(init_const_ndef, &status); |
| init_const->set_assigned_device_name(cpu_device); |
| TF_RETURN_IF_ERROR(status); |
| |
| Node* assign_value_node = init_const; |
| // If the variable is sharded, we will insert "Split" node between the initial |
| // value and AssignVariableOp, so the variables on each TPU device get |
| // assigned to the splitted value. |
| // |
| // initial_value--Split--AssignVariableOp ("/device:TPU:0") |
| // | |
| // AssignVariableOp ("/device:TPU:1") |
| if (split_dim >= 0) { |
| // Add a split dimension node. |
| NodeDef split_dim_def; |
| split_dim_def.set_name("initial_value_split_dim"); |
| split_dim_def.set_op("Const"); |
| split_dim_def.set_device(cpu_device); |
| AddNodeAttr("dtype", DT_INT32, &split_dim_def); |
| TensorProto tensor_proto; |
| tensor_proto.set_dtype(DT_INT32); |
| tensor_proto.add_int_val(split_dim); |
| TensorShape shape({}); |
| shape.AsProto(tensor_proto.mutable_tensor_shape()); |
| AddNodeAttr("value", tensor_proto, &split_dim_def); |
| Node* split_dim_node = init_graph->AddNode(split_dim_def, &status); |
| split_dim_node->set_assigned_device_name(cpu_device); |
| TF_RETURN_IF_ERROR(status); |
| |
| // Add a split node. |
| NodeDef split_def; |
| int split_num = ndefs.size(); |
| split_def.set_name("initial_value_split"); |
| split_def.set_op("Split"); |
| split_def.set_device(cpu_device); |
| AddNodeAttr("num_split", split_num, &split_def); |
| AddNodeAttr("T", var->tensor()->dtype(), &split_def); |
| split_def.add_input(absl::StrCat(split_dim_node->name(), ":0")); |
| split_def.add_input(absl::StrCat(init_const->name(), ":0")); |
| Node* split_node = init_graph->AddNode(split_def, &status); |
| split_node->set_assigned_device_name(cpu_device); |
| TF_RETURN_IF_ERROR(status); |
| |
| init_graph->AddEdge(split_dim_node, 0, split_node, 0); |
| init_graph->AddEdge(init_const, 0, split_node, 1); |
| |
| assign_value_node = split_node; |
| } |
| |
| for (int i = 0; i < num_cores; i++) { |
| NodeDef assign_node_def; |
| assign_node_def.set_name(absl::StrCat("Assign_", i)); |
| assign_node_def.set_op("AssignVariableOp"); |
| assign_node_def.set_device(devices[i]); |
| AddNodeAttr("dtype", var->tensor()->dtype(), &assign_node_def); |
| Node* init_assign = init_graph->AddNode(assign_node_def, &status); |
| init_assign->set_assigned_device_name(devices[i]); |
| TF_RETURN_IF_ERROR(status); |
| |
| init_graph->AddEdge(init_handles[i], 0, init_assign, 0); |
| if (split_dim >= 0) { |
| init_graph->AddEdge(assign_value_node, i, init_assign, 1); |
| } else { |
| init_graph->AddEdge(assign_value_node, 0, init_assign, 1); |
| } |
| } |
| |
| GraphOptimizationPassOptions optimization_options; |
| SessionOptions session_options; |
| session_options.env = ctx->env(); |
| optimization_options.session_handle = ctx->session_handle(); |
| optimization_options.session_options = &session_options; |
| optimization_options.flib_def = flib_def_.get(); |
| optimization_options.graph = nullptr; |
| optimization_options.device_set = nullptr; |
| std::unordered_map<std::string, std::unique_ptr<Graph>> subgraphs; |
| optimization_options.partition_graphs = &subgraphs; |
| TF_RETURN_IF_ERROR(PartitionHelper(device_set_, optimization_options, |
| init_graph.get(), &subgraphs)); |
| |
| std::vector<DeviceAndFHandle> functions; |
| std::vector<std::string> function_names; |
| for (auto& pair : subgraphs) { |
| string target = pair.first; |
| Device* device; |
| TF_RETURN_IF_ERROR( |
| library_runtime_->device_mgr()->LookupDevice(target, &device)); |
| Graph* subgraph = pair.second.get(); |
| string function_name = flib_def_->UniqueFunctionName( |
| strings::StrCat(func_.name(), "_hash_", pair.first)); |
| function_names.push_back(function_name); |
| FHandle handle; |
| TF_RETURN_IF_ERROR(InstantiatePartition(*subgraph, function_name, target, |
| &handle, nullptr)); |
| functions.push_back(DeviceAndFHandle{.device = target, .handle = handle}); |
| } |
| |
| FunctionLibraryRuntime::Options opts; |
| |
| // Blocking on threads in the same thread pool is disallowed because |
| // concurrent warm-up requests can exhaust the default thread pool. |
| // Create a new thread pool to initialize variables on TPU. |
| std::function<void(std::function<void()>)> runner = |
| [this](std::function<void()> fn) { pool_.Schedule(fn); }; |
| opts.runner = &runner; |
| |
| opts.step_container = ctx->step_container(); |
| opts.cancellation_manager = ctx->cancellation_manager(); |
| opts.stats_collector = ctx->stats_collector(); |
| opts.source_device = local_device_name_; |
| opts.run_all_kernels_inline = ctx->run_all_kernels_inline(); |
| |
| OpInputList arguments; |
| TF_RETURN_IF_ERROR(ctx->input_list("args", &arguments)); |
| |
| auto* rendez = new PrivateIntraProcessRendezvous(device_mgr_); |
| opts.rendezvous = rendez; |
| |
| BlockingCounter bcount(functions.size()); |
| for (const DeviceAndFHandle& entry : functions) { |
| const string& target_device = entry.device; |
| FHandle handle = entry.handle; |
| |
| TF_RETURN_IF_ERROR( |
| ShouldUseRemoteExecutionForFn(target_device, &(opts.remote_execution))); |
| std::vector<Tensor> dummy_args; |
| std::vector<Tensor>* dummy_rets = new std::vector<Tensor>; |
| |
| profiler::TraceMe trace_me( |
| "TPUPartitionedCallOp-InitializeShardedVarOnTPU"); |
| library_runtime_->Run(opts, handle, dummy_args, dummy_rets, |
| [dummy_rets, &bcount, ctx](const Status& status) { |
| if (!status.ok()) { |
| ctx->SetStatus(status); |
| } |
| delete dummy_rets; |
| bcount.DecrementCount(); |
| }); |
| } |
| bcount.Wait(); |
| |
| for (int i = 0; i < functions.size(); i++) { |
| TF_RETURN_IF_ERROR(flib_def_->RemoveFunction(function_names[i])); |
| TF_RETURN_IF_ERROR(library_runtime_->ReleaseHandle(functions[i].handle)); |
| } |
| return Status::OK(); |
| } |
| |
| bool TPUPartitionedCallOp::IsInputToTPUReplicate(Node* node) { |
| for (Node* successor : node->out_nodes()) { |
| if (successor->attrs().Find(kTpuReplicateAttr) != nullptr) { |
| return true; |
| } |
| } |
| return false; |
| } |
| |
| Status TPUPartitionedCallOp::ReplaceResourceArgsWithVarHandleOps( |
| Graph* graph, OpKernelContext* ctx, int device_ordinal, |
| int num_cores_per_replica, bool enable_spmd_xla_partitioning) { |
| // Currently variable deduplication is not supported for XLA SPMD |
| // partitioning. It is possible that it could be supported in the future. |
| const bool enable_variable_deduplication = |
| runtime_params_.enable_variable_deduplication; |
| if (enable_spmd_xla_partitioning && enable_variable_deduplication) { |
| // If enable_spmd_xla_partitioning is true, the user set the |
| // enable_auto_xla_input_sharding flag. Warn them that only one of the flags |
| // can be set safely. |
| return errors::InvalidArgument( |
| "The following flags are incompatible: enable_auto_xla_input_sharding " |
| "and enable_variable_deduplication. Only enable one of the flags."); |
| } |
| std::vector<Node*> tpu_resource_args; |
| std::vector<int> arg_indices; |
| absl::flat_hash_map<const Node*, xla::OpSharding> variable_to_xla_sharding; |
| for (Node* node : graph->op_nodes()) { |
| if (node->IsArg()) { |
| const AttrValue* attr_value; |
| TF_RETURN_IF_ERROR(node->attrs().Find("T", &attr_value)); |
| DataType dtype = attr_value->type(); |
| if (dtype == DT_RESOURCE && IsInputToTPUReplicate(node)) { |
| // If this VarHandleOp is used by a TPU computation, |
| // we need to create a TPU version of the variable, |
| TF_RETURN_IF_ERROR(node->attrs().Find("index", &attr_value)); |
| int index = attr_value->i(); |
| tpu_resource_args.push_back(node); |
| arg_indices.push_back(index); |
| replaced_input_indices_[index] = true; |
| } |
| } |
| } |
| |
| VLOG(3) << "tpu_resource_args.size(): " << tpu_resource_args.size(); |
| // Create a mapping from ResourceHandle to variable node. When a |
| // ResourceHandle backs several variable nodes, the variable nodes refer to |
| // the same underlying resource. In that case, only one variable node needs |
| // to be mirrored to the TPU for that resource. |
| absl::flat_hash_map<uint64, Node*> tpu_variables; |
| for (int i = 0; i < tpu_resource_args.size(); i++) { |
| Node* node = tpu_resource_args[i]; |
| ResourceHandle handle = HandleFromInput(ctx, arg_indices[i]); |
| |
| if (num_cores_per_replica > 1 && enable_spmd_xla_partitioning) { |
| TF_RETURN_IF_ERROR(ReplaceAndPartitionXLAShardingVariable( |
| graph, ctx, device_ordinal, handle, node, num_cores_per_replica)); |
| continue; |
| } |
| TPUVariableInfo var_info(/*device_ordinal_id=*/0, /*use_fast_mem=*/false); |
| TF_RETURN_IF_ERROR( |
| ParseTPUVariableInfor(node, num_cores_per_replica, &var_info)); |
| // Only respect graph's placement when model parallelism enabled. |
| if (num_cores_per_replica > 1) device_ordinal = var_info.device_ordinal; |
| |
| const uint64 handle_fp = |
| Fingerprint64(strings::StrCat(handle.container(), handle.name())); |
| if (enable_variable_deduplication && tpu_variables.contains(handle_fp) && |
| num_cores_per_replica == 1) { |
| Node* tpu_variable = tpu_variables.at(handle_fp); |
| std::vector<Node*> dst_nodes; |
| std::vector<int> src_indices; |
| std::vector<int> dst_indices; |
| for (const Edge* edge : node->out_edges()) { |
| dst_nodes.push_back(edge->dst()); |
| src_indices.push_back(edge->src_output()); |
| dst_indices.push_back(edge->dst_input()); |
| } |
| graph->RemoveNode(node); |
| for (int i = 0; i < dst_nodes.size(); i++) { |
| graph->AddEdge(tpu_variable, src_indices[i], dst_nodes[i], |
| dst_indices[i]); |
| } |
| } else { |
| uint64 fp = |
| Fingerprint64(strings::StrCat(handle.container(), handle.name(), i)); |
| NodeDef ndef; |
| ndef.set_name(strings::StrCat(handle.name(), fp)); |
| ndef.set_op(kVarHandleOp); |
| if (num_cores_per_replica > 1) { |
| ndef.set_device(strings::StrCat(kTPUDeviceNamePrefix, device_ordinal)); |
| } else { |
| // Assign this new VarHandleOp to TPU:0 so the partitioner only |
| // partiitons the graph into two subgraphs, one on CPU and one on TPU. |
| // The actual device ordinal on which this VarHandleOp runs is assigned |
| // after partitioning (in SetDeviceOrdinal). |
| ndef.set_device( |
| strings::StrCat(kTPUDeviceNamePrefix, kTPUDefaultDeviceOrdinal)); |
| } |
| |
| // Replace each _Arg node of type DT_RESOURCE that goes into a TPU node |
| // by a VarHandleOp on TPU with shared_name "v_tpu_x" where "v" is the |
| // shared_name of the variable on CPU and "x" is the rewritten device |
| // ordinal. |
| const string sname = |
| strings::StrCat(handle.name(), "_tpu_", device_ordinal); |
| AddNodeAttr("shared_name", sname, &ndef); |
| const string cname = ctx->resource_manager()->default_container(); |
| AddNodeAttr("container", cname, &ndef); |
| core::RefCountPtr<Var> var; |
| TF_RETURN_IF_ERROR(LookupResource(ctx, handle, &var)); |
| AddNodeAttr("dtype", var->tensor()->dtype(), &ndef); |
| TensorShapeProto proto; |
| var->tensor()->shape().AsProto(&proto); |
| AddNodeAttr("shape", proto, &ndef); |
| Status status; |
| Node* new_node = graph->AddNode(ndef, &status); |
| TF_RETURN_IF_ERROR(status); |
| std::vector<const Edge*> in_edges(node->in_edges().begin(), |
| node->in_edges().end()); |
| for (const Edge* edge : in_edges) { |
| graph->AddEdge(edge->src(), edge->src_output(), new_node, |
| edge->dst_input()); |
| } |
| std::vector<Node*> dst_nodes; |
| std::vector<int> src_indices; |
| std::vector<int> dst_indices; |
| for (const Edge* edge : node->out_edges()) { |
| dst_nodes.push_back(edge->dst()); |
| src_indices.push_back(edge->src_output()); |
| dst_indices.push_back(edge->dst_input()); |
| } |
| graph->RemoveNode(node); |
| for (int i = 0; i < dst_nodes.size(); i++) { |
| graph->AddEdge(new_node, src_indices[i], dst_nodes[i], dst_indices[i]); |
| } |
| // Don't initialize variables on TPU if it is done for the ordinal |
| // already. |
| if (seen_ordinals_.contains(device_ordinal)) continue; |
| |
| Device* d; |
| TF_RETURN_IF_ERROR(library_runtime_->device_mgr()->LookupDevice( |
| strings::StrCat(kTPUDeviceNamePrefix, device_ordinal), &d)); |
| Var* tpu_var; |
| status = d->resource_manager()->Lookup(cname, sname, &tpu_var); |
| if (!status.ok()) { |
| TF_RETURN_IF_ERROR(InitializeVarOnTPU(ctx, var, &ndef, device_ordinal, |
| var_info.fast_mem)); |
| } |
| tpu_variables[handle_fp] = new_node; |
| } |
| } |
| |
| // adjust the index attr of other non-resource arg nodes |
| int new_index = 0; |
| for (Node* node : graph->op_nodes()) { |
| if (node->IsArg()) { |
| node->ClearAttr("index"); |
| node->AddAttr("index", new_index); |
| new_index++; |
| } |
| } |
| |
| seen_ordinals_.insert(device_ordinal); |
| |
| return Status::OK(); |
| } |
| |
| Status TPUPartitionedCallOp::ReplaceAndPartitionXLAShardingVariable( |
| Graph* graph, OpKernelContext* ctx, int device_ordinal, |
| ResourceHandle& handle, Node* variable, int num_cores_per_replica) { |
| TF_ASSIGN_OR_RETURN( |
| auto sharding, |
| GetShardingFromNodeDef(variable->def(), /*add_metadata=*/false)); |
| xla::OpSharding xla_sharding; |
| bool is_var_sharded = false; |
| if (sharding.has_value() && |
| sharding.value().type() == xla::OpSharding::OTHER) { |
| xla_sharding = sharding.value(); |
| is_var_sharded = true; |
| } else { |
| xla_sharding.set_type(xla::OpSharding::REPLICATED); |
| is_var_sharded = false; |
| } |
| VLOG(3) << "Replace and partition variable " << variable->name() |
| << " with xla_sharding: " << xla_sharding.DebugString(); |
| |
| core::RefCountPtr<Var> var; |
| TF_RETURN_IF_ERROR(LookupResource(ctx, handle, &var)); |
| |
| int split_dim = -1; |
| int split_size = 0; |
| if (is_var_sharded) { |
| for (int dim = 0; dim < xla_sharding.tile_assignment_dimensions_size(); |
| dim++) { |
| if (xla_sharding.tile_assignment_dimensions(dim) > 1) { |
| if (split_dim != -1) { |
| return errors::InvalidArgument( |
| "Currently we only support inference with one split dimension, " |
| "however got sharding: ", |
| xla_sharding.DebugString()); |
| } |
| split_dim = dim; |
| split_size = xla_sharding.tile_assignment_dimensions(dim); |
| } |
| } |
| } |
| const string cname = ctx->resource_manager()->default_container(); |
| std::vector<Node*> per_core_vars; |
| for (int core_index = device_ordinal; |
| core_index < (device_ordinal + num_cores_per_replica); core_index++) { |
| NodeDef ndef; |
| uint64 fp = Fingerprint64( |
| strings::StrCat(handle.container(), handle.name(), "_", core_index)); |
| ndef.set_name(strings::StrCat(handle.name(), fp)); |
| ndef.set_op(kVarHandleOp); |
| ndef.set_device(strings::StrCat(kTPUDeviceNamePrefix, core_index)); |
| |
| // Replace each _Arg node of type DT_RESOURCE that goes into a TPU node |
| // by a VarHandleOp on TPU with shared_name "v_tpu_x" where "v" is the |
| // shared_name of the variable on CPU and "x" is the rewritten device |
| // ordinal. |
| const string sname = strings::StrCat(handle.name(), "_tpu_", core_index); |
| AddNodeAttr("shared_name", sname, &ndef); |
| AddNodeAttr("container", cname, &ndef); |
| AddNodeAttr("dtype", var->tensor()->dtype(), &ndef); |
| |
| TensorShapeProto proto; |
| var->tensor()->shape().AsProto(&proto); |
| |
| if (is_var_sharded) { |
| int dim_size = proto.dim(split_dim).size(); |
| if (dim_size % split_size != 0) { |
| return errors::InvalidArgument("dimension size ", dim_size, |
| " cannot be divisible by split size ", |
| split_size); |
| } |
| proto.mutable_dim(split_dim)->set_size(dim_size / split_size); |
| } |
| AddNodeAttr("shape", proto, &ndef); |
| |
| Status status; |
| Node* new_node = graph->AddNode(ndef, &status); |
| TF_RETURN_IF_ERROR(status); |
| per_core_vars.push_back(new_node); |
| } |
| |
| // Insert TPUPartitionedInput op. |
| NodeDefBuilder builder(absl::StrCat(handle.name(), "/tpu_partitioned_input"), |
| "TPUPartitionedInput"); |
| builder.Attr("N", num_cores_per_replica); |
| builder.Attr("T", DT_RESOURCE); |
| builder.Attr("partition_dim", split_dim); |
| builder.Attr("_XlaSharding", xla_sharding.SerializeAsString()); |
| std::vector<NodeDefBuilder::NodeOut> inputs; |
| inputs.reserve(num_cores_per_replica); |
| for (int core_index = 0; core_index < num_cores_per_replica; core_index++) { |
| inputs.push_back({per_core_vars[core_index]->name(), 0, DT_RESOURCE}); |
| } |
| builder.Input(inputs); |
| NodeDef node_def; |
| TF_RETURN_IF_ERROR(builder.Finalize(&node_def)); |
| Status s; |
| Node* tpu_partitioned_input_node = graph->AddNode(node_def, &s); |
| if (!s.ok()) { |
| return s; |
| } |
| |
| for (int core_index = 0; core_index < num_cores_per_replica; core_index++) { |
| graph->AddEdge(per_core_vars[core_index], 0, tpu_partitioned_input_node, |
| core_index); |
| } |
| |
| // Insert TPUReplicatedInput op. |
| NodeDefBuilder replicated_builder( |
| absl::StrCat(handle.name(), "/tpu_replicated_input"), |
| "TPUReplicatedInput"); |
| replicated_builder.Attr("N", 1); |
| replicated_builder.Attr("T", DT_RESOURCE); |
| replicated_builder.Attr("is_mirrored_variable", true); |
| std::vector<NodeDefBuilder::NodeOut> replicated_inputs; |
| replicated_inputs.push_back( |
| {tpu_partitioned_input_node->name(), 0, DT_RESOURCE}); |
| replicated_builder.Input(replicated_inputs); |
| NodeDef replicated_node_def; |
| TF_RETURN_IF_ERROR(replicated_builder.Finalize(&replicated_node_def)); |
| Status replicated_s; |
| Node* tpu_replicated_input_node = |
| graph->AddNode(replicated_node_def, &replicated_s); |
| if (!replicated_s.ok()) { |
| return replicated_s; |
| } |
| graph->AddEdge(tpu_partitioned_input_node, 0, tpu_replicated_input_node, 0); |
| |
| // Connect the TPUReplicatedInput node to the previous output nodes of the |
| // variable, and remove the variable node. |
| std::vector<Node*> dst_nodes; |
| std::vector<int> src_indices; |
| std::vector<int> dst_indices; |
| for (const Edge* edge : variable->out_edges()) { |
| dst_nodes.push_back(edge->dst()); |
| src_indices.push_back(edge->src_output()); |
| dst_indices.push_back(edge->dst_input()); |
| } |
| for (int i = 0; i < dst_nodes.size(); i++) { |
| graph->AddEdge(tpu_replicated_input_node, src_indices[i], dst_nodes[i], |
| dst_indices[i]); |
| } |
| |
| graph->RemoveNode(variable); |
| |
| std::vector<NodeDef> ndefs; |
| Status status; |
| for (int core_index = 0; core_index < num_cores_per_replica; core_index++) { |
| Device* d; |
| TF_RETURN_IF_ERROR(library_runtime_->device_mgr()->LookupDevice( |
| strings::StrCat(kTPUDeviceNamePrefix, device_ordinal + core_index), |
| &d)); |
| string sname; |
| const NodeDef& ndef = per_core_vars[core_index]->def(); |
| TF_RETURN_IF_ERROR(GetNodeAttr(ndef, "shared_name", &sname)); |
| ndefs.push_back(ndef); |
| Var* tpu_var; |
| status = d->resource_manager()->Lookup(cname, sname, &tpu_var); |
| } |
| |
| if (!status.ok()) { |
| TF_RETURN_IF_ERROR( |
| InitializeShardedVarOnTPU(ctx, var, ndefs, split_dim, device_ordinal)); |
| } |
| |
| return Status::OK(); |
| } |
| |
| Status TPUPartitionedCallOp::InferShapesWithResourceVar( |
| Graph* graph, OpKernelContext* ctx, |
| std::map<int, InferredShape>& arg_shapes, |
| GraphShapeInfo* tpu_inferred_info) { |
| auto shape_inference_graph_interim = |
| absl::make_unique<Graph>(graph->flib_def()); |
| CopyGraph(*graph, shape_inference_graph_interim.get()); |
| |
| for (Node* node : shape_inference_graph_interim->nodes()) { |
| if (node->type_string() != "_Arg" || |
| node->attrs().Find("T")->type() != DT_RESOURCE) |
| continue; |
| |
| std::vector<std::function<void()>> to_remove; |
| |
| for (const Edge* out_edge : node->out_edges()) { |
| Node* read_node = out_edge->dst(); |
| if (read_node->type_string() != "ReadVariableOp") continue; |
| |
| for (const Edge* variable_edge : read_node->out_edges()) { |
| // We are delaying these modifications as we cannot do in-place |
| // modification of EdgeSets. |
| to_remove.push_back( |
| [variable_edge, graph = shape_inference_graph_interim.get(), node] { |
| Node* dst = variable_edge->dst(); |
| graph->RemoveEdge(variable_edge); |
| graph->AddEdge(node, variable_edge->src_output(), dst, |
| variable_edge->dst_input()); |
| }); |
| } |
| to_remove.push_back( |
| [graph = shape_inference_graph_interim.get(), out_edge, read_node] { |
| graph->RemoveEdge(out_edge); |
| graph->RemoveNode(read_node); |
| }); |
| } |
| |
| for (auto& func : to_remove) { |
| func(); |
| } |
| |
| int resource_arg_index = node->attrs().Find("index")->i(); |
| |
| // Get resource variable tensor |
| core::RefCountPtr<Var> variable; |
| const ResourceHandle& handle = HandleFromInput(ctx, resource_arg_index); |
| TF_RETURN_IF_ERROR(LookupResource(ctx, handle, &variable)); |
| |
| const Tensor* variable_tensor = variable->tensor(); |
| std::vector<int> variable_tensor_vec; |
| |
| variable_tensor_vec.reserve(variable_tensor->dims()); |
| for (int d = 0; d < variable_tensor->dims(); ++d) { |
| variable_tensor_vec.push_back(variable_tensor->dim_size(d)); |
| } |
| |
| PartialTensorShape partial_tensor_shape; |
| auto partial_shape = PartialTensorShape::MakePartialShape( |
| variable_tensor_vec.data(), variable_tensor_vec.size(), |
| &partial_tensor_shape); |
| InferredShape inferred_shape = {partial_tensor_shape}; |
| arg_shapes.emplace(resource_arg_index, inferred_shape); |
| } |
| |
| TF_RETURN_IF_ERROR(tensorflow::InferShapes( |
| shape_inference_graph_interim.get(), arg_shapes, |
| &shape_inference_graph_interim->flib_def(), tpu_inferred_info)); |
| return Status::OK(); |
| } |
| |
| Status TPUPartitionedCallOp::ShardInputsWithXlaSharding( |
| Graph* graph, int num_cores_per_replica, OpKernelContext* ctx) { |
| for (Node* replicated_input_node : graph->nodes()) { |
| if (replicated_input_node->type_string() != "TPUReplicatedInput") continue; |
| |
| Node* arg_node; |
| auto input_node_status = replicated_input_node->input_node(0, &arg_node); |
| if (!input_node_status.ok()) { |
| VLOG(2) << "Skip because cannot retrieve input node 0 of " |
| << replicated_input_node->name() << " because " |
| << input_node_status.ToString(); |
| continue; |
| } |
| |
| // Check if this TPUReplicatedInput can qualify because it has _Arg |
| // as input and doesn't have XlaSharding already as an output, then |
| // try to shard inputs automatically. |
| // |
| // In short, we want to see the following graph: |
| // _Arg -> TPUReplicatedInput -> (not XlaSharding op) |
| // and transform it to: |
| // _Arg -> TPUReplicatedInput -> XlaSharding -> (not XlaSharding op) |
| if (arg_node->IsArg() && |
| replicated_input_node->out_nodes().begin()->type_string() != |
| "XlaSharding") { |
| int arg_id; |
| if (!absl::SimpleAtoi(absl::StripPrefix(arg_node->name(), "arg_"), |
| &arg_id)) { |
| VLOG(3) << "Skip auto-sharding because we are unable to extract " |
| "argument number from " |
| << arg_node->name(); |
| continue; |
| } |
| |
| auto shape = ctx->input(arg_id).shape(); |
| |
| VLOG(3) << "Identified arg node " << arg_node->DebugString() |
| << " for TPUReplicatedInput " |
| << replicated_input_node->DebugString(); |
| VLOG(3) << "Shape within TPUReplicatedInput is: " << shape.DebugString(); |
| |
| int rank = shape.dims(); |
| int shard_dim = |
| (runtime_params_.auto_xla_input_sharding_dim + rank) % rank; |
| |
| if (shape.dim_size(shard_dim) % num_cores_per_replica != 0) { |
| VLOG(3) << "Skip auto-sharding " << replicated_input_node->name() |
| << " because the specified sharding dimension " << shard_dim |
| << " cannot be evenly split by " << num_cores_per_replica; |
| continue; |
| } |
| |
| auto sharding = absl::make_optional<xla::OpSharding>(); |
| sharding->set_type(xla::OpSharding::OTHER); |
| |
| // Sets up tile_assignment_dimensions. |
| std::vector<int64> dims(rank, 1LL); |
| dims[shard_dim] = num_cores_per_replica; |
| for (auto dim : dims) { |
| sharding->add_tile_assignment_dimensions(dim); |
| } |
| |
| // Sets up tile_assignment_devices. |
| for (int d = 0; d < num_cores_per_replica; ++d) { |
| sharding->add_tile_assignment_devices(d); |
| } |
| |
| std::vector<const Edge*> edges_to_remove; |
| for (const Edge* edge : replicated_input_node->out_edges()) { |
| if (edge->IsControlEdge()) continue; |
| edges_to_remove.push_back(edge); |
| } |
| |
| // Create XlaSharding Op. |
| Node* sharding_op = nullptr; |
| TF_RETURN_IF_ERROR( |
| NodeBuilder(absl::StrCat(replicated_input_node->name(), "/sharding"), |
| "XlaSharding") |
| .Input(replicated_input_node, 0) |
| .Attr("T", replicated_input_node->output_type(0)) |
| .Attr(kXLAShardingAttrName, sharding->SerializeAsString()) |
| .Attr(kXLAShardingAttrAltName, sharding->SerializeAsString()) |
| .Attr("_tpu_replicate", "cluster") |
| .Finalize(graph, &sharding_op)); |
| for (const Edge* edge : edges_to_remove) { |
| VLOG(3) << "XlaSharding op creation output edge " |
| << edge->DebugString(); |
| graph->RemoveEdge(edge); |
| graph->AddEdge(sharding_op, 0, edge->dst(), edge->dst_input()); |
| } |
| |
| VLOG(3) << "Auto shard " << replicated_input_node->name() << " by dim " |
| << shard_dim << " into " << num_cores_per_replica << " slices"; |
| |
| VLOG(3) << "Created XlaSharding Op " << sharding_op->DebugString(); |
| } |
| } |
| |
| return Status::OK(); |
| } |
| |
| // OptimizeTpuInputOutputTensors does the following things; |
| // (1) Detect input arguments, and add XlaSharding op to the arguments if the |
| // enable_auto_xla_input_sharding is turned on |
| // (2) Pack multiple input tensors into one tensor by a concat to avoid PCIe |
| // transfer overheads for small tensors. |
| // (3) Reshape input tensors to R1 to leverage the fast path in TPU input |
| // preparation done by runtime. |
| // (4) Pack multiple output tensors into one tensor by a concat. |
| // |
| // (1) is controlled by --enable_auto_xla_input_sharding and |
| // --auto_xla_input_sharding_dim |
| // (2) and (3) are controlled by flags --minimum_input_tensors_packing |
| // and --input_shape_opt, respectively, while (4) is controlled by |
| // --minimum_output_tensors_packing. |
| Status TPUPartitionedCallOp::OptimizeTpuInputOutputTensors( |
| Graph* graph, bool enable_spmd_xla_partitioning, int num_cores_per_replica, |
| std::map<std::string, std::vector<int>>& named_input_shapes, |
| OpKernelContext* ctx) { |
| if (runtime_params_.enable_auto_xla_input_sharding) { |
| VLOG(2) << DumpGraphToFile("before_enable_auto_xla_input_sharding", *graph, |
| flib_def_.get()); |
| |
| TF_RETURN_IF_ERROR( |
| ShardInputsWithXlaSharding(graph, num_cores_per_replica, ctx)); |
| } |
| |
| GraphShapeInfo tpu_inferred_info; |
| std::map<int, InferredShape> arg_shapes; |
| EdgeShapes tpu_input_shapes; |
| absl::flat_hash_map<const Edge*, DataType> tpu_input_dtypes; |
| |
| // Contains attrs "T", "sharding", "_tpu_replicate" for each XlaSharding op. |
| XlaShardingInfoMap xla_sharding_ops; |
| |
| // Contains attrs "T", and a pointer to tpu_replicated_metadata for ctrl dep |
| TpuReplicatedInputInfoMap tpu_replicated_input_ops; |
| |
| bool xla_spmd_input_sharded = false; |
| |
| if (enable_spmd_xla_partitioning) { |
| xla_spmd_input_sharded = FindTpuReplicatedInputAndXlaSharding( |
| graph, xla_sharding_ops, tpu_replicated_input_ops); |
| } |
| |
| VLOG(1) << "xla_spmd_input_sharded: " << xla_spmd_input_sharded; |
| VLOG(2) << DumpGraphToFile("before_remove_descendant_nodes", *graph, |
| flib_def_.get()); |
| |
| if (!xla_spmd_input_sharded || |
| runtime_params_.minimum_input_tensors_packing > 1 || |
| runtime_params_.enable_auto_xla_input_sharding) { |
| // Currently we remove `TPUReplicatedInput` nodes when the input tensors are |
| // not sharded, input tensors packing optimization is enabled or when |
| // auto xla input sharding is there. |
| // |
| // In all thse cases, we want to remove both the TPUReplicatedInput and |
| // XlaSharding ops or else downstream rewrites will be confused. |
| RemoveDescendantNodeOfArg(graph, "TPUReplicatedInput", |
| /*must_be_child_of=*/{}); |
| } |
| |
| if (xla_spmd_input_sharded) { |
| // We are setting must_be_child_of to {"Arg"} because we do not want |
| // to remove other XlaSharding ops that might be in the graph. We only |
| // want the XlaSharding ops that are directly attached to the input |
| // arguments to be removed. |
| RemoveDescendantNodeOfArg(graph, "XlaSharding", |
| /*must_be_child_of=*/{"_Arg"}); |
| } |
| |
| VLOG(2) << DumpGraphToFile("before_get_input_output_info", *graph, |
| flib_def_.get()); |
| |
| TF_RETURN_IF_ERROR(GetInputOutputInfo(graph, tpu_inferred_info, arg_shapes, |
| tpu_input_shapes, tpu_input_dtypes, |
| ctx)); |
| |
| VLOG(2) << DumpGraphToFile("before_optimize_tpu_input_output_tensors", *graph, |
| flib_def_.get()); |
| |
| string cluster_name; |
| TF_RETURN_IF_ERROR(GetClusterName(graph, &cluster_name)); |
| |
| if (runtime_params_.minimum_output_tensors_packing > 1) { |
| // Copy graph to shape_inference_graph |
| EdgeShapes tpu_output_shapes; |
| TF_RETURN_IF_ERROR( |
| InferShapesWithResourceVar(graph, ctx, arg_shapes, &tpu_inferred_info)); |
| |
| // Find TPU -> CPU output edges. |
| GroupedEdges shape_to_output = |
| tpu_functional_internal::GroupTensorsForOutputPacking( |
| graph, tpu_output_shapes, &tpu_inferred_info); |
| |
| TF_RETURN_IF_ERROR( |
| tpu_functional_internal::CreateConcatAndSplitNodesForOutputTensor( |
| graph, cluster_name, &tpu_output_shapes, &tpu_inferred_info, |
| shape_to_output, runtime_params_.minimum_output_tensors_packing)); |
| } |
| |
| if (runtime_params_.minimum_input_tensors_packing > 1) { |
| GroupedEdges grouped_input_edges = |
| tpu_functional_internal::GroupTensorsForInputPacking( |
| tpu_input_shapes, tpu_input_dtypes, runtime_params_.input_shape_opt, |
| runtime_params_.group_tensors_for_packing); |
| TF_RETURN_IF_ERROR( |
| tpu_functional_internal::CreateConcatAndSplitNodesForInputTensor( |
| graph, cluster_name, &tpu_input_shapes, grouped_input_edges, |
| runtime_params_.minimum_input_tensors_packing, |
| xla_spmd_input_sharded, xla_sharding_ops, |
| tpu_replicated_input_ops)); |
| } |
| if (runtime_params_.input_shape_opt) { |
| TF_RETURN_IF_ERROR(tpu_functional_internal::InsertReshapeNodePairs( |
| graph, cluster_name, &tpu_input_shapes, num_cores_per_replica)); |
| } |
| VLOG(1) << DumpGraphToFile("optim_result", *graph); |
| |
| // With or without optimizations, collect the input names and shapes. |
| for (const auto& iter : tpu_input_shapes) { |
| std::string name = iter.first->src()->name(); |
| named_input_shapes[name] = iter.second; |
| } |
| return Status::OK(); |
| } |
| |
| Status TPUPartitionedCallOp::GetGraphFromFunction( |
| Graph* graph, int device_ordinal, int* num_core_per_replica, |
| bool* use_spmd_for_xla_partitioning) { |
| FunctionLibraryRuntime::InstantiateOptions opts; |
| FHandle handle; |
| TF_RETURN_IF_ERROR(library_runtime_->Instantiate( |
| func_.name(), AttrSlice(&func_.attr()), opts, &handle)); |
| const FunctionBody* fbody = library_runtime_->GetFunctionBody(handle); |
| if (fbody == nullptr) { |
| return errors::Internal("Could not find handle ", handle); |
| } |
| CopyGraph(*fbody->graph, graph); |
| |
| // Pin the inputs and outputs to the local device to simplify the |
| // function-dispatching logic. |
| local_device_name_ = library_runtime_->device()->name(); |
| replaced_input_indices_.resize(fbody->arg_nodes.size(), false); |
| for (Node* node : graph->op_nodes()) { |
| if (node->IsArg() || node->IsRetval()) { |
| node->set_assigned_device_name(local_device_name_); |
| } else if (node->type_string() == "TPUReplicateMetadata") { |
| TF_RETURN_IF_ERROR(GetNodeAttr(node->attrs(), "num_cores_per_replica", |
| num_core_per_replica)); |
| TF_RETURN_IF_ERROR(GetNodeAttr(node->attrs(), |
| "use_spmd_for_xla_partitioning", |
| use_spmd_for_xla_partitioning)); |
| VLOG(1) << "num_core_per_replica = " << *num_core_per_replica |
| << ", use_spmd_for_xla_partitioning = " |
| << *use_spmd_for_xla_partitioning; |
| |
| if (*num_core_per_replica > 1) { |
| std::string topology_str; |
| std::vector<int> device_assignment; |
| TF_RETURN_IF_ERROR( |
| GetNodeAttr(node->attrs(), "topology", &topology_str)); |
| TF_RETURN_IF_ERROR(GetNodeAttr(node->attrs(), "device_assignment", |
| &device_assignment)); |
| |
| tpu::TopologyProto topology; |
| topology.ParseFromString(topology_str); |
| int num_cores = topology.device_coordinates_size() / 4; |
| |
| if (device_assignment.empty()) { |
| // Number of devices match the cores per replica, so we can just use |
| // the device assignment from the existing topology instead of |
| // generating our own. |
| // |
| // TODO(b/179292031): Add support for non-natural orders for pods. |
| |
| // check that the device coordinates for a donut is always in |
| // natural order. |
| std::vector<int> natural_order; |
| switch (num_cores) { |
| case 2: |
| TF_RETURN_IF_ERROR(GenerateDeviceNaturalOrder( |
| /*x_num_cores=*/1, /*y_num_cores=*/1, /*z_num_cores=*/1, |
| /*num_cores_per_chip=*/2, &natural_order)); |
| break; |
| case 4: // we assume this is a puffylite donut (2x2 w/ 1 core/chip) |
| TF_RETURN_IF_ERROR(GenerateDeviceNaturalOrder( |
| /*x_num_cores=*/2, /*y_num_cores=*/2, /*z_num_cores=*/1, |
| /*num_cores_per_chip=*/1, &natural_order)); |
| break; |
| case 8: |
| TF_RETURN_IF_ERROR(GenerateDeviceNaturalOrder( |
| /*x_num_cores=*/2, /*y_num_cores=*/2, /*z_num_cores=*/1, |
| /*num_cores_per_chip=*/2, &natural_order)); |
| break; |
| default: |
| return errors::Unimplemented( |
| "You must specify a device assignment for all TPU " |
| "configurations other than JF/DF/PL 1x1 or 2x2."); |
| } |
| if (*num_core_per_replica != num_cores && |
| !std::equal(natural_order.begin(), natural_order.end(), |
| topology.device_coordinates().begin())) { |
| return errors::InvalidArgument( |
| "Topology device coordinates for XLA SPMD on donuts must be in " |
| "natural order."); |
| } |
| |
| auto coordinates_start = |
| topology.device_coordinates().begin() + device_ordinal * 4; |
| auto coordinates_end = topology.device_coordinates().begin() + |
| (device_ordinal + *num_core_per_replica) * 4; |
| |
| node->ClearAttr("device_assignment"); |
| device_assignment.insert(device_assignment.begin(), coordinates_start, |
| coordinates_end); |
| node->AddAttr("device_assignment", device_assignment); |
| } |
| } |
| } |
| } |
| return Status::OK(); |
| } |
| |
| Status TPUPartitionedCallOp::PlacementHelper( |
| const DeviceSet& device_set, |
| const GraphOptimizationPassOptions& optimization_options, |
| const string& function_name) { |
| TF_RETURN_IF_ERROR(OptimizationPassRegistry::Global()->RunGrouping( |
| OptimizationPassRegistry::PRE_PLACEMENT, optimization_options)); |
| Placer placer(optimization_options.graph->get(), function_name, |
| optimization_options.flib_def, &device_set); |
| TF_RETURN_IF_ERROR(placer.Run()); |
| TF_RETURN_IF_ERROR(OptimizationPassRegistry::Global()->RunGrouping( |
| OptimizationPassRegistry::POST_PLACEMENT, optimization_options)); |
| TF_RETURN_IF_ERROR(OptimizationPassRegistry::Global()->RunGrouping( |
| OptimizationPassRegistry::POST_REWRITE_FOR_EXEC, optimization_options)); |
| return Status::OK(); |
| } |
| |
| Status TPUPartitionedCallOp::PartitionHelper( |
| const DeviceSet& device_set, |
| const GraphOptimizationPassOptions& optimization_options, Graph* graph, |
| std::unordered_map<std::string, std::unique_ptr<Graph>>* subgraphs) { |
| PartitionOptions partition_options; |
| partition_options.node_to_loc = [](const Node* node) { |
| // TODO(akshayka): To better support the distributed case, first split |
| // the graph by worker (e.g,. using the master session's |
| // `SplitByWorker` policy), and then recursively partition the |
| // per-worker shards at the remote worker(s). |
| return node->assigned_device_name(); |
| }; |
| int64 edge_name_counter = 0; |
| partition_options.new_name = [&edge_name_counter](const string& prefix) { |
| return strings::StrCat(prefix, "/_", ++edge_name_counter); |
| }; |
| partition_options.get_incarnation = [&device_set](const string& name) { |
| const Device* d = device_set.FindDeviceByName(name); |
| if (d == nullptr) { |
| return PartitionOptions::kIllegalIncarnation; |
| } else { |
| return d->attributes().incarnation(); |
| } |
| }; |
| partition_options.control_flow_added = false; |
| std::unordered_map<std::string, GraphDef> partitions; |
| TF_RETURN_IF_ERROR(Partition(partition_options, graph, &partitions)); |
| |
| VLOG(3) << "Partitioned function '" << func_.name() << "', yielding " |
| << partitions.size() << " shards."; |
| |
| const FunctionLibraryDefinition* flib_def = &graph->flib_def(); |
| for (auto& partition : partitions) { |
| std::unique_ptr<Graph> subgraph(new Graph(flib_def)); |
| GraphConstructorOptions opts; |
| opts.allow_internal_ops = true; |
| opts.expect_device_spec = true; |
| const string& device = partition.first; |
| GraphDef& graph_def = partition.second; |
| TF_RETURN_IF_ERROR( |
| ConvertGraphDefToGraph(opts, std::move(graph_def), subgraph.get())); |
| subgraphs->emplace(device, std::move(subgraph)); |
| } |
| |
| TF_RETURN_IF_ERROR(OptimizationPassRegistry::Global()->RunGrouping( |
| OptimizationPassRegistry::POST_PARTITIONING, optimization_options)); |
| |
| return Status::OK(); |
| } |
| |
| Status TPUPartitionedCallOp::InstantiatePartition( |
| const Graph& graph, const string& function_name, |
| const string& target_device, FHandle* handle, |
| std::unique_ptr<FunctionLibraryDefinition>* out_flib_def) { |
| FunctionDef shard; |
| TF_RETURN_IF_ERROR(GraphToFunctionDef(graph, function_name, &shard)); |
| TF_RETURN_IF_ERROR(flib_def_->AddFunctionDef(shard)); |
| FunctionLibraryRuntime::InstantiateOptions opts; |
| opts.target = target_device; |
| if (out_flib_def) { |
| *out_flib_def = std::make_unique<FunctionLibraryDefinition>(*flib_def_); |
| opts.lib_def = out_flib_def->get(); |
| } else { |
| opts.lib_def = flib_def_.get(); |
| } |
| return library_runtime_->Instantiate(function_name, AttrSlice(&shard.attr()), |
| opts, handle); |
| } |
| |
| Status TPUPartitionedCallOp::SetDeviceOrdinal(const DeviceSet& device_set, |
| int device_ordinal, Graph* graph, |
| bool* modified) { |
| int ordinal = -1; |
| for (Node* node : graph->op_nodes()) { |
| if (node->type_string() == kVarHandleOp) { |
| if (IsInputToTPUReplicate(node)) { |
| // If this VarHandleOp is going to a TPU computation, |
| // it refers to the TPU variable that we created when replacing the |
| // resource arguments with VarHandleOps. |
| node->set_assigned_device_name( |
| strings::StrCat(kTPUDeviceNamePrefix, device_ordinal)); |
| } |
| continue; |
| } |
| if (HasNodeAttr(node->def(), kXlaHasHostTransferAttrName)) { |
| // Outside compilation related node. |
| TF_RETURN_IF_ERROR( |
| SetDeviceOrdinalAttributeForNode(node, device_ordinal)); |
| *modified = true; |
| continue; |
| } |
| const AttrValue* attr = node->attrs().Find(kDeviceOrdinalAttr); |
| if (attr != nullptr) { |
| if (!IsSupportedTPUOp(node->type_string())) { |
| return errors::InvalidArgument("Node ", node->type_string(), |
| " is not yet supported."); |
| } |
| if (ordinal == -1) { |
| ordinal = attr->i(); |
| } else { |
| if (ordinal != attr->i()) { |
| return errors::InvalidArgument( |
| "Can only partition graphs that use a single device ordinal."); |
| } |
| } |
| node->ClearAttr(kDeviceOrdinalAttr); |
| node->AddAttr(kDeviceOrdinalAttr, device_ordinal); |
| VLOG(3) << "Set device ordinal of " << node->type_string() << " to " |
| << device_ordinal; |
| *modified = true; |
| } |
| if (node->IsSend() || node->IsRecv()) { |
| static const char* kSendDevice = "send_device"; |
| static const char* kSendDeviceIncarnation = "send_device_incarnation"; |
| static const char* kRecvDevice = "recv_device"; |
| const AttrValue* attr = node->attrs().Find(kSendDevice); |
| if (attr != nullptr) { |
| string device = attr->s(); |
| TF_RETURN_IF_ERROR( |
| UpdateTPUDeviceOrdinal(device_ordinal, &device, modified)); |
| node->ClearAttr(kSendDevice); |
| node->AddAttr(kSendDevice, device); |
| node->ClearAttr(kSendDeviceIncarnation); |
| const Device* d = device_set.FindDeviceByName(device); |
| int64 send_incarnation = (d == nullptr) |
| ? PartitionOptions::kIllegalIncarnation |
| : d->attributes().incarnation(); |
| node->AddAttr(kSendDeviceIncarnation, send_incarnation); |
| } |
| attr = node->attrs().Find(kRecvDevice); |
| if (attr != nullptr) { |
| string device = attr->s(); |
| TF_RETURN_IF_ERROR( |
| UpdateTPUDeviceOrdinal(device_ordinal, &device, modified)); |
| node->ClearAttr(kRecvDevice); |
| node->AddAttr(kRecvDevice, device); |
| } |
| } |
| } |
| return Status::OK(); |
| } |
| |
| Status TPUPartitionedCallOp::InstantiateFunctionsFromSubgraphs( |
| const DeviceSet& device_set, int replica_id, uint64 cache_hash, |
| int num_cores_per_replica, |
| std::unordered_map<std::string, std::unique_ptr<Graph>> subgraphs) { |
| const Device* reference_device = nullptr; |
| auto entry = |
| partition_cache_.emplace(cache_hash, std::vector<DeviceAndFHandle>()); |
| |
| bool rewritten = false; |
| for (auto& pair : subgraphs) { |
| string target = pair.first; |
| int device_ordinal = replica_id; |
| if (num_cores_per_replica > 1) { |
| DeviceNameUtils::ParsedName parsed_device; |
| if (!DeviceNameUtils::ParseFullName(target, &parsed_device)) { |
| return errors::InvalidArgument("Malformed assigned device '", target, |
| "'"); |
| } |
| device_ordinal = parsed_device.id; |
| } |
| Device* device; |
| TF_RETURN_IF_ERROR( |
| library_runtime_->device_mgr()->LookupDevice(target, &device)); |
| if (reference_device == nullptr) { |
| reference_device = device; |
| } else { |
| if (!DeviceNameUtils::IsSameAddressSpace( |
| device->parsed_name(), reference_device->parsed_name())) { |
| return errors::InvalidArgument( |
| "TPUPartitionedCallOp does not yet support inter-process" |
| "execution."); |
| } |
| } |
| TF_RETURN_IF_ERROR(device->MaybeRewriteGraph(&pair.second)); |
| Graph* subgraph = pair.second.get(); |
| // For model paralleism inference, we only support num_replica == 1, thus |
| // there is no need to update the device_ordinal anymore. |
| if (num_cores_per_replica == 1) { |
| TF_RETURN_IF_ERROR( |
| SetDeviceOrdinal(device_set, device_ordinal, subgraph, &rewritten)); |
| } else { |
| VLOG(1) << "Skip SetDeviceOrdinal()"; |
| } |
| string function_name = flib_def_->UniqueFunctionName( |
| strings::StrCat(func_.name(), "_hash_", cache_hash)); |
| TF_RETURN_IF_ERROR( |
| UpdateTPUDeviceOrdinal(device_ordinal, &target, &rewritten)); |
| FHandle handle; |
| // Use a copy of the current `flib_def_` to instantiate the function to |
| // avoid races. |
| std::unique_ptr<FunctionLibraryDefinition> sub_flib_def; |
| TF_RETURN_IF_ERROR(InstantiatePartition(*subgraph, function_name, target, |
| &handle, &sub_flib_def)); |
| // Add handle to the cache entry. |
| entry.first->second.push_back( |
| DeviceAndFHandle{.device = target, |
| .handle = handle, |
| .flib_def = std::move(sub_flib_def)}); |
| } |
| |
| if (!rewritten) { |
| // For regular use cases, TPUPartitionedCallOp only works when the |
| // function being called in rewritten for TPU. If we don't see any signs |
| // of this rewriting, warn the user about it. |
| // We don't raise an error because we want to support the use case of |
| // running tpu.initialize_system eagerly. In this case, we can't use |
| // tpu.rewrite because it will add compilation ops that require TPU |
| // to be initialized, i.e. there is a chicken and egg problem. |
| // We run tpu.initialize_system through TPUPartitionedCallOp because it |
| // invokes graph rewrite passes that are necessary for initialization to |
| // work. |
| LOG(INFO) << "Function body was not rewritten for TPU. " |
| << "This is probably a bug unless you are initializing " |
| << "TPUs eagerly."; |
| } |
| return Status::OK(); |
| } |
| |
| void TPUPartitionedCallOp::ExecuteRemoteFunction( |
| const FunctionLibraryRuntime::Options& opts, FHandle handle, |
| OpKernelContext* ctx, ReffedStatusCallback* done) { |
| std::vector<Tensor> dummy_args; |
| std::vector<Tensor>* dummy_rets = new std::vector<Tensor>; |
| |
| profiler::TraceMe trace_me("TPUPartitionedCallOp-ExecuteRemote"); |
| absl::ReaderMutexLock l(&mu_); |
| library_runtime_->Run(opts, handle, dummy_args, dummy_rets, |
| [dummy_rets, done, ctx](const Status& status) { |
| if (!status.ok()) { |
| ctx->SetStatus(status); |
| } |
| delete dummy_rets; |
| done->Unref(); |
| }); |
| } |
| |
| void TPUPartitionedCallOp::ExecuteLocalFunction( |
| const FunctionLibraryRuntime::Options& opts, const OpInputList& arguments, |
| FHandle handle, OpKernelContext* ctx, ReffedStatusCallback* done) { |
| std::vector<Tensor> args; |
| |
| for (int i = 0; i < arguments.size(); ++i) { |
| if (!replaced_input_indices_[i]) { |
| // _Arg nodes of type DT_RESOURCE that go into a TPU node have been |
| // replaced by TPU VarHandleOp nodes. No longer need to pass them as |
| // inputs. |
| args.push_back(arguments[i]); |
| } |
| } |
| auto* rets = new std::vector<Tensor>; |
| |
| profiler::TraceMe trace_me("TPUPartitionedCallOp-ExecuteLocal"); |
| absl::ReaderMutexLock l(&mu_); |
| library_runtime_->Run(opts, handle, args, rets, |
| [rets, done, ctx](const Status& status) { |
| if (!status.ok()) { |
| ctx->SetStatus(status); |
| } else { |
| for (int i = 0; i < rets->size(); ++i) { |
| ctx->set_output(i, (*rets)[i]); |
| } |
| } |
| delete rets; |
| done->Unref(); |
| }); |
| } |
| |
| void TPUPartitionedCallOp::ExecuteFunctions( |
| const std::vector<DeviceAndFHandle>& functions, OpKernelContext* ctx, |
| int device_ordinal, int64_t ordinal_selector_req_id, DoneCallback done) { |
| FunctionLibraryRuntime::Options opts; |
| opts.step_container = ctx->step_container(); |
| opts.cancellation_manager = ctx->cancellation_manager(); |
| opts.stats_collector = ctx->stats_collector(); |
| // TODO(akshayka): Consider selecting a runner on a per-device basis, |
| // i.e., using device-specific threadpools when available. |
| opts.runner = ctx->runner(); |
| opts.source_device = local_device_name_; |
| opts.run_all_kernels_inline = ctx->run_all_kernels_inline(); |
| |
| OpInputList arguments; |
| OP_REQUIRES_OK_ASYNC(ctx, ctx->input_list("args", &arguments), done); |
| |
| auto* rendez = new PrivateIntraProcessRendezvous(device_mgr_); |
| opts.rendezvous = rendez; |
| |
| StatusCallback callback( |
| [rendez = rendez, done = std::move(done), device_ordinal = device_ordinal, |
| req_id = ordinal_selector_req_id, |
| ordinal_selector = ordinal_selector_](const Status& status) { |
| delete rendez; |
| done(); |
| if (req_id >= 0) { |
| ordinal_selector->DequeueFromCoreSelector(device_ordinal, req_id); |
| } |
| }); |
| |
| auto* refcounted_done = new ReffedStatusCallback(std::move(callback)); |
| for (int i = 1; i < functions.size(); ++i) { |
| refcounted_done->Ref(); |
| } |
| for (const DeviceAndFHandle& entry : functions) { |
| const string& target_device = entry.device; |
| FHandle handle = entry.handle; |
| VLOG(3) << "Running function shard on device " << target_device |
| << " with local device name " << local_device_name_; |
| if (target_device == local_device_name_) { |
| opts.remote_execution = false; |
| ExecuteLocalFunction(opts, arguments, handle, ctx, refcounted_done); |
| } else { |
| opts.remote_execution = true; |
| ExecuteRemoteFunction(opts, handle, ctx, refcounted_done); |
| } |
| } |
| } |
| |
| REGISTER_KERNEL_BUILDER(Name("TPUPartitionedCall").Device(DEVICE_CPU), |
| TPUPartitionedCallOp); |
| |
| } // end namespace tensorflow |