| /* Copyright 2020 The TensorFlow Authors. All Rights Reserved. |
| |
| Licensed under the Apache License, Version 2.0 (the "License"); |
| you may not use this file except in compliance with the License. |
| You may obtain a copy of the License at |
| |
| http://www.apache.org/licenses/LICENSE-2.0 |
| |
| Unless required by applicable law or agreed to in writing, software |
| distributed under the License is distributed on an "AS IS" BASIS, |
| WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
| See the License for the specific language governing permissions and |
| limitations under the License. |
| ==============================================================================*/ |
| |
| #include "llvm/ADT/STLExtras.h" |
| #include "llvm/ADT/SmallSet.h" |
| #include "llvm/ADT/SmallVector.h" |
| #include "llvm/ADT/StringRef.h" |
| #include "llvm/Support/FormatVariadic.h" |
| #include "mlir/IR/Attributes.h" // from @llvm-project |
| #include "mlir/IR/Operation.h" // from @llvm-project |
| #include "mlir/IR/Types.h" // from @llvm-project |
| #include "mlir/Support/LLVM.h" // from @llvm-project |
| #include "mlir/Support/LogicalResult.h" // from @llvm-project |
| #include "tensorflow/compiler/mlir/tensorflow/analysis/side_effect_analysis.h" |
| #include "tensorflow/compiler/mlir/tensorflow/ir/tf_device.h" |
| #include "tensorflow/compiler/mlir/tensorflow/transforms/passes.h" |
| |
| namespace mlir { |
| namespace TFTPU { |
| |
| namespace { |
| |
| constexpr char kXlaOutsideCompilationAttr[] = "_xla_outside_compilation"; |
| |
| struct TPUOutsideCompilationCluster |
| : public TF::PerFunctionAggregateAnalysisConsumerPass< |
| TPUOutsideCompilationCluster, TF::SideEffectAnalysis> { |
| void runOnFunction(FuncOp func, |
| const TF::SideEffectAnalysis::Info& side_effect_analysis); |
| }; |
| |
| bool IsVariant(Value value) { |
| return getElementTypeOrSelf(value.getType()).isa<TF::VariantType>(); |
| } |
| |
| bool HasOutsideCompiledAncestor(Operation* op) { |
| Operation* parent = op->getParentOp(); |
| while (parent) { |
| if (parent->getAttrOfType<StringAttr>(kXlaOutsideCompilationAttr)) |
| return true; |
| parent = parent->getParentOp(); |
| } |
| return false; |
| } |
| |
| // Represents an outside compiled cluster. All ops that are added to the same |
| // cluster will be extracted together in a later pass. |
| class OutsideCompiledCluster { |
| public: |
| explicit OutsideCompiledCluster(int number) |
| : cluster_name_(llvm::formatv("cluster{0}", number).str()) {} |
| |
| // Attempts to add an op to this cluster. Ops can be grouped to the same |
| // cluster if they have data dependency and are inside the same block. |
| bool AddOp(Operation* op, |
| const TF::SideEffectAnalysis::Info& side_effect_analysis) { |
| // Check if the op is safe to add before adding it. |
| if (IsSafeToAdd(op, side_effect_analysis)) { |
| op->setAttr(kXlaOutsideCompilationAttr, |
| StringAttr::get(cluster_name_, op->getContext())); |
| host_cluster_ops_.insert(op); |
| return true; |
| } |
| return false; |
| } |
| |
| // If any tf.variants are inputs/outputs to the cluster, add them to the |
| // cluster unless they are already marks with outside compilation attribute. |
| bool AddVariantInputsOutputs() { |
| bool added_op = false; |
| llvm::SmallPtrSet<Operation*, 8> expanded_cluster_ops(host_cluster_ops_); |
| for (Operation* cluster_op : host_cluster_ops_) { |
| // Walk the clustered operations to handle nested ops. |
| cluster_op->walk([&](Operation* op) { |
| // Add any operations that provide variant inputs to the cluster. |
| for (auto value : op->getOperands()) { |
| auto input_defining_op = value.getDefiningOp(); |
| if (IsVariant(value) && input_defining_op && |
| !HasOutsideCompiledAncestor(input_defining_op) && |
| !input_defining_op->getAttrOfType<StringAttr>( |
| kXlaOutsideCompilationAttr)) { |
| expanded_cluster_ops.insert(input_defining_op); |
| input_defining_op->setAttr( |
| kXlaOutsideCompilationAttr, |
| StringAttr::get(cluster_name_, |
| input_defining_op->getContext())); |
| added_op = true; |
| } |
| } |
| // Add any operations that consume variant outputs to the cluster. |
| for (auto value : op->getResults()) { |
| if (IsVariant(value)) { |
| for (auto user : value.getUsers()) { |
| if (!host_cluster_ops_.contains(user) && |
| !HasOutsideCompiledAncestor(user) && |
| !user->getAttrOfType<StringAttr>( |
| kXlaOutsideCompilationAttr)) { |
| expanded_cluster_ops.insert(user); |
| user->setAttr( |
| kXlaOutsideCompilationAttr, |
| StringAttr::get(cluster_name_, user->getContext())); |
| added_op = true; |
| } |
| } |
| } |
| } |
| }); |
| } |
| host_cluster_ops_.swap(expanded_cluster_ops); |
| |
| return added_op; |
| } |
| |
| private: |
| // TODO(hinsu): Consider using GraphCycles data structure available in xla |
| // directory to avoid potentially full traversal for each new op and cluster |
| // pair. |
| // Checks if it is safe for `op` to be merged into this cluster. |
| bool IsSafeToAdd(Operation* op, |
| const TF::SideEffectAnalysis::Info& side_effect_analysis) { |
| if (host_cluster_ops_.empty()) return true; |
| |
| // If there is an intermediate data or side effect dependency between the op |
| // and ops in the cluster, it's not safe to add. |
| std::vector<Operation*> dependencies; |
| |
| // Materialize data dependencies as the llvm::concat doesn't support |
| // non-materialized iteration. |
| auto data_deps = llvm::to_vector<4>(op->getUsers()); |
| llvm::SmallVector<Operation*, 4> control_deps = |
| side_effect_analysis.DirectControlSuccessors(op); |
| for (auto* dep : llvm::concat<Operation*>(data_deps, control_deps)) { |
| if (!host_cluster_ops_.contains(dep)) dependencies.push_back(dep); |
| } |
| |
| llvm::SmallPtrSet<Operation*, 4> visited; |
| while (!dependencies.empty()) { |
| Operation* next_op = dependencies.back(); |
| dependencies.pop_back(); |
| if (visited.count(next_op)) continue; |
| visited.insert(next_op); |
| |
| auto data_deps = llvm::to_vector<4>(next_op->getUsers()); |
| llvm::SmallVector<Operation*, 4> control_deps = |
| side_effect_analysis.DirectControlSuccessors(next_op); |
| for (auto* dep : llvm::concat<Operation*>(data_deps, control_deps)) { |
| if (host_cluster_ops_.contains(dep)) return false; |
| dependencies.push_back(dep); |
| } |
| } |
| |
| return true; |
| } |
| |
| // `host_cluster_op_` stores a set of ops that will be grouped and computed |
| // on host as single XlaHostCompute op. An outside compiled op can be grouped |
| // to a single cluster if it has data dependency to another op already in the |
| // cluster. |
| llvm::SmallPtrSet<Operation*, 8> host_cluster_ops_; |
| std::string cluster_name_; |
| }; |
| |
| void TPUOutsideCompilationCluster::runOnFunction( |
| FuncOp func, const TF::SideEffectAnalysis::Info& side_effect_analysis) { |
| llvm::SmallVector<OutsideCompiledCluster, 8> clusters; |
| int cluster_counter = 0; |
| |
| func.walk([&](tf_device::ClusterOp tpu_cluster) { |
| llvm::SmallVector<Operation*, 4> outside_ops; |
| tpu_cluster.walk([&](Operation* op) { |
| if (op->getAttrOfType<StringAttr>(kXlaOutsideCompilationAttr)) |
| outside_ops.emplace_back(op); |
| }); |
| |
| // In order to cluster ops feeding results to the same operation, traverse |
| // the ops in reverse order. |
| for (Operation* op : llvm::reverse(outside_ops)) { |
| // Try to add the op to existing clusters. |
| bool added = false; |
| for (auto& cluster : clusters) |
| if ((added = cluster.AddOp(op, side_effect_analysis))) break; |
| |
| // If the op cannot be added to existing clusters, create a new cluster. |
| if (!added) { |
| OutsideCompiledCluster new_cluster(cluster_counter++); |
| new_cluster.AddOp(op, side_effect_analysis); |
| clusters.push_back(new_cluster); |
| } |
| } |
| }); |
| for (auto& cluster : clusters) { |
| bool variants_to_add = true; |
| while (variants_to_add) variants_to_add = cluster.AddVariantInputsOutputs(); |
| } |
| } |
| |
| } // anonymous namespace |
| |
| std::unique_ptr<OperationPass<ModuleOp>> |
| CreateTPUOutsideCompilationClusterPass() { |
| return std::make_unique<TPUOutsideCompilationCluster>(); |
| } |
| |
| static PassRegistration<TPUOutsideCompilationCluster> pass( |
| "tf-tpu-outside-compilation-cluster", |
| "Identifies clusters of operations assigned to outside compilation"); |
| |
| } // namespace TFTPU |
| } // namespace mlir |