Internal change

PiperOrigin-RevId: 335680049
Change-Id: I91e6edc767caf596d3cf1a28c075cc87388043e2
diff --git a/tensorflow/compiler/jit/BUILD b/tensorflow/compiler/jit/BUILD
index 76b411b..da3db17 100644
--- a/tensorflow/compiler/jit/BUILD
+++ b/tensorflow/compiler/jit/BUILD
@@ -283,7 +283,6 @@
         "//tensorflow/compiler/xla:parse_flags_from_env",
         "//tensorflow/core:framework_internal",
         "//tensorflow/core:lib",
-        "//tensorflow/core:protos_all_cc",
         "@com_google_absl//absl/base",
         "@com_google_absl//absl/strings",
     ],
diff --git a/tensorflow/compiler/jit/flags.cc b/tensorflow/compiler/jit/flags.cc
index 01e43b0..ee7daf0 100644
--- a/tensorflow/compiler/jit/flags.cc
+++ b/tensorflow/compiler/jit/flags.cc
@@ -167,6 +167,9 @@
   jitter_flags = new IntroduceFloatingPointJitterPassFlags;
   jitter_flags->jitter_amount = 1e-5;
 
+  mlir_flags = new MlirCommonFlags;
+  mlir_flags->tf_mlir_enable_mlir_bridge = false;
+
   auto setter_for_jitter_tensor_names = [](string sequence) {
     jitter_flags->tensor_names = absl::StrSplit(sequence, ',');
     return true;
@@ -212,28 +215,14 @@
        Flag("tf_introduce_floating_point_jitter_amount",
             &jitter_flags->jitter_amount,
             "The amount of jitter to introduce.  This amount is added to each "
-            "element in the tensors named in `tensor_names.")});
+            "element in the tensors named in `tensor_names."),
 
-  bool enable_mlir_bridge = false;
-  flag_list->emplace_back(
-      "tf_mlir_enable_mlir_bridge", &enable_mlir_bridge,
-      "Enables experimental MLIR-Based TensorFlow Compiler Bridge.");
-  const Flag& enable_mlir_bridge_flag = flag_list->back();
+       Flag("tf_mlir_enable_mlir_bridge",
+            &mlir_flags->tf_mlir_enable_mlir_bridge,
+            "Enables experimental MLIR-Based TensorFlow Compiler Bridge.")});
 
   AppendMarkForCompilationPassFlagsInternal(flag_list);
   xla::ParseFlagsFromEnvAndDieIfUnknown("TF_XLA_FLAGS", *flag_list);
-
-  mlir_flags = new MlirCommonFlags;
-  if (enable_mlir_bridge_flag.is_default_initialized()) {
-    mlir_flags->tf_mlir_enable_mlir_bridge =
-        ConfigProto::Experimental::MLIR_BRIDGE_ROLLOUT_UNSPECIFIED;
-  } else if (enable_mlir_bridge) {
-    mlir_flags->tf_mlir_enable_mlir_bridge =
-        ConfigProto::Experimental::MLIR_BRIDGE_ROLLOUT_ENABLED;
-  } else {
-    mlir_flags->tf_mlir_enable_mlir_bridge =
-        ConfigProto::Experimental::MLIR_BRIDGE_ROLLOUT_DISABLED;
-  }
 }
 
 }  // namespace
diff --git a/tensorflow/compiler/jit/flags.h b/tensorflow/compiler/jit/flags.h
index a0860da..5612b3b 100644
--- a/tensorflow/compiler/jit/flags.h
+++ b/tensorflow/compiler/jit/flags.h
@@ -19,7 +19,6 @@
 #include <vector>
 
 #include "tensorflow/core/platform/types.h"
-#include "tensorflow/core/protobuf/config.pb.h"
 #include "tensorflow/core/util/command_line_flags.h"
 
 namespace tensorflow {
@@ -136,7 +135,7 @@
 
 // Flags for common MLIR configurations.
 struct MlirCommonFlags {
-  ConfigProto::Experimental::MlirBridgeRollout tf_mlir_enable_mlir_bridge;
+  bool tf_mlir_enable_mlir_bridge;
 };
 
 // Return a pointer to the DumpGraphFlags struct;
diff --git a/tensorflow/compiler/jit/xla_kernel_creator.cc b/tensorflow/compiler/jit/xla_kernel_creator.cc
index 0ccdacb..d4a69da 100644
--- a/tensorflow/compiler/jit/xla_kernel_creator.cc
+++ b/tensorflow/compiler/jit/xla_kernel_creator.cc
@@ -89,8 +89,7 @@
   XlaOpRegistry::RegisterCompilationKernels();
 
   // Only check for compilability if the MLIR bridge is not enabled.
-  if (tensorflow::GetMlirCommonFlags()->tf_mlir_enable_mlir_bridge !=
-      tensorflow::ConfigProto::Experimental::MLIR_BRIDGE_ROLLOUT_ENABLED) {
+  if (!GetMlirCommonFlags()->tf_mlir_enable_mlir_bridge) {
     RecursiveCompilabilityChecker::UncompilableNodesMap uncompilable_nodes_map;
     if (!IsCompilable(flr, node_def, &uncompilable_nodes_map)) {
       std::vector<RecursiveCompilabilityChecker::UncompilableNodeInfo>
diff --git a/tensorflow/compiler/tf2xla/mlir_bridge_pass.h b/tensorflow/compiler/tf2xla/mlir_bridge_pass.h
index 8efe2d6..f7541e6 100644
--- a/tensorflow/compiler/tf2xla/mlir_bridge_pass.h
+++ b/tensorflow/compiler/tf2xla/mlir_bridge_pass.h
@@ -47,9 +47,7 @@
 
   bool IsEnabled(const ConfigProto& config_proto) const override {
     return config_proto.experimental().enable_mlir_bridge() ||
-           tensorflow::GetMlirCommonFlags()->tf_mlir_enable_mlir_bridge ==
-               tensorflow::ConfigProto::Experimental::
-                   MLIR_BRIDGE_ROLLOUT_ENABLED;
+           tensorflow::GetMlirCommonFlags()->tf_mlir_enable_mlir_bridge;
   }
 
   // This should be used as a thin mapper around mlir::ModulePass::runOnModule
diff --git a/tensorflow/compiler/tf2xla/xla_compiler.cc b/tensorflow/compiler/tf2xla/xla_compiler.cc
index 549d631..c62b828 100644
--- a/tensorflow/compiler/tf2xla/xla_compiler.cc
+++ b/tensorflow/compiler/tf2xla/xla_compiler.cc
@@ -734,15 +734,13 @@
 
   VLOG(1) << "====================================================";
 #ifdef LIBTPU_ON_GCE
-  if (GetMlirCommonFlags()->tf_mlir_enable_mlir_bridge ==
-      tensorflow::ConfigProto::Experimental::MLIR_BRIDGE_ROLLOUT_ENABLED) {
+  if (GetMlirCommonFlags()->tf_mlir_enable_mlir_bridge) {
     VLOG(1) << "MLIR is not supported in this environment.";
   }
   TF_RETURN_IF_ERROR(
       CompileGraph(options, function_id, std::move(graph), args, result));
 #else
-  if (GetMlirCommonFlags()->tf_mlir_enable_mlir_bridge ==
-      tensorflow::ConfigProto::Experimental::MLIR_BRIDGE_ROLLOUT_ENABLED) {
+  if (GetMlirCommonFlags()->tf_mlir_enable_mlir_bridge) {
     VLOG(1) << "Using MLIR bridge";
     GraphDebugInfo debug_info;
     TF_RETURN_IF_ERROR(CompileGraphToXlaHlo(
diff --git a/tensorflow/core/util/command_line_flags.cc b/tensorflow/core/util/command_line_flags.cc
index 83bb300..00a9cba 100644
--- a/tensorflow/core/util/command_line_flags.cc
+++ b/tensorflow/core/util/command_line_flags.cc
@@ -135,9 +135,8 @@
 Flag::Flag(const char* name, tensorflow::int32* dst, const string& usage_text)
     : name_(name),
       type_(TYPE_INT32),
-      int32_hook_([this, dst](int32 value) {
+      int32_hook_([dst](int32 value) {
         *dst = value;
-        this->default_initialized_ = false;
         return true;
       }),
       int32_default_for_display_(*dst),
@@ -146,9 +145,8 @@
 Flag::Flag(const char* name, tensorflow::int64* dst, const string& usage_text)
     : name_(name),
       type_(TYPE_INT64),
-      int64_hook_([this, dst](int64 value) {
+      int64_hook_([dst](int64 value) {
         *dst = value;
-        this->default_initialized_ = false;
         return true;
       }),
       int64_default_for_display_(*dst),
@@ -157,9 +155,8 @@
 Flag::Flag(const char* name, float* dst, const string& usage_text)
     : name_(name),
       type_(TYPE_FLOAT),
-      float_hook_([this, dst](float value) {
+      float_hook_([dst](float value) {
         *dst = value;
-        this->default_initialized_ = false;
         return true;
       }),
       float_default_for_display_(*dst),
@@ -168,9 +165,8 @@
 Flag::Flag(const char* name, bool* dst, const string& usage_text)
     : name_(name),
       type_(TYPE_BOOL),
-      bool_hook_([this, dst](bool value) {
+      bool_hook_([dst](bool value) {
         *dst = value;
-        this->default_initialized_ = false;
         return true;
       }),
       bool_default_for_display_(*dst),
@@ -179,9 +175,8 @@
 Flag::Flag(const char* name, string* dst, const string& usage_text)
     : name_(name),
       type_(TYPE_STRING),
-      string_hook_([this, dst](string value) {
+      string_hook_([dst](string value) {
         *dst = std::move(value);
-        this->default_initialized_ = false;
         return true;
       }),
       string_default_for_display_(*dst),
diff --git a/tensorflow/core/util/command_line_flags.h b/tensorflow/core/util/command_line_flags.h
index 9011ee1..928ae8a 100644
--- a/tensorflow/core/util/command_line_flags.h
+++ b/tensorflow/core/util/command_line_flags.h
@@ -85,8 +85,6 @@
   Flag(const char* name, std::function<bool(string)> string_hook,
        string default_value_for_display, const string& usage_text);
 
-  bool is_default_initialized() const { return default_initialized_; }
-
  private:
   friend class Flags;
 
@@ -117,7 +115,6 @@
   string string_default_for_display_;
 
   string usage_text_;
-  bool default_initialized_ = true;
 };
 
 class Flags {
diff --git a/tensorflow/python/tfe_wrapper.cc b/tensorflow/python/tfe_wrapper.cc
index 1dd046f..36165de 100644
--- a/tensorflow/python/tfe_wrapper.cc
+++ b/tensorflow/python/tfe_wrapper.cc
@@ -580,15 +580,10 @@
 
   // MLIR Logic
   m.def("TF_IsMlirBridgeEnabled", [] {
-    return (tensorflow::GetMlirCommonFlags()->tf_mlir_enable_mlir_bridge ==
-            tensorflow::ConfigProto::Experimental::MLIR_BRIDGE_ROLLOUT_ENABLED);
+    return tensorflow::GetMlirCommonFlags()->tf_mlir_enable_mlir_bridge;
   });
   m.def("TF_EnableMlirBridge", [](bool enabled) {
-    tensorflow::GetMlirCommonFlags()->tf_mlir_enable_mlir_bridge =
-        enabled
-            ? tensorflow::ConfigProto::Experimental::MLIR_BRIDGE_ROLLOUT_ENABLED
-            : tensorflow::ConfigProto::Experimental::
-                  MLIR_BRIDGE_ROLLOUT_DISABLED;
+    tensorflow::GetMlirCommonFlags()->tf_mlir_enable_mlir_bridge = enabled;
   });
   m.def("TF_EnableXlaDevices", [] {
     tensorflow::GetXlaDeviceFlags()->tf_xla_enable_xla_devices = true;