blob: 82e893e748b0233c76e4a9ab558e4007003cfada [file] [log] [blame]
/* Copyright 2021 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
// This pass converts each tfrt_dist.remote_execute_func op into a combination
// of tfrt_dist.register_tfrt_function op and tfrt_dist.remote_execute op. The
// function to be executed in the remote host will be serialized as a string
// attribute of the tfrt_dist.register_tfrt_function op.
#include "llvm/ADT/STLExtras.h"
#include "llvm/ADT/Sequence.h"
#include "llvm/ADT/SmallVector.h"
#include "mlir/Dialect/Func/IR/FuncOps.h" // from @llvm-project
#include "mlir/IR/Attributes.h" // from @llvm-project
#include "mlir/IR/Builders.h" // from @llvm-project
#include "mlir/IR/SymbolTable.h" // from @llvm-project
#include "mlir/IR/Types.h" // from @llvm-project
#include "mlir/IR/Visitors.h" // from @llvm-project
#include "mlir/Pass/Pass.h" // from @llvm-project
#include "mlir/Pass/PassManager.h" // from @llvm-project
#include "mlir/Support/LLVM.h" // from @llvm-project
#include "mlir/Transforms/Passes.h" // from @llvm-project
#include "mlir/Transforms/RegionUtils.h" // from @llvm-project
#include "tensorflow/compiler/mlir/tensorflow/ir/tf_device.h"
#include "tensorflow/compiler/mlir/tensorflow/transforms/passes.h"
#include "tensorflow/compiler/mlir/tensorflow/utils/device_util.h"
#include "tensorflow/compiler/mlir/tensorflow/utils/serialize_mlir_module_utils.h"
#include "tensorflow/core/util/device_name_utils.h"
#include "tfrt/basic_kernels/opdefs/types.h" // from @tf_runtime
#include "tfrt/core_runtime/opdefs/types.h" // from @tf_runtime
#include "tfrt/distributed_runtime/opdefs/kernels.h" // from @tf_runtime
#include "tfrt/distributed_runtime/opdefs/types.h" // from @tf_runtime
#include "tfrt/test_kernels/opdefs/test_kernels.h" // from @tf_runtime
namespace tensorflow {
namespace {
constexpr const char* kHost = "host";
constexpr const char* kTFRTDevice = "tfrt.device";
struct DistRemoteRunEncapsulatePass
: public PassWrapper<DistRemoteRunEncapsulatePass,
OperationPass<ModuleOp>> {
MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(DistRemoteRunEncapsulatePass)
llvm::StringRef getArgument() const final {
return "tfrt-dist-remote-run-encapsulate";
}
llvm::StringRef getDescription() const final {
return "This pass looks for a remote_run_func and serialize the callee to "
"a string attribute attached to a remote_register operation, "
"followed by a remote_execute invocation.";
}
void runOnOperation() override;
void getDependentDialects(DialectRegistry& registry) const override {
registry.insert<tfrt::dist::DistributedDialect>();
}
};
LogicalResult EncapsulateFuncAndSerialize(func::FuncOp entry_func,
std::string* serialized_func_module) {
ModuleOp module = entry_func->getParentOfType<ModuleOp>();
SymbolTable entry_module_table(module);
SmallVector<func::FuncOp, 4> referenced({entry_func});
// Create a new module to hold func and all referenced functions.
OwningOpRef<mlir::ModuleOp> module_for_func =
ModuleOp::create(mlir::UnknownLoc::get(entry_func.getContext()));
SymbolTable symbol_table(module_for_func.get());
while (!referenced.empty()) {
func::FuncOp func = referenced.pop_back_val();
// Skip functions that have already been cloned into new module.
if (symbol_table.lookup<func::FuncOp>(func.getName())) continue;
// Find any SymbolRefAttr in func that maps to a FuncOp. We need to clone
// all found FuncOps to new_module to make sure new_module is
// self-contained.
Optional<SymbolTable::UseRange> uses = SymbolTable::getSymbolUses(func);
assert(uses && "expected to be able to collect symbol uses");
for (SymbolTable::SymbolUse use : *uses) {
func::FuncOp referenced_func = entry_module_table.lookup<func::FuncOp>(
use.getSymbolRef().cast<FlatSymbolRefAttr>().getValue());
// Skip Symbols that do not map to a function.
if (!referenced_func) continue;
referenced.emplace_back(referenced_func);
}
func::FuncOp clone = func.clone();
if (clone.getName() == entry_func.getName()) {
clone.setPublic();
} else {
clone.setPrivate();
}
symbol_table.insert(clone);
}
*serialized_func_module =
tensorflow::SerializeMlirModule(module_for_func.get());
return success();
}
void DistRemoteRunEncapsulatePass::runOnOperation() {
mlir::TF::RuntimeDevices devices;
ModuleOp module = getOperation();
SymbolTable symtab(module);
Type chain_type = tfrt::compiler::ChainType::get(&getContext());
Type remote_object_id_ty = tfrt::dist::RemoteObjectIdType::get(&getContext());
Type tensor_handle_ty = tfrt::corert::TensorHandleType::get(&getContext());
module.walk([&](tfrt::dist::RemoteExecuteFuncOp remote_exec_op) {
FlatSymbolRefAttr callee_sym = remote_exec_op.calleeAttr();
func::FuncOp callee = symtab.lookup<func::FuncOp>(callee_sym.getValue());
if (!callee) {
remote_exec_op.emitOpError("callee function ")
<< callee_sym.getValue() << " is not found";
signalPassFailure();
return WalkResult::interrupt();
}
std::string txt_module;
if (failed(EncapsulateFuncAndSerialize(callee, &txt_module))) {
remote_exec_op.emitOpError("failed to serialize the callee function ")
<< callee.getName();
signalPassFailure();
return WalkResult::interrupt();
}
Location loc = remote_exec_op.getLoc();
StringAttr callee_name =
StringAttr::get(&getContext(), callee_sym.getValue());
OpBuilder builder(remote_exec_op);
auto register_op = builder.create<tfrt::dist::RegisterTFRTFunctionOp>(
loc, chain_type, remote_exec_op.in_op_chain(), remote_exec_op.context(),
remote_exec_op.remote_task(),
StringAttr::get(&getContext(), txt_module), callee_name);
// Build the device assignment for the results
// TODO(tfrt-devs): Define properly MLIR types and operations
SmallVector<Attribute, 8> result_devices;
for (const auto& result : llvm::enumerate(remote_exec_op.results())) {
StringAttr device =
callee.getResultAttrOfType<StringAttr>(result.index(), kTFRTDevice);
if (!device) {
// The result might not have the device attribute if it is added by
// the tf-to-tfrt pass. Use the first CPU on the remote host as the
// device of this result.
DeviceNameUtils::ParsedName parsed_name;
if (StringAttr host_attr = callee->getAttrOfType<StringAttr>(kHost)) {
auto host = host_attr.getValue();
DeviceNameUtils::ParseFullName({host.data(), host.size()},
&parsed_name);
}
parsed_name.has_type = true;
parsed_name.type = "CPU";
parsed_name.has_id = true;
parsed_name.id = 0;
device = StringAttr::get(
&getContext(), DeviceNameUtils::ParsedNameToString(parsed_name));
}
result_devices.push_back(std::move(device));
}
// IDEA(donglin): Update the create_remote_execute_spec kernel to use Device
// object instead of Device string.
Type remote_spec_ty = tfrt::dist::RemoteExecuteSpecType::get(&getContext());
auto result_devices_attr = ArrayAttr::get(&getContext(), result_devices);
auto remote_spec = builder.create<tfrt::dist::CreateRemoteExecuteSpecOp>(
loc, remote_spec_ty, remote_exec_op.context(), result_devices_attr);
// If original argument is already tfrt_dist.remote_object_id, use it
// directly. If it is TensorHandle, insert an op to extract the
// tfrt_dist.remote_object_id from it. Otherwise, emit an error.
SmallVector<Value, 4> arguments;
for (Value value : remote_exec_op.callee_args()) {
if (value.getType().isa<tfrt::dist::RemoteObjectIdType>()) {
arguments.push_back(value);
} else if (value.getType().isa<tfrt::corert::TensorHandleType>()) {
auto new_op = builder.create<tfrt::dist::GetRemoteObjectIdFromTHOp>(
loc, remote_object_id_ty, value);
arguments.push_back(new_op.result());
} else {
remote_exec_op.emitOpError(
"callee argument type should be either "
"TensorHandle or RemoteObjectId");
signalPassFailure();
return WalkResult::interrupt();
}
}
// Result types are 1 chain, followed by `num_th_results + 1`
// tfrt_dist.remote_object_id results, followed by `num_th_results`
// corert.tensorhandle results.
int32_t num_th_results = remote_exec_op.results().size() - 1;
SmallVector<Type, 8> result_types;
result_types.push_back(chain_type);
for (int count : llvm::seq<int>(0, num_th_results + 1)) {
(void)count;
result_types.push_back(remote_object_id_ty);
}
for (int count : llvm::seq<int>(0, num_th_results)) {
(void)count;
result_types.push_back(tensor_handle_ty);
}
auto new_remote_exec_th_op = builder.create<tfrt::dist::RemoteExecuteTHOp>(
loc, result_types, register_op.out_op_chain(), remote_exec_op.context(),
remote_exec_op.remote_task(), remote_spec, num_th_results,
callee_name.getValue(), std::move(arguments));
// The part of the new results to replace the original results are 2 chains,
// followed `num_th_results` corert.tesnorhandle results from the callee
// function.
SmallVector<Value, 4> new_results;
new_results.push_back(new_remote_exec_th_op.getResult(0));
new_results.push_back(new_remote_exec_th_op.getResult(1));
for (int i : llvm::seq<int>(0, num_th_results)) {
new_results.push_back(
new_remote_exec_th_op.getResult(i + 2 + num_th_results));
}
remote_exec_op.replaceAllUsesWith(new_results);
remote_exec_op.erase();
return WalkResult::advance();
});
}
} // namespace
std::unique_ptr<OperationPass<ModuleOp>> CreateDistRemoteRunEncapsulatePass() {
return std::make_unique<DistRemoteRunEncapsulatePass>();
}
static PassRegistration<DistRemoteRunEncapsulatePass> pass;
} // namespace tensorflow