| /* Copyright 2019 The TensorFlow Authors. All Rights Reserved. |
| |
| Licensed under the Apache License, Version 2.0 (the "License"); |
| you may not use this file except in compliance with the License. |
| You may obtain a copy of the License at |
| |
| http://www.apache.org/licenses/LICENSE-2.0 |
| |
| Unless required by applicable law or agreed to in writing, software |
| distributed under the License is distributed on an "AS IS" BASIS, |
| WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
| See the License for the specific language governing permissions and |
| limitations under the License. |
| ==============================================================================*/ |
| |
| // This pass lifts resource variable operations outside of device computation. |
| |
| #include <cstddef> |
| #include <cstdint> |
| |
| #include "llvm/ADT/BitVector.h" |
| #include "llvm/ADT/DenseMap.h" |
| #include "llvm/ADT/DenseSet.h" |
| #include "llvm/ADT/MapVector.h" |
| #include "llvm/ADT/STLExtras.h" |
| #include "llvm/ADT/SetVector.h" |
| #include "llvm/ADT/SmallVector.h" |
| #include "llvm/ADT/StringRef.h" |
| #include "llvm/Support/Casting.h" |
| #include "mlir/Dialect/Func/IR/FuncOps.h" // from @llvm-project |
| #include "mlir/IR/Attributes.h" // from @llvm-project |
| #include "mlir/IR/Block.h" // from @llvm-project |
| #include "mlir/IR/BlockAndValueMapping.h" // from @llvm-project |
| #include "mlir/IR/Builders.h" // from @llvm-project |
| #include "mlir/IR/BuiltinOps.h" // from @llvm-project |
| #include "mlir/IR/BuiltinTypes.h" // from @llvm-project |
| #include "mlir/IR/Diagnostics.h" // from @llvm-project |
| #include "mlir/IR/Operation.h" // from @llvm-project |
| #include "mlir/IR/Region.h" // from @llvm-project |
| #include "mlir/IR/SymbolTable.h" // from @llvm-project |
| #include "mlir/IR/TypeUtilities.h" // from @llvm-project |
| #include "mlir/IR/Types.h" // from @llvm-project |
| #include "mlir/IR/Value.h" // from @llvm-project |
| #include "mlir/IR/Verifier.h" // from @llvm-project |
| #include "mlir/IR/Visitors.h" // from @llvm-project |
| #include "mlir/Pass/Pass.h" // from @llvm-project |
| #include "mlir/Support/LLVM.h" // from @llvm-project |
| #include "mlir/Support/LogicalResult.h" // from @llvm-project |
| #include "mlir/Transforms/RegionUtils.h" // from @llvm-project |
| #include "tensorflow/compiler/mlir/tensorflow/analysis/resource_alias_analysis.h" |
| #include "tensorflow/compiler/mlir/tensorflow/ir/tf_device.h" |
| #include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h" |
| #include "tensorflow/compiler/mlir/tensorflow/ir/tf_types.h" |
| #include "tensorflow/compiler/mlir/tensorflow/transforms/passes.h" |
| #include "tensorflow/compiler/mlir/tensorflow/transforms/resource_op_lifting_cleanup.h" |
| #include "tensorflow/compiler/mlir/tensorflow/transforms/tf_device_passes_detail.h" |
| #include "tensorflow/compiler/mlir/tensorflow/utils/convert_tensor.h" |
| #include "tensorflow/compiler/mlir/tensorflow/utils/convert_type.h" |
| #include "tensorflow/compiler/mlir/tensorflow/utils/mangling_util.h" |
| #include "tensorflow/core/framework/tensor_shape.pb.h" |
| |
| namespace mlir { |
| |
| namespace { |
| |
| constexpr char kDeviceAttr[] = "device"; |
| |
| // Lift resource operations out of device computation. |
| struct ResourceOpLiftingPass |
| : public TFDevice::ResourceOpLiftingPassBase<ResourceOpLiftingPass> { |
| void runOnOperation() override; |
| }; |
| |
| bool IsResource(Value value) { |
| return getElementTypeOrSelf(value.getType()).isa<TF::ResourceType>(); |
| } |
| |
| // Get the type of the data contained in a resource. Returns null if there is |
| // no single type in the resource. |
| Type GetResourceSubtype(Value value) { |
| auto resource_type = |
| getElementTypeOrSelf(value.getType()).dyn_cast<TF::ResourceType>(); |
| auto subtypes = resource_type.getSubtypes(); |
| if (subtypes.size() == 1) return subtypes[0]; |
| return nullptr; |
| } |
| |
| // Replaces all `tf.VarIsInitializedOp` in a block with a constant true. |
| // TODO(b/171039585): Replace this with proper analysis of |
| // `tf.VarIsInitializedOp` in regards to resource writes and control flow. |
| void SetAllVarIsInitializedToTrue(Block* block) { |
| auto builder = OpBuilder::atBlockBegin(block); |
| TF::ConstOp const_true = nullptr; |
| for (auto op : |
| llvm::make_early_inc_range(block->getOps<TF::VarIsInitializedOp>())) { |
| builder.setInsertionPoint(op); |
| if (!const_true) |
| const_true = builder.create<TF::ConstOp>( |
| op.getLoc(), |
| DenseIntElementsAttr::get( |
| RankedTensorType::get(/*shape=*/{}, builder.getI1Type()), true)); |
| |
| op.is_initialized().replaceAllUsesWith(const_true); |
| op.erase(); |
| } |
| } |
| |
| // Performs store-load forwarding. This effectively removes |
| // 1) Any resource loads after a store to that same resource is done |
| // 2) Any resource stores except the last one. |
| // TODO(ycao): Store-load forwarding implemented here is only correct when |
| // computation is purely sequential (no concurrency). Need to support concurrent |
| // computation as well. |
| void ForwardStoreToLoad(Block* block) { |
| // resource_handle_to_last_store_op keeps track of the most recent (last) |
| // store to each resource. Non-existent entry indicates that a resource has |
| // not been stored to yet. |
| llvm::SmallDenseMap<Value, TF::AssignVariableOp> |
| resource_handle_to_last_store_op; |
| |
| // Only iterate through ops directly in the block as we can't handle ops |
| // nested deeper in regions. |
| for (Operation& op : llvm::make_early_inc_range(*block)) { |
| if (auto read_variable_op = dyn_cast<TF::ReadVariableOp>(&op)) { |
| Value resource = read_variable_op.resource(); |
| auto last_store = resource_handle_to_last_store_op[resource]; |
| if (!last_store) continue; |
| |
| // Use stored value in last_store to replace all uses of current resource |
| // load's result, then erase this resource load. Add an intermediate |
| // CastOp if the shape of types doesn't exactly match. |
| Type read_type = read_variable_op.value().getType(); |
| if (read_type != last_store.value().getType()) { |
| OpBuilder builder(last_store); |
| builder.setInsertionPointAfter(last_store); |
| auto cast = builder.create<TF::CastOp>( |
| last_store.getLoc(), read_type, last_store.value(), |
| /*Truncate=*/builder.getBoolAttr(false)); |
| read_variable_op.value().replaceAllUsesWith(cast); |
| } else { |
| read_variable_op.value().replaceAllUsesWith(last_store.value()); |
| } |
| |
| read_variable_op.erase(); |
| continue; |
| } |
| |
| if (auto assign_variable_op = dyn_cast<TF::AssignVariableOp>(&op)) { |
| Value resource = assign_variable_op.resource(); |
| auto last_store = resource_handle_to_last_store_op[resource]; |
| // Previous store ops to same resource can be erased. |
| if (last_store) last_store.erase(); |
| |
| resource_handle_to_last_store_op[resource] = assign_variable_op; |
| } |
| } |
| } |
| |
| //===----------------------------------------------------------------------===// |
| // RegionResourceHoister |
| //===----------------------------------------------------------------------===// |
| |
| // Helper class to hoist resource ops out of regions attached to an op. |
| class RegionResourceHoister { |
| public: |
| explicit RegionResourceHoister(Operation* op) : op_(op) {} |
| |
| // Analyzes attached regions to record resources read and written. |
| LogicalResult Analyze(); |
| |
| // Returns all resources accessed by the regions attached the op. |
| auto& GetResources() { return resources_; } |
| |
| // Returns if the given value is a resource that needs lifting. |
| bool Contains(Value resource) const { |
| return resources_.find(resource) != resources_.end(); |
| } |
| |
| // Drops the given resource from lifting. |
| void DropResource(Value resource) { |
| resources_.erase(resource); |
| written_resources_.remove(resource); |
| } |
| |
| // Replaces all resource loads in all regions attached to the op. |
| void ReplaceResourceLoads(bool read_only) { |
| llvm::for_each(op_->getRegions(), [&](Region& region) { |
| ReplaceResourceLoads(region, read_only); |
| }); |
| } |
| |
| static LogicalResult ReplaceOpWithNewOp(Operation* op); |
| |
| private: |
| // Returns if any resources need lifting. |
| bool NeedsLifting() const { return !resources_.empty(); } |
| |
| // Returns the number of results generated by the lifted op. |
| int GetLiftedNumResults() const { return num_new_results_; } |
| |
| // Generates hoisted reads for resources that need them before the op. |
| void GenerateHoistedReads(); |
| |
| // Replaces all resource loads in the given region with hoisted loads. If |
| // `read_only` is true, limit this replacement to read only resources. |
| void ReplaceResourceLoads(Region& region, bool read_only); |
| |
| // Appends final values writte to resources to the region returns for the |
| // given set of regions. |
| void AppendResourceStoreValueToReturn(RegionRange regions); |
| |
| // Performs the final replacement of the op. |
| void ReplaceOpWithNewOp(); |
| |
| // Returns is this resource was written to in any of the regions. |
| bool IsWritten(Value resource) const { |
| return written_resources_.contains(resource); |
| } |
| |
| static LogicalResult HoistResourcesOutOfIfCaseCluster(Operation* op); |
| static LogicalResult HoistResourcesOutOfWhileRegion(TF::WhileRegionOp op); |
| |
| Operation* op_; |
| |
| // Per resource information about accesses to that resource. |
| struct ResourceInfo { |
| // Is this resource read in any of the regions? |
| bool is_read; |
| // Is this resource written in any of the regions? |
| bool is_written; |
| // Is this resource written in all of the regions? |
| bool is_written_all; |
| // The hoisted read used to replace region reads. |
| Value hoisted_read; |
| // the type of the data held by the resource. |
| Type data_type; |
| // For written resources, the result # of the lifted op which will hold the |
| // value of the resource. This result will be used to generates writes to |
| // the resource after the lifted op. |
| int result_index; |
| // Attributes on the read operation. |
| DictionaryAttr read_attrs; |
| // Attributes on the write operation. |
| DictionaryAttr write_attrs; |
| |
| ResourceInfo() |
| : is_read(false), |
| is_written(false), |
| is_written_all(false), |
| hoisted_read(nullptr), |
| data_type(nullptr), |
| result_index(-1) {} |
| |
| bool IsResultIndexAssigned() { return result_index != -1; } |
| |
| // Refine the resource type using the given type `type`. |
| void RefineType(Type type) { |
| if (!data_type) { |
| data_type = type; |
| } else { |
| data_type = TF::GetCastCompatibleType(data_type, type, |
| /*may_ignore_ref_type_a=*/false); |
| assert(data_type != nullptr && "Resource used with incompatible types"); |
| } |
| } |
| }; |
| llvm::MapVector<Value, ResourceInfo> resources_; |
| llvm::SetVector<Value> written_resources_; |
| // number of new results after lifting. |
| int num_new_results_; |
| }; |
| |
| // Analyzes resources that are read or written within attached regions. |
| LogicalResult RegionResourceHoister::Analyze() { |
| // Hoisting of child regions might have created opportunity for store-load |
| // forwarding. |
| for (Region& region : op_->getRegions()) { |
| ForwardStoreToLoad(®ion.front()); |
| } |
| |
| llvm::SetVector<Value> all_resources; |
| bool is_func = false; |
| // For functions, the resources to analyze are the function arguments. |
| // Otherwise, its the region captures. |
| if (func::FuncOp func = dyn_cast<func::FuncOp>(op_)) { |
| is_func = true; |
| Region& body = func.getBody(); |
| for (BlockArgument arg : body.getArguments()) { |
| if (IsResource(arg)) all_resources.insert(arg); |
| } |
| } else { |
| getUsedValuesDefinedAbove(op_->getRegions(), all_resources); |
| all_resources.remove_if([](Value value) { return !IsResource(value); }); |
| } |
| |
| num_new_results_ = op_->getNumResults(); |
| |
| for (auto resource : all_resources) { |
| ResourceInfo info; |
| info.data_type = GetResourceSubtype(resource); |
| llvm::BitVector written_regions(op_->getNumRegions()); |
| bool unsupported_use = false; |
| for (OpOperand& use : resource.getUses()) { |
| Operation* user = use.getOwner(); |
| // If the user is not in one of the regions, we are not interested in it. |
| // Since all the sub-regions within this region (i.e., regions attached to |
| // op's in this region) have themselves gone through lifting, all resource |
| // users are expected to be operations in this region and not embedded |
| // within other sub-regions attached to op's in this region. So the check |
| // for whether a user is in one of the regions attached to this op is |
| // straightforward. |
| if (user->getParentRegion()->getParentOp() != op_) continue; |
| |
| // For functions, if the resource is used as a return operand, use that |
| // as its result index. |
| if (is_func && isa<func::ReturnOp>(user)) { |
| assert(!info.IsResultIndexAssigned() && |
| "Expect resource argument to returned no more than once"); |
| info.result_index = use.getOperandNumber(); |
| continue; |
| } |
| |
| auto read = dyn_cast<TF::ReadVariableOp>(user); |
| auto write = dyn_cast<TF::AssignVariableOp>(user); |
| if (!read && !write) { |
| unsupported_use = true; |
| break; |
| } |
| |
| if (read && !info.is_read) { |
| info.is_read = true; |
| info.RefineType(read.value().getType()); |
| info.read_attrs = user->getAttrDictionary(); |
| } |
| |
| if (write) { |
| info.is_written = true; |
| info.RefineType(write.value().getType()); |
| info.write_attrs = user->getAttrDictionary(); |
| written_regions.set(user->getParentRegion()->getRegionNumber()); |
| } |
| } |
| |
| // If the resource is used in an op that we do not understand, skip |
| // lifting for that resource. |
| if (unsupported_use) continue; |
| |
| info.is_written_all = written_regions.count() == op_->getNumRegions(); |
| |
| // If the resource is written in some but not all regions, we would need |
| // a read for the value before these regions. Note that this is applicable |
| // only to multi-region ops: |
| // If/Case: If not all regions write to the resource, post hoisting the read |
| // value need to be routed through all paths that don't write. |
| // While: since while condition cannot write, any resource written in the |
| // while body will need to be read as well in case the while body is never |
| // executed. |
| // Both cases are handled by the condition below. |
| if (info.is_written && !info.is_written_all) info.is_read = true; |
| |
| // Allocate a result index for written resources that don't have one. |
| if (info.is_written) { |
| written_resources_.insert(resource); |
| if (!info.IsResultIndexAssigned()) info.result_index = num_new_results_++; |
| } |
| |
| resources_.insert({resource, info}); |
| } |
| return success(); |
| } |
| |
| // Generates hoisted reads for all resources that need them just before the op. |
| void RegionResourceHoister::GenerateHoistedReads() { |
| OpBuilder builder(op_); |
| DictionaryAttr empty_attrs = builder.getDictionaryAttr({}); |
| for (auto& resource_it : GetResources()) { |
| Value resource = resource_it.first; |
| auto& info = resource_it.second; |
| |
| if (info.is_read) { |
| Operation* read = builder.create<TF::ReadVariableOp>( |
| op_->getLoc(), info.data_type, resource); |
| read->setAttrs(info.read_attrs ? info.read_attrs : empty_attrs); |
| read->removeAttr(kDeviceAttr); |
| info.hoisted_read = read->getResult(0); |
| } |
| } |
| } |
| |
| // Replaces all resource reads with the hoisted read. |
| void RegionResourceHoister::ReplaceResourceLoads(Region& region, |
| bool read_only) { |
| assert(llvm::hasSingleElement(region) && "Expected single block region"); |
| // Only iterate through ops directly in the body as we can't handle |
| // ops nested deeper in regions. |
| auto all_reads = region.front().getOps<TF::ReadVariableOp>(); |
| for (auto read_op : llvm::make_early_inc_range(all_reads)) { |
| Value resource = read_op.resource(); |
| if (!Contains(resource)) continue; |
| |
| ResourceInfo& info = resources_[resource]; |
| // If replacing loads for read only resources, skip if the resource |
| // was written to. |
| if (read_only && info.is_written) continue; |
| |
| read_op.replaceAllUsesWith(info.hoisted_read); |
| read_op.erase(); |
| } |
| } |
| |
| // For written resources, add its value at the end of each region to that |
| // regions return value. For a region, its value at the end may be a value |
| // written to that resource in that region, or its hoisted read value if the |
| // resource is not written in that region. The return value can be vended out |
| // either as an existing return value, or a newly allocated return value. |
| void RegionResourceHoister::AppendResourceStoreValueToReturn( |
| RegionRange regions) { |
| for (Region* region : regions) { |
| assert(llvm::hasSingleElement(*region) && "Expected single block region"); |
| Block& front = region->front(); |
| auto old_return = front.getTerminator(); |
| assert(old_return->getNumOperands() == op_->getNumResults()); |
| auto new_return_operands = llvm::to_vector<4>(old_return->getOperands()); |
| new_return_operands.resize(num_new_results_); |
| |
| // initialize return values for written resources to be the hoisted reads. |
| for (Value resource : written_resources_) { |
| const ResourceInfo& info = resources_[resource]; |
| new_return_operands[info.result_index] = info.hoisted_read; |
| } |
| |
| // Only iterate through ops directly in the body as op's embedded in child |
| // regions should have been lifted out. |
| auto assign_ops = front.getOps<TF::AssignVariableOp>(); |
| for (auto assign_variable_op : llvm::make_early_inc_range(assign_ops)) { |
| Value resource = assign_variable_op.resource(); |
| if (!IsWritten(resource)) continue; |
| |
| // TODO(ycao): Prevent same value from being returned multiple times. |
| // TODO(ycao): Do not return resource store value if it is defined outside |
| // of cluster. Both of these can be post-resource-op-lifting cleanup |
| // passes. |
| int result_index = resources_[resource].result_index; |
| new_return_operands[result_index] = assign_variable_op.value(); |
| assign_variable_op.erase(); |
| } |
| old_return->setOperands(new_return_operands); |
| } |
| } |
| |
| // Replace the old op with a new op (with potentially additional results), and |
| // add stores to written resources after the new op. |
| void RegionResourceHoister::ReplaceOpWithNewOp() { |
| auto new_result_types = llvm::to_vector<4>(op_->getResultTypes()); |
| int result_region = isa<TF::WhileRegionOp>(op_) ? 1 : 0; |
| Operation* terminator = op_->getRegion(result_region).front().getTerminator(); |
| auto extra_result_types = |
| terminator->getOperands().drop_front(op_->getNumResults()).getTypes(); |
| new_result_types.insert(new_result_types.end(), extra_result_types.begin(), |
| extra_result_types.end()); |
| OpBuilder builder(op_); |
| // Clone this old operation but with new result types. |
| Operation* new_op = Operation::create( |
| op_->getLoc(), op_->getName(), new_result_types, op_->getOperands(), |
| op_->getAttrs(), op_->getSuccessors(), op_->getNumRegions()); |
| builder.insert(new_op); |
| |
| // Move regions to the new op. |
| for (auto it : llvm::zip(op_->getRegions(), new_op->getRegions())) { |
| Region& old_region = std::get<0>(it); |
| Region& new_region = std::get<1>(it); |
| new_region.takeBody(old_region); |
| } |
| |
| // Insert stores to all written resources. |
| for (Value resource : written_resources_) { |
| ResourceInfo& info = resources_[resource]; |
| Value value_to_write = new_op->getResult(info.result_index); |
| Operation* write = builder.create<TF::AssignVariableOp>( |
| op_->getLoc(), resource, value_to_write); |
| write->setAttrs(info.write_attrs); |
| write->removeAttr(kDeviceAttr); |
| } |
| |
| // As a part of lifting, we either reuse an existing slot for resource type |
| // results or add a new slot. Resource type results should not have any uses |
| // to begin with. So we can safely replace each old op result with the |
| // corresponding new op result. |
| int old_num_results = op_->getNumResults(); |
| op_->replaceAllUsesWith(new_op->getResults().take_front(old_num_results)); |
| op_->erase(); |
| op_ = nullptr; |
| } |
| |
| // Lift resource load and stores out of regions attached to `op`, where op is |
| // an If/case/cluster op. |
| LogicalResult RegionResourceHoister::HoistResourcesOutOfIfCaseCluster( |
| Operation* op) { |
| RegionResourceHoister hoister(op); |
| if (failed(hoister.Analyze())) return failure(); |
| |
| // If there are no resource region captures, then nothing to do. |
| if (!hoister.NeedsLifting()) return success(); |
| |
| // Start the transformation. For each region, replace the resource read with |
| // the value read before the op. |
| hoister.GenerateHoistedReads(); |
| hoister.ReplaceResourceLoads(/*read_only=*/false); |
| hoister.AppendResourceStoreValueToReturn(op->getRegions()); |
| hoister.ReplaceOpWithNewOp(); |
| return success(); |
| } |
| |
| // Lift resource loads and stores out of WhileRegion |
| LogicalResult RegionResourceHoister::HoistResourcesOutOfWhileRegion( |
| TF::WhileRegionOp op) { |
| // For WhileRegion, post canonicalization all resource used within the |
| // body and condition regions are replaced with captured values, so we do not |
| // need to take into account the body and condition region arguments. |
| RegionResourceHoister hoister(op); |
| |
| if (failed(hoister.Analyze())) return failure(); |
| |
| // If there are no resource region captures, then nothing to do. |
| if (!hoister.NeedsLifting()) return success(); |
| |
| // The resources captured for While loop fall into two categories: |
| // (a) read-only. These reads can be replaced by a hoisted read created |
| // before the WhileOp (similar to if and case). |
| // (b) written: since the value is written in the loop (which can only in |
| // loop body, all these will become loop variables. Since all resource |
| // variables are removed from the loop variabled during |
| // canonicalizationW, we need to create new operand/result slots. The |
| // input operands for these slots are the read values |
| // prior to the op, and all references to these are replaced by the |
| // corresponding slot argument. We need to generate writes following |
| // the while for these resources. |
| // |
| // Note that for WhileRegion ops, if a resource is written, it will be written |
| // only in the body and not the condition, so the hoister analysis will infer |
| // it as needing a read as well. |
| |
| // Generate hoisted reads before the while. |
| hoister.GenerateHoistedReads(); |
| |
| // Replace just the read-only resources with the hoisted reads. |
| hoister.ReplaceResourceLoads(/*read_only=*/true); |
| |
| // For written resources, add additional operands to the while op. |
| int num_old_results = op.getNumResults(); |
| int num_new_results = hoister.GetLiftedNumResults(); |
| int num_extra_results = num_new_results - num_old_results; |
| |
| SmallVector<Type, 4> new_result_types; |
| SmallVector<Value, 4> new_while_operands; |
| new_result_types.resize(num_extra_results); |
| new_while_operands.resize(num_extra_results); |
| |
| for (auto& it : hoister.GetResources()) { |
| if (!it.second.is_written) continue; |
| int index = it.second.result_index - num_old_results; |
| new_result_types[index] = it.second.data_type; |
| new_while_operands[index] = it.second.hoisted_read; |
| } |
| op.getOperation()->insertOperands(op.getNumOperands(), new_while_operands); |
| |
| // Patch the cond and body regions to have additional arguments, and replace |
| // the remaining resource reads (which will be resource reads for written |
| // resources) with these arguments. |
| Location loc = op.getLoc(); |
| for (Region* region : op.getRegions()) { |
| region->addArguments(new_result_types, |
| SmallVector<Location>(new_result_types.size(), loc)); |
| // Point hoisted read for written resources to the region's arguments. |
| for (auto& it : hoister.GetResources()) { |
| if (!it.second.is_written) continue; |
| it.second.hoisted_read = region->getArgument(it.second.result_index); |
| } |
| hoister.ReplaceResourceLoads(*region, /*read_only=*/false); |
| } |
| |
| // Add additional return values to body return. These correspond to values |
| // written to resources in the body region. |
| hoister.AppendResourceStoreValueToReturn(op.getRegions().drop_front()); |
| |
| // Finally, create a new while with additional return values. |
| hoister.ReplaceOpWithNewOp(); |
| return success(); |
| } |
| |
| // Lift resources out of the regions attached to `op` |
| LogicalResult RegionResourceHoister::ReplaceOpWithNewOp(Operation* op) { |
| if (auto while_op = dyn_cast<TF::WhileRegionOp>(op)) |
| return HoistResourcesOutOfWhileRegion(while_op); |
| return HoistResourcesOutOfIfCaseCluster(op); |
| } |
| |
| // Holds information about a function's use of a resource argument. |
| struct ResourceArgUseInfo { |
| // Data type of the data contained in the resource. |
| Type data_type; |
| // Is the resource argument used in an assign op? |
| bool updated; |
| // Is the resource argument used in a read or assign op? |
| bool used; |
| }; |
| |
| // Finds the ResourceArgUseInfo for each resource argument. Forwarding to the |
| // output (i.e., the argument is an operand of the return op) is not considered |
| // as a use. This doesn't support nesting of ops, so before calling this, nested |
| // ops/functions need to be already resource-lifted. |
| LogicalResult FindResourceArgUseInfo( |
| func::FuncOp func_op, |
| llvm::SmallDenseMap<int64_t, ResourceArgUseInfo>* result) { |
| auto return_op = func_op.front().getTerminator(); |
| for (auto arg : TF::filter_resources(func_op.getArguments())) { |
| ResourceArgUseInfo info; |
| info.used = false; |
| info.updated = false; |
| bool read_or_assigned = false; |
| bool used_in_unsupported_op = false; |
| for (auto user : arg.getUsers()) { |
| if (user == return_op) continue; |
| info.used = true; |
| if (auto read = llvm::dyn_cast<TF::ReadVariableOp>(user)) { |
| read_or_assigned = true; |
| info.data_type = read.getType(); |
| continue; |
| } |
| |
| if (auto assign = llvm::dyn_cast<TF::AssignVariableOp>(user)) { |
| read_or_assigned = true; |
| info.updated = true; |
| info.data_type = assign.value().getType(); |
| continue; |
| } |
| |
| used_in_unsupported_op = true; |
| break; |
| } |
| |
| // If the arg is used in an unsupported op, skip lifting it. |
| if (used_in_unsupported_op) continue; |
| (*result)[arg.getArgNumber()] = info; |
| } |
| return success(); |
| } |
| |
| // Merges two sets of resource arg use infos. An argument is considered used in |
| // the merged result as long as either set marks it as used. This is used to |
| // merge results from functions that have aliasing inputs, e.g., a while loop's |
| // body and condition. The sets of keys of the two maps must be the same. |
| llvm::SmallDenseMap<int64_t, ResourceArgUseInfo> MergeArgResourceUseInfo( |
| const llvm::SmallDenseMap<int64_t, ResourceArgUseInfo>& infos0, |
| const llvm::SmallDenseMap<int64_t, ResourceArgUseInfo>& infos1) { |
| llvm::SmallDenseMap<int64_t, ResourceArgUseInfo> result; |
| for (const auto& entry : infos0) { |
| auto info1_it = infos1.find(entry.getFirst()); |
| // If the entry is missing in any input, we should not touch this entry. |
| if (info1_it == infos1.end()) continue; |
| auto& info = result[entry.getFirst()]; |
| info = entry.getSecond(); |
| if (info.updated) continue; |
| if (info1_it->getSecond().used) { |
| info.used = true; |
| info.updated = info1_it->getSecond().updated; |
| info.data_type = info1_it->getSecond().data_type; |
| } |
| } |
| return result; |
| } |
| |
| // Removes the unused resource arguments, and the return values that forward the |
| // removed arguments. If old_to_new_arg_indices is provided, it will store the |
| // new argument index that corresponds to each original index (-1 means it is |
| // removed). If remaining_resource_data_types is provided, it will store the |
| // data types of the remaining resource arguments, where the indices are after |
| // removing unused ones. |
| void RemoveUnusedResourceArgumentsAndForwardedRetvals( |
| const llvm::SmallDenseMap<int64_t, ResourceArgUseInfo>& infos, |
| func::FuncOp func_op, |
| llvm::SmallVector<int64_t, 4>* old_to_new_arg_indices = nullptr, |
| llvm::SmallDenseMap<int64_t, Type>* remaining_resource_data_types = |
| nullptr) { |
| // Remove return values forwarded from unused arguments. |
| auto return_op = func_op.front().getTerminator(); |
| auto old_return_vals = llvm::to_vector<8>(return_op->getOperands()); |
| int64_t skipped_retvals = 0; |
| for (auto entry : llvm::enumerate(old_return_vals)) { |
| auto return_val = entry.value(); |
| if (auto arg = return_val.dyn_cast<BlockArgument>()) { |
| auto it = infos.find(arg.getArgNumber()); |
| if (it != infos.end() && !it->getSecond().used) { |
| return_op->eraseOperand(entry.index() - skipped_retvals++); |
| } |
| } |
| } |
| llvm::BitVector indices_to_erase(func_op.getNumArguments()); |
| llvm::SmallVector<Type, 4> new_types; |
| int64_t skipped_args = 0; |
| for (auto arg : func_op.getArguments()) { |
| auto it = infos.find(arg.getArgNumber()); |
| if (it != infos.end() && !it->getSecond().used) { |
| indices_to_erase.set(arg.getArgNumber()); |
| skipped_args++; |
| if (old_to_new_arg_indices != nullptr) { |
| old_to_new_arg_indices->push_back(-1); |
| } |
| } else { |
| new_types.push_back(arg.getType()); |
| if (old_to_new_arg_indices != nullptr) { |
| old_to_new_arg_indices->push_back(arg.getArgNumber() - skipped_args); |
| } |
| if (it != infos.end() && remaining_resource_data_types != nullptr) { |
| (*remaining_resource_data_types)[arg.getArgNumber() - skipped_args] = |
| it->second.data_type; |
| } |
| } |
| } |
| func_op.eraseArguments(indices_to_erase); |
| func_op.setType( |
| FunctionType::get(func_op.getContext(), new_types, |
| llvm::to_vector<4>(return_op->getOperandTypes()))); |
| } |
| |
| // Lifts reads/writes of resource arguments from func_op and changes its |
| // signature. resource_data_types is the (index, data type) pair for each |
| // resource argument. handle_updated_arg_value is a caller-provided function |
| // that handles the updated value for an resource argument. |
| LogicalResult LiftArgRetResourcesForFunction( |
| func::FuncOp func_op, |
| const llvm::SmallDenseMap<int64_t, Type>& resource_data_types, |
| llvm::function_ref<void(int64_t, Value)> handle_updated_arg_value) { |
| RegionResourceHoister hoister(func_op); |
| if (failed(hoister.Analyze())) return failure(); |
| |
| // Each of these resources could be read or written in the function. If its |
| // read, we need to replace the resource arg with a value arg to get the |
| // read value. If its written, we need to replace the write with an additional |
| // value to be written. |
| |
| // Now create read values that will be used to replace each resource that |
| // is read in the function body. These read values are just the same argument |
| // with type replaced. |
| llvm::SmallVector<Value, 4> skipped_args; |
| for (auto& it : hoister.GetResources()) { |
| BlockArgument arg = it.first.dyn_cast<BlockArgument>(); |
| assert(arg && "Expect resources for FuncOp to be its arguments"); |
| auto type_iter = resource_data_types.find(arg.getArgNumber()); |
| if (type_iter == resource_data_types.end()) { |
| // Skip lifting the resource if it's not present in the data type map. |
| // This indicates that the resource is not to be lifted because it is used |
| // in an unsupported op in some other function. |
| skipped_args.push_back(arg); |
| } else { |
| arg.setType(type_iter->second); |
| it.second.hoisted_read = arg; |
| } |
| } |
| |
| // Drop all the args that have to be skipped. |
| for (Value arg : skipped_args) hoister.DropResource(arg); |
| |
| hoister.ReplaceResourceLoads(/*read_only=*/false); |
| |
| // For writes, invoke the callback and then erase the write. |
| auto assign_ops = func_op.front().getOps<TF::AssignVariableOp>(); |
| for (auto assign_variable_op : llvm::make_early_inc_range(assign_ops)) { |
| Value resource = assign_variable_op.resource(); |
| if (!hoister.Contains(resource)) continue; |
| |
| auto arg = resource.dyn_cast<BlockArgument>(); |
| handle_updated_arg_value(arg.getArgNumber(), assign_variable_op.value()); |
| assign_variable_op.erase(); |
| } |
| |
| func_op.setType(FunctionType::get( |
| func_op.getContext(), func_op.front().getArgumentTypes(), |
| func_op.front().getTerminator()->getOperandTypes())); |
| |
| return success(); |
| } |
| |
| // Returns a vector filtered from range where the unused elements (specified by |
| // resource_arg_uses) are removed. |
| template <typename T, typename Range> |
| llvm::SmallVector<T, 4> FilterRange( |
| Range range, |
| const llvm::SmallDenseMap<int64_t, ResourceArgUseInfo>& resource_arg_uses) { |
| llvm::SmallVector<T, 4> filtered; |
| for (auto entry : llvm::enumerate(range)) { |
| auto it = resource_arg_uses.find(entry.index()); |
| if (it == resource_arg_uses.end() || it->getSecond().used) |
| filtered.push_back(entry.value()); |
| } |
| return filtered; |
| } |
| |
| // Changes the types of the control flow op (e.g., while, if) and adds loads and |
| // stores around it. arg_data_type_and_updated_output_index maps an operand (to |
| // be changed) index to its data type and the updated value index in the output |
| // (-1 means not updated.) |
| void AddLoadsStoresOutsideControlFlowOp( |
| Operation* caller, |
| const llvm::SmallDenseMap<int64_t, std::pair<Type, int64_t>>& |
| arg_data_type_and_updated_output_index) { |
| OpBuilder builder(caller); |
| auto new_operands = llvm::to_vector<8>(caller->getOperands()); |
| llvm::SmallVector<int64_t, 8> changed_indices; |
| // Find the operands to change, and create the loads. |
| for (auto& entry : arg_data_type_and_updated_output_index) { |
| int64_t index = entry.getFirst(); |
| Type new_type = entry.getSecond().first; |
| int64_t updated_index = entry.getSecond().second; |
| auto operand = caller->getOperand(index); |
| builder.setInsertionPoint(caller); |
| new_operands[index] = builder.create<TF::ReadVariableOp>( |
| caller->getLoc(), ArrayRef<Type>{new_type}, ArrayRef<Value>{operand}); |
| caller->setOperand(index, new_operands[index]); |
| if (updated_index < 0) continue; |
| builder.setInsertionPointAfter(caller); |
| builder.create<TF::AssignVariableOp>( |
| caller->getLoc(), ArrayRef<Type>{}, |
| ArrayRef<Value>{operand, caller->getResult(updated_index)}); |
| } |
| } |
| |
| // Lifts loads/stores from while loop's body and cond functions. |
| LogicalResult HandleWhileLoop(TF::WhileOp while_op, func::FuncOp body, |
| func::FuncOp cond) { |
| auto return_op = body.front().getTerminator(); |
| llvm::SmallDenseMap<int64_t, ResourceArgUseInfo> body_use_info; |
| llvm::SmallDenseMap<int64_t, ResourceArgUseInfo> cond_use_info; |
| if (failed(FindResourceArgUseInfo(body, &body_use_info)) || |
| failed(FindResourceArgUseInfo(cond, &cond_use_info))) { |
| return failure(); |
| } |
| // A resource is considered used as long as it is used in either body or cond. |
| auto resource_arg_uses = |
| MergeArgResourceUseInfo(body_use_info, cond_use_info); |
| if (resource_arg_uses.empty()) return success(); |
| |
| // Remove unused resources in functions. |
| llvm::SmallVector<int64_t, 4> old_to_new_indices; |
| llvm::SmallDenseMap<int64_t, Type> remaining_resource_data_types; |
| RemoveUnusedResourceArgumentsAndForwardedRetvals( |
| resource_arg_uses, body, &old_to_new_indices, |
| &remaining_resource_data_types); |
| RemoveUnusedResourceArgumentsAndForwardedRetvals(resource_arg_uses, cond); |
| (void)LiftArgRetResourcesForFunction( |
| body, remaining_resource_data_types, |
| [&](int64_t index, Value value) { return_op->setOperand(index, value); }); |
| (void)LiftArgRetResourcesForFunction(cond, remaining_resource_data_types, |
| [&](int64_t index, Value value) { |
| // We already checked that cond should |
| // not have variable writes. |
| assert(false && "Should not happen"); |
| }); |
| // Recreate the while op. |
| OpBuilder builder(while_op); |
| // Now use the filtered original operands, which will be replaced by |
| // AddLoadsStoresOutsideControlFlowOp(). |
| auto new_while = builder.create<TF::WhileOp>( |
| while_op.getLoc(), body.getFunctionType().getResults(), |
| FilterRange<Value, OperandRange>(while_op.getOperands(), |
| resource_arg_uses), |
| while_op->getAttrs()); |
| // Prepare for AddLoadsStoresOutsideControlFlowOp(). |
| llvm::SmallDenseMap<int64_t, std::pair<Type, int64_t>> |
| arg_data_type_and_updated_output_index; |
| for (const auto& entry : remaining_resource_data_types) { |
| int64_t update_index = return_op->getOperand(entry.getFirst()) == |
| body.getArgument(entry.getFirst()) |
| ? -1 |
| : entry.getFirst(); |
| arg_data_type_and_updated_output_index[entry.getFirst()] = { |
| entry.getSecond(), update_index}; |
| } |
| AddLoadsStoresOutsideControlFlowOp(new_while, |
| arg_data_type_and_updated_output_index); |
| // Replace uses. |
| for (int64_t i = 0, end = old_to_new_indices.size(); i < end; ++i) { |
| if (old_to_new_indices[i] >= 0) { |
| while_op.getResult(i).replaceAllUsesWith( |
| new_while.getResult(old_to_new_indices[i])); |
| } |
| } |
| while_op.erase(); |
| return success(); |
| } |
| |
| // Lifts loads/stores from an IfOp or CaseOp's branches. |
| template <class CaseOrIfOp> |
| LogicalResult HandleCaseOrIfOp(CaseOrIfOp op, ArrayRef<func::FuncOp> branches) { |
| // For canonicalized If/Case, there should not be any resource outputs |
| int64_t non_resource_results = op.getNumResults(); |
| |
| llvm::SmallDenseMap<int64_t, ResourceArgUseInfo> resource_arg_uses; |
| if (failed(FindResourceArgUseInfo(branches.front(), &resource_arg_uses))) |
| return failure(); |
| |
| for (auto func : branches.drop_front()) { |
| llvm::SmallDenseMap<int64_t, ResourceArgUseInfo> branch_use_info; |
| if (failed(FindResourceArgUseInfo(func, &branch_use_info))) |
| return failure(); |
| // A resource is considered used as long as it is used in either branch. |
| resource_arg_uses = |
| MergeArgResourceUseInfo(resource_arg_uses, branch_use_info); |
| } |
| |
| if (resource_arg_uses.empty()) return success(); |
| // Remove unused resources in functions. |
| llvm::SmallDenseMap<int64_t, Type> remaining_resource_data_types; |
| RemoveUnusedResourceArgumentsAndForwardedRetvals( |
| resource_arg_uses, branches.front(), /*old_to_new_arg_indices=*/nullptr, |
| &remaining_resource_data_types); |
| for (auto func : branches.drop_front()) |
| RemoveUnusedResourceArgumentsAndForwardedRetvals(resource_arg_uses, func); |
| |
| // Forward resource inputs updated in any branch to the outputs of both |
| // branches. First prepare the mapping from arg to new update output. |
| llvm::SmallDenseMap<int64_t, int64_t> resource_arg_to_new_output; |
| { |
| int64_t removed_args = 0; |
| for (const auto& entry : resource_arg_uses) { |
| if (!entry.getSecond().used) { |
| removed_args++; |
| continue; |
| } |
| if (!entry.getSecond().updated) continue; |
| int64_t new_output_index = |
| non_resource_results + resource_arg_to_new_output.size(); |
| resource_arg_to_new_output[entry.getFirst() - removed_args] = |
| new_output_index; |
| } |
| } |
| |
| // Append resource updates to the return ops: now they are just forwarded |
| // input resources, but will be replaced by the data value in |
| // LiftArgRetResourcesForFunction(). |
| for (auto branch : branches) { |
| auto new_retvals = |
| llvm::to_vector<4>(branch.front().getTerminator()->getOperands()); |
| new_retvals.resize(new_retvals.size() + resource_arg_to_new_output.size()); |
| for (const auto& entry : resource_arg_to_new_output) { |
| int64_t resource_arg_index = entry.getFirst(); |
| int64_t output_index = entry.getSecond(); |
| new_retvals[output_index] = branch.getArgument(resource_arg_index); |
| } |
| auto old_return = branch.front().getTerminator(); |
| OpBuilder builder(old_return); |
| auto new_return = |
| builder.create<func::ReturnOp>(old_return->getLoc(), new_retvals); |
| old_return->erase(); |
| (void)LiftArgRetResourcesForFunction( |
| branch, remaining_resource_data_types, [&](int64_t index, Value value) { |
| new_return.setOperand(resource_arg_to_new_output[index], value); |
| }); |
| } |
| |
| // Recreate the op without resource operands. |
| OpBuilder builder(op); |
| // Now use the filtered original operands, which will be replaced by |
| // AddLoadsStoresOutsideControlFlowOp(). |
| auto new_operands = |
| FilterRange<Value, OperandRange>(op.input(), resource_arg_uses); |
| new_operands.insert(new_operands.begin(), op.getOperand(0)); |
| func::FuncOp first_func = branches.front(); |
| auto new_op = builder.create<CaseOrIfOp>( |
| op.getLoc(), first_func.getFunctionType().getResults(), new_operands, |
| op->getAttrs()); |
| // Prepare for AddLoadsStoresOutsideControlFlowOp() |
| llvm::SmallDenseMap<int64_t, std::pair<Type, int64_t>> |
| arg_data_type_and_updated_output_index; |
| for (const auto& entry : remaining_resource_data_types) { |
| auto new_output_it = resource_arg_to_new_output.find(entry.getFirst()); |
| int64_t update_index = new_output_it == resource_arg_to_new_output.end() |
| ? -1 |
| : new_output_it->getSecond(); |
| arg_data_type_and_updated_output_index[entry.getFirst() + 1] = { |
| entry.getSecond(), update_index}; |
| } |
| AddLoadsStoresOutsideControlFlowOp(new_op, |
| arg_data_type_and_updated_output_index); |
| // Replace uses. |
| op.replaceAllUsesWith(new_op.getResults().take_front(op.getNumResults())); |
| op.erase(); |
| return success(); |
| } |
| |
| // A resource-lifted function for (potentially multiple) PartitionedCallOps and |
| // information about the lifting changes. |
| struct PartitionedCallLiftingInfo { |
| // Function with resources lifted. Can be nullptr if nothing needs to change. |
| func::FuncOp lifted_callee; |
| // Mapping from old resource outputs to their aliasing output inputs. |
| llvm::SmallDenseMap<int64_t, int64_t> old_outputs_aliasing_old_inputs; |
| // Mapping from old to new output indices in case any output is removed. |
| llvm::SmallVector<int64_t, 4> old_to_new_output_indices; |
| // ResourceArgUseInfo for each old resource argument. |
| llvm::SmallDenseMap<int64_t, ResourceArgUseInfo> use_info; |
| // Input for AddLoadsStoresOutsideControlFlowOp(), see its comment. |
| llvm::SmallDenseMap<int64_t, std::pair<Type, int64_t>> |
| arg_data_type_and_updated_output_index; |
| }; |
| |
| // Lifts loads/stores from a PartitionedCallOp's callee function. If anything |
| // needs to be changed, the original function will be preserved, and the lifting |
| // happens on a clone, which will be stored in `result`. |
| LogicalResult HandlePartitionedCallOpCallee( |
| func::FuncOp callee, PartitionedCallLiftingInfo* result) { |
| // Sanity check: return of resources should be aliases of inputs. Such outputs |
| // will be removed later. |
| int64_t non_resource_results = 0; |
| for (auto entry : |
| llvm::enumerate(callee.front().getTerminator()->getOperands())) { |
| auto retval = entry.value(); |
| if (!getElementTypeOrSelf(retval.getType()).isa<TF::ResourceType>()) { |
| result->old_to_new_output_indices.push_back(non_resource_results++); |
| continue; |
| } |
| auto aliasing_arg = retval.dyn_cast<BlockArgument>(); |
| if (!aliasing_arg) { |
| return callee.emitOpError("unsupported function call: ") |
| << "resource return value does not alias an input."; |
| } |
| result->old_outputs_aliasing_old_inputs[entry.index()] = |
| aliasing_arg.getArgNumber(); |
| result->old_to_new_output_indices.push_back(-1); |
| } |
| |
| if (failed(FindResourceArgUseInfo(callee, &result->use_info))) { |
| return failure(); |
| } |
| if (result->use_info.empty()) { |
| result->lifted_callee = nullptr; |
| return success(); |
| } |
| |
| // Clone the callee before making changes. |
| SmallString<64> name_base = callee.getName(); |
| auto module = callee->getParentOfType<ModuleOp>(); |
| name_base += "_resource_lifted"; |
| auto name = name_base; |
| callee = callee.clone(); |
| callee.setPrivate(); |
| callee.setName(mlir::StringAttr::get(callee->getContext(), name)); |
| SymbolTable(module).insert(callee); |
| result->lifted_callee = callee; |
| |
| // Remove unused resources in functions. |
| llvm::SmallDenseMap<int64_t, Type> remaining_resource_data_types; |
| RemoveUnusedResourceArgumentsAndForwardedRetvals( |
| result->use_info, callee, /*old_to_new_arg_indices=*/nullptr, |
| &remaining_resource_data_types); |
| for (const auto& entry : remaining_resource_data_types) { |
| result->arg_data_type_and_updated_output_index[entry.getFirst()] = { |
| entry.getSecond(), -1}; |
| } |
| llvm::SmallVector<int64_t, 4> retval_indices_to_preserve; |
| for (auto& val : callee.front().getTerminator()->getOpOperands()) { |
| // Store indices of results that are not resources. |
| if (!getElementTypeOrSelf(val.get().getType()).isa<TF::ResourceType>()) |
| retval_indices_to_preserve.push_back(val.getOperandNumber()); |
| } |
| int64_t num_retvals = retval_indices_to_preserve.size(); |
| llvm::SmallVector<Value, 4> new_retvals; |
| // Lift resources. |
| (void)LiftArgRetResourcesForFunction( |
| callee, remaining_resource_data_types, [&](int64_t index, Value value) { |
| result->arg_data_type_and_updated_output_index[index].second = |
| num_retvals++; |
| new_retvals.push_back(value); |
| }); |
| |
| auto old_return = callee.front().getTerminator(); |
| llvm::SmallVector<Value, 4> old_and_new_retvals; |
| old_and_new_retvals.reserve(retval_indices_to_preserve.size() + |
| new_retvals.size()); |
| for (int64_t retval_index : retval_indices_to_preserve) |
| old_and_new_retvals.push_back(old_return->getOperand(retval_index)); |
| |
| old_and_new_retvals.append(new_retvals.begin(), new_retvals.end()); |
| // Replace old return with the new ones with update values. |
| OpBuilder builder(old_return); |
| auto new_return = |
| builder.create<func::ReturnOp>(old_return->getLoc(), old_and_new_retvals); |
| old_return->erase(); |
| callee.setType(FunctionType::get( |
| callee.getContext(), callee.getFunctionType().getInputs(), |
| llvm::to_vector<4>(new_return.getOperandTypes()))); |
| return success(); |
| } |
| |
| // Updates a PartitionedCallOp/StatefulPartitionedCallOp according to the |
| // resource-lifted new callee function in lifting_info. |
| template <typename CallOpType> |
| void UpdatePartitionedCallOpWithNewCallee( |
| CallOpType call_op, PartitionedCallLiftingInfo& lifting_info) { |
| if (!lifting_info.lifted_callee) return; |
| // Replace output resource uses with the aliasing input, so that we can remove |
| // this output. |
| for (const auto& entry : lifting_info.old_outputs_aliasing_old_inputs) { |
| call_op.getResult(entry.getFirst()) |
| .replaceAllUsesWith(call_op.getOperand(entry.getSecond())); |
| } |
| // Recreate the call op. |
| OpBuilder builder(call_op); |
| // Now use the filtered original operands, which will be replaced by |
| // AddLoadsStoresOutsideControlFlowOp(). |
| auto new_operands = |
| FilterRange<Value, OperandRange>(call_op.args(), lifting_info.use_info); |
| auto new_call = builder.create<CallOpType>( |
| call_op.getLoc(), |
| lifting_info.lifted_callee.getFunctionType().getResults(), new_operands, |
| call_op->getAttrs()); |
| new_call->setAttr("f", |
| SymbolRefAttr::get(builder.getContext(), |
| lifting_info.lifted_callee.getName())); |
| AddLoadsStoresOutsideControlFlowOp( |
| new_call, lifting_info.arg_data_type_and_updated_output_index); |
| // Replace uses. |
| for (int64_t i = 0, end = lifting_info.old_to_new_output_indices.size(); |
| i < end; ++i) { |
| if (lifting_info.old_to_new_output_indices[i] >= 0) { |
| call_op.getResult(i).replaceAllUsesWith( |
| new_call.getResult(lifting_info.old_to_new_output_indices[i])); |
| } |
| } |
| call_op.erase(); |
| } |
| |
| LogicalResult HoistForControlFlow( |
| Block*, ModuleOp, bool, |
| llvm::SmallDenseMap<llvm::StringRef, PartitionedCallLiftingInfo>*); |
| |
| // A templated routine for handling both PartitionedCallOp and |
| // StatefulPartitionedCallOp. If the callee is already lifted, it just updates |
| // the caller op itself; otherwise, it first recursively handles nested control |
| // flow, then performs lifting on the callee. |
| template <typename CallOpType> |
| LogicalResult HandlePartitionedCallOp( |
| CallOpType call_op, func::FuncOp callee, ModuleOp module, |
| bool vars_initialized, |
| llvm::SmallDenseMap<llvm::StringRef, PartitionedCallLiftingInfo>* |
| lifted_callees) { |
| auto emplace_res = lifted_callees->try_emplace(callee.getName(), |
| PartitionedCallLiftingInfo()); |
| if (emplace_res.second) { |
| // Unseen callee. Perform resource lifting on it. |
| if (failed(HoistForControlFlow(&callee.front(), module, vars_initialized, |
| lifted_callees))) |
| return failure(); |
| |
| if (failed(HandlePartitionedCallOpCallee( |
| callee, &emplace_res.first->getSecond()))) { |
| return failure(); |
| } |
| } |
| UpdatePartitionedCallOpWithNewCallee(call_op, emplace_res.first->getSecond()); |
| return success(); |
| } |
| |
| // Hoists resource loads/stores from control flow ops in `block` outside the |
| // body/cond/branch/callee functions. |
| LogicalResult HoistForControlFlow( |
| Block* block, ModuleOp module, bool vars_initialized, |
| llvm::SmallDenseMap<llvm::StringRef, PartitionedCallLiftingInfo>* |
| lifted_partitioned_call_callees) { |
| if (vars_initialized) SetAllVarIsInitializedToTrue(block); |
| |
| for (Operation& op : llvm::make_early_inc_range(*block)) { |
| if (auto while_op = llvm::dyn_cast<TF::WhileOp>(&op)) { |
| auto body = while_op.body_function(); |
| auto cond = while_op.cond_function(); |
| // Recursively handle the nested control flow. |
| (void)HoistForControlFlow(&body.front(), module, vars_initialized, |
| lifted_partitioned_call_callees); |
| (void)HoistForControlFlow(&cond.front(), module, vars_initialized, |
| lifted_partitioned_call_callees); |
| if (failed(HandleWhileLoop(while_op, body, cond))) return failure(); |
| } else if (auto if_op = llvm::dyn_cast<TF::IfOp>(&op)) { |
| auto then_branch = if_op.then_function(); |
| auto else_branch = if_op.else_function(); |
| // Recursively handle the nested control flow. |
| (void)HoistForControlFlow(&then_branch.front(), module, vars_initialized, |
| lifted_partitioned_call_callees); |
| (void)HoistForControlFlow(&else_branch.front(), module, vars_initialized, |
| lifted_partitioned_call_callees); |
| if (failed(HandleCaseOrIfOp(if_op, {then_branch, else_branch}))) |
| return failure(); |
| } else if (auto case_op = llvm::dyn_cast<TF::CaseOp>(&op)) { |
| SmallVector<func::FuncOp, 4> branch_functions; |
| case_op.get_branch_functions(branch_functions); |
| for (func::FuncOp func : branch_functions) { |
| // Recursively handle the nested control flow. |
| (void)HoistForControlFlow(&func.front(), module, vars_initialized, |
| lifted_partitioned_call_callees); |
| } |
| if (failed(HandleCaseOrIfOp(case_op, branch_functions))) return failure(); |
| } else if (auto call_op = llvm::dyn_cast<TF::PartitionedCallOp>(&op)) { |
| auto callee = call_op.func(); |
| if (!callee) { |
| return call_op.emitOpError( |
| "resource lifting does not support call with nested references."); |
| } |
| if (failed(HandlePartitionedCallOp(call_op, callee, module, |
| vars_initialized, |
| lifted_partitioned_call_callees))) { |
| // Nested control flow handling is done in HandlePartitionedCallOp(). |
| return failure(); |
| } |
| } else if (auto call_op = |
| llvm::dyn_cast<TF::StatefulPartitionedCallOp>(&op)) { |
| if (failed(HandlePartitionedCallOp(call_op, call_op.func(), module, |
| vars_initialized, |
| lifted_partitioned_call_callees))) { |
| return failure(); |
| } |
| } else if (isa<TF::IfRegionOp, TF::CaseRegionOp, TF::WhileRegionOp>(op)) { |
| for (Region& region : op.getRegions()) |
| (void)HoistForControlFlow(®ion.front(), module, vars_initialized, |
| lifted_partitioned_call_callees); |
| LogicalResult result = RegionResourceHoister::ReplaceOpWithNewOp(&op); |
| if (failed(result)) return failure(); |
| } |
| } |
| |
| // After we have hoisted operations in the block, we may have added new read |
| // and writes of resources to this block. Clean them up by doing store-load |
| // forwarding. |
| ForwardStoreToLoad(block); |
| return success(); |
| } |
| |
| // Lifts resource operation from tf_device.cluster ops nested in `op` outside. |
| // Returns failure if there are remaining resource-type values that can not be |
| // lifted. |
| void ResourceOpLiftingPass::runOnOperation() { |
| llvm::SmallDenseMap<llvm::StringRef, PartitionedCallLiftingInfo> |
| lifted_partitioned_call_callees; |
| ModuleOp module = getOperation(); |
| |
| if (failed(TF::CleanupAndCanonicalizeForResourceOpLifting(module))) |
| return signalPassFailure(); |
| |
| auto walk_result = module.walk([&](func::FuncOp func_op) { |
| return func_op.walk([&](tf_device::ClusterOp cluster) { |
| LogicalResult result = HoistForControlFlow( |
| &cluster.GetBody(), module, /*vars_initialized=*/true, |
| &lifted_partitioned_call_callees); |
| if (failed(result)) return WalkResult::interrupt(); |
| result = RegionResourceHoister::ReplaceOpWithNewOp(cluster); |
| if (failed(result)) return WalkResult::interrupt(); |
| return WalkResult::advance(); |
| }); |
| }); |
| |
| if (walk_result.wasInterrupted()) return signalPassFailure(); |
| } |
| |
| struct ResourceOpLiftingForMainFunctionPass |
| : public TFDevice::ResourceOpLiftingForMainFunctionPassBase< |
| ResourceOpLiftingForMainFunctionPass> { |
| void runOnOperation() override; |
| }; |
| |
| void ResourceOpLiftingForMainFunctionPass::runOnOperation() { |
| ModuleOp module = getOperation(); |
| func::FuncOp main_func = module.lookupSymbol<func::FuncOp>("main"); |
| if (!main_func) { |
| return; |
| } |
| |
| if (failed(TF::ResourceLiftingForFunctionalControlFlow(main_func))) { |
| return signalPassFailure(); |
| } |
| } |
| |
| } // namespace |
| |
| namespace TFDevice { |
| std::unique_ptr<OperationPass<ModuleOp>> CreateResourceOpLiftingPass() { |
| return std::make_unique<ResourceOpLiftingPass>(); |
| } |
| |
| std::unique_ptr<OperationPass<ModuleOp>> |
| CreateResourceOpLiftingForMainFunctionPass() { |
| return std::make_unique<ResourceOpLiftingForMainFunctionPass>(); |
| } |
| |
| } // namespace TFDevice |
| |
| namespace TF { |
| LogicalResult ResourceLiftingForFunctionalControlFlow(func::FuncOp function) { |
| // This routine should only be called when control flow operations are still |
| // represented with TF IfOp and WhileOp operations. In this case, there should |
| // be only one basic blocks in the MLIR representation. |
| if (!llvm::hasSingleElement(function)) { |
| return function.emitError() |
| << "expect the function to have 1 block while it has " |
| << function.getBlocks().size(); |
| } |
| |
| if (failed(TF::CleanupAndCanonicalizeForResourceOpLifting(function))) |
| return failure(); |
| |
| llvm::SmallDenseMap<llvm::StringRef, PartitionedCallLiftingInfo> |
| lifted_partitioned_call_callees; |
| if (failed(HoistForControlFlow( |
| &function.front(), cast<ModuleOp>(function->getParentOp()), |
| /*vars_initialized=*/false, &lifted_partitioned_call_callees))) |
| return failure(); |
| |
| // Clean up and canonicalize to remove dead local variables as some local |
| // variables might be dead after hoisting resource loads/stores from control |
| // flow ops. |
| return TF::CleanupAndCanonicalizeForResourceOpLifting(function); |
| } |
| } // namespace TF |
| |
| } // namespace mlir |