blob: 379293f2f68ef8f132b3016e2949530d92cc5d77 [file] [log] [blame]
/* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include <iterator>
#include <memory>
#include <tuple>
#include <utility>
#include "llvm/ADT/ArrayRef.h"
#include "llvm/ADT/DenseMap.h"
#include "llvm/ADT/Optional.h"
#include "llvm/ADT/STLExtras.h"
#include "llvm/ADT/SetVector.h"
#include "llvm/ADT/StringRef.h"
#include "llvm/ADT/iterator_range.h"
#include "llvm/Support/Casting.h"
#include "llvm/Support/Debug.h"
#include "mlir/Dialect/Func/IR/FuncOps.h" // from @llvm-project
#include "mlir/IR/Attributes.h" // from @llvm-project
#include "mlir/IR/Builders.h" // from @llvm-project
#include "mlir/IR/BuiltinOps.h" // from @llvm-project
#include "mlir/IR/OpImplementation.h" // from @llvm-project
#include "mlir/IR/Operation.h" // from @llvm-project
#include "mlir/IR/OperationSupport.h" // from @llvm-project
#include "mlir/IR/Types.h" // from @llvm-project
#include "mlir/IR/Value.h" // from @llvm-project
#include "mlir/IR/Visitors.h" // from @llvm-project
#include "mlir/Pass/Pass.h" // from @llvm-project
#include "mlir/Pass/PassRegistry.h" // from @llvm-project
#include "mlir/Support/LogicalResult.h" // from @llvm-project
#include "tensorflow/compiler/mlir/tensorflow/analysis/resource_alias_analysis.h"
#include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h"
#include "tensorflow/compiler/mlir/tensorflow/ir/tf_types.h"
#include "tensorflow/compiler/mlir/tensorflow/transforms/passes_detail.h"
#define DEBUG_TYPE "tf-resource-device-inference"
namespace mlir {
namespace TF {
namespace {
constexpr char kDeviceAttr[] = "device";
constexpr char kFuncDeviceAttr[] = "tf.device";
// A pass that propagates device assignment of resources on a module. It
// performs in-function propagation, as well as cross-function propagation from
// callers to callees.
//
// This pass changes the module by adding "tf.device" attribute to function
// arguments and adding "device" attribute to TF ops.
struct ResourceDeviceInference
: public ResourceDeviceInferencePassBase<ResourceDeviceInference> {
void runOnOperation() override;
};
// A class that records each resource's device assignment in a function.
class PerFunctionResult {
public:
explicit PerFunctionResult(
func::FuncOp func_op,
const TF::ResourceAliasAnalysis::Info& alias_analysis)
: alias_analysis_(alias_analysis) {}
// Returns the recorded device assignment for a resource, if any.
Optional<StringRef> DeviceForResource(Value resource) const {
Optional<StringRef> result;
if (alias_analysis_.IsUnknownResource(resource)) return llvm::None;
for (int64_t id : alias_analysis_.GetResourceUniqueIds(resource)) {
auto it = resource_id_to_device_.find(id);
if (it == resource_id_to_device_.end()) continue;
if (!result || result == it->second) {
result = it->getSecond();
continue;
}
// Got conflicting assignments
return llvm::None;
}
return result;
}
// Records the device assignment for a resource. If the new assignment
// conflicts with an existing one, returns an error.
//
// If `changed` is provided, assign *changed to true if anything is modified.
LogicalResult AddResourceDevice(Value resource, StringRef device,
bool* changed = nullptr) {
if (alias_analysis_.IsUnknownResource(resource)) return success();
for (int64_t id : alias_analysis_.GetResourceUniqueIds(resource)) {
auto emplace_res = resource_id_to_device_.try_emplace(id, device);
if (emplace_res.second) {
if (changed) *changed = true;
} else if (emplace_res.first->getSecond() != device) {
// Existing assignment does not equal the new assignment.
return failure();
}
}
return success();
}
private:
llvm::SmallDenseMap<int64_t, StringRef, 8> resource_id_to_device_;
const TF::ResourceAliasAnalysis::Info& alias_analysis_;
};
// Tries to record device assignment for a resource.
LogicalResult AddResourceDeviceAndEmitError(Value resource, StringRef device,
Operation* error_reporting_op,
PerFunctionResult* result,
bool* changed = nullptr) {
auto res = result->AddResourceDevice(resource, device, changed);
if (failed(res)) {
error_reporting_op->emitError()
<< "Conflicting device assignment for resource";
}
return res;
}
// Extracts and canonicalizes the device attribute.
inline StringRef GetDeviceAttr(func::FuncOp func, int arg_no) {
auto device_attr =
func.getArgAttrOfType<mlir::StringAttr>(arg_no, kFuncDeviceAttr);
return device_attr ? device_attr.getValue() : "";
}
// Extracts and canonicalizes the device attribute.
inline StringRef GetDeviceAttr(Operation* op) {
auto device_attr = op->getAttrOfType<mlir::StringAttr>(kDeviceAttr);
return device_attr ? device_attr.getValue() : "";
}
// Print operation with debug info (to get line number info for debugging)
void dump(StringRef message, Operation* op) {
llvm::dbgs() << message;
op->print(llvm::dbgs(), OpPrintingFlags().enableDebugInfo(true));
llvm::dbgs() << "\n";
}
// Propagates device assignment inside a function.
LogicalResult ComputeResourceDevicesInComputation(func::FuncOp func_op,
PerFunctionResult* result) {
OpBuilder builder(func_op);
// Function arguments.
for (auto arg : filter_resources(func_op.getArguments())) {
StringRef device_attr = GetDeviceAttr(func_op, arg.getArgNumber());
if (device_attr.empty()) {
// If device_attr does not exist, try to construct it from any recorded
// assignment.
if (auto device = result->DeviceForResource(arg)) {
func_op.setArgAttr(arg.getArgNumber(), kFuncDeviceAttr,
builder.getStringAttr(*device));
}
continue;
}
// Record the attribute.
auto res = AddResourceDeviceAndEmitError(arg, device_attr, func_op, result);
if (failed(res)) return res;
}
// To support WhileRegion, we need to propagate device attributes from
// WhileRegion operands to body/cond region arguments *prior* to visiting
// these regions, so use a pre-order walk.
WalkResult walk_res = func_op.walk<WalkOrder::PreOrder>([&](Operation* op) {
if (auto var_handle = dyn_cast<VarHandleOp>(op)) {
// Record VarHandleOp's device attribute.
StringRef device_attr = GetDeviceAttr(op);
if (device_attr.empty()) return WalkResult::advance();
auto res = AddResourceDeviceAndEmitError(var_handle.resource(),
device_attr, op, result);
if (failed(res)) return WalkResult::interrupt();
} else if (auto identity = dyn_cast<IdentityOp>(op)) {
LLVM_DEBUG(dump("Visiting ", identity));
// Try to construct IdentityOp's attribute from recorded assignment.
if (!GetDeviceAttr(op).empty()) return WalkResult::advance();
for (auto output : filter_resources(op->getResults())) {
LLVM_DEBUG(llvm::dbgs() << " Processing output #"
<< output.getResultNumber() << "\n");
if (auto device = result->DeviceForResource(output)) {
LLVM_DEBUG(llvm::dbgs() << " Setting device = " << *device << "\n");
identity->setAttr(kDeviceAttr, builder.getStringAttr(*device));
}
}
} else if (auto while_region = dyn_cast<WhileRegionOp>(op)) {
// For WhileRegion, do local analysis prior to visiting the attached
// regions and propagate device annotations to the cond and body
// region arguments. The annotations are the union of annotations
// on the input and result. Resource alias analysis already propagates
// resource ID from the inputs to the results for a while, so just
// need to consider the results.
LLVM_DEBUG(llvm::dbgs() << "Visiting WhileRegion\n");
for (auto output : filter_resources(while_region.getResults())) {
auto device = result->DeviceForResource(output);
int output_index = output.getResultNumber();
if (!device) {
LLVM_DEBUG(llvm::dbgs()
<< " No device for output #" << output_index << "\n");
continue;
}
// Transfer the annotation to both region arguments
for (Region* region : while_region.getRegions()) {
BlockArgument arg = region->getArgument(output_index);
LLVM_DEBUG(llvm::dbgs()
<< " Propagating device = '" << *device << "' to arg #"
<< output_index << " of region #"
<< region->getRegionNumber() << "\n");
if (failed(AddResourceDeviceAndEmitError(arg, *device, while_region,
result)))
return WalkResult::interrupt();
}
}
}
return WalkResult::advance();
});
return failure(walk_res.wasInterrupted());
}
void ResourceDeviceInference::runOnOperation() {
auto module = getOperation();
const auto& resource_alias_analysis =
getAnalysis<TF::ResourceAliasAnalysis>();
llvm::SmallDenseMap<func::FuncOp, PerFunctionResult, 4> per_function_results;
llvm::SetVector<func::FuncOp> worklist;
for (auto func_op : module.getOps<func::FuncOp>()) {
worklist.insert(func_op);
per_function_results.try_emplace(
func_op, func_op, resource_alias_analysis.GetAnalysisForFunc(func_op));
}
// Helper that propagates an op's recorded operand device assignments to its
// called function's arguments.
auto propagate_operands_to_callee_arguments =
[&](Operation* caller, Operation::operand_range caller_operands,
ArrayRef<func::FuncOp> callees, const PerFunctionResult& caller_res) {
for (func::FuncOp callee : callees) {
assert(callee);
auto& callee_res = per_function_results.find(callee)->getSecond();
bool callee_needs_recompute = false;
for (BlockArgument arg : filter_resources(callee.getArguments())) {
Value arg_operand = caller_operands[arg.getArgNumber()];
auto device = caller_res.DeviceForResource(arg_operand);
if (!device) continue;
LLVM_DEBUG(llvm::dbgs()
<< "Propagating '" << *device << "' to arg #"
<< arg.getArgNumber() << " of function @"
<< callee.getName() << "\n");
if (failed(AddResourceDeviceAndEmitError(arg, *device, caller,
&callee_res,
&callee_needs_recompute)))
return failure();
}
// If the callee recording is modified, make sure that it will be
// reprocessed.
if (callee_needs_recompute) worklist.insert(callee);
}
return success();
};
while (!worklist.empty()) {
auto func_op = worklist.pop_back_val();
auto& func_res = per_function_results.find(func_op)->getSecond();
// In-function propagation.
if (failed(ComputeResourceDevicesInComputation(func_op, &func_res)))
return signalPassFailure();
// Propagation to callees.
auto walk_res = func_op.walk([&](Operation* op) {
if (auto while_op = dyn_cast<WhileOp>(op)) {
if (failed(propagate_operands_to_callee_arguments(
while_op, while_op.getOperands(),
{while_op.body_function(), while_op.cond_function()},
func_res)))
return WalkResult::interrupt();
} else if (auto if_op = dyn_cast<IfOp>(op)) {
if (failed(propagate_operands_to_callee_arguments(
if_op, if_op.input(),
{if_op.then_function(), if_op.else_function()}, func_res)))
return WalkResult::interrupt();
} else if (auto call = dyn_cast<CallOpInterface>(op)) {
auto func = dyn_cast<func::FuncOp>(call.resolveCallable());
if (!func) {
op->emitError(
"Cannot propagate device attribute to callee: Unable to resolve "
"call");
return WalkResult::interrupt();
}
LLVM_DEBUG(llvm::dbgs()
<< "Visiting call to function @" << func.getName() << "\n");
if (failed(propagate_operands_to_callee_arguments(
call, call.getArgOperands(), {func}, func_res)))
return WalkResult::interrupt();
}
return WalkResult::advance();
});
if (walk_res.wasInterrupted()) return signalPassFailure();
}
}
} // namespace
std::unique_ptr<OperationPass<ModuleOp>> CreateResourceDeviceInferencePass() {
return std::make_unique<ResourceDeviceInference>();
}
} // namespace TF
} // namespace mlir