| /* Copyright 2018 The TensorFlow Authors. All Rights Reserved. |
| |
| Licensed under the Apache License, Version 2.0 (the "License"); |
| you may not use this file except in compliance with the License. |
| You may obtain a copy of the License at |
| |
| http://www.apache.org/licenses/LICENSE-2.0 |
| |
| Unless required by applicable law or agreed to in writing, software |
| distributed under the License is distributed on an "AS IS" BASIS, |
| WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
| See the License for the specific language governing permissions and |
| limitations under the License. |
| ==============================================================================*/ |
| |
| // ALGORITHM OVERVIEW |
| // ================== |
| // |
| // An XLA cluster hoists all resource reads to be beginning of the cluster |
| // execution and all the resource writes to the end. This means it cannot |
| // enforce arbitrary ordering dependencies (via control or data edges) between |
| // resource operations. Since all resource reads happen before all resource |
| // writes, edges constraining resource reads to happen before resource writes |
| // are fine, but all other kinds of edges are problematic. This analysis |
| // computes the set of pairs of resource operations that cannot be put in the |
| // same cluster because XLA cannot respect the dependencies between them in the |
| // TensorFlow program. |
| // |
| // TODO(b/112856632): We can, in theory, support Read->Read and Write->Write |
| // dependencies. |
| // |
| // Specifically the result computed by this analysis contains the edge {W, R} |
| // iff all of these hold true: |
| // |
| // - In the graph (g - {edges from NextIteration to Merge}) there is a path |
| // from W to R. |
| // - IsEdgeSafe(W, R) == False [defined below] |
| // - W != R (note: some resource operations both read from and write to |
| // resource variables). |
| // |
| // The result is incorrect around loops because we ignore edges from |
| // NextIteration to Merge. For instance, in: |
| // |
| // Init -----> Merge <-------+ |
| // | | |
| // v | |
| // Read | |
| // | | |
| // v | |
| // Write | |
| // | | |
| // v | |
| // NextIteration --+ |
| // |
| // we won't put (Read, Write) in the returned set. This is fine if |
| // auto-clustering can only cluster the Read->Write edge, but it is a problem if |
| // it clusters the Write->NextIteration->Merge->Read edges instead. So we rely |
| // on auto-clustering to not cluster NextIteration->Merge edges. The same |
| // problem is present for the functional version of the loop above and we also |
| // rely on auto-clustering not clustering functional while loops containing |
| // resource operations. |
| // |
| // One way to think about this is that we only care about cases where two nodes, |
| // A and B, would normally have been put in the same cluster but cannot legally |
| // be in the same cluster because of resourcevar-dependencies. If A and B would |
| // normally have been put in the same cluster then all paths between A and B |
| // would have to be clusterable (otherwise we'd have introduced a cycle). Ergo |
| // there could not have been a NextIteration->Merge edge between A and B since |
| // we don't cluster these edges. |
| // |
| // IMPLEMENTATION |
| // -------------- |
| // |
| // We traverse the graph minus backedges in reverse post order, mapping each |
| // node to the set of resource operation reaching that node. Since we visit |
| // producers before consumers, we can construct the set of reaching operations |
| // by taking the union of the operations reaching the input nodes. These |
| // "reaching resource operations" can then be used to create the pairs of |
| // incompatible nodes using `IsEdgeSafe`. |
| |
| #include "tensorflow/compiler/jit/resource_operation_safety_analysis.h" |
| |
| #include "absl/container/flat_hash_set.h" |
| #include "absl/memory/memory.h" |
| #include "absl/strings/str_join.h" |
| #include "absl/types/optional.h" |
| #include "tensorflow/compiler/jit/xla_cluster_util.h" |
| #include "tensorflow/compiler/tf2xla/resource_operation_table.h" |
| #include "tensorflow/core/framework/node_def.pb.h" |
| #include "tensorflow/core/graph/algorithm.h" |
| #include "tensorflow/core/graph/tensor_id.h" |
| #include "tensorflow/core/lib/hash/hash.h" |
| #include "tensorflow/core/util/ptr_util.h" |
| |
| namespace tensorflow { |
| namespace { |
| // Maps `n` to the XlaResourceOpKind corresponding to its operation. If `n` is |
| // not a resource operation recognized by XLA then sets `out_resource_op_kind` |
| // to nullopt. |
| Status XlaResourceOpKindForNode( |
| const Node& n, const FunctionLibraryDefinition* flib_def, |
| const std::function<Status(const Node&, bool*)>& resource_ops_to_ignore, |
| absl::optional<XlaResourceOpKind>* out_resource_op_kind) { |
| bool should_ignore = false; |
| if (resource_ops_to_ignore) { |
| TF_RETURN_IF_ERROR(resource_ops_to_ignore(n, &should_ignore)); |
| } |
| if (should_ignore) { |
| *out_resource_op_kind = absl::nullopt; |
| return Status::OK(); |
| } |
| |
| const XlaResourceOpInfo* op_info = GetResourceOpInfoForOp(n.type_string()); |
| if (op_info) { |
| *out_resource_op_kind = op_info->kind(); |
| return Status::OK(); |
| } |
| |
| // We conservatively assume that functions will both read and write resource |
| // variables. In the future we may consider doing some form of |
| // inter-procedural analysis. |
| if (MayCallFunction(n, flib_def)) { |
| *out_resource_op_kind = XlaResourceOpKind::kReadWrite; |
| } else { |
| *out_resource_op_kind = absl::nullopt; |
| } |
| |
| return Status::OK(); |
| } |
| |
| // Returns true if a control or data dependence from a TensorFlow operation of |
| // resource op kind `from` to a TensorFlow operation of resource op kind `to` |
| // can be represented by an XLA cluster and needs no special handling around |
| // auto-jit. |
| bool IsEdgeSafe(XlaResourceOpKind from, XlaResourceOpKind to) { |
| // XLA clusters force all reads to happen before all writes. Moreover the set |
| // of reads are executed as one atomic operation, and the set of writes are as |
| // another atomic operation. This means we can faithfully represent the |
| // following edges: Read->*, *->Write. |
| |
| return from == XlaResourceOpKind::kRead || to == XlaResourceOpKind::kWrite; |
| } |
| |
| using ResourceOp = std::pair<int, XlaResourceOpKind>; |
| |
| string ResourceOpToString(const ResourceOp& resource_op) { |
| return absl::StrCat( |
| resource_op.first, ": ", |
| XlaResourceOpInfo::XlaResourceOpKindToString(resource_op.second)); |
| } |
| |
| // A copy-on-write set used to store the set of ResourceOps reaching a node in a |
| // TensorFlow graph. |
| // |
| // TODO(sanjoy): It may be useful to pull this out into its own header at some |
| // point. |
| class ResourceOpSet { |
| private: |
| using Impl = absl::flat_hash_set<ResourceOp>; |
| |
| public: |
| ResourceOpSet() = default; |
| |
| // Adds all ResourceOp s in `other` to this set. |
| void Add(const ResourceOpSet& other) { |
| CHECK(!frozen_); |
| if (other.impl_ == impl_) { |
| other.frozen_ = true; |
| return; |
| } |
| |
| if (!impl_) { |
| other.frozen_ = true; |
| impl_ = other.impl_; |
| return; |
| } |
| |
| for (ResourceOp resource_op : other) { |
| Add(resource_op); |
| } |
| } |
| |
| void Add(const ResourceOp& resource_op) { |
| CHECK(!frozen_); |
| if (!IsCopy() && Contains(resource_op)) { |
| // We can avoid the copy if the item we want to insert already exists. |
| return; |
| } |
| |
| EnsureIsCopied(); |
| impl_->insert(resource_op); |
| } |
| |
| Impl::const_iterator begin() const { |
| return impl_ ? impl_->begin() : GetEmptyImpl()->begin(); |
| } |
| |
| Impl::const_iterator end() const { |
| return impl_ ? impl_->end() : GetEmptyImpl()->end(); |
| } |
| |
| bool Contains(const ResourceOp& resource_op) const { |
| return impl_ != nullptr && impl_->count(resource_op); |
| } |
| |
| private: |
| bool IsCopy() const { return storage_ != nullptr; } |
| |
| void EnsureIsCopied() { |
| if (storage_ == nullptr) { |
| storage_ = absl::make_unique<Impl>(); |
| for (ResourceOp op : *this) { |
| storage_->insert(op); |
| } |
| impl_ = storage_.get(); |
| } |
| } |
| |
| static Impl* GetEmptyImpl() { |
| static Impl* empty_impl = new Impl; |
| return empty_impl; |
| } |
| |
| Impl* impl_ = nullptr; |
| std::unique_ptr<Impl> storage_; |
| |
| // frozen_ is true if there is another set pointing to this set's impl_. We |
| // can no longer add elements to this set in that case since the sets pointing |
| // to this set expect the contents of this set to be stable. |
| mutable bool frozen_ = false; |
| |
| TF_DISALLOW_COPY_AND_ASSIGN(ResourceOpSet); |
| }; |
| |
| string ResourceOpSetToString(const ResourceOpSet& resource_op_set) { |
| std::vector<string> elements_debug_string; |
| std::transform(resource_op_set.begin(), resource_op_set.end(), |
| std::back_inserter(elements_debug_string), ResourceOpToString); |
| return absl::StrCat("{", absl::StrJoin(elements_debug_string, ","), "}"); |
| } |
| |
| string NodeToString(const Node& n, XlaResourceOpKind resource_op_kind) { |
| return absl::StrCat( |
| "[", n.name(), ": ", n.type_string(), "(", |
| XlaResourceOpInfo::XlaResourceOpKindToString(resource_op_kind), ")", "]"); |
| } |
| } // namespace |
| |
| Status ComputeIncompatibleResourceOperationPairs( |
| const Graph& g, const FunctionLibraryDefinition* flib_def, |
| const std::function<Status(const Node&, bool*)>& resource_ops_to_ignore, |
| std::vector<std::pair<int, int>>* result) { |
| CHECK(result->empty()); |
| |
| std::vector<Node*> rpo; |
| GetReversePostOrder(g, &rpo, /*stable_comparator=*/NodeComparatorName(), |
| /*edge_filter=*/[](const Edge& edge) { |
| return !edge.src()->IsNextIteration(); |
| }); |
| |
| auto resource_op_set_for_node = |
| absl::make_unique<ResourceOpSet[]>(g.num_node_ids()); |
| |
| const bool vlog = VLOG_IS_ON(2); |
| |
| for (Node* n : rpo) { |
| absl::optional<XlaResourceOpKind> op_kind; |
| TF_RETURN_IF_ERROR(XlaResourceOpKindForNode( |
| *n, flib_def, resource_ops_to_ignore, &op_kind)); |
| |
| ResourceOpSet* resource_op_set = &resource_op_set_for_node[n->id()]; |
| |
| // Merge the reaching resource operations for all the incoming edges to |
| // create the set of all possible resource ops reaching `n`. |
| for (const Edge* e : n->in_edges()) { |
| if (n->IsMerge() && e->src()->IsNextIteration()) { |
| // Ignore back-edges (see file comment). |
| continue; |
| } |
| |
| const ResourceOpSet& incoming_op_set = |
| resource_op_set_for_node[e->src()->id()]; |
| resource_op_set->Add(incoming_op_set); |
| } |
| |
| // Add to the "incompatible resource ops" set if necessary. |
| if (op_kind) { |
| for (ResourceOp incoming_op : *resource_op_set) { |
| if (IsEdgeSafe(incoming_op.second, *op_kind)) { |
| continue; |
| } |
| |
| if (vlog) { |
| VLOG(2) << "Unsafe edge: " |
| << NodeToString(*g.FindNodeId(incoming_op.first), |
| incoming_op.second) |
| << " -> " << NodeToString(*n, *op_kind); |
| } |
| result->push_back({incoming_op.first, n->id()}); |
| } |
| |
| resource_op_set->Add({n->id(), *op_kind}); |
| } |
| |
| if (vlog) { |
| VLOG(3) << n->name() << " -> " << ResourceOpSetToString(*resource_op_set); |
| } |
| } |
| |
| std::sort(result->begin(), result->end()); |
| CHECK(std::unique(result->begin(), result->end()) == result->end()); |
| |
| return Status::OK(); |
| } |
| } // namespace tensorflow |