| /* Copyright 2017 The TensorFlow Authors. All Rights Reserved. |
| |
| Licensed under the Apache License, Version 2.0 (the "License"); |
| you may not use this file except in compliance with the License. |
| You may obtain a copy of the License at |
| |
| http://www.apache.org/licenses/LICENSE-2.0 |
| |
| Unless required by applicable law or agreed to in writing, software |
| distributed under the License is distributed on an "AS IS" BASIS, |
| WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
| See the License for the specific language governing permissions and |
| limitations under the License. |
| ==============================================================================*/ |
| |
| #include "tensorflow/compiler/jit/mark_for_compilation_pass.h" |
| |
| #include <algorithm> |
| #include <atomic> |
| #include <deque> |
| #include <limits> |
| #include <unordered_map> |
| #include <unordered_set> |
| |
| #include "absl/base/call_once.h" |
| #include "absl/container/flat_hash_map.h" |
| #include "absl/container/flat_hash_set.h" |
| #include "absl/strings/str_join.h" |
| #include "tensorflow/compiler/jit/compilability_check_util.h" |
| #include "tensorflow/compiler/jit/deadness_analysis.h" |
| #include "tensorflow/compiler/jit/defs.h" |
| #include "tensorflow/compiler/jit/device_util.h" |
| #include "tensorflow/compiler/jit/flags.h" |
| #include "tensorflow/compiler/jit/resource_operation_safety_analysis.h" |
| #include "tensorflow/compiler/jit/xla_cluster_util.h" |
| #include "tensorflow/compiler/tf2xla/const_analysis.h" |
| #include "tensorflow/compiler/tf2xla/resource_operation_table.h" |
| #include "tensorflow/compiler/tf2xla/xla_op_registry.h" |
| #include "tensorflow/compiler/xla/service/graphcycles/graphcycles.h" |
| #include "tensorflow/compiler/xla/statusor.h" |
| #include "tensorflow/compiler/xla/union_find.h" |
| #include "tensorflow/compiler/xla/util.h" |
| #include "tensorflow/core/common_runtime/function.h" |
| #include "tensorflow/core/common_runtime/graph_constructor.h" |
| #include "tensorflow/core/framework/bounds_check.h" |
| #include "tensorflow/core/framework/graph_def_util.h" |
| #include "tensorflow/core/framework/memory_types.h" |
| #include "tensorflow/core/framework/node_def.pb.h" |
| #include "tensorflow/core/framework/op_kernel.h" |
| #include "tensorflow/core/framework/tensor.pb.h" |
| #include "tensorflow/core/framework/types.h" |
| #include "tensorflow/core/graph/algorithm.h" |
| #include "tensorflow/core/graph/control_flow.h" |
| #include "tensorflow/core/lib/gtl/cleanup.h" |
| #include "tensorflow/core/lib/gtl/flatmap.h" |
| #include "tensorflow/core/lib/strings/stringprintf.h" |
| #include "tensorflow/core/platform/errors.h" |
| #include "tensorflow/core/platform/mutex.h" |
| #include "tensorflow/core/platform/statusor.h" |
| #include "tensorflow/core/platform/types.h" |
| #include "tensorflow/core/public/version.h" |
| #include "tensorflow/core/util/dump_graph.h" |
| |
| namespace tensorflow { |
| |
| namespace { |
| using DeadnessPredicate = DeadnessAnalysis::DeadnessPredicate; |
| using jit::DeviceId; |
| using jit::DeviceSet; |
| |
| // The clusters we create here are eventually lowered into an |
| // _XlaCompile/_XlaRun pair with a TF executor "fallback" that uses the |
| // PartitionedCall op to execute the cluster in the regular graph executor if |
| // need be. PartitionedCall, however, reruns the entire TF graph optimization |
| // pipeline over the cluster which includes this mark for compilation pass. To |
| // avoid endlessly recursing we tag nodes that we've already visited with this |
| // attribute so that we can bail out if we see them a second time. |
| // |
| // TODO(sanjoy): This method is not robust since it is possible that the |
| // optimizations run by PartitionedCall can mutate the cluster arbitrarily, |
| // dropping the kXlaAlreadyClustered attributes from all nodes in the process. |
| // The correct fix is to use the ConfigProto to pass in some sort of flag into |
| // the PartitionedCall kernel that tells it to not rerun auto-clustering on the |
| // cluster. |
| const char* kXlaAlreadyClustered = "_XlaAlreadyClustered"; |
| |
| class MarkForCompilationPassImpl { |
| public: |
| struct DebugOptions { |
| // If true, do not respect the results of deadness analysis. |
| bool ignore_deadness_checks; |
| |
| // If true, do not do safety checks to preserve TensorFlow's resource |
| // variable concurrency semantics. |
| bool ignore_resource_variable_checks; |
| |
| // If true, do not respect the _XlaCompile=false attribute. |
| bool ignore_xla_compile_attr; |
| |
| // If true, compute the cluster name in a deterministic way so that its |
| // stable from run to rum. |
| bool deterministic_cluster_names; |
| |
| int max_cluster_size; |
| int min_cluster_size; |
| |
| // Compiler fuel for the auto-clustering algorithm. |
| // |
| // We decrement this value by one on every time we choose a compilation |
| // candidate and we stop clustering when it hits zero. This means the |
| // initial value for this variable (via --tf_xla_clustering_fuel=N) |
| // effectively acts as a "cap" for how much we cluster and we can bisect |
| // over this initial value to discover clustering decisions that cause a |
| // miscompile or a performance regression. |
| std::atomic<int64_t>* fuel; |
| |
| bool dump_graphs; |
| }; |
| |
| MarkForCompilationPassImpl(DebugOptions debug_options, Graph* graph, |
| FunctionLibraryDefinition* flib_def, Env* env, |
| OptimizerOptions::GlobalJitLevel global_jit_level, |
| bool cpu_global_jit) |
| : debug_options_(debug_options), |
| graph_(graph), |
| graph_fingerprint_(0), |
| flib_def_(flib_def), |
| env_(env), |
| global_jit_level_(global_jit_level), |
| cpu_global_jit_(cpu_global_jit) {} |
| |
| Status Run(); |
| |
| private: |
| // Represents a "cluster" or a connected subgraph of a TensorFlow graph. |
| class Cluster { |
| public: |
| // Constructs a trivial cluster representing a single TF node. |
| Cluster(int tf_graph_node_id, int effective_cluster_size, |
| bool has_functional_control_flow, DeviceSet devices, |
| std::optional<DeviceId> resource_op_device, |
| std::optional<int> resource_var_operation_node_id, |
| std::optional<DeadnessPredicate> deadness_predicate, |
| bool is_xla_compile_attr_true, std::optional<string> xla_scope) |
| : cycles_graph_node_id_(tf_graph_node_id), |
| effective_cluster_size_(effective_cluster_size), |
| has_functional_control_flow_(has_functional_control_flow), |
| devices_(std::move(devices)), |
| resource_op_device_(resource_op_device), |
| deadness_predicate_(deadness_predicate), |
| is_xla_compile_attr_true_(is_xla_compile_attr_true), |
| xla_scope_(std::move(xla_scope)) { |
| if (resource_var_operation_node_id.has_value()) { |
| resource_var_operation_node_ids_.push_back( |
| *resource_var_operation_node_id); |
| } |
| } |
| |
| // Merges `other` into this cluster, and clears `other`. This method is |
| // closely tied with the implementation of `MarkForCompilationPassImpl`. |
| void Merge(Cluster* other); |
| |
| // If this is a trivial cluster containing only one node then return the ID |
| // of that node. May not be called otherwise. |
| int GetIdOfOnlyNode() const { |
| DCHECK_EQ(cluster_size(), 1); |
| return cycles_graph_node_id(); |
| } |
| |
| // The number of TF nodes in this cluster. |
| int cluster_size() const { return cluster_size_; } |
| |
| // The ID of the cluster as represented in `cycles_graph_`. |
| int cycles_graph_node_id() const { return cycles_graph_node_id_; } |
| |
| // Sets the ID of the cluster as represented in `cycles_graph_`. |
| void set_cycles_graph_node_id(int cycles_graph_node_id) { |
| cycles_graph_node_id_ = cycles_graph_node_id; |
| } |
| |
| // The size of the cluster excluding constant and identity nodes. |
| int effective_cluster_size() const { return effective_cluster_size_; } |
| |
| // True if the cluster has functional control flow like `If` and `While`. |
| bool has_functional_control_flow() const { |
| return has_functional_control_flow_; |
| } |
| |
| // The set of devices nodes in the cluster are placed on. |
| const DeviceSet& devices() const { return devices_; } |
| |
| // If the cluster has a resource operation then the device the resource |
| // operation is placed on. A cluster may have resource ops placed only on a |
| // single device. |
| const std::optional<DeviceId>& resource_op_device() const { |
| return resource_op_device_; |
| } |
| |
| // If not nullopt the a predicate that is true iff the cluster is alive. |
| // Otherwise the user has (unsafely) disabled deadness analysis. If this is |
| // unset on a single Cluster instance then it is unset on all Cluster |
| // instances. |
| const std::optional<DeadnessPredicate>& deadness_predicate() const { |
| return deadness_predicate_; |
| } |
| |
| // If true then the cluster has a XlaCompile=true attribute on one of its |
| // nodes. |
| bool is_xla_compile_attr_true() const { return is_xla_compile_attr_true_; } |
| |
| // If not nullopt then the all nodes in the cluster either do not have the |
| // XlaScope attribute set or have it set to the value returned. |
| const std::optional<string>& xla_scope() const { return xla_scope_; } |
| |
| // Returns the TF graph node IDs for the resource variable operations in |
| // this cluster. |
| absl::Span<const int> resource_var_operation_node_ids() const { |
| return resource_var_operation_node_ids_; |
| } |
| |
| string DebugString(const Graph& graph) const { |
| Node* node = graph.FindNodeId(cycles_graph_node_id()); |
| if (!node) { |
| // This should never happen but we try to be resilient because this is a |
| // debugging aid. |
| return absl::StrCat("NULL NODE IN #", cycles_graph_node_id()); |
| } |
| |
| if (cluster_size() == 1) { |
| return absl::StrCat("<", node->name(), " #", cycles_graph_node_id(), |
| ">"); |
| } |
| |
| return absl::StrCat("<", node->name(), " + ", cluster_size() - 1, |
| " others #", cycles_graph_node_id(), ">"); |
| } |
| |
| private: |
| int cluster_size_ = 1; |
| int cycles_graph_node_id_; |
| int effective_cluster_size_; |
| bool has_functional_control_flow_; |
| DeviceSet devices_; |
| std::optional<DeviceId> resource_op_device_; |
| std::optional<DeadnessPredicate> deadness_predicate_; |
| bool is_xla_compile_attr_true_; |
| std::optional<string> xla_scope_; |
| std::vector<int> resource_var_operation_node_ids_; |
| |
| TF_DISALLOW_COPY_AND_ASSIGN(Cluster); |
| }; |
| |
| // If `cluster` has only a single node then returns that, otherwise returns |
| // nullptr. |
| Node* GetOnlyNodeIn(const Cluster& cluster); |
| |
| // Returns true if `cluster` is a trivial cluster containing a "sink like" |
| // node -- a NoOp node that only the Sink node control depends on. |
| bool IsSinkLike(const Cluster& cluster); |
| |
| // Returns true if `cluster` looks like an "i++" operation on an integer |
| // scalar resource variable. |
| bool IsScalarIntegerResourceOperation(const Cluster& cluster); |
| |
| // --------------------------------------------------------------------------- |
| // The pass proceeds in five steps, out of which `RunEdgeContractionLoop` and |
| // `CreateClusters` do most of the heavy lifting. |
| |
| // Initializes some internal data structures. |
| // |
| // If this returns false then Initialize exited early (either because there is |
| // nothing to do or we saw a graph that we can't handle) and not all the |
| // fields in this MarkForCompilationPassImpl instance are set up. |
| StatusOr<bool> Initialize(); |
| |
| // Runs through the entire cluster graph in post-order and calls `fn(from, |
| // to)` on each edge. `fn(from, to)` is expected to return true if it was |
| // able to contract `from`->`to`. |
| // |
| // Returns true if `fn` returned true for any edge. |
| template <typename FnTy> |
| StatusOr<bool> ForEachEdgeInPostOrder(FnTy fn); |
| |
| // Contracts as many edges as possible to create XLA clusters. After this |
| // finishes the clustering decisions made are implicitly stored in |
| // `clusters_`. |
| Status RunEdgeContractionLoop(); |
| |
| // "Fixes up" clusters by removing some modes. |
| // |
| // Autoclustering can sometimes be overeager. For example, clustering large |
| // constants (or large broadcasts of constants) can increase the live range |
| // of those constants, and increase overall memory usage. |
| // |
| // This function removes "obviously bad" cases like these. |
| Status DeclusterNodes(); |
| |
| // Manifests the clustering decisions into the TF graph by tagging nodes with |
| // an `_XlaCluster` attribute. Also some basic filter logic, like |
| // tf_xla_min_cluster_size, are applied here. |
| Status CreateClusters(); |
| |
| Status DumpDebugInfo(); |
| |
| bool IsCompilationCandidate(Node* n) const { |
| return compilation_candidates_.find(n) != compilation_candidates_.end(); |
| } |
| |
| // Tries to contract the edge from cluster `from` to cluster `to`. Returns |
| // true if successful. |
| StatusOr<bool> TryToContractEdge(Cluster* from, Cluster* to); |
| |
| // Nodes that XLA can compile are put in `compilation_candidates_`. |
| Status FindCompilationCandidates(); |
| |
| bool CompilationDisallowedByXlaCompileAttr(Node* node); |
| |
| // Populates `clusters_`. |
| Status BuildInitialClusterSet(); |
| |
| StatusOr<bool> ShouldCompileClusterImpl(const Cluster& cluster); |
| |
| StatusOr<bool> ShouldCompileCluster(const Cluster& cluster); |
| |
| StatusOr<bool> ClusteringWillIntroduceInterDeviceDependency( |
| const Cluster& from, const Cluster& to); |
| |
| // Returns true if the devices in `cluster_a` and `cluster_b` are compatible |
| // and therefore not a hindrance for combining the two clusters into a larger |
| // cluster. |
| StatusOr<bool> AreDevicesCompatible(const Cluster& cluster_a, |
| const Cluster& cluster_b); |
| |
| void DumpPostClusteringGraphs(); |
| void VLogClusteringSummary(); |
| |
| Cluster* MakeNewCluster(int cycles_graph_node_id, int effective_cluster_size, |
| bool has_functional_control_flow, |
| const DeviceSet& device_set, |
| std::optional<DeviceId> resource_op_device, |
| std::optional<int> resource_var_operation_node_id, |
| std::optional<DeadnessPredicate> deadness_predicate, |
| bool is_xla_compile_attr_true, |
| std::optional<string> xla_scope) { |
| cluster_storage_.push_back(std::make_unique<Cluster>( |
| cycles_graph_node_id, effective_cluster_size, |
| has_functional_control_flow, device_set, resource_op_device, |
| resource_var_operation_node_id, deadness_predicate, |
| is_xla_compile_attr_true, xla_scope)); |
| return cluster_storage_.back().get(); |
| } |
| |
| std::optional<string> GetXlaScope(Node* n); |
| |
| // Returns the cluster for node `n`. If two nodes, N1 and N2, are placed in |
| // the same cluster by the clustering algorithm then this function will return |
| // the same Cluster instance for N1 and N2. |
| // |
| // Returns nullptr if `n` is not a compilation candidate. |
| Cluster* GetClusterForNode(Node* n) { |
| return cluster_for_node_[n->id()].Get(); |
| } |
| |
| // Returns the cluster for a node in `cycles_graph_`. This uses the same |
| // underlying map because of how we set things up, but we can do an additional |
| // CHECK in this accessor. |
| // |
| // Returns nullptr if `node_id` is not a compilation candidate. |
| Cluster* GetClusterForCyclesGraphNode(int node_id) { |
| // We have to check `graph_->FindNodeId(node) == nullptr` because we add all |
| // nodes in [0, graph_->num_node_ids()) to the cycle detection graph but the |
| // TF graph may be missing some node ids. |
| if (node_id >= graph_->num_node_ids() || |
| graph_->FindNodeId(node_id) == nullptr) { |
| return nullptr; |
| } |
| Cluster* cluster = cluster_for_node_[node_id].Get(); |
| if (cluster) { |
| DCHECK_EQ(cluster->cycles_graph_node_id(), node_id); |
| } |
| return cluster; |
| } |
| |
| bool LogNotContractableAndReturnFalse(Cluster* from, Cluster* to, |
| absl::string_view reason); |
| |
| // Finds a path in `cycles_graph_` from `from` to `to` that is not a direct |
| // edge from `from` to `to`. |
| // |
| // Tries to find a path that contains at least one unclusterable node. |
| std::vector<int> FindAlternatePathForDebugging(int from, int to); |
| |
| // Returns a string representing `cycles_graph_node_id`. If the node is |
| // unclusterable (either it is a phatom "frame" node or is not a compilation |
| // candidate) then set `*found_unclustered` to true. |
| string DebugStringForCyclesGraphNode(int node_id, bool* found_unclustered); |
| |
| // We could not contract the edge from `from` to `to`. Return a string |
| // describing an alternate path from `from` to `to` (besides the direct edge |
| // from `from` to `to`) which would have created a cycle had we contracted the |
| // edge. |
| // |
| // Tries (if possible) to find a path that contains at least one unclusterable |
| // node as it is surprising to the user if we print "A->B could not be |
| // contracted because of the path [P,Q,R]" where P, Q and R are all clusters |
| // since in that case a natural question is why we could not form a {A, P, Q, |
| // R, B} cluster. |
| string DescribePotentialCycle(int from, int to); |
| |
| // Merge the clusters `cluster_from` and `cluster_to`. After this step the |
| // larger combined cluster is represented by `cluster_from`, but can have |
| // `cycles_graph_`'s ID of either `cluster_from` or `cluster_to` depending on |
| // which way will require less operations. |
| bool MergeClusters(Cluster* cluster_from, Cluster* cluster_to) { |
| int from = cluster_from->cycles_graph_node_id(); |
| int to = cluster_to->cycles_graph_node_id(); |
| |
| auto optional_merged_node = cycles_graph_.ContractEdge(from, to); |
| if (!optional_merged_node.has_value()) { |
| VLOG(3) << "Could not contract " << cluster_from->DebugString(*graph_) |
| << " -> " << cluster_to->DebugString(*graph_) |
| << " because contracting the edge would create a cycle via " |
| << DescribePotentialCycle(from, to) << "."; |
| return false; |
| } |
| |
| // Merge the clusters. |
| cluster_from->Merge(cluster_to); |
| // Update `cycle_graph_`'s ID. |
| cluster_from->set_cycles_graph_node_id(optional_merged_node.value()); |
| |
| // Merge the UnionFind<Cluster*>. |
| cluster_for_node_[from].Merge(&cluster_for_node_[to]); |
| |
| return true; |
| } |
| |
| string EdgeContractionFailureMsg(Cluster* from, Cluster* to, |
| absl::string_view reason) { |
| return absl::StrCat("Could not contract ", from->DebugString(*graph_), |
| " -> ", to->DebugString(*graph_), " because ", reason, |
| "."); |
| } |
| |
| DebugOptions debug_options_; |
| Graph* graph_; |
| uint64 graph_fingerprint_; |
| FunctionLibraryDefinition* flib_def_; |
| Env* env_; |
| OptimizerOptions::GlobalJitLevel global_jit_level_; |
| bool cpu_global_jit_; |
| absl::flat_hash_map<const Cluster*, bool> should_compile_cluster_cache_; |
| jit::DeviceInfoCache device_info_cache_; |
| |
| bool initialized_ = false; |
| bool edges_contracted_ = false; |
| bool clusters_created_ = false; |
| |
| std::vector<std::unique_ptr<Cluster>> cluster_storage_; |
| std::vector<UnionFind<Cluster*>> cluster_for_node_; |
| absl::flat_hash_set<const Node*> declustered_nodes_; |
| GraphCycles cycles_graph_; |
| OrderedNodeSet compilation_candidates_; |
| std::unique_ptr<DeadnessAnalysis> deadness_analysis_; |
| int64_t iteration_count_ = 0; |
| absl::flat_hash_set<std::pair<int, int>> unsafe_resource_deps_; |
| }; |
| |
| std::vector<int> MarkForCompilationPassImpl::FindAlternatePathForDebugging( |
| int from, int to) { |
| std::vector<int> rpo = cycles_graph_.AllNodesInPostOrder(); |
| absl::c_reverse(rpo); |
| |
| // best_pred_for_node[n] contains a predecessor of `n` that has an |
| // unclusterable node in some path from `from` to itself. |
| // best_pred_for_node[n] is unpopulated for nodes that are not reachable from |
| // `from`. We build this table up inductively by traversing the cycles graph |
| // in RPO. |
| absl::flat_hash_map<int, int> best_pred_for_node; |
| best_pred_for_node[from] = -1; |
| |
| int rpo_index = 0, current_rpo_node; |
| do { |
| current_rpo_node = rpo[rpo_index++]; |
| std::optional<int> some_pred, preferred_pred; |
| for (int pred : cycles_graph_.Predecessors(current_rpo_node)) { |
| if (!best_pred_for_node.contains(pred)) { |
| continue; |
| } |
| |
| // Ignore the from->to edge since we're trying to find an alternate path. |
| if (current_rpo_node == to && pred == from) { |
| continue; |
| } |
| |
| some_pred = pred; |
| if (GetClusterForCyclesGraphNode(pred) == nullptr) { |
| preferred_pred = pred; |
| } |
| } |
| |
| if (some_pred || preferred_pred) { |
| best_pred_for_node[current_rpo_node] = |
| preferred_pred.has_value() ? *preferred_pred : *some_pred; |
| } |
| } while (current_rpo_node != to); |
| |
| auto get_best_pred = [&](int n) { |
| auto it = best_pred_for_node.find(n); |
| CHECK(it != best_pred_for_node.end()); |
| return it->second; |
| }; |
| |
| std::vector<int> path; |
| int current_path_node = get_best_pred(to); |
| while (current_path_node != from) { |
| path.push_back(current_path_node); |
| current_path_node = get_best_pred(current_path_node); |
| } |
| |
| absl::c_reverse(path); |
| return path; |
| } |
| |
| string MarkForCompilationPassImpl::DebugStringForCyclesGraphNode( |
| int cycles_graph_node_id, bool* found_unclustered) { |
| Cluster* cluster = GetClusterForCyclesGraphNode(cycles_graph_node_id); |
| if (cluster) { |
| return cluster->DebugString(*graph_); |
| } |
| |
| *found_unclustered = true; |
| if (cycles_graph_node_id >= graph_->num_node_ids()) { |
| return absl::StrCat("<oob #", cycles_graph_node_id, ">"); |
| } |
| |
| Node* node = graph_->FindNodeId(cycles_graph_node_id); |
| if (!node) { |
| return absl::StrCat("<bad #", cycles_graph_node_id, ">"); |
| } |
| |
| return node->name(); |
| } |
| |
| string MarkForCompilationPassImpl::DescribePotentialCycle(int from, int to) { |
| std::vector<string> path_str; |
| bool found_unclustered = false; |
| absl::c_transform(FindAlternatePathForDebugging(from, to), |
| std::back_inserter(path_str), [&](int node_id) { |
| return DebugStringForCyclesGraphNode(node_id, |
| &found_unclustered); |
| }); |
| return absl::StrCat(!found_unclustered ? "(all clusters) " : "", "[", |
| absl::StrJoin(path_str, ","), "]"); |
| } |
| |
| void MarkForCompilationPassImpl::Cluster::Merge(Cluster* other) { |
| // We keep our own cycles_graph_node_id_ to mirror what GraphCycles does. |
| |
| // Clearing out data structures in `other` is just a memory saving |
| // optimization and not needed for correctness. |
| |
| cluster_size_ += other->cluster_size_; |
| effective_cluster_size_ += other->effective_cluster_size_; |
| has_functional_control_flow_ |= other->has_functional_control_flow_; |
| |
| devices_.UnionWith(other->devices_); |
| |
| DCHECK(!(resource_op_device_.has_value() && |
| other->resource_op_device_.has_value()) || |
| *resource_op_device_ == *other->resource_op_device_) |
| << "AreDevicesCompatible should have returned false otherwise!"; |
| |
| if (!resource_op_device_.has_value()) { |
| resource_op_device_ = other->resource_op_device_; |
| } |
| |
| is_xla_compile_attr_true_ |= other->is_xla_compile_attr_true_; |
| |
| if (!xla_scope_.has_value()) { |
| xla_scope_ = std::move(other->xla_scope_); |
| } |
| |
| resource_var_operation_node_ids_.reserve( |
| resource_var_operation_node_ids_.size() + |
| other->resource_var_operation_node_ids_.size()); |
| absl::c_copy(other->resource_var_operation_node_ids_, |
| std::back_inserter(resource_var_operation_node_ids_)); |
| other->resource_var_operation_node_ids_.clear(); |
| } |
| |
| Status IgnoreResourceOpForSafetyAnalysis( |
| jit::DeviceInfoCache* device_info_cache, const Node& n, bool* ignore) { |
| // If a resource operation is assigned to XLA_CPU or XLA_GPU explicitly then |
| // ignore it during resource operation safety analysis. We need this hack |
| // because of two reasons: |
| // |
| // 1. Operations assigned to XLA_CPU and XLA_GPU have to always be compiled. |
| // 2. We don't support live-out values of type DT_RESOURCE and live-in values |
| // of type DT_RESOURCE that are not resource variables. |
| // |
| // Together these imply we cannot let resource variable safety analysis |
| // constrain e.g. a TensorArrayV3->TensorArrayAssignV3 edge to be in different |
| // clusters: both of them will have to be clustered because of (1) and we |
| // won't be able to keep the edge between the two as neither the input to the |
| // second XLA cluster nor the output from the first XLA cluster are supported |
| // because of (2). |
| // |
| // TODO(b/113100872): This can be fixed if the TensorFlow representation for |
| // TensorArray and Stack on the XLA_{C|G}PU devices were the same in XLA; then |
| // (2) would no longer hold. |
| |
| if (n.assigned_device_name().empty()) { |
| *ignore = false; |
| return OkStatus(); |
| } |
| |
| TF_ASSIGN_OR_RETURN( |
| const XlaOpRegistry::DeviceRegistration* registration, |
| device_info_cache->GetCompilationDevice(n.assigned_device_name())); |
| |
| if (!registration) { |
| *ignore = true; |
| } else { |
| *ignore = registration->cluster_resource_variable_ops_unsafely; |
| } |
| return OkStatus(); |
| } |
| |
| StatusOr<bool> MarkForCompilationPassImpl::Initialize() { |
| TF_RET_CHECK(!initialized_ && !edges_contracted_ && !clusters_created_); |
| initialized_ = true; |
| |
| TF_RETURN_IF_ERROR(FindCompilationCandidates()); |
| |
| if (compilation_candidates_.empty()) { |
| VLOG(2) << "No compilable candidates"; |
| return false; |
| } |
| |
| TF_ASSIGN_OR_RETURN(bool cycle_detection_graph_ok, |
| CreateCycleDetectionGraph(graph_, &cycles_graph_)); |
| if (!cycle_detection_graph_ok) { |
| // TODO(sanjoy): This should be logged via the XLA activity listener. |
| VLOG(2) << "Could not form cycle detection graph"; |
| return false; |
| } |
| |
| if (!debug_options_.ignore_deadness_checks) { |
| XLA_SCOPED_LOGGING_TIMER_LEVEL("DeadnessAnalysis", 1); |
| TF_RETURN_IF_ERROR(DeadnessAnalysis::Run(*graph_, &deadness_analysis_)); |
| } |
| |
| // If the user is requesting deterministic cluster names compute a hash of the |
| // input graph to provide a stable but unique prefix for the name. |
| if (debug_options_.deterministic_cluster_names) { |
| TF_ASSIGN_OR_RETURN(graph_fingerprint_, FingerprintGraph(*graph_)); |
| } |
| |
| // Each compilation candidate belongs to a cluster. The cluster's |
| // representative names the node in the 'cycles' graph that represents the |
| // cluster. |
| TF_RETURN_IF_ERROR(BuildInitialClusterSet()); |
| return true; |
| } |
| |
| template <typename FnTy> |
| StatusOr<bool> MarkForCompilationPassImpl::ForEachEdgeInPostOrder(FnTy fn) { |
| bool changed = false; |
| for (int32_t node : cycles_graph_.AllNodesInPostOrder()) { |
| Cluster* cluster_from = GetClusterForCyclesGraphNode(node); |
| if (!cluster_from) { |
| continue; |
| } |
| |
| // Make a copy of the set of successors because we may modify the graph in |
| // TryToContractEdge. |
| std::vector<int32> successors_copy = |
| cycles_graph_.SuccessorsCopy(cluster_from->cycles_graph_node_id()); |
| |
| for (int to : successors_copy) { |
| iteration_count_++; |
| |
| Cluster* cluster_to = GetClusterForCyclesGraphNode(to); |
| if (!cluster_to) { |
| continue; |
| } |
| |
| TF_ASSIGN_OR_RETURN(bool contracted_edge, fn(cluster_from, cluster_to)); |
| changed |= contracted_edge; |
| } |
| } |
| |
| return changed; |
| } |
| |
| Node* MarkForCompilationPassImpl::GetOnlyNodeIn(const Cluster& cluster) { |
| return cluster.cluster_size() == 1 |
| ? graph_->FindNodeId(cluster.GetIdOfOnlyNode()) |
| : nullptr; |
| } |
| |
| bool MarkForCompilationPassImpl::IsSinkLike(const Cluster& cluster) { |
| if (Node* n = GetOnlyNodeIn(cluster)) { |
| return n->type_string() == "NoOp" && n->out_edges().size() == 1 && |
| (*n->out_edges().begin())->dst()->IsSink(); |
| } |
| |
| return false; |
| } |
| |
| bool MarkForCompilationPassImpl::IsScalarIntegerResourceOperation( |
| const Cluster& cluster) { |
| Node* n = GetOnlyNodeIn(cluster); |
| if (!n) { |
| return false; |
| } |
| |
| if (n->type_string() != "AssignAddVariableOp" && |
| n->type_string() != "AssignSubVariableOp") { |
| return false; |
| } |
| |
| DataType dtype; |
| if (!TryGetNodeAttr(n->def(), "dtype", &dtype) || !DataTypeIsInteger(dtype)) { |
| return false; |
| } |
| |
| Node* const_input = nullptr; |
| for (const Edge* e : n->in_edges()) { |
| if (!e->IsControlEdge() && e->src()->IsConstant()) { |
| const_input = e->src(); |
| break; |
| } |
| } |
| |
| if (!const_input) { |
| return false; |
| } |
| |
| const TensorProto* proto = nullptr; |
| if (!TryGetNodeAttr(const_input->def(), "value", &proto)) { |
| return false; |
| } |
| |
| return TensorShapeUtils::IsScalar(proto->tensor_shape()); |
| } |
| |
| Status MarkForCompilationPassImpl::RunEdgeContractionLoop() { |
| TF_RET_CHECK(initialized_ && !edges_contracted_ && !clusters_created_); |
| edges_contracted_ = true; |
| |
| // TODO(hpucha): Handle the case where kXlaClusterAttr is already set (for |
| // example, from the Grappler fusion pass). |
| |
| // In general there are multiple maximal clusterings, but they are not all |
| // equally performant. Some clustering decision are likely to improve |
| // performance much more than others, and we cannot order contractions on this |
| // cost function, nor can we look at global information while deciding on |
| // individual edges to contract. Instead, we will make decisions on these |
| // important edges then make decisions on all other edges, causing the highest |
| // chance of all most important edges to be contracted. |
| // |
| // An example of where this might occur is with a digraph: |
| // {A -> B, B -> C, A -> X, X -> C} where B is a Size operation and X is |
| // not-compilable. In this case, the valid clusterings are {A,B} or {B,C}. B |
| // should be clustered with A because it will prevent a potentially large |
| // tensor from A being computed and copied. |
| // |
| // To choose better maximal clusterings we make multiple iterations over the |
| // graph in post-order, where each such iteration is called a "phase". |
| |
| // Phase 0: contract metadata operations with their producer. |
| |
| VLOG(4) << "Running phase 0"; |
| TF_RETURN_IF_ERROR( |
| ForEachEdgeInPostOrder([&](Cluster* from, Cluster* to) -> StatusOr<bool> { |
| // Shape consuming operations are desirable to cluster with their |
| // operands because they return a small set of scalar values after |
| // consuming a large amount of data. For example, given a graph X -> Y |
| // -> Size -> Z, where the possible clustering is [{X, Y, Size}, {Z}] or |
| // [{X, Y}, {Size, Z}], the better clustering is Size with Y because the |
| // output of size will be a small tensor while Y is a potentially large |
| // tensor that must be computed and possible transposed/copied before |
| // the second cluster executes. |
| Node* n = GetOnlyNodeIn(*to); |
| bool is_shape_consumer_op = n && IsShapeConsumerOp(*n); |
| if (!is_shape_consumer_op) { |
| return false; |
| } |
| |
| return TryToContractEdge(from, to); |
| }).status()); |
| |
| // Phase 1: apply a heuristic to ensure that we don't mess up clustering due |
| // to "group_deps". After this phase most edges should have been contracted. |
| |
| VLOG(4) << "Running phase 1"; |
| TF_RETURN_IF_ERROR( |
| ForEachEdgeInPostOrder([&](Cluster* from, Cluster* to) -> StatusOr<bool> { |
| // We split out this phase to get good clustering in the presence of a |
| // specific pattern seen in some graphs: |
| // |
| // digraph { |
| // ApplyWeightUpdates_0 -> "iteration++" |
| // ApplyWeightUpdates_1 -> "iteration++" |
| // ApplyWeightUpdates_2 -> "iteration++" |
| // ApplyWeightUpdates_0 -> Computation_A |
| // ApplyWeightUpdates_1 -> Computation_B |
| // ApplyWeightUpdates_2 -> Computation_C |
| // Computation_A -> NoOp |
| // Computation_B -> NoOp |
| // Computation_C -> NoOp |
| // "iteration++" -> NoOp |
| // } |
| // |
| // In the graph above we can't cluster iteration++ with any of the |
| // gradient update operations since that will break the TF resource |
| // variable memory model. Given that constraint the ideal clustering |
| // would be to put all the gradient updates and all of the Computation_* |
| // nodes in one cluster, and leave iteration++ and NoOp unclustered. |
| // |
| // A naive post-order traversal would not create this good clustering, |
| // however. Instead it will first create a cluster that puts |
| // Computation_* nodes, the NoOp and iteration++ node in a single |
| // cluster, after which it will fail to put any of the |
| // ApplyWeightUpdates_* nodes into this cluster. To avoid this fate we |
| // instead run a pass that avoids contracting edges _into_ NoOps like |
| // the above, and avoid clustering edges _from_ "iteration++" like the |
| // above. Then we run a second pass that contracts the edges we could |
| // not contract the first time around. |
| |
| if (IsSinkLike(*to)) { |
| return false; |
| } |
| |
| if (IsScalarIntegerResourceOperation(*from)) { |
| return false; |
| } |
| |
| return TryToContractEdge(from, to); |
| }).status()); |
| |
| // Phase 2: contract any remaining edges. After this phase we should have a |
| // maximal clustering: |
| // |
| // A. We visit a cluster only after maximally clustering all its children. |
| // B. By the time we're done with a node all of its children that could have |
| // been absorbed into the node have been absorbed. |
| // C. We have an invariant that making a cluster larger does not make edges |
| // leaving it more contractable. That is, if we have |
| // digraph { X->Y; Y->Z; } then collapsing X->Y does not make it possible |
| // to contract Y->Z if Y->Z was not contractible originally. |
| VLOG(4) << "Running phase 2"; |
| TF_RETURN_IF_ERROR(ForEachEdgeInPostOrder([&](Cluster* from, Cluster* to) { |
| return TryToContractEdge(from, to); |
| }).status()); |
| |
| // Check that the conclusion made above (that iterating over the graph once in |
| // post order gives a maximal clustering) holds. Once the linear time |
| // post-order scheme has been battle tested we can move this to happen only in |
| // debug builds. |
| VLOG(2) << "Checking idempotence"; |
| TF_ASSIGN_OR_RETURN(bool changed, |
| ForEachEdgeInPostOrder([&](Cluster* from, Cluster* to) { |
| return TryToContractEdge(from, to); |
| })); |
| TF_RET_CHECK(!changed); |
| |
| return OkStatus(); |
| } |
| |
| Status MarkForCompilationPassImpl::DeclusterNodes() { |
| for (Node* n : compilation_candidates_) { |
| Cluster* cluster = GetClusterForNode(n); |
| if (cluster == nullptr) { |
| continue; |
| } |
| |
| // De-cluster Fill ops that are |
| // - used at least once outside the cluster, and |
| // - not used inside the cluster. |
| // |
| // In this case, using XLA for the op can only make peak memory usage worse. |
| // If we don't cluster the Fill, it can be materialized right before it's |
| // used in the TF graph. Whereas if we do cluster it, the Fill must be live |
| // starting at the end of the XLA cluster, potentially significantly |
| // increasing its live range. |
| // |
| // See b/221997940 for a real-world example of this. |
| if (n->op_def().name() == "Fill" && |
| n->out_nodes().begin() != n->out_nodes().end() && |
| absl::c_all_of(n->out_nodes(), [&](Node* user) { |
| return GetClusterForNode(user) != cluster; |
| })) { |
| declustered_nodes_.insert(n); |
| } |
| } |
| |
| return OkStatus(); |
| } |
| |
| // Tracks monotonic sequence numbers for graphs. |
| class ClusterSequenceNumberGenerator { |
| public: |
| void Reset() { |
| mutex_lock lock(mu_); |
| sequence_numbers_.clear(); |
| } |
| |
| int64 GetNext(uint64 key) { |
| mutex_lock lock(mu_); |
| return sequence_numbers_[key]++; |
| } |
| |
| static ClusterSequenceNumberGenerator& Global() { |
| static ClusterSequenceNumberGenerator* gen = |
| new ClusterSequenceNumberGenerator; |
| return *gen; |
| } |
| |
| private: |
| mutex mu_; |
| absl::flat_hash_map<uint64, int64> sequence_numbers_; |
| }; |
| |
| // Get a monotonic sequence numbers for a graph identified by its `fingerprint`. |
| // The sequence number is necessary to disambiguate clusters extracted from the |
| // same graph and when duplicate graphs exist within the same process. |
| int64_t GetNextClusterSequenceNumber(uint64 fingerprint) { |
| return ClusterSequenceNumberGenerator::Global().GetNext(fingerprint); |
| } |
| |
| Status MarkForCompilationPassImpl::CreateClusters() { |
| TF_RET_CHECK(initialized_ && edges_contracted_ && !clusters_created_); |
| clusters_created_ = true; |
| |
| // Names for each cluster. |
| std::unordered_map<int, string> cluster_names; |
| |
| if (debug_options_.dump_graphs) { |
| DumpGraphToFile("before_mark_for_compilation", *graph_, flib_def_); |
| } |
| |
| // Mark clusters for compilation that: |
| // * are placed on a device that requires compilation (an XlaDevice), |
| // * are explicitly marked for compilation (_XlaCompile=true), or |
| // * have more than debug_options_.xla_min_cluster_size elements (applicable |
| // only if compilation is enabled, otherwise there will be no such |
| // candidates). |
| for (Node* n : compilation_candidates_) { |
| Cluster* cluster = GetClusterForNode(n); |
| TF_ASSIGN_OR_RETURN(bool should_compile_cluster, |
| ShouldCompileCluster(*cluster)); |
| if (!should_compile_cluster || declustered_nodes_.contains(n)) { |
| continue; |
| } |
| |
| // We assume that functional If and While nodes have at least |
| // min_cluster_size non-trivial nodes in them. It would be more principled |
| // to (recursively) verify this fact, but that's probably not worth the |
| // trouble. |
| |
| if (cluster->effective_cluster_size() >= debug_options_.min_cluster_size || |
| cluster->has_functional_control_flow() || |
| cluster->is_xla_compile_attr_true()) { |
| string& name = cluster_names[cluster->cycles_graph_node_id()]; |
| |
| if (name.empty()) { |
| if (debug_options_.deterministic_cluster_names) { |
| name = absl::StrCat("cluster_", graph_fingerprint_, "_", |
| GetNextClusterSequenceNumber(graph_fingerprint_)); |
| } else { |
| name = absl::StrCat("cluster_", |
| GetNextClusterSequenceNumber(graph_fingerprint_)); |
| } |
| } |
| |
| n->AddAttr(kXlaClusterAttr, name); |
| n->AddAttr(kXlaAlreadyClustered, true); |
| VLOG(3) << "Assigning node " << n->name() << " to cluster " << name; |
| } |
| } |
| |
| return OkStatus(); |
| } |
| |
| Status MarkForCompilationPassImpl::DumpDebugInfo() { |
| TF_RET_CHECK(initialized_ && edges_contracted_ && clusters_created_); |
| |
| if (debug_options_.dump_graphs) { |
| DumpPostClusteringGraphs(); |
| } |
| |
| VLogClusteringSummary(); |
| |
| return OkStatus(); |
| } |
| |
| StatusOr<bool> |
| MarkForCompilationPassImpl::ClusteringWillIntroduceInterDeviceDependency( |
| const Cluster& cluster_from, const Cluster& cluster_to) { |
| // If any of the consumer's producers are on a different device, do not |
| // cluster these nodes. This prevents other work on this device from being |
| // delayed by work on other devices. We consider predecessors of the entire |
| // cluster rather than just the inputs to the node to prevent the cluster |
| // still being combined in cases where the 'to' cluster has multiple |
| // dependencies on the 'from' cluster and another dependency leads to a |
| // merging of the clusters. |
| // |
| // TODO(b/117085735): We probably want to handle the reciprocal of this case |
| // where a cluster is producing data for multiple devices. |
| for (const auto& in_id : |
| cycles_graph_.Predecessors(cluster_to.cycles_graph_node_id())) { |
| const Cluster* cluster_in = GetClusterForCyclesGraphNode(in_id); |
| if (cluster_in) { |
| TF_ASSIGN_OR_RETURN(bool devices_compatible, |
| AreDevicesCompatible(cluster_to, *cluster_in)); |
| if (!devices_compatible) { |
| return true; |
| } |
| TF_ASSIGN_OR_RETURN(devices_compatible, |
| AreDevicesCompatible(cluster_from, *cluster_in)); |
| if (!devices_compatible) { |
| return true; |
| } |
| } |
| } |
| |
| return false; |
| } |
| |
| std::optional<string> MarkForCompilationPassImpl::GetXlaScope(Node* node) { |
| // Look for either _XlaScope or _XlaInternalScope on both nodes to guide |
| // clustering. If both nodes have a scope and the scopes do not match, do |
| // not cluster along this edge. If even one of the nodes lacks a scope |
| // attribute, then it is treated as a "bridge" and a cluster may be created |
| // along it. |
| // |
| // The difference between _XlaScope and _XlaInternalScope is that _XlaScope is |
| // provided by users through jit_scope APIs, while _XlaInternalScope is |
| // automatically generated by the ClusterScopingPass when auto_jit is on. As |
| // such, we respect _XlaScope only when auto_jit is off, while respecting |
| // _XlaInternalScope only when auto_jit is on. |
| // |
| // We may want to restrict the _XlaScope behavior to require all nodes marked |
| // with _XlaCompile=true to also have a _XlaScope property set (and raise an |
| // error otherwise); but for now we don't do this. |
| |
| if (global_jit_level_ != OptimizerOptions::OFF) { |
| // If global_jit_level_ is ON, respect only _XlaInternalScope. |
| const string& scope = |
| GetNodeAttrString(node->attrs(), kXlaInternalScopeAttr); |
| if (!scope.empty()) { |
| return scope; |
| } |
| } else { |
| // If global_jit_level_ is OFF, respect only _XlaScope. |
| const string& scope = GetNodeAttrString(node->attrs(), kXlaScopeAttr); |
| if (!scope.empty()) { |
| return scope; |
| } |
| } |
| |
| return std::nullopt; |
| } |
| |
| // Returns true iff the attribute `attr_name` is attached to either the node or |
| // to it's callee. |
| static bool GetNodeOrFuncAttr(Node* node, FunctionLibraryDefinition* flib_def, |
| const char* attr_name) { |
| bool out = false; |
| bool attr_value; |
| if (TryGetNodeAttr(node->attrs(), attr_name, &attr_value)) { |
| out |= attr_value; |
| } |
| |
| if (flib_def->GetAttr(*node, attr_name, &attr_value).ok()) { |
| out |= attr_value; |
| } |
| return out; |
| } |
| |
| Status MarkForCompilationPassImpl::BuildInitialClusterSet() { |
| auto ignore_resource_ops = [&](const Node& n, bool* ignore) { |
| return IgnoreResourceOpForSafetyAnalysis(&device_info_cache_, n, ignore); |
| }; |
| |
| std::vector<std::pair<int, int>> unsafe_resource_deps_vect; |
| TF_RETURN_IF_ERROR(ComputeIncompatibleResourceOperationPairs( |
| *graph_, flib_def_, ignore_resource_ops, &unsafe_resource_deps_vect)); |
| absl::c_copy( |
| unsafe_resource_deps_vect, |
| std::inserter(unsafe_resource_deps_, unsafe_resource_deps_.begin())); |
| |
| cluster_for_node_.resize(graph_->num_node_ids()); |
| for (Node* node : graph_->nodes()) { |
| if (!IsCompilationCandidate(node)) { |
| cluster_for_node_[node->id()].Get() = nullptr; |
| continue; |
| } |
| |
| // We want clusters to be big enough that the benefit from XLA's |
| // optimizations offsets XLA related overhead (for instance we add some |
| // Switch/Merge nodes into the graph to implement lazy compilation). To |
| // this end, we don't count Identity and Constant nodes because they do not |
| // enable interesting optimizations by themselves. |
| int effective_cluster_size = |
| (node->IsIdentity() || node->IsConstant()) ? 0 : 1; |
| |
| bool has_functional_control_flow = node->IsWhileNode() || node->IsIfNode(); |
| |
| std::optional<DeadnessPredicate> deadness_predicate; |
| if (deadness_analysis_) { |
| TF_ASSIGN_OR_RETURN( |
| deadness_predicate, |
| deadness_analysis_->GetPredicateFor(node, Graph::kControlSlot)); |
| } |
| |
| const string& device_name_str = !node->assigned_device_name().empty() |
| ? node->assigned_device_name() |
| : node->requested_device(); |
| TF_ASSIGN_OR_RETURN(DeviceId device, |
| device_info_cache_.GetIdFor(device_name_str)); |
| |
| bool is_resource_op = HasResourceInputOrOutput(*node); |
| std::optional<DeviceId> resource_op_device; |
| if (is_resource_op) { |
| resource_op_device = device; |
| } |
| |
| std::optional<int> resource_var_operation_node_id; |
| if (is_resource_op || MayCallFunction(*node, flib_def_)) { |
| resource_var_operation_node_id = node->id(); |
| } |
| |
| bool is_xla_compile_attr_true = |
| GetNodeOrFuncAttr(node, flib_def_, kXlaCompileAttr) || |
| (global_jit_level_ != OptimizerOptions::OFF && |
| GetNodeOrFuncAttr(node, flib_def_, kXlaMustCompileAttr)); |
| |
| DeviceSet devices; |
| devices.Insert(device); |
| |
| Cluster* new_cluster = MakeNewCluster( |
| /*cycles_graph_node_id=*/node->id(), |
| /*effective_cluster_size=*/effective_cluster_size, |
| /*has_functional_control_flow=*/has_functional_control_flow, devices, |
| resource_op_device, resource_var_operation_node_id, deadness_predicate, |
| /*is_xla_compile_attr_true=*/is_xla_compile_attr_true, |
| GetXlaScope(node)); |
| |
| cluster_for_node_[node->id()].Get() = new_cluster; |
| } |
| |
| return OkStatus(); |
| } |
| |
| StatusOr<bool> IsIdentityDrivingConstsInLoop(Node* node) { |
| if (!node->IsIdentity()) { |
| return false; |
| } |
| |
| // Check if the Identity is driven by a Switch on its true path. |
| auto it = absl::c_find_if(node->in_edges(), [](const Edge* e) { |
| return e->src()->IsSwitch() && e->src_output() == 1; |
| }); |
| if (it == node->in_edges().end()) { |
| return false; |
| } |
| const Node* switch_node = (*it)->src(); |
| |
| // Check if the Switch is driven by LoopCond. |
| const Node* maybe_loop_cond; |
| TF_RETURN_IF_ERROR(switch_node->input_node(1, &maybe_loop_cond)); |
| if (!maybe_loop_cond->IsLoopCond()) { |
| return false; |
| } |
| |
| // Check if the Identity is driving any const nodes through a control edge. |
| bool driving_any_consts = |
| absl::c_any_of(node->out_edges(), [](const Edge* e) { |
| return e->dst()->IsConstant() && e->IsControlEdge(); |
| }); |
| if (!driving_any_consts) { |
| return false; |
| } |
| |
| return true; |
| } |
| |
| absl::flat_hash_set<string> GetOrCreateClusterExcludeList() { |
| MarkForCompilationPassFlags* flags = GetMarkForCompilationPassFlags(); |
| absl::flat_hash_set<string> excludelist; |
| for (auto s : absl::StrSplit(flags->tf_xla_cluster_exclude_ops, ',')) { |
| if (!s.empty()) { |
| excludelist.insert(string(s)); |
| } |
| } |
| if (VLOG_IS_ON(2) && !excludelist.empty()) { |
| std::vector<string> vexcludelist(excludelist.begin(), excludelist.end()); |
| absl::c_sort(vexcludelist); |
| VLOG(2) << "XLA clustering will exclude following TF operations from auto " |
| "clustering: " |
| << absl::StrJoin(vexcludelist, " "); |
| } |
| return excludelist; |
| } |
| |
| absl::flat_hash_set<string> GetOrCreateAllowlist() { |
| absl::flat_hash_map<string, std::vector<string>>* allowlist_table = |
| tensorflow::GetAllowlistTable(); |
| MarkForCompilationPassFlags* flags = GetMarkForCompilationPassFlags(); |
| absl::flat_hash_set<string> allowlist; |
| |
| for (auto s : absl::StrSplit(flags->tf_xla_ops_to_cluster, ',')) { |
| if (s == "FUSIBLE") { |
| for (auto pair : *allowlist_table) { |
| allowlist.insert(pair.second.begin(), pair.second.end()); |
| } |
| } else if (allowlist_table->contains(s)) { |
| auto v = allowlist_table->at(s); |
| allowlist.insert(v.begin(), v.end()); |
| } else if (!s.empty()) { |
| // Should be a user provided TF operation. |
| allowlist.insert(string(s)); |
| } |
| } |
| |
| if (VLOG_IS_ON(2) && !allowlist.empty()) { |
| std::vector<string> vallowlist(allowlist.begin(), allowlist.end()); |
| absl::c_sort(vallowlist); |
| VLOG(2) << "XLA clustering will only consider the following TF operations: " |
| << absl::StrJoin(vallowlist, " "); |
| } |
| return allowlist; |
| } |
| |
| Status MarkForCompilationPassImpl::FindCompilationCandidates() { |
| OptimizerOptions opts; |
| std::unique_ptr<ProcessFunctionLibraryRuntime> pflr( |
| new ProcessFunctionLibraryRuntime(nullptr, env_, /*config=*/nullptr, |
| TF_GRAPH_DEF_VERSION, flib_def_, opts)); |
| FunctionLibraryRuntime* lib_runtime = |
| pflr->GetFLR(ProcessFunctionLibraryRuntime::kDefaultFLRDevice); |
| std::vector<bool> compile_time_const_nodes(graph_->num_node_ids(), false); |
| TF_RETURN_IF_ERROR(BackwardsConstAnalysis( |
| *graph_, /*compile_time_const_arg_indices=*/nullptr, |
| &compile_time_const_nodes, lib_runtime)); |
| // Iterate over nodes in sorted order so that compiler fuel is deterministic. |
| // We can't simply pass op_nodes().begin() and op_nodes().end() to the |
| // std::vector constructor because they're not proper iterators, with |
| // iterator_traits defined and so on. |
| std::vector<Node*> sorted_nodes; |
| for (Node* node : graph_->op_nodes()) { |
| sorted_nodes.push_back(node); |
| } |
| std::sort(sorted_nodes.begin(), sorted_nodes.end(), NodeComparatorID()); |
| |
| if (*debug_options_.fuel >= std::numeric_limits<int64_t>::max() / 2) { |
| // The assumption is that if fuel started out as INT64_MAX, it will forever |
| // stay greater than INT64_MAX / 2. |
| VLOG(2) << "Starting fuel: infinity"; |
| } else { |
| VLOG(2) << "Starting fuel: " << *debug_options_.fuel; |
| } |
| |
| VLOG(2) << "sorted_nodes.size() = " << sorted_nodes.size(); |
| |
| auto allowlist = GetOrCreateAllowlist(); |
| |
| std::vector<string> vall_ops = XlaOpRegistry::GetAllRegisteredOps(); |
| absl::flat_hash_set<string> all_ops(vall_ops.begin(), vall_ops.end()); |
| // Check that user's provided TF operation really exists. |
| for (const auto& s : allowlist) { |
| if (!all_ops.contains(s)) { |
| return errors::InvalidArgument( |
| "The operation '", s, |
| "' passed to --tf_xla_ops_to_cluster is not supported by XLA."); |
| } |
| } |
| |
| for (Node* node : sorted_nodes) { |
| if (*debug_options_.fuel <= 0) { |
| VLOG(1) |
| << "Hit fuel limit; not marking any remaining ops as clusterable."; |
| break; |
| } |
| |
| TF_ASSIGN_OR_RETURN( |
| const DeviceType& device_type, |
| device_info_cache_.GetDeviceTypeFor(node->assigned_device_name())); |
| VLOG(4) << "Device type for " << node->name() << ": " |
| << device_type.type_string(); |
| |
| if (CompilationDisallowedByXlaCompileAttr(node)) { |
| VLOG(2) << "Not clustering " << node->name() |
| << ": disallowed by _XlaCompile attribute"; |
| continue; |
| } |
| |
| const XlaOpRegistry::DeviceRegistration* registration; |
| if (!XlaOpRegistry::GetCompilationDevice(device_type.type(), |
| ®istration)) { |
| VLOG(2) << "Rejecting " << node->name() |
| << ": could not find JIT device for " << device_type.type(); |
| continue; |
| } |
| |
| auto cluster_exclude_op_list = GetOrCreateClusterExcludeList(); |
| RecursiveCompilabilityChecker::OperationFilter filter = |
| CreateOperationFilter(*registration); |
| filter.require_always_compilable = true; |
| filter.allow_string_consts = false; |
| filter.allow_collective_reduce_v2 = false; |
| filter.allow_unique_op = false; |
| filter.allow_where_op = true; |
| |
| for (const auto& s : cluster_exclude_op_list) { |
| if (s == "Where") { |
| filter.allow_where_op = false; |
| } else { |
| return errors::InvalidArgument( |
| "The operation '", s, |
| "' passed to --tf_xla_cluster_exclude_ops is not supported by " |
| "XLA."); |
| } |
| } |
| |
| RecursiveCompilabilityChecker checker( |
| filter, DeviceType{registration->compilation_device_name}); |
| |
| if (!checker.IsCompilableNode(*node, lib_runtime)) { |
| continue; |
| } |
| |
| if (node->type_string() == "Const") { |
| // Skip Const op with type DT_STRING, since XLA autoclustering doesn't |
| // support it. |
| const AttrValue* attr = node->attrs().Find("dtype"); |
| if (attr != nullptr && attr->type() == DT_STRING) { |
| continue; |
| } |
| } |
| |
| if (!allowlist.empty() && !allowlist.contains(node->def().op())) { |
| VLOG(1) << "Rejecting TF operation " << node->def().op() |
| << " as it is not listed in --tf_xla_ops_to_cluster."; |
| continue; |
| } |
| |
| if (compile_time_const_nodes[node->id()]) { |
| const OpDef* op_def; |
| TF_RETURN_IF_ERROR( |
| graph_->op_registry()->LookUpOpDef(node->type_string(), &op_def)); |
| if (op_def->is_stateful()) { |
| // It is easiest to demonstrate the problem we're trying to solve with |
| // an example. Say we have this graph: |
| // |
| // shape = RandomUniformInt(); |
| // reshape = Reshape(input, shape) |
| // |
| // Both RandomUniformInt and Reshape are compilable by XLA so, absent |
| // any other reason, we will try to put both shape and reshape in the |
| // same cluster. However, since XLA only supports statically shaped |
| // values, it will expect to be able to constant fold `shape` to get a |
| // static shape for `reshape`. This is a problem because side-effecting |
| // ops like RandomUniformInt() cannot be constant folded. We fix this |
| // by putting `shape` and `reshape` in different clusters, which results |
| // in us recompiling `reshape`'s cluster for every new value of `shape`, |
| // making `reshape` statically sized within each compilation. We |
| // simplify the solution even further by disallowing operations like |
| // `shape` from being part of *any* non-trivial cluster. They're either |
| // not compiled by XLA altogether or, if assigned to an XLA_* device |
| // with "must compile" semantics, compiled into a trivial single-op |
| // cluster. This approach leaves some room for improvement, and we can |
| // consider implementing a more aggressive data-flow-analysis based |
| // solution in the future if needed. |
| // |
| // One ugly problem we have to contend with: certain sets of ops *have* |
| // to be in the same cluster because values flowing between them have |
| // types that can't be live-in or live-out of a cluster. These ops are: |
| // |
| // - TensorArray ops operating on the same TensorArray instance. |
| // - Stack ops operating on the same Stack instance. |
| // |
| // To work around this we avoid isolating these specific ops. Because |
| // of this concession it is unsound to auto-cluster them because then |
| // we'd create clusters we could not compile (because we can't constant |
| // fold, say, a TensorArrayRead or a StackPopV2). But we don't |
| // auto-cluster these operations today so we're good for now. |
| const XlaResourceOpInfo* op_info = |
| GetResourceOpInfoForOp(node->type_string()); |
| bool is_tensor_array_or_stack_op = |
| op_info && op_info->resource_kind() != XlaResourceKind::kVariable; |
| if (!is_tensor_array_or_stack_op) { |
| VLOG(2) << "Isolating " << node->name() |
| << ": must-be-constant stateful op"; |
| continue; |
| } |
| } |
| } |
| |
| // This is a heuristic to avoid creating dependency between while loop |
| // condition and body computations. Dependency between them can be created |
| // if a special Identity node in the following pattern is clustered in. |
| // That is, an Identity node in the loop cond computation is used to drive |
| // const nodes consumed by the loop body. If this Identity node goes into |
| // the same cluster with nodes from the loop body, extra dependency is |
| // created between the loop cond and body computations and it hinders the |
| // progression of the loop cond computation at runtime with significant |
| // overhead. Specifically, we look for the below pattern and do not cluster |
| // in this Identity to avoid the described issue. Since Identity has low |
| // execution cost in native TF, the fact that this heuristic gives up these |
| // special Identity nodes as candidates should not harm any performance. If |
| // other considerations emerge in the future, we can revisit the heuristic |
| // and only disallow these Identities to go into the cluster with nodes from |
| // the loop body but still consider them candidates. |
| // |
| // LoopCond -> |
| // Merge -> Switch -> Identity -> i++ -> ... -> NextIteration |
| // ..> Const -> LoopBody |
| // (control edge) |
| TF_ASSIGN_OR_RETURN(bool is_identity_driving_consts_in_loop, |
| IsIdentityDrivingConstsInLoop(node)); |
| if (is_identity_driving_consts_in_loop) { |
| VLOG(2) << "Rejecting " << node->name() |
| << ": including it can create dependencies between while loop " |
| "condition and body computations with runtime overhead."; |
| continue; |
| } |
| |
| compilation_candidates_.insert(node); |
| --(*debug_options_.fuel); |
| } |
| |
| VLOG(2) << "compilation_candidates_.size() = " |
| << compilation_candidates_.size(); |
| return OkStatus(); |
| } |
| |
| bool MarkForCompilationPassImpl::CompilationDisallowedByXlaCompileAttr( |
| Node* node) { |
| if (debug_options_.ignore_xla_compile_attr) { |
| return false; |
| } |
| |
| // If there is a _XlaCompile annotation, use its value. |
| bool compile = false; |
| Status status = GetNodeAttr(node->attrs(), kXlaCompileAttr, &compile); |
| if (status.ok()) { |
| if (!compile) { |
| VLOG(2) << "Rejecting " << node->name() << ": kXlaCompileAttr(" |
| << kXlaCompileAttr << ") is false."; |
| } |
| return !compile; |
| } |
| |
| status = flib_def_->GetAttr(*node, kXlaCompileAttr, &compile); |
| if (status.ok()) { |
| if (!compile) { |
| VLOG(2) << "Rejecting " << node->name() << ": kXlaCompileAttr(" |
| << kXlaCompileAttr << ") on callee is false."; |
| } |
| return !compile; |
| } |
| |
| return false; |
| } |
| |
| bool MarkForCompilationPassImpl::LogNotContractableAndReturnFalse( |
| Cluster* from, Cluster* to, absl::string_view reason) { |
| VLOG(3) << EdgeContractionFailureMsg(from, to, reason); |
| return false; |
| } |
| |
| StatusOr<bool> MarkForCompilationPassImpl::TryToContractEdge(Cluster* from, |
| Cluster* to) { |
| DCHECK(from->deadness_predicate().has_value() == |
| to->deadness_predicate().has_value()); |
| if (from->deadness_predicate() != to->deadness_predicate()) { |
| VLOG(3) << EdgeContractionFailureMsg( |
| from, to, |
| absl::StrCat( |
| "the two nodes have mismatching deadness: ", |
| deadness_analysis_->DebugString(*from->deadness_predicate()), |
| " and ", |
| deadness_analysis_->DebugString(*to->deadness_predicate()))); |
| return false; |
| } |
| |
| TF_ASSIGN_OR_RETURN(bool devices_compatible, |
| AreDevicesCompatible(*from, *to)); |
| if (!devices_compatible) { |
| return LogNotContractableAndReturnFalse( |
| from, to, "the two nodes have incompatible devices"); |
| } |
| |
| if (from->xla_scope().has_value() && to->xla_scope().has_value() && |
| *from->xla_scope() != *to->xla_scope()) { |
| return LogNotContractableAndReturnFalse( |
| from, to, "the two nodes have mismatching XLA scopes"); |
| } |
| |
| // Don't exceed the maximum cluster size. |
| if (from->cluster_size() + to->cluster_size() > |
| debug_options_.max_cluster_size) { |
| return LogNotContractableAndReturnFalse( |
| from, to, "the new cluster will be larger than the max cluster size"); |
| } |
| |
| TF_ASSIGN_OR_RETURN(bool will_introduce_cross_device_dependency, |
| ClusteringWillIntroduceInterDeviceDependency(*from, *to)); |
| |
| if (will_introduce_cross_device_dependency) { |
| return LogNotContractableAndReturnFalse( |
| from, to, "the new cluster will introduce a cross device dependency"); |
| } |
| |
| // Check if contracting this edge will break the resource variable concurrency |
| // semantics. In theory this is quadratic in the number of nodes, but seems |
| // to not be a problem in practice so far. |
| if (!debug_options_.ignore_resource_variable_checks) { |
| for (int resource_var_from : from->resource_var_operation_node_ids()) { |
| for (int resource_var_to : to->resource_var_operation_node_ids()) { |
| // If unsafe_resource_deps_ contains {A, B} then |
| // |
| // a. A and B are resource operations. |
| // b. A and B cannot be placed in the same cluster. |
| // c. There is no path from B to A in the cycles graph (but there may |
| // be a path from A to B). |
| // |
| // So check the legality of the edge contraction by checking if any of |
| // the n^2 pairs of resource variable operations are forbidden. |
| if (unsafe_resource_deps_.contains( |
| {resource_var_from, resource_var_to})) { |
| return LogNotContractableAndReturnFalse( |
| from, to, |
| "the new cluster would break resource variable semantics"); |
| } |
| } |
| } |
| } |
| |
| return MergeClusters(from, to); |
| } |
| |
| Status MarkForCompilationPassImpl::Run() { |
| // Make sure that kernels have been registered on the JIT device. |
| XlaOpRegistry::RegisterCompilationKernels(); |
| |
| // Start the timer after XlaOpRegistry::RegisterCompilationKernels which does |
| // some one-time work. |
| XLA_SCOPED_LOGGING_TIMER_LEVEL("MarkForCompilationPassImpl::Run", 1); |
| |
| TF_ASSIGN_OR_RETURN(bool initialized, Initialize()); |
| if (!initialized) { |
| // Initialization exited early which means this instance of |
| // MarkForCompilationPassImpl is not set up to run the subsequent phases. |
| return OkStatus(); |
| } |
| |
| TF_RETURN_IF_ERROR(RunEdgeContractionLoop()); |
| TF_RETURN_IF_ERROR(DeclusterNodes()); |
| TF_RETURN_IF_ERROR(CreateClusters()); |
| TF_RETURN_IF_ERROR(DumpDebugInfo()); |
| |
| return OkStatus(); |
| } |
| |
| void MarkForCompilationPassImpl::DumpPostClusteringGraphs() { |
| DumpGraphToFile("mark_for_compilation", *graph_, flib_def_); |
| |
| // We also dump out an annotated version of the TF graph where the nodes |
| // names are prefixed with the cluster names. This can help visualizing the |
| // clustering decisions on TensorBoard. |
| Graph new_graph(graph_->op_registry()); |
| CopyGraph(*graph_, &new_graph); |
| |
| for (Node* n : new_graph.nodes()) { |
| if (std::optional<absl::string_view> cluster_name = |
| GetXlaClusterForNode(*n)) { |
| n->set_name(absl::StrCat(*cluster_name, "/", n->name())); |
| } else if (n->type_string() == "VarHandleOp") { |
| n->set_name(absl::StrCat("varhandle/", n->name())); |
| } else { |
| // There is room for improvement here. In particular, it may help to |
| // split these unclustered nodes into classes where every node in a |
| // specific class has edges to and from the same set of clusters. |
| n->set_name(absl::StrCat("unclustered/", n->name())); |
| } |
| } |
| |
| DumpGraphToFile("mark_for_compilation_annotated", new_graph, flib_def_); |
| } |
| |
| string RatioToString(int numerator, int denominator) { |
| return absl::StrFormat("%d / %d (%.2f%%)", numerator, denominator, |
| (100.0 * numerator) / denominator); |
| } |
| |
| void MarkForCompilationPassImpl::VLogClusteringSummary() { |
| if (!VLOG_IS_ON(2)) { |
| return; |
| } |
| |
| XlaAutoClusteringSummary auto_clustering_info = |
| GetXlaAutoClusteringSummary(*graph_); |
| |
| VLOG(2) << "*** Clustering info for graph of size " << graph_->num_nodes(); |
| VLOG(2) << " Built " << auto_clustering_info.clusters_size() |
| << " clusters, size " |
| << RatioToString(auto_clustering_info.clustered_node_count(), |
| graph_->num_nodes()); |
| |
| for (const XlaAutoClusteringSummary::Cluster& cluster : |
| auto_clustering_info.clusters()) { |
| absl::string_view cluster_name = cluster.name(); |
| int size = cluster.size(); |
| VLOG(2) << " " << cluster_name << " " |
| << RatioToString(size, graph_->num_nodes()); |
| for (const XlaAutoClusteringSummary::OpAndCount& op_count : |
| cluster.op_histogram()) { |
| VLOG(3) << " " << op_count.op() << ": " << op_count.count() |
| << " instances"; |
| } |
| } |
| |
| if (!auto_clustering_info.unclustered_op_histogram().empty()) { |
| VLOG(2) << " Unclustered nodes: " |
| << RatioToString(auto_clustering_info.unclustered_node_count(), |
| graph_->num_nodes()); |
| for (const XlaAutoClusteringSummary::OpAndCount& op_count : |
| auto_clustering_info.unclustered_op_histogram()) { |
| VLOG(3) << " " << op_count.op() << ": " << op_count.count() |
| << " instances"; |
| } |
| } |
| |
| struct EdgeInfo { |
| absl::string_view node_name; |
| std::optional<absl::string_view> cluster_name; |
| |
| absl::string_view GetClusterName() const { |
| return cluster_name ? *cluster_name : "[none]"; |
| } |
| |
| std::pair<absl::string_view, std::optional<absl::string_view>> AsPair() |
| const { |
| return {node_name, cluster_name}; |
| } |
| |
| bool operator<(const EdgeInfo& other) const { |
| return AsPair() < other.AsPair(); |
| } |
| }; |
| |
| using EdgeInfoMap = std::map<absl::string_view, std::map<EdgeInfo, int64_t>>; |
| |
| EdgeInfoMap incoming_edge_infos; |
| EdgeInfoMap outgoing_edge_infos; |
| |
| std::set<absl::string_view> cluster_names_to_print; |
| |
| for (const Edge* e : graph_->edges()) { |
| const Node* from = e->src(); |
| std::optional<absl::string_view> from_cluster_name = |
| GetXlaClusterForNode(*from); |
| |
| const Node* to = e->dst(); |
| std::optional<absl::string_view> to_cluster_name = |
| GetXlaClusterForNode(*to); |
| |
| if (to_cluster_name == from_cluster_name) { |
| continue; |
| } |
| |
| if (to_cluster_name) { |
| incoming_edge_infos[*to_cluster_name] |
| [EdgeInfo{from->name(), from_cluster_name}]++; |
| cluster_names_to_print.insert(*to_cluster_name); |
| } |
| |
| if (from_cluster_name) { |
| outgoing_edge_infos[*from_cluster_name][{to->name(), to_cluster_name}]++; |
| cluster_names_to_print.insert(*from_cluster_name); |
| } |
| } |
| |
| VLOG(4) << "*** Inter-Cluster edges:"; |
| if (cluster_names_to_print.empty()) { |
| VLOG(4) << " [none]"; |
| } |
| |
| auto print_edge_info_set_for_cluster = [&](absl::string_view cluster_name, |
| const EdgeInfoMap& edge_info_map, |
| absl::string_view desc) { |
| auto it = edge_info_map.find(cluster_name); |
| if (it != edge_info_map.end()) { |
| VLOG(4) << " " << it->second.size() << " " << desc << " edges"; |
| for (const auto& edge_info_count_pair : it->second) { |
| VLOG(4) << " " << edge_info_count_pair.first.GetClusterName() << " " |
| << edge_info_count_pair.first.node_name << " # " |
| << edge_info_count_pair.second; |
| } |
| } else { |
| VLOG(4) << " No " << desc << " edges."; |
| } |
| }; |
| |
| for (absl::string_view cluster_name : cluster_names_to_print) { |
| VLOG(4) << " ** Cluster " << cluster_name; |
| print_edge_info_set_for_cluster(cluster_name, incoming_edge_infos, |
| "incoming"); |
| print_edge_info_set_for_cluster(cluster_name, outgoing_edge_infos, |
| "outgoing"); |
| } |
| } |
| |
| StatusOr<bool> MarkForCompilationPassImpl::AreDevicesCompatible( |
| const Cluster& cluster_a, const Cluster& cluster_b) { |
| DeviceSet devices = cluster_a.devices(); |
| devices.UnionWith(cluster_b.devices()); |
| |
| TF_ASSIGN_OR_RETURN( |
| std::optional<jit::DeviceId> maybe_chosen_device, |
| MaybePickDeviceForXla(device_info_cache_, devices, |
| /*allow_mixing_unknown_and_cpu=*/false)); |
| if (!maybe_chosen_device.has_value()) { |
| return false; |
| } |
| |
| jit::DeviceId chosen_device = *maybe_chosen_device; |
| |
| // If we are able to pick a device `chosen_device` for the larger cluster, the |
| // resource operations in `cluster_a` and `cluster_b` must be placed on the |
| // same device as `chosen_device`. This is because the _XlaCompile and |
| // _XlaRun kernels are going to run on and therefore try to access the |
| // resource variables from `chosen_device`, which will be an error if the |
| // resource variables are placed on some other device. |
| auto resource_op_device_ok = [&](std::optional<DeviceId> resource_op_device) { |
| return !resource_op_device.has_value() || |
| *resource_op_device == chosen_device; |
| }; |
| |
| return resource_op_device_ok(cluster_a.resource_op_device()) && |
| resource_op_device_ok(cluster_b.resource_op_device()); |
| } |
| |
| // Returns `true` iff we should compile `cluster`. |
| StatusOr<bool> MarkForCompilationPassImpl::ShouldCompileClusterImpl( |
| const Cluster& cluster) { |
| TF_ASSIGN_OR_RETURN(DeviceId chosen_device, |
| PickDeviceForXla(device_info_cache_, cluster.devices(), |
| /*allow_mixing_unknown_and_cpu=*/false)); |
| |
| const DeviceType& device_type = |
| device_info_cache_.GetDeviceTypeFor(chosen_device); |
| const XlaOpRegistry::DeviceRegistration* registration = |
| device_info_cache_.GetCompilationDevice(chosen_device); |
| TF_RET_CHECK(registration) |
| << "chosen device = " << device_info_cache_.GetNameFor(chosen_device) |
| << "; device type = " << device_type.type() << "; devices (" |
| << device_info_cache_.DebugString(cluster.devices()); |
| |
| auto policy = registration->autoclustering_policy; |
| bool should_compile = |
| cluster.is_xla_compile_attr_true() || |
| policy == XlaOpRegistry::AutoclusteringPolicy::kAlways || |
| (policy == XlaOpRegistry::AutoclusteringPolicy::kIfEnabledGlobally && |
| global_jit_level_ != OptimizerOptions::OFF) || |
| (device_type.type_string() == DEVICE_CPU && |
| policy == XlaOpRegistry::AutoclusteringPolicy::kIfExplicitlyRequested && |
| cpu_global_jit_); |
| |
| if (!should_compile && device_type.type_string() == DEVICE_CPU && |
| global_jit_level_ > OptimizerOptions::OFF) { |
| static absl::once_flag once; |
| absl::call_once(once, [] { |
| LOG(WARNING) << R"((One-time warning): Not using XLA:CPU for cluster. |
| |
| If you want XLA:CPU, do one of the following: |
| |
| - set the TF_XLA_FLAGS to include "--tf_xla_cpu_global_jit", or |
| - set cpu_global_jit to true on this session's OptimizerOptions, or |
| - use experimental_jit_scope, or |
| - use tf.function(jit_compile=True). |
| |
| To confirm that XLA is active, pass --vmodule=xla_compilation_cache=1 (as a |
| proper command-line flag, not via TF_XLA_FLAGS).)"; |
| |
| MarkForCompilationPassFlags* flags = GetMarkForCompilationPassFlags(); |
| if (flags->tf_xla_cpu_global_jit) { |
| LOG(WARNING) |
| << "(Although the tf_xla_cpu_global_jit flag is currently enabled, " |
| "perhaps it wasn't enabled at process startup?)"; |
| } |
| }); |
| } |
| |
| VLOG(3) << (should_compile ? "Compiling" : "Not compiling") |
| << " cluster with device " |
| << device_info_cache_.GetNameFor(chosen_device); |
| |
| return should_compile; |
| } |
| |
| StatusOr<bool> MarkForCompilationPassImpl::ShouldCompileCluster( |
| const Cluster& cluster) { |
| auto it = should_compile_cluster_cache_.find(&cluster); |
| if (it != should_compile_cluster_cache_.end()) { |
| return it->second; |
| } |
| |
| TF_ASSIGN_OR_RETURN(bool should_compile, ShouldCompileClusterImpl(cluster)); |
| should_compile_cluster_cache_.insert({&cluster, should_compile}); |
| return should_compile; |
| } |
| |
| Status MarkForCompilation( |
| const GraphOptimizationPassOptions& options, |
| const MarkForCompilationPassImpl::DebugOptions& debug_options) { |
| Graph* graph = options.graph->get(); |
| FunctionLibraryDefinition* flib_def = options.flib_def; |
| |
| // Deadness analysis expects a graph with source and sink edges properly |
| // connected but sometimes the incoming graph does not follow this invariant. |
| // So fix up the source and sink edges before calling into deadness analysis. |
| FixupSourceAndSinkEdges(graph); |
| |
| for (Node* n : graph->nodes()) { |
| // See explanation on `kXlaAlreadyClustered`. |
| if (n->attrs().Find(kXlaAlreadyClustered)) { |
| return OkStatus(); |
| } |
| // Skip the pass if we found TPUExecute or TPUExecuteAndUpdateVariables ops |
| // in the graph, which indicates the graph is produced by TPU TF-XLA bridge |
| // and doesn't require auto clustering. |
| if (n->type_string() == "TPUExecute" || |
| n->type_string() == "TPUExecuteAndUpdateVariables") { |
| return OkStatus(); |
| } |
| } |
| |
| return MarkForCompilationPassImpl{ |
| debug_options, |
| graph, |
| flib_def, |
| options.session_options != nullptr ? options.session_options->env |
| : Env::Default(), |
| GetGlobalJitLevelForGraph(options), |
| options.session_options->config.graph_options() |
| .optimizer_options() |
| .cpu_global_jit()} |
| .Run(); |
| } |
| |
| std::atomic<int64_t>* GetPointerToFuel(int64_t initial_value) { |
| static std::atomic<int64_t>* fuel = [&]() { |
| std::atomic<int64_t>* fuel = new std::atomic<int64_t>; |
| *fuel = initial_value; |
| return fuel; |
| }(); |
| |
| return fuel; |
| } |
| } // anonymous namespace |
| |
| Status MarkForCompilationPass::Run( |
| const GraphOptimizationPassOptions& options) { |
| MarkForCompilationPassFlags* flags = GetMarkForCompilationPassFlags(); |
| |
| MarkForCompilationPassImpl::DebugOptions debug_options; |
| debug_options.ignore_deadness_checks = |
| flags->tf_xla_disable_deadness_safety_checks_for_debugging; |
| debug_options.ignore_resource_variable_checks = |
| flags->tf_xla_disable_resource_variable_safety_checks_for_debugging; |
| debug_options.ignore_xla_compile_attr = false; |
| debug_options.deterministic_cluster_names = |
| flags->tf_xla_deterministic_cluster_names; |
| debug_options.max_cluster_size = flags->tf_xla_max_cluster_size; |
| debug_options.min_cluster_size = flags->tf_xla_min_cluster_size; |
| debug_options.fuel = GetPointerToFuel(flags->tf_xla_clustering_fuel); |
| debug_options.dump_graphs = flags->tf_xla_clustering_debug; |
| |
| return MarkForCompilation(options, debug_options); |
| } |
| |
| Status MarkForCompilationPass::RunForTest( |
| const GraphOptimizationPassOptions& options, bool disable_deadness_analysis, |
| bool deterministic_cluster_names) { |
| MarkForCompilationPassFlags* flags = GetMarkForCompilationPassFlags(); |
| |
| MarkForCompilationPassImpl::DebugOptions debug_options; |
| debug_options.ignore_deadness_checks = disable_deadness_analysis; |
| debug_options.ignore_resource_variable_checks = |
| flags->tf_xla_disable_resource_variable_safety_checks_for_debugging; |
| debug_options.ignore_xla_compile_attr = true; |
| debug_options.deterministic_cluster_names = deterministic_cluster_names; |
| debug_options.max_cluster_size = flags->tf_xla_max_cluster_size; |
| debug_options.min_cluster_size = flags->tf_xla_min_cluster_size; |
| debug_options.fuel = GetPointerToFuel(flags->tf_xla_clustering_fuel); |
| debug_options.dump_graphs = flags->tf_xla_clustering_debug; |
| |
| return MarkForCompilation(options, debug_options); |
| } |
| |
| absl::flat_hash_map<string, std::vector<string>>* GetAllowlistTable() { |
| // Table format: category name: {list of TF operations in that category} |
| static absl::flat_hash_map<string, std::vector<string>>* result = |
| new absl::flat_hash_map<string, std::vector<string>>{ |
| // Unary |
| {"PW", |
| {"ComplexAbs", "Angle", "Conj", "Abs", "Acos", "Acosh", "Asin", |
| "Atan", "Atanh", "Ceil", "Cos", "Cosh", "Sin", "Exp", "Expm1", |
| "Floor", "IsFinite", "IsInf", "IsNan", "Inv", "Reciprocal", "Log", |
| "Log1p", "Invert", "LogicalNot", "Ndtri", "Neg", "Rint", "Round", |
| "Rsqrt", "Sigmoid", "Sign", "Sinh", "Softplus", "Softsign", "Sqrt", |
| "Square", "Tan", "Tanh", "Real", "Imag", "Erf", "Erfc", "Erfinv", |
| "Lgamma", "Digamma", |
| // Binary |
| "Add", "AddV2", "Sub", "Mul", "Div", "Atan2", "Complex", "DivNoNan", |
| "MulNoNan", "FloorDiv", "Xlogy", "Xlog1py", "Xdivy", "FloorMod", |
| "BitwiseAnd", "BitwiseOr", "BitwiseXor", "LeftShift", "RightShift", |
| "LogicalAnd", "LogicalOr", "Mod", "Maximum", "Minimum", "RealDiv", |
| "ReciprocalGrad", "RsqrtGrad", "SqrtGrad", "TruncateDiv", |
| "TruncateMod", "Equal", "NotEqual", "Greater", "GreaterEqual", |
| "Less", "LessEqual", "SigmoidGrad", "SoftplusGrad", "SoftsignGrad", |
| "TanhGrad", "Pow", "SquaredDifference", "ApproximateEqual", |
| // Others |
| "AddN", "Bitcast", "Cast", "ClipByValue", "Const", "Empty", |
| "Identity", "IdentityN", "Relu", "Relu6", "ReluGrad", "Relu6Grad", |
| "LeakyReluGrad", "Elu", "EluGrad", "Selu", "SeluGrad", "Select", |
| "SelectV2", "Transpose", "ConjugateTranspose", |
| "_UnaryOpsComposition", "CollectiveReduceV2", |
| "CollectiveAssignGroupV2", |
| // The following 5 operations are converted to identity |
| "PlaceholderWithDefault", "PreventGradient", "StopGradient", |
| "Snapshot", "_EagerConst"}}, |
| // clang-format off |
| {"RED", |
| {"All", "Any", "Min", "Max", "Mean", "Prod", "Sum"}}, |
| // clang-format on |
| {"PWRED", |
| {"ArgMax", "ArgMin", "DiagPart", "Softmax", |
| "SparseSoftmaxCrossEntropyWithLogits", "LogSoftmax"}}, |
| {"REDUCEWINDOW", |
| {"ArgMax", "ArgMin", "DiagPart", "Softmax", |
| "SparseSoftmaxCrossEntropyWithLogits", "LogSoftmax"}}, |
| {"REDUCEWINDOWPW", {"BiasAddGrad", "LRN", "LRNGrad"}}, |
| {"BN", |
| {"FusedBatchNorm", "FusedBatchNormV2", "FusedBatchNormV3", |
| "_FusedBatchNormEx", "FusedBatchNormGrad", "FusedBatchNormGradV2", |
| "FusedBatchNormGradV3"}}, |
| {"SORT", {"TopKV2"}}, // XLA version much faster then TF version. |
| {"MISC", |
| // clang-format off |
| {"ApproxTopK", "BroadcastTo", "ExpandDims", "Fill", "NoOp", |
| "Range", "Rank", "Reshape", "Shape", "ShapeN", "Size", "Squeeze", |
| "Transpose", "ZerosLike", "OnesLike", "BiasAdd" /*PW + Broadcast*/, |
| "BroadcastArgs", "BroadcastGradientArgs", "OneHot", "Concat", "ConcatV2", |
| "ConcatOffset", "Const", "MirrorPad", "MirrorPadGrad", "Pack", "Pad", |
| "PadV2", "Reverse", "ReverseV2", "ReverseSequence", "Slice", "Split", |
| "SplitV", "StridedSlice", "StridedSliceGrad", |
| "ResourceStridedSliceAssign", "Tile", "Transpose", "InvertPermutation", |
| "Unpack", "DeviceIndex", "TensorStridedSliceUpdate", "XlaConcatND", |
| "XlaSplitND", |
| }}}; |
| // clang-format on |
| return result; |
| } |
| |
| namespace testing { |
| void ResetClusterSequenceNumber() { |
| ClusterSequenceNumberGenerator::Global().Reset(); |
| } |
| |
| absl::flat_hash_set<string> GetKnownXLAAllowlistOp() { |
| absl::flat_hash_set<string> result{ |
| "AdjustContrastv2", |
| "AdjustHue", |
| "AdjustSaturation", |
| "Asinh", |
| "Assert", |
| "AssignAddVariableOp", |
| "AssignSubVariableOp", |
| "AssignVariableOp", |
| "AssignVariableXlaConcatND", |
| "AvgPool", |
| "AvgPool3D", |
| "AvgPool3DGrad", |
| "AvgPoolGrad", |
| "BatchMatMul", |
| "BatchMatMulV2", |
| "BatchMatMulV3", |
| "BatchToSpace", |
| "BatchToSpaceND", |
| "BesselI0e", |
| "BesselI1e", |
| "Betainc", |
| "BiasAddV1", |
| "Bincount", |
| "Bucketize", |
| "Case", |
| "CheckNumerics", |
| "Cholesky", |
| "ControlTrigger", |
| "Conv2D", |
| "Conv2DBackpropFilter", |
| "Conv2DBackpropInput", |
| "Conv3D", |
| "Conv3DBackpropFilterV2", |
| "Conv3DBackpropInputV2", |
| "Cross", |
| "Cumprod", |
| "Cumsum", |
| "DenseBincount", |
| "DataFormatDimMap", |
| "DataFormatVecPermute", |
| "DepthToSpace", |
| "DepthwiseConv2dNative", |
| "DepthwiseConv2dNativeBackpropFilter", |
| "DepthwiseConv2dNativeBackpropInput", |
| "Dequantize", |
| "Diag", |
| "DynamicStitch", |
| "DynamicPartition", |
| "Einsum", |
| "EmptyTensorList", |
| "EnsureShape", |
| "ExtractImagePatches", |
| "Igamma", |
| "IgammaGradA", |
| "RandomGammaGrad", |
| "Igammac", |
| "FFT", |
| "FFT2D", |
| "FFT3D", |
| "FakeParam", |
| "FakeQuantWithMinMaxArgs", |
| "FakeQuantWithMinMaxArgsGradient", |
| "FakeQuantWithMinMaxVars", |
| "FakeQuantWithMinMaxVarsGradient", |
| "FakeQuantWithMinMaxVarsPerChannel", |
| "FakeQuantWithMinMaxVarsPerChannelGradient", |
| "Gather", |
| "GatherNd", |
| "GatherV2", |
| "HSVToRGB", |
| "IFFT", |
| "IFFT2D", |
| "IFFT3D", |
| "IRFFT", |
| "IRFFT2D", |
| "IRFFT3D", |
| "If", |
| "InTopKV2", |
| "L2Loss", |
| "LeakyRelu", |
| "LinSpace", |
| "ListDiff", |
| "LogMatrixDeterminant", |
| "LowerBound", |
| "MatMul", |
| "MatrixBandPart", |
| "MatrixDiag", |
| "MatrixDiagPart", |
| "MatrixDiagPartV2", |
| "MatrixDiagPartV3", |
| "MatrixDiagV2", |
| "MatrixDiagV3", |
| "MatrixInverse", |
| "MatrixSetDiag", |
| "MatrixSetDiagV2", |
| "MatrixSetDiagV3", |
| "MatrixSolve", |
| "MatrixTriangularSolve", |
| "MaxPool", |
| "MaxPool3D", |
| "MaxPool3DGrad", |
| "MaxPool3DGradGrad", |
| "MaxPoolGrad", |
| "MaxPoolGradGrad", |
| "MaxPoolGradGradV2", |
| "MaxPoolGradV2", |
| "MaxPoolV2", |
| "Multinomial", |
| "NextAfter", |
| "NonMaxSuppressionV3", |
| "NonMaxSuppressionV4", |
| "ParallelDynamicStitch", |
| "ParameterizedTruncatedNormal", |
| "PartitionedCall", |
| "Polygamma", |
| "PopulationCount", |
| "Qr", |
| "QuantizeAndDequantizeV2", |
| "QuantizeAndDequantizeV3", |
| "QuantizeAndDequantizeV4", |
| "RFFT", |
| "RFFT2D", |
| "RFFT3D", |
| "RGBToHSV", |
| "RandomShuffle", |
| "RandomStandardNormal", |
| "RandomUniform", |
| "RandomUniformInt", |
| "ReadVariableOp", |
| "ReadVariableXlaSplitND", |
| "ResizeBilinear", |
| "ResizeBilinearGrad", |
| "ResizeNearestNeighbor", |
| "ResourceApplyAdaMax", |
| "ResourceApplyAdadelta", |
| "ResourceApplyAdagrad", |
| "ResourceApplyAdagradDA", |
| "ResourceApplyAdagradV2", |
| "ResourceApplyAdam", |
| "ResourceApplyAddSign", |
| "ResourceApplyCenteredRMSProp", |
| "ResourceApplyFtrl", |
| "ResourceApplyFtrlV2", |
| "ResourceApplyGradientDescent", |
| "ResourceApplyKerasMomentum", |
| "ResourceApplyMomentum", |
| "ResourceApplyPowerSign", |
| "ResourceApplyProximalAdagrad", |
| "ResourceApplyProximalGradientDescent", |
| "ResourceApplyRMSProp", |
| "ResourceGather", |
| "ResourceScatterAdd", |
| "ResourceScatterDiv", |
| "ResourceScatterMax", |
| "ResourceScatterMin", |
| "ResourceScatterMul", |
| "ResourceScatterNdAdd", |
| "ResourceScatterNdSub", |
| "ResourceScatterNdUpdate", |
| "ResourceScatterSub", |
| "ResourceScatterUpdate", |
| "RngReadAndSkip", |
| "RngSkip", |
| "Roll", |
| "ScatterNd", |
| "SelfAdjointEigV2", |
| "SoftmaxCrossEntropyWithLogits", |
| "SpaceToBatch", |
| "SpaceToBatchND", |
| "SpaceToDepth", |
| "SparseMatMul", |
| "SparseToDense", |
| "StackCloseV2", |
| "StackPopV2", |
| "StackPushV2", |
| "StackV2", |
| "StatefulPartitionedCall", |
| "StatefulStandardNormalV2", |
| "StatefulTruncatedNormal", |
| "StatefulUniform", |
| "StatefulUniformFullInt", |
| "StatefulUniformInt", |
| "StatelessCase", |
| "StatelessIf", |
| "StatelessMultinomial", |
| "StatelessParameterizedTruncatedNormal", |
| "StatelessRandomGetAlg", |
| "StatelessRandomGetKeyCounter", |
| "StatelessRandomGetKeyCounterAlg", |
| "StatelessRandomNormal", |
| "StatelessRandomNormalV2", |
| "StatelessRandomUniform", |
| "StatelessRandomUniformV2", |
| "StatelessRandomUniformInt", |
| "StatelessRandomUniformIntV2", |
| "StatelessRandomUniformFullInt", |
| "StatelessRandomUniformFullIntV2", |
| "StatelessTruncatedNormal", |
| "StatelessTruncatedNormalV2", |
| "StatelessWhile", |
| "Svd", |
| "SymbolicGradient", |
| "TensorArrayCloseV3", |
| "TensorArrayConcatV3", |
| "TensorArrayGatherV3", |
| "TensorArrayGradV3", |
| "TensorArrayReadV3", |
| "TensorArrayScatterV3", |
| "TensorArraySizeV3", |
| "TensorArraySplitV3", |
| "TensorArrayV3", |
| "TensorArrayWriteV3", |
| "TensorListConcatV2", |
| "TensorListElementShape", |
| "TensorListFromTensor", |
| "TensorListGather", |
| "TensorListGetItem", |
| "TensorListLength", |
| "TensorListPopBack", |
| "TensorListPushBack", |
| "TensorListReserve", |
| "TensorListSetItem", |
| "TensorListSplit", |
| "TensorListStack", |
| "TensorScatterAdd", |
| "TensorScatterMax", |
| "TensorScatterMin", |
| "TensorScatterSub", |
| "TensorScatterUpdate", |
| "ToBool", |
| "TridiagonalSolve", |
| "TridiagonalMatMul", |
| "TruncatedNormal", |
| "Unique", |
| "UniqueV2", |
| "UpperBound", |
| "UnsortedSegmentMax", |
| "UnsortedSegmentMin", |
| "UnsortedSegmentProd", |
| "UnsortedSegmentSum", |
| "VarIsInitializedOp", |
| "VariableShape", |
| "Where", |
| "While", |
| "XlaBroadcastHelper", |
| "XlaCallModule", |
| "XlaConcatND", |
| "XlaConv", |
| "XlaConvV2", |
| "XlaCustomCall", |
| "XlaDequantize", |
| "XlaDot", |
| "XlaDotV2", |
| "XlaDynamicSlice", |
| "XlaDynamicUpdateSlice", |
| "XlaEinsum", |
| "XlaGather", |
| "XlaIf", |
| "XlaKeyValueSort", |
| "XlaOptimizationBarrier", |
| "XlaPad", |
| "XlaRecv", |
| "XlaReduce", |
| "XlaReducePrecision", |
| "XlaReduceWindow", |
| "XlaRemoveDynamicDimensionSize", |
| "XlaReplicaId", |
| "XlaRngBitGenerator", |
| "XlaScatter", |
| "XlaSelectAndScatter", |
| "XlaSelfAdjointEig", |
| "XlaSend", |
| "XlaSetBound", |
| "XlaSetDynamicDimensionSize", |
| "XlaSharding", |
| "XlaSort", |
| "XlaSplitND", |
| "XlaSpmdFullToShardShape", |
| "XlaSpmdShardToFullShape", |
| "XlaSvd", |
| "XlaVariadicReduce", |
| "XlaVariadicReduceV2", |
| "XlaVariadicSort", |
| "XlaWhile", |
| "Zeta", |
| "_Arg", |
| "_ArrayToList", |
| "_ListToArray", |
| "_Retval"}; |
| return result; |
| } |
| |
| } // namespace testing |
| } // namespace tensorflow |