Avoid depending on the implementation of jit:flags in pywrap_tfe.
It was causing the IsXlaEnabled function to return false erroneously.
PiperOrigin-RevId: 296368921
Change-Id: I22507c7fa4bcf8804a333f4eafe38d4c009b76d2
diff --git a/tensorflow/python/BUILD b/tensorflow/python/BUILD
index 8126e99..63593f1 100644
--- a/tensorflow/python/BUILD
+++ b/tensorflow/python/BUILD
@@ -5822,6 +5822,7 @@
"//tensorflow/c:checkpoint_reader", # checkpoint_reader
"//tensorflow/c:python_api", # tf_session
"//tensorflow/c:tf_status_helper", # tfe
+ "//tensorflow/compiler/jit:flags", #tfe
"//tensorflow/compiler/mlir/python:mlir", # mlir
"//tensorflow/core:core_cpu_base_no_ops", # tf_session
"//tensorflow/core:core_cpu_impl", # device_lib
@@ -8046,6 +8047,7 @@
"@com_google_absl//absl/types:optional",
"@pybind11",
"//third_party/python_runtime:headers",
+ "//tensorflow/compiler/jit:flags_headers_only",
"//tensorflow/core:core_cpu_headers_lib",
"//tensorflow/core:framework",
"//tensorflow/core:lib",
@@ -8054,13 +8056,11 @@
"//tensorflow/core/platform:platform",
] + if_static(
extra_deps = [
- "//tensorflow/compiler/jit:flags",
"//tensorflow/core:eager_service_proto_cc",
"//tensorflow/core:master_proto_cc",
"//tensorflow/core:worker_proto_cc",
],
otherwise = [
- "//tensorflow/compiler/jit:flags_headers_only",
"//tensorflow/core:eager_service_proto_cc_headers_only",
"//tensorflow/core:master_proto_cc_headers_only",
"//tensorflow/core:worker_proto_cc_headers_only",
diff --git a/tensorflow/tools/def_file_filter/symbols_pybind.txt b/tensorflow/tools/def_file_filter/symbols_pybind.txt
index 7bf9f56..1298479 100644
--- a/tensorflow/tools/def_file_filter/symbols_pybind.txt
+++ b/tensorflow/tools/def_file_filter/symbols_pybind.txt
@@ -340,3 +340,6 @@
[cost_analyzer_lib] # cost_analyzer
tensorflow::grappler::CostAnalyzer::CostAnalyzer
tensorflow::grappler::CostAnalyzer::GenerateReport
+
+[flags] # tfe
+tensorflow::IsXlaEnabled