Added test for SavedModelCLI. Sorted the functions by names and concrete function ID. Skipped TF 1.X models while showing Polymorphic function.
diff --git a/tensorflow/python/tools/saved_model_cli.py b/tensorflow/python/tools/saved_model_cli.py
index eb5e4a1..fa31904 100644
--- a/tensorflow/python/tools/saved_model_cli.py
+++ b/tensorflow/python/tools/saved_model_cli.py
@@ -170,23 +170,31 @@
Args:
saved_model_dir: Directory containing the SavedModel to inspect.
"""
+ meta_graphs = saved_model_utils.read_saved_model(saved_model_dir).meta_graphs
+ has_object_graph_def = False
+
+ for meta_graph_def in meta_graphs:
+ has_object_graph_def |= meta_graph_def.HasField("object_graph_def")
+ if not has_object_graph_def:
+ return
with ops_lib.Graph().as_default():
trackable_object = load.load(saved_model_dir)
- print('\nDefined Functions:')
+ print('\nDefined Functions:', end="")
functions = save._AugmentedGraphView(
trackable_object).list_functions(trackable_object)
- for name, function in functions.items():
- print(' Function Name: \'%s\'' % name)
- for index, concrete_functions in enumerate(
- function._list_all_concrete_functions_for_serialization(), 1):
- args, kwargs = concrete_functions.structured_input_signature
+ functions = sorted(functions.items(), key=lambda x: x[0])
+ for name, function in functions:
+ print('\n Function Name: \'%s\'' % name)
+ concrete_functions = function._list_all_concrete_functions_for_serialization()
+ concrete_functions = sorted(concrete_functions, key=lambda x: x.name)
+ for index, concrete_function in enumerate(concrete_functions, 1):
+ args, kwargs = concrete_function.structured_input_signature
print(' Option #%d' % index)
print(' Callable with:')
_print_args(args, indent=4)
if kwargs:
_print_args(kwargs, "Named Argument", indent=4)
- print()
def _print_args(arguments, argument_type="Argument", indent=0):
@@ -199,7 +207,7 @@
"""
indent_str = ' ' * indent
- def _may_be_add_quotes(value):
+ def _maybe_add_quotes(value):
is_quotes = '\'' * isinstance(value, str)
return is_quotes + str(value) + is_quotes
@@ -215,13 +223,13 @@
in_print(' DType: %s' % type(element).__name__)
in_print(' Value: [', end='')
for value in element:
- print('%s' % _may_be_add_quotes(value), end=', ')
+ print('%s' % _maybe_add_quotes(value), end=', ')
print('\b\b]')
elif isinstance(element, dict):
in_print(' DType: %s' % type(element).__name__)
in_print(' Value: {', end='')
for (key, value) in element.items():
- print('\'%s\': %s' % (str(key), _may_be_add_quotes(value)), end=', ')
+ print('\'%s\': %s' % (str(key), _maybe_add_quotes(value)), end=', ')
print('\b\b}')
else:
in_print(' DType: %s' % type(element).__name__)
diff --git a/tensorflow/python/tools/saved_model_cli_test.py b/tensorflow/python/tools/saved_model_cli_test.py
index eedc893..3d34edc 100644
--- a/tensorflow/python/tools/saved_model_cli_test.py
+++ b/tensorflow/python/tools/saved_model_cli_test.py
@@ -31,12 +31,16 @@
from tensorflow.core.framework import types_pb2
from tensorflow.core.protobuf import meta_graph_pb2
from tensorflow.python.debug.wrappers import local_cli_wrapper
+from tensorflow.python.eager import def_function
+from tensorflow.python.framework import constant_op
+from tensorflow.python.framework import dtypes
+from tensorflow.python.framework import tensor_spec
from tensorflow.python.platform import test
-from tensorflow.python.tools import saved_model_cli
-
+from tensorflow.python.saved_model import save
+import saved_model_cli
+from tensorflow.python.training.tracking import util
SAVED_MODEL_PATH = ('cc/saved_model/testdata/half_plus_two/00000123')
-
@contextlib.contextmanager
def captured_output():
new_out, new_err = StringIO(), StringIO()
@@ -47,6 +51,22 @@
finally:
sys.stdout, sys.stderr = old_out, old_err
+class DummyModel(util.Checkpoint):
+ @def_function.function
+ def func1(self, a, b, c):
+ if c:
+ return a + b
+ else:
+ return a * b
+ @def_function.function(
+ input_signature=[
+ tensor_spec.TensorSpec(shape=(2, 2),
+ dtype=dtypes.float32)])
+ def func2(self, x):
+ return x + 2
+ @def_function.function
+ def __call__(self, y, c=7):
+ return y + 2 * c
class SavedModelCLITestCase(test.TestCase):
@@ -57,6 +77,8 @@
with captured_output() as (out, err):
saved_model_cli.show(args)
output = out.getvalue().strip()
+ with open("out.txt", "w") as f:
+ f.write(output)
# pylint: disable=line-too-long
exp_out = """MetaGraphDef with tag-set: 'serve' contains the following SignatureDefs:
@@ -141,7 +163,85 @@
self.maxDiff = None # Produce a useful error msg if the comparison fails
self.assertMultiLineEqual(output, exp_out)
self.assertEqual(err.getvalue().strip(), '')
+ def testShowAllWithConcreteFunctions(self):
+
+ temp_dir = self.get_temp_dir()
+ trackable_object = DummyModel()
+ trackable_object.func1(
+ constant_op.constant(5),
+ constant_op.constant(9),
+ True)
+ trackable_object.func1(constant_op.constant(5), constant_op.constant(9), False)
+ trackable_object(constant_op.constant(5))
+ save.save(trackable_object, temp_dir)
+ self.parser = saved_model_cli.create_parser()
+ args = self.parser.parse_args(['show', '--dir', temp_dir, '--all'])
+ with captured_output() as (out, err):
+ saved_model_cli.show(args)
+ output = out.getvalue().strip()
+ exp_out = """MetaGraphDef with tag-set: 'serve' contains the following SignatureDefs:
+signature_def['__saved_model_init_op']:
+ The given SavedModel SignatureDef contains the following input(s):
+ The given SavedModel SignatureDef contains the following output(s):
+ outputs['__saved_model_init_op'] tensor_info:
+ dtype: DT_INVALID
+ shape: unknown_rank
+ name: NoOp
+ Method name is:
+
+signature_def['serving_default']:
+ The given SavedModel SignatureDef contains the following input(s):
+ inputs['x'] tensor_info:
+ dtype: DT_FLOAT
+ shape: (2, 2)
+ name: serving_default_x:0
+ The given SavedModel SignatureDef contains the following output(s):
+ outputs['output_0'] tensor_info:
+ dtype: DT_FLOAT
+ shape: (2, 2)
+ name: PartitionedCall:0
+ Method name is: tensorflow/serving/predict
+
+Defined Functions:
+ Function Name: '__call__'
+ Option #1
+ Callable with:
+ Argument #1
+ y: TensorSpec(shape=(), dtype=tf.int32, name='y')
+ Argument #2
+ DType: int
+ Value: 7
+
+ Function Name: 'func1'
+ Option #1
+ Callable with:
+ Argument #1
+ a: TensorSpec(shape=(), dtype=tf.int32, name='a')
+ Argument #2
+ b: TensorSpec(shape=(), dtype=tf.int32, name='b')
+ Argument #3
+ DType: bool
+ Value: False
+ Option #2
+ Callable with:
+ Argument #1
+ a: TensorSpec(shape=(), dtype=tf.int32, name='a')
+ Argument #2
+ b: TensorSpec(shape=(), dtype=tf.int32, name='b')
+ Argument #3
+ DType: bool
+ Value: True
+
+ Function Name: 'func2'
+ Option #1
+ Callable with:
+ Argument #1
+ x: TensorSpec(shape=(2, 2), dtype=tf.float32, name='x')
+""".strip() # pylint: enable=line-too-long
+ self.maxDiff = None # Produce a useful error msg if the comparison fails
+ self.assertMultiLineEqual(output, exp_out)
+ self.assertEqual(err.getvalue().strip(), '')
def testShowCommandTags(self):
base_path = test.test_src_dir_path(SAVED_MODEL_PATH)
self.parser = saved_model_cli.create_parser()