blob: e5a2038a10fd0d929fc325e8ec430822fa645317 [file] [log] [blame]
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include <memory>
#include "llvm/ADT/STLExtras.h"
#include "llvm/ADT/SmallVector.h"
#include "mlir/Dialect/Func/IR/FuncOps.h" // from @llvm-project
#include "mlir/IR/Builders.h" // from @llvm-project
#include "mlir/IR/BuiltinOps.h" // from @llvm-project
#include "mlir/Pass/Pass.h" // from @llvm-project
#include "mlir/Pass/PassRegistry.h" // from @llvm-project
#include "mlir/Support/LLVM.h" // from @llvm-project
#include "tensorflow/compiler/mlir/tensorflow/ir/tf_device.h"
#include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h"
#include "tensorflow/compiler/mlir/tensorflow/ir/tf_types.h"
#include "tensorflow/compiler/mlir/tensorflow/transforms/passes_detail.h"
namespace mlir {
namespace TFTPU {
// A pass that finds TPU clusters with write only resource access and adds an
// associated resource read, so the resource can later be fused into TPUExecute.
namespace {
struct TPUResourceReadForWritePass
: public TF::TPUResourceReadForWritePassBase<TPUResourceReadForWritePass> {
void runOnOperation() override;
};
// Helper struct holding a resource value and its associated type.
struct ResourceValueAndSubtype {
Value resource;
Type subtype;
};
// Finds resource handle and type for result if result writes to a resource.
ResourceValueAndSubtype GetResourceWriteResult(
tf_device::ClusterFuncOp cluster_func, Value result) {
ResourceValueAndSubtype resource;
if (!result.hasOneUse()) return resource;
Operation* result_user = *result.getUsers().begin();
auto assign_var = dyn_cast<TF::AssignVariableOp>(result_user);
if (!assign_var) return resource;
auto handle = assign_var.resource();
// Skip result if cluster writes to the same variable via multiple results.
for (Operation* handle_user : handle.getUsers()) {
if (handle_user == assign_var) continue;
auto assign_var_user = dyn_cast<TF::AssignVariableOp>(handle_user);
if (!assign_var_user) continue;
if (assign_var_user.value().getDefiningOp() == cluster_func)
return resource;
}
resource.resource = assign_var.resource();
resource.subtype = assign_var.value().getType();
return resource;
}
// Checks if resource is read by TPU cluster.
bool ClusterFuncHasResourceRead(tf_device::ClusterFuncOp cluster_func,
Value resource) {
for (Operation* resource_user : resource.getUsers())
if (auto read = dyn_cast<TF::ReadVariableOp>(resource_user))
for (Operation* read_user : read.value().getUsers())
if (read_user == cluster_func) return true;
return false;
}
void TPUResourceReadForWritePass::runOnOperation() {
SmallVector<tf_device::ClusterFuncOp, 4> cluster_funcs;
getOperation().walk([&](tf_device::ClusterFuncOp cluster_func) {
cluster_funcs.push_back(cluster_func);
});
OpBuilder builder(&getContext());
// Add resource reads for resource writes from TPU cluster where for such
// resources the TPU cluster does not read from.
for (tf_device::ClusterFuncOp cluster_func : cluster_funcs) {
builder.setInsertionPoint(cluster_func);
SmallVector<Value, 4> read_operands;
for (Value result : cluster_func.getResults()) {
// TODO(lyandy): Update pass to use resource alias analysis.
auto resource_and_type = GetResourceWriteResult(cluster_func, result);
if (!resource_and_type.resource) continue;
if (ClusterFuncHasResourceRead(cluster_func, resource_and_type.resource))
continue;
auto new_read = builder.create<TF::ReadVariableOp>(
resource_and_type.resource.getLoc(), resource_and_type.subtype,
resource_and_type.resource);
read_operands.push_back(new_read.value());
}
if (read_operands.empty()) continue;
// Update caller and function types with new read operands.
auto operands = llvm::to_vector<4>(cluster_func.getOperands());
operands.append(read_operands.begin(), read_operands.end());
auto loc = cluster_func.getLoc();
auto new_cluster_func = builder.create<tf_device::ClusterFuncOp>(
loc, cluster_func.getResultTypes(), operands, cluster_func->getAttrs());
cluster_func.replaceAllUsesWith(new_cluster_func);
func::FuncOp func = cluster_func.getFunc();
Block& block = func.front();
for (Value read_operand : read_operands)
block.addArgument(read_operand.getType(), loc);
func.setType(FunctionType::get(&getContext(), block.getArgumentTypes(),
func.getCallableResults()));
cluster_func.erase();
}
}
} // namespace
std::unique_ptr<OperationPass<ModuleOp>> CreateTPUResourceReadForWritePass() {
return std::make_unique<TPUResourceReadForWritePass>();
}
} // namespace TFTPU
} // namespace mlir