Expose nest related function to tf.__internal__ API.
PiperOrigin-RevId: 340722550
Change-Id: Ic5f7451346dbd0543e0cc25c90bcd16edfff212a
diff --git a/tensorflow/python/keras/utils/tf_utils.py b/tensorflow/python/keras/utils/tf_utils.py
index 0297446..9d716d0 100644
--- a/tensorflow/python/keras/utils/tf_utils.py
+++ b/tensorflow/python/keras/utils/tf_utils.py
@@ -128,7 +128,7 @@
raise ValueError(
'Received non-atomic and non-sequence element: {}'.format(nested))
if nest._is_mapping(nested):
- values = [nested[k] for k in nest._sorted(nested)]
+ values = [nested[k] for k in sorted(nested.keys())]
elif nest._is_attrs(nested):
values = _astuple(nested)
else:
diff --git a/tensorflow/python/tools/api/generator/api_init_files.bzl b/tensorflow/python/tools/api/generator/api_init_files.bzl
index 3daf9ee..8dac33c 100644
--- a/tensorflow/python/tools/api/generator/api_init_files.bzl
+++ b/tensorflow/python/tools/api/generator/api_init_files.bzl
@@ -10,6 +10,7 @@
"__internal__/distribute/__init__.py",
"__internal__/distribute/combinations/__init__.py",
"__internal__/distribute/multi_process_runner/__init__.py",
+ "__internal__/nest/__init__.py",
"__internal__/test/__init__.py",
"__internal__/test/combinations/__init__.py",
"__internal__/tf2/__init__.py",
diff --git a/tensorflow/python/util/nest.py b/tensorflow/python/util/nest.py
index 1a547b8..cdd6d0c 100644
--- a/tensorflow/python/util/nest.py
+++ b/tensorflow/python/util/nest.py
@@ -126,6 +126,19 @@
_is_mapping = _pywrap_utils.IsMapping
+@tf_export("__internal__.nest.is_attrs", v1=[])
+def is_attrs(obj):
+ """Returns a true if its input is an instance of an attr.s decorated class."""
+ return _is_attrs(obj)
+
+
+@tf_export("__internal__.nest.is_mapping", v1=[])
+def is_mapping(obj):
+ """Returns a true if its input is a collections.Mapping."""
+ return is_mapping(obj)
+
+
+@tf_export("__internal__.nest.sequence_like", v1=[])
def _sequence_like(instance, args):
"""Converts the sequence `args` to the same type as `instance`.
@@ -894,6 +907,7 @@
expand_composites=expand_composites)
+@tf_export("__internal__.nest.flatten_up_to", v1=[])
def flatten_up_to(shallow_tree, input_tree, check_types=True,
expand_composites=False):
"""Flattens `input_tree` up to `shallow_tree`.
@@ -1082,6 +1096,7 @@
return list(_yield_flat_up_to(shallow_tree, input_tree, is_seq))
+@tf_export("__internal__.nest.map_structure_up_to", v1=[])
def map_structure_up_to(shallow_tree, func, *inputs, **kwargs):
"""Applies a function or op to a number of partially flattened inputs.
@@ -1261,6 +1276,7 @@
expand_composites=expand_composites)
+@tf_export("__internal__.nest.get_traverse_shallow_structure", v1=[])
def get_traverse_shallow_structure(traverse_fn, structure,
expand_composites=False):
"""Generates a shallow structure from a `traverse_fn` and `structure`.
@@ -1331,6 +1347,7 @@
return _sequence_like(structure, level_traverse)
+@tf_export("__internal__.nest.yield_flat_paths", v1=[])
def yield_flat_paths(nest, expand_composites=False):
"""Yields paths for some nested structure.
@@ -1425,6 +1442,7 @@
flatten(structure, expand_composites=expand_composites)))
+@tf_export("__internal__.nest.list_to_tuple", v1=[])
def list_to_tuple(structure):
"""Replace all lists with tuples.
diff --git a/tensorflow/tools/api/golden/v2/tensorflow.__internal__.nest.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.__internal__.nest.pbtxt
new file mode 100644
index 0000000..feca4a0
--- /dev/null
+++ b/tensorflow/tools/api/golden/v2/tensorflow.__internal__.nest.pbtxt
@@ -0,0 +1,35 @@
+path: "tensorflow.__internal__.nest"
+tf_module {
+ member_method {
+ name: "flatten_up_to"
+ argspec: "args=[\'shallow_tree\', \'input_tree\', \'check_types\', \'expand_composites\'], varargs=None, keywords=None, defaults=[\'True\', \'False\'], "
+ }
+ member_method {
+ name: "get_traverse_shallow_structure"
+ argspec: "args=[\'traverse_fn\', \'structure\', \'expand_composites\'], varargs=None, keywords=None, defaults=[\'False\'], "
+ }
+ member_method {
+ name: "is_attrs"
+ argspec: "args=[\'obj\'], varargs=None, keywords=None, defaults=None"
+ }
+ member_method {
+ name: "is_mapping"
+ argspec: "args=[\'obj\'], varargs=None, keywords=None, defaults=None"
+ }
+ member_method {
+ name: "list_to_tuple"
+ argspec: "args=[\'structure\'], varargs=None, keywords=None, defaults=None"
+ }
+ member_method {
+ name: "map_structure_up_to"
+ argspec: "args=[\'shallow_tree\', \'func\'], varargs=inputs, keywords=kwargs, defaults=None"
+ }
+ member_method {
+ name: "sequence_like"
+ argspec: "args=[\'instance\', \'args\'], varargs=None, keywords=None, defaults=None"
+ }
+ member_method {
+ name: "yield_flat_paths"
+ argspec: "args=[\'nest\', \'expand_composites\'], varargs=None, keywords=None, defaults=[\'False\'], "
+ }
+}
diff --git a/tensorflow/tools/api/golden/v2/tensorflow.__internal__.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.__internal__.pbtxt
index b6a3857..35d23f0 100644
--- a/tensorflow/tools/api/golden/v2/tensorflow.__internal__.pbtxt
+++ b/tensorflow/tools/api/golden/v2/tensorflow.__internal__.pbtxt
@@ -17,6 +17,10 @@
mtype: "<type \'module\'>"
}
member {
+ name: "nest"
+ mtype: "<type \'module\'>"
+ }
+ member {
name: "test"
mtype: "<type \'module\'>"
}