blob: 2b1e29c53470963a3dc03ddd9d4e14bb9ea94a0f [file] [log] [blame]
/* Copyright 2021 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
// This pass inserts corert.transfer op to make sure any argument of any op is
// on the same device of the op itself.
#include "llvm/ADT/SmallVector.h"
#include "llvm/ADT/StringMap.h"
#include "mlir/Dialect/Func/IR/FuncOps.h" // from @llvm-project
#include "mlir/IR/Attributes.h" // from @llvm-project
#include "mlir/IR/Builders.h" // from @llvm-project
#include "mlir/IR/Types.h" // from @llvm-project
#include "mlir/Pass/PassManager.h" // from @llvm-project
#include "mlir/Transforms/Passes.h" // from @llvm-project
#include "tensorflow/core/util/device_name_utils.h"
#include "tfrt/basic_kernels/opdefs/basic_kernels.h" // from @tf_runtime
#include "tfrt/basic_kernels/opdefs/types.h" // from @tf_runtime
#include "tfrt/core_runtime/opdefs/core_runtime.h" // from @tf_runtime
#include "tfrt/core_runtime/opdefs/types.h" // from @tf_runtime
namespace tensorflow {
namespace {
using DeviceNameUtils = ::tensorflow::DeviceNameUtils;
constexpr const char *kDeviceAttr = "device";
constexpr const char *kTFRTDeviceAttr = "tfrt.device";
// TODO(b/175480458): Do not assign default device once every op in the TF
// dialect has the device attribute.
constexpr const char *kDefaultDevice =
"/job:localhost/replica:0/task:0/device:CPU:0";
// This method canonicalizes the device name so that we can use string
// comparison to see if two devices are the same. It does the following
// transformations:
// 1) Set device ID to 0 if device ID is not already specified.
// 2) Change the device type to uppercase string.
static std::string CanonicalizeDeviceName(const std::string &device) {
if (device.empty()) return kDefaultDevice;
DeviceNameUtils::ParsedName parsed_name;
if (!device.empty() && device.at(0) == '/') {
DeviceNameUtils::ParseFullName(device, &parsed_name);
} else {
DeviceNameUtils::ParseFullName("/device:" + device, &parsed_name);
}
if (!parsed_name.has_id) {
parsed_name.has_id = true;
parsed_name.id = 0;
}
if (parsed_name.type == "cpu")
parsed_name.type = "CPU";
else if (parsed_name.type == "gpu")
parsed_name.type = "GPU";
else if (parsed_name.type == "tpu")
parsed_name.type = "TPU";
return DeviceNameUtils::ParsedNameToString(parsed_name);
}
// Return the device of the given operation.
static std::string GetDevice(Operation *op) {
std::string device = "";
if (StringAttr device_attr = op->getAttrOfType<StringAttr>(kDeviceAttr)) {
device = device_attr.getValue().str();
} else if (auto execute_op = llvm::dyn_cast<tfrt::corert::ExecuteOp>(op)) {
SmallVector<std::pair<StringRef, Attribute>, 4> attrs;
execute_op.getOpAttrs(&attrs);
for (std::pair<StringRef, Attribute> entry : attrs) {
if (entry.first == kDeviceAttr && entry.second.isa<StringAttr>()) {
device = entry.second.cast<StringAttr>().getValue().str();
break;
}
}
}
return CanonicalizeDeviceName(device);
}
// Return the device of the given value.
static std::string GetDevice(mlir::Value value, func::FuncOp parent_func_op) {
std::string device = "";
if (BlockArgument block_arg = value.dyn_cast<BlockArgument>()) {
if (StringAttr device_attr = parent_func_op.getArgAttrOfType<StringAttr>(
block_arg.getArgNumber(), kTFRTDeviceAttr)) {
device = device_attr.getValue().str();
}
} else {
device = GetDevice(value.getDefiningOp());
}
return CanonicalizeDeviceName(device);
}
struct CrossDeviceTransferPass
: public PassWrapper<CrossDeviceTransferPass, OperationPass<func::FuncOp>> {
MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(CrossDeviceTransferPass)
void runOnOperation() override;
llvm::StringRef getArgument() const final {
return "tfrt-cross-device-transfer";
}
llvm::StringRef getDescription() const final {
return "This pass inserts corert.transfer op to make sure any argument of "
"any op is on the same device of the op itself.";
}
};
void CrossDeviceTransferPass::runOnOperation() {
func::FuncOp func_op = getOperation();
llvm::DenseMap<mlir::Value, llvm::StringMap<mlir::Value>>
transferred_value_by_value_and_device;
func_op.getBody().walk([&](Operation *op) {
if (op->hasTrait<OpTrait::IsTerminator>()) return WalkResult::advance();
// Do not transfer the argument of corert.transfer op.
if (llvm::isa<tfrt::corert::TransferOp>(op)) return WalkResult::advance();
OpBuilder builder(op);
std::string dst_device = GetDevice(op);
mlir::Type tensor_type_type =
builder.getType<::tfrt::compiler::TensorTypeType>();
mlir::Type device_type = builder.getType<::tfrt::compiler::DeviceType>();
for (mlir::Value arg : op->getOperands()) {
// Do not transfer non-TensorHandle values.
if (!arg.getType().isa<tfrt::corert::TensorHandleType>()) continue;
// Do not transfer the result of corert.transfer op.
if (OpResult op_result = arg.dyn_cast<OpResult>()) {
Operation *defining_op = arg.getDefiningOp();
if (llvm::isa<tfrt::corert::TransferOp>(defining_op)) continue;
}
std::string src_device = GetDevice(arg, func_op);
if (DeviceNameUtils::LocalName(src_device) ==
DeviceNameUtils::LocalName(dst_device))
continue;
// Re-use the value already transferred to the given device.
llvm::StringMap<mlir::Value> &transferred_value_by_device =
transferred_value_by_value_and_device[arg];
auto iter = transferred_value_by_device.find(dst_device);
if (iter != transferred_value_by_device.end()) {
op->replaceUsesOfWith(arg, iter->second);
continue;
}
mlir::Value chain_in = func_op.getArgument(0);
auto get_device_op = builder.create<tfrt::compiler::GetDeviceOp>(
op->getLoc(), device_type, chain_in, dst_device);
auto get_tensor_type_op =
builder.create<tfrt::corert::GetDstTensorTypeOp>(
op->getLoc(), tensor_type_type, arg, get_device_op.getResult());
auto transfer_op = builder.create<tfrt::corert::TransferOp>(
op->getLoc(), arg.getType(), arg, get_device_op.getResult(),
get_tensor_type_op.getResult());
mlir::Value new_arg = transfer_op.getResult();
transferred_value_by_device[dst_device] = new_arg;
op->replaceUsesOfWith(arg, new_arg);
}
return WalkResult::advance();
});
}
} // namespace
std::unique_ptr<OperationPass<func::FuncOp>> CreateCrossDeviceTransferPass() {
return std::make_unique<CrossDeviceTransferPass>();
}
static PassRegistration<CrossDeviceTransferPass> pass;
} // namespace tensorflow