Internal change
PiperOrigin-RevId: 335680049
Change-Id: I91e6edc767caf596d3cf1a28c075cc87388043e2
diff --git a/tensorflow/compiler/jit/BUILD b/tensorflow/compiler/jit/BUILD
index 76b411b..da3db17 100644
--- a/tensorflow/compiler/jit/BUILD
+++ b/tensorflow/compiler/jit/BUILD
@@ -283,7 +283,6 @@
"//tensorflow/compiler/xla:parse_flags_from_env",
"//tensorflow/core:framework_internal",
"//tensorflow/core:lib",
- "//tensorflow/core:protos_all_cc",
"@com_google_absl//absl/base",
"@com_google_absl//absl/strings",
],
diff --git a/tensorflow/compiler/jit/flags.cc b/tensorflow/compiler/jit/flags.cc
index 01e43b0..ee7daf0 100644
--- a/tensorflow/compiler/jit/flags.cc
+++ b/tensorflow/compiler/jit/flags.cc
@@ -167,6 +167,9 @@
jitter_flags = new IntroduceFloatingPointJitterPassFlags;
jitter_flags->jitter_amount = 1e-5;
+ mlir_flags = new MlirCommonFlags;
+ mlir_flags->tf_mlir_enable_mlir_bridge = false;
+
auto setter_for_jitter_tensor_names = [](string sequence) {
jitter_flags->tensor_names = absl::StrSplit(sequence, ',');
return true;
@@ -212,28 +215,14 @@
Flag("tf_introduce_floating_point_jitter_amount",
&jitter_flags->jitter_amount,
"The amount of jitter to introduce. This amount is added to each "
- "element in the tensors named in `tensor_names.")});
+ "element in the tensors named in `tensor_names."),
- bool enable_mlir_bridge = false;
- flag_list->emplace_back(
- "tf_mlir_enable_mlir_bridge", &enable_mlir_bridge,
- "Enables experimental MLIR-Based TensorFlow Compiler Bridge.");
- const Flag& enable_mlir_bridge_flag = flag_list->back();
+ Flag("tf_mlir_enable_mlir_bridge",
+ &mlir_flags->tf_mlir_enable_mlir_bridge,
+ "Enables experimental MLIR-Based TensorFlow Compiler Bridge.")});
AppendMarkForCompilationPassFlagsInternal(flag_list);
xla::ParseFlagsFromEnvAndDieIfUnknown("TF_XLA_FLAGS", *flag_list);
-
- mlir_flags = new MlirCommonFlags;
- if (enable_mlir_bridge_flag.is_default_initialized()) {
- mlir_flags->tf_mlir_enable_mlir_bridge =
- ConfigProto::Experimental::MLIR_BRIDGE_ROLLOUT_UNSPECIFIED;
- } else if (enable_mlir_bridge) {
- mlir_flags->tf_mlir_enable_mlir_bridge =
- ConfigProto::Experimental::MLIR_BRIDGE_ROLLOUT_ENABLED;
- } else {
- mlir_flags->tf_mlir_enable_mlir_bridge =
- ConfigProto::Experimental::MLIR_BRIDGE_ROLLOUT_DISABLED;
- }
}
} // namespace
diff --git a/tensorflow/compiler/jit/flags.h b/tensorflow/compiler/jit/flags.h
index a0860da..5612b3b 100644
--- a/tensorflow/compiler/jit/flags.h
+++ b/tensorflow/compiler/jit/flags.h
@@ -19,7 +19,6 @@
#include <vector>
#include "tensorflow/core/platform/types.h"
-#include "tensorflow/core/protobuf/config.pb.h"
#include "tensorflow/core/util/command_line_flags.h"
namespace tensorflow {
@@ -136,7 +135,7 @@
// Flags for common MLIR configurations.
struct MlirCommonFlags {
- ConfigProto::Experimental::MlirBridgeRollout tf_mlir_enable_mlir_bridge;
+ bool tf_mlir_enable_mlir_bridge;
};
// Return a pointer to the DumpGraphFlags struct;
diff --git a/tensorflow/compiler/jit/xla_kernel_creator.cc b/tensorflow/compiler/jit/xla_kernel_creator.cc
index 0ccdacb..d4a69da 100644
--- a/tensorflow/compiler/jit/xla_kernel_creator.cc
+++ b/tensorflow/compiler/jit/xla_kernel_creator.cc
@@ -89,8 +89,7 @@
XlaOpRegistry::RegisterCompilationKernels();
// Only check for compilability if the MLIR bridge is not enabled.
- if (tensorflow::GetMlirCommonFlags()->tf_mlir_enable_mlir_bridge !=
- tensorflow::ConfigProto::Experimental::MLIR_BRIDGE_ROLLOUT_ENABLED) {
+ if (!GetMlirCommonFlags()->tf_mlir_enable_mlir_bridge) {
RecursiveCompilabilityChecker::UncompilableNodesMap uncompilable_nodes_map;
if (!IsCompilable(flr, node_def, &uncompilable_nodes_map)) {
std::vector<RecursiveCompilabilityChecker::UncompilableNodeInfo>
diff --git a/tensorflow/compiler/tf2xla/mlir_bridge_pass.h b/tensorflow/compiler/tf2xla/mlir_bridge_pass.h
index 8efe2d6..f7541e6 100644
--- a/tensorflow/compiler/tf2xla/mlir_bridge_pass.h
+++ b/tensorflow/compiler/tf2xla/mlir_bridge_pass.h
@@ -47,9 +47,7 @@
bool IsEnabled(const ConfigProto& config_proto) const override {
return config_proto.experimental().enable_mlir_bridge() ||
- tensorflow::GetMlirCommonFlags()->tf_mlir_enable_mlir_bridge ==
- tensorflow::ConfigProto::Experimental::
- MLIR_BRIDGE_ROLLOUT_ENABLED;
+ tensorflow::GetMlirCommonFlags()->tf_mlir_enable_mlir_bridge;
}
// This should be used as a thin mapper around mlir::ModulePass::runOnModule
diff --git a/tensorflow/compiler/tf2xla/xla_compiler.cc b/tensorflow/compiler/tf2xla/xla_compiler.cc
index 549d631..c62b828 100644
--- a/tensorflow/compiler/tf2xla/xla_compiler.cc
+++ b/tensorflow/compiler/tf2xla/xla_compiler.cc
@@ -734,15 +734,13 @@
VLOG(1) << "====================================================";
#ifdef LIBTPU_ON_GCE
- if (GetMlirCommonFlags()->tf_mlir_enable_mlir_bridge ==
- tensorflow::ConfigProto::Experimental::MLIR_BRIDGE_ROLLOUT_ENABLED) {
+ if (GetMlirCommonFlags()->tf_mlir_enable_mlir_bridge) {
VLOG(1) << "MLIR is not supported in this environment.";
}
TF_RETURN_IF_ERROR(
CompileGraph(options, function_id, std::move(graph), args, result));
#else
- if (GetMlirCommonFlags()->tf_mlir_enable_mlir_bridge ==
- tensorflow::ConfigProto::Experimental::MLIR_BRIDGE_ROLLOUT_ENABLED) {
+ if (GetMlirCommonFlags()->tf_mlir_enable_mlir_bridge) {
VLOG(1) << "Using MLIR bridge";
GraphDebugInfo debug_info;
TF_RETURN_IF_ERROR(CompileGraphToXlaHlo(
diff --git a/tensorflow/core/util/command_line_flags.cc b/tensorflow/core/util/command_line_flags.cc
index 83bb300..00a9cba 100644
--- a/tensorflow/core/util/command_line_flags.cc
+++ b/tensorflow/core/util/command_line_flags.cc
@@ -135,9 +135,8 @@
Flag::Flag(const char* name, tensorflow::int32* dst, const string& usage_text)
: name_(name),
type_(TYPE_INT32),
- int32_hook_([this, dst](int32 value) {
+ int32_hook_([dst](int32 value) {
*dst = value;
- this->default_initialized_ = false;
return true;
}),
int32_default_for_display_(*dst),
@@ -146,9 +145,8 @@
Flag::Flag(const char* name, tensorflow::int64* dst, const string& usage_text)
: name_(name),
type_(TYPE_INT64),
- int64_hook_([this, dst](int64 value) {
+ int64_hook_([dst](int64 value) {
*dst = value;
- this->default_initialized_ = false;
return true;
}),
int64_default_for_display_(*dst),
@@ -157,9 +155,8 @@
Flag::Flag(const char* name, float* dst, const string& usage_text)
: name_(name),
type_(TYPE_FLOAT),
- float_hook_([this, dst](float value) {
+ float_hook_([dst](float value) {
*dst = value;
- this->default_initialized_ = false;
return true;
}),
float_default_for_display_(*dst),
@@ -168,9 +165,8 @@
Flag::Flag(const char* name, bool* dst, const string& usage_text)
: name_(name),
type_(TYPE_BOOL),
- bool_hook_([this, dst](bool value) {
+ bool_hook_([dst](bool value) {
*dst = value;
- this->default_initialized_ = false;
return true;
}),
bool_default_for_display_(*dst),
@@ -179,9 +175,8 @@
Flag::Flag(const char* name, string* dst, const string& usage_text)
: name_(name),
type_(TYPE_STRING),
- string_hook_([this, dst](string value) {
+ string_hook_([dst](string value) {
*dst = std::move(value);
- this->default_initialized_ = false;
return true;
}),
string_default_for_display_(*dst),
diff --git a/tensorflow/core/util/command_line_flags.h b/tensorflow/core/util/command_line_flags.h
index 9011ee1..928ae8a 100644
--- a/tensorflow/core/util/command_line_flags.h
+++ b/tensorflow/core/util/command_line_flags.h
@@ -85,8 +85,6 @@
Flag(const char* name, std::function<bool(string)> string_hook,
string default_value_for_display, const string& usage_text);
- bool is_default_initialized() const { return default_initialized_; }
-
private:
friend class Flags;
@@ -117,7 +115,6 @@
string string_default_for_display_;
string usage_text_;
- bool default_initialized_ = true;
};
class Flags {
diff --git a/tensorflow/python/tfe_wrapper.cc b/tensorflow/python/tfe_wrapper.cc
index 1dd046f..36165de 100644
--- a/tensorflow/python/tfe_wrapper.cc
+++ b/tensorflow/python/tfe_wrapper.cc
@@ -580,15 +580,10 @@
// MLIR Logic
m.def("TF_IsMlirBridgeEnabled", [] {
- return (tensorflow::GetMlirCommonFlags()->tf_mlir_enable_mlir_bridge ==
- tensorflow::ConfigProto::Experimental::MLIR_BRIDGE_ROLLOUT_ENABLED);
+ return tensorflow::GetMlirCommonFlags()->tf_mlir_enable_mlir_bridge;
});
m.def("TF_EnableMlirBridge", [](bool enabled) {
- tensorflow::GetMlirCommonFlags()->tf_mlir_enable_mlir_bridge =
- enabled
- ? tensorflow::ConfigProto::Experimental::MLIR_BRIDGE_ROLLOUT_ENABLED
- : tensorflow::ConfigProto::Experimental::
- MLIR_BRIDGE_ROLLOUT_DISABLED;
+ tensorflow::GetMlirCommonFlags()->tf_mlir_enable_mlir_bridge = enabled;
});
m.def("TF_EnableXlaDevices", [] {
tensorflow::GetXlaDeviceFlags()->tf_xla_enable_xla_devices = true;