blob: 32eeba7af47f6767d424f6dbd72b1aa7af3e29bf [file] [log] [blame]
/* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "tensorflow/compiler/tf2xla/mlir_bridge_pass.h"
#include <string>
#include "tensorflow/compiler/mlir/tensorflow/ir/tf_structs.h"
#include "tensorflow/compiler/mlir/tensorflow/transforms/bridge.h"
#include "tensorflow/compiler/mlir/tensorflow/utils/device_util.h"
#include "tensorflow/core/common_runtime/device_set.h"
#include "tensorflow/core/lib/monitoring/gauge.h"
#include "tensorflow/core/public/session_options.h"
#include "tensorflow/core/util/device_name_utils.h"
namespace tensorflow {
auto* mlir_bridge_gauge_v1 = monitoring::Gauge<bool, 0>::New(
"/tensorflow/config/experimental/enable_mlir_bridge_gauge_v1",
"Tracks usage of the MLIR-based TF2XLA bridge among TF1 models");
auto* mlir_bridge_gauge_v2 = monitoring::Gauge<bool, 0>::New(
"/tensorflow/config/experimental/enable_mlir_bridge_gauge_v2",
"Tracks usage of the MLIR-based TF2XLA bridge among TF2 models");
namespace {
constexpr char kTPUReplicateAttr[] = "_tpu_replicate";
bool HasTPUDevice(mlir::ModuleOp module) {
mlir::TF::RuntimeDevices devices;
if (failed(GetDevicesFromOp(module.getOperation(), &devices))) return false;
return absl::c_any_of(
devices.device_names(),
[](const tensorflow::DeviceNameUtils::ParsedName& device) {
return device.has_type && device.type == "TPU";
});
}
bool HasTPUOp(mlir::ModuleOp module) {
auto walk_result = module.walk([&](mlir::Operation* op) {
auto replicate_attr =
op->getAttrOfType<mlir::StringAttr>(kTPUReplicateAttr);
if (replicate_attr) return mlir::WalkResult::interrupt();
return mlir::WalkResult::advance();
});
return walk_result.wasInterrupted();
}
// Checks that the module has both - TPU devices in its device list and contains
// TPU ops (identifed by `_tpu_replicate` attribute on ops).
bool HasTPUDevicesAndOps(mlir::ModuleOp module) {
return HasTPUDevice(module) && HasTPUOp(module);
}
bool HasTPUDevice(const DeviceSet& device_set) {
for (const Device* device : device_set.devices()) {
if (!device) continue;
const DeviceNameUtils::ParsedName& name = device->parsed_name();
if (name.has_type && name.type == "TPU") return true;
}
return false;
}
} // namespace
// Analyzes the user requested policy as well as the contents of the graph and
// determines whether the MLIR Bridge should be run.
//
// If the user explicitly requests the bridge be enabled or disabled, this
// function will respect the request. If the user does not explicitly request
// enabled or disabled, it will decide whether or not to run the bridge.
//
// The config_proto param is a required input for all TF1 graphs but it is
// redundant for TF2 graphs.
MlirOptimizationPassState MlirBridgePass::GetPassState(
const DeviceSet* device_set, const ConfigProto& config_proto,
const Graph& graph) const {
// Skip MLIR TPU Bridge if no TPU devices found.
if (device_set && !HasTPUDevice(*device_set)) {
return MlirOptimizationPassState::Disabled;
}
// We set `uses_uninitialized_resource_args` to false here because the first
// phase of the bridge is not affected by uninitialized resource args.
MlirBridgeRolloutPolicy policy = GetMlirBridgeRolloutPolicy(
graph, config_proto, /*uses_uninitialized_resource_args=*/false);
switch (policy) {
case MlirBridgeRolloutPolicy::kEnabledByUser:
return MlirOptimizationPassState::Enabled;
case MlirBridgeRolloutPolicy::kEnabledAfterGraphAnalysis:
return MlirOptimizationPassState::ShadowEnabled;
case MlirBridgeRolloutPolicy::kDisabledByUser:
case MlirBridgeRolloutPolicy::kDisabledAfterGraphAnalysis:
return MlirOptimizationPassState::Disabled;
}
}
namespace {
// Log just once by default (on default log level), and let the user adjust
// the log level for more detailed logging.
#define LOG_AT_LEAST_ONCE(log_message) \
{ \
if (VLOG_IS_ON(1)) { \
VLOG(1) << log_message; \
} else { \
LOG_FIRST_N(INFO, 1) << log_message; \
} \
}
} // namespace
// This runs the first phase of the "bridge", transforming the graph in a form
// that can be executed with delegation of some computations to an accelerator.
// This builds on the model of XLA where a subset of the graph is encapsulated
// and attached to a "compile" operation, whose result is fed to an "execute"
// operation. The kernel for these operations is responsible to lower the
// encapsulated graph to a particular device.
Status MlirBridgePass::Run(const ConfigProto& config_proto,
mlir::ModuleOp module, const Graph& graph) {
// Set device_set to nullptr here as the device specific checks are performed
// based on the devices in the module.
if (GetPassState(/*device_set=*/nullptr, config_proto, graph) ==
MlirOptimizationPassState::Disabled) {
LOG_AT_LEAST_ONCE("Skipping MLIR TPU Bridge, session flag not enabled");
mlir_bridge_gauge_v2->GetCell()->Set(false);
return Status::OK();
}
// Skip MLIR TPU Bridge if no TPU devices or TPU ops found.
if (!HasTPUDevicesAndOps(module)) {
LOG_AT_LEAST_ONCE(
"Skipping MLIR TPU Bridge, no TPU devices or TPU ops found");
return Status::OK();
}
LOG_AT_LEAST_ONCE("Running MLIR TPU Bridge");
mlir_bridge_gauge_v2->GetCell()->Set(true);
TF_RETURN_IF_ERROR(
mlir::TFTPU::TPUBridge(module, /*enable_logging=*/VLOG_IS_ON(1)));
return Status::OK();
}
bool MlirBridgeV1CompatPass::IsEnabled(const DeviceSet* device_set,
const ConfigProto& config_proto,
const Graph& graph) const {
// Skip MLIR TPU Bridge if no TPU devices found.
if (device_set && !HasTPUDevice(*device_set)) return false;
// Do not run the bridge if it's enabled by the graph analysis,
// only run if it's enabled by the user explicitly.
// We set `uses_uninitialized_resource_args` to false here because the first
// phase of the bridge is not affected by uninitialized resource args.
MlirBridgeRolloutPolicy policy = GetMlirBridgeRolloutPolicy(
graph, config_proto, /*uses_uninitialized_resource_args=*/false);
return policy == MlirBridgeRolloutPolicy::kEnabledByUser;
}
Status MlirBridgeV1CompatPass::Run(const GraphOptimizationPassOptions& options,
mlir::ModuleOp module) {
// Skip function graphs as MlirBridgePass will be used instead.
if (options.is_function_graph) return Status::OK();
// Set device_set to nullptr here as the device specific checks are performed
// based on the devices in the module.
if (!IsEnabled(/*device_set=*/nullptr, options.session_options->config,
**options.graph)) {
LOG_AT_LEAST_ONCE(
"Skipping MLIR TPU Bridge V1 Compat, session flag not enabled");
mlir_bridge_gauge_v1->GetCell()->Set(false);
return Status::OK();
}
// Skip MLIR TPU Bridge if no TPU devices or TPU ops found.
if (!HasTPUDevicesAndOps(module)) {
LOG_AT_LEAST_ONCE(
"Skipping MLIR TPU Bridge V1 Compat, no TPU devices or TPU ops found");
return Status::OK();
}
LOG_AT_LEAST_ONCE("Running MLIR TPU Bridge V1 Compat");
mlir_bridge_gauge_v1->GetCell()->Set(true);
TF_RETURN_IF_ERROR(
mlir::TFTPU::TPUBridgeV1Compat(module, /*enable_logging=*/VLOG_IS_ON(1)));
return Status::OK();
}
} // namespace tensorflow