Removes back link from external graph transform...
PiperOrigin-RevId: 360207504
Change-Id: I0ec32094baada0daf4a200387aca7470c482a981
diff --git a/tensorflow/python/BUILD b/tensorflow/python/BUILD
index 904e734..4cfc389 100644
--- a/tensorflow/python/BUILD
+++ b/tensorflow/python/BUILD
@@ -239,6 +239,7 @@
"//tensorflow/python/util:_pywrap_nest",
"//tensorflow/python/util:_pywrap_stat_summarizer",
"//tensorflow/python/util:_pywrap_tfprof",
+ "//tensorflow/python/util:_pywrap_transform_graph",
"//tensorflow/python/util:_pywrap_util_port",
"//third_party/py/numpy",
],
@@ -763,6 +764,7 @@
"//tensorflow/python/util:_pywrap_checkpoint_reader",
"//tensorflow/python/util:_pywrap_stat_summarizer",
"//tensorflow/python/util:_pywrap_tfprof",
+ "//tensorflow/python/util:_pywrap_transform_graph",
"//tensorflow/python/util:_pywrap_util_port",
":_pywrap_utils",
":_errors_test_helper",
@@ -5272,6 +5274,7 @@
"//tensorflow/core/profiler/internal:print_model_analysis",
"//tensorflow/lite/delegates/flex:delegate",
"//tensorflow/core/profiler/internal/cpu:python_tracer",
+ "//tensorflow/tools/graph_transforms:transform_graph_lib",
"//tensorflow/lite/toco/python:toco_python_api",
"//tensorflow/python/eager:pywrap_tfe_lib",
"//tensorflow/core/util/tensor_bundle",
@@ -5354,6 +5357,7 @@
":tf_session_helper", # tf_session
"//tensorflow/python/eager:pywrap_tfe_lib", # pywrap_tfe_lib
"//tensorflow/stream_executor:stream_executor_pimpl", # stat_summarizer
+ "//tensorflow/tools/graph_transforms:transform_graph_lib", # transform_graph
] + if_xla_available([
"//tensorflow/compiler/aot:tfcompile_lib", # tfcompile
"//tensorflow/compiler/xla:status_macros", # tfcompile
diff --git a/tensorflow/python/util/BUILD b/tensorflow/python/util/BUILD
index 5b6fa5a..e0eb8d0 100644
--- a/tensorflow/python/util/BUILD
+++ b/tensorflow/python/util/BUILD
@@ -170,6 +170,21 @@
)
tf_python_pybind_extension(
+ name = "_pywrap_transform_graph",
+ srcs = ["transform_graph_wrapper.cc"],
+ hdrs = ["//tensorflow/tools/graph_transforms:transform_graph_hdrs"],
+ module_name = "_pywrap_transform_graph",
+ deps = [
+ "//tensorflow/core:framework_headers_lib",
+ "//tensorflow/core:lib_headers_for_pybind",
+ "//tensorflow/core:protos_all_cc",
+ "//tensorflow/python/lib/core:pybind11_status",
+ "//third_party/python_runtime:headers",
+ "@pybind11",
+ ],
+)
+
+tf_python_pybind_extension(
name = "_pywrap_checkpoint_reader",
srcs = ["py_checkpoint_reader_wrapper.cc"],
hdrs = [
diff --git a/tensorflow/tools/graph_transforms/transform_graph_wrapper.cc b/tensorflow/python/util/transform_graph_wrapper.cc
similarity index 100%
rename from tensorflow/tools/graph_transforms/transform_graph_wrapper.cc
rename to tensorflow/python/util/transform_graph_wrapper.cc
diff --git a/tensorflow/tools/def_file_filter/symbols_pybind.txt b/tensorflow/tools/def_file_filter/symbols_pybind.txt
index dac80cb..621dd44 100644
--- a/tensorflow/tools/def_file_filter/symbols_pybind.txt
+++ b/tensorflow/tools/def_file_filter/symbols_pybind.txt
@@ -112,6 +112,10 @@
toco::MlirSparsifyModel
toco::RegisterCustomOpdefs
+[//tensorflow/tools/graph_transforms:transform_graph_lib] # transform_graph
+tensorflow::graph_transforms::TransformGraph
+tensorflow::graph_transforms::ParseTransformParameters
+
[//tensorflow/c:checkpoint_reader] # py_checkpoint_reader
tensorflow::checkpoint::CheckpointReader
tensorflow::checkpoint::CheckpointReader::Init
diff --git a/tensorflow/tools/graph_transforms/BUILD b/tensorflow/tools/graph_transforms/BUILD
index 5e7c21b..fb2915b 100644
--- a/tensorflow/tools/graph_transforms/BUILD
+++ b/tensorflow/tools/graph_transforms/BUILD
@@ -11,9 +11,6 @@
)
load("//tensorflow:tensorflow.bzl", "get_compatible_with_cloud")
-# buildifier: disable=same-origin-load
-load("//tensorflow:tensorflow.bzl", "tf_python_pybind_extension")
-
package(
default_visibility = ["//visibility:public"],
licenses = ["notice"], # Apache 2.0
@@ -333,30 +330,15 @@
],
)
-tf_python_pybind_extension(
- name = "_pywrap_transform_graph",
- srcs = ["transform_graph_wrapper.cc"],
- hdrs = ["//tensorflow/tools/graph_transforms:transform_graph_hdrs"],
- module_name = "_pywrap_transform_graph",
- deps = [
- "//tensorflow/core:framework_headers_lib",
- "//tensorflow/core:protos_all_cc",
- "//tensorflow/python/lib/core:pybind11_status",
- "//tensorflow/tools/graph_transforms:transform_graph_lib",
- "//third_party/python_runtime:headers",
- "@pybind11",
- ],
-)
-
py_library(
name = "transform_graph_py",
srcs = ["__init__.py"],
srcs_version = "PY3",
deps = [
- ":_pywrap_transform_graph",
"//tensorflow/core:protos_all_py",
"//tensorflow/python:errors",
"//tensorflow/python:util",
+ "//tensorflow/python/util:_pywrap_transform_graph",
],
)
diff --git a/tensorflow/tools/graph_transforms/__init__.py b/tensorflow/tools/graph_transforms/__init__.py
index 111b214..84f7ea0 100644
--- a/tensorflow/tools/graph_transforms/__init__.py
+++ b/tensorflow/tools/graph_transforms/__init__.py
@@ -20,7 +20,7 @@
# pylint: disable=unused-import,wildcard-import, line-too-long
from tensorflow.core.framework import graph_pb2
from tensorflow.python.util import compat
-from tensorflow.tools.graph_transforms._pywrap_transform_graph import TransformGraphWithStringInputs
+from tensorflow.python.util._pywrap_transform_graph import TransformGraphWithStringInputs
def TransformGraph(input_graph_def, inputs, outputs, transforms):