authoring: Cleanup authoring parameters
PiperOrigin-RevId: 389075797
Change-Id: I95102a57c0f1e69257a205d3e8393468dd31bfa5
diff --git a/tensorflow/lite/python/authoring/authoring.py b/tensorflow/lite/python/authoring/authoring.py
index 84f19dd..e5b4fe3 100644
--- a/tensorflow/lite/python/authoring/authoring.py
+++ b/tensorflow/lite/python/authoring/authoring.py
@@ -20,22 +20,26 @@
time.
Example:
- @lite.authoring.compatible
+ @tf.lite.experimental.authoring.compatible
@tf.function(input_signature=[
- tf.TensorSpec(shape=[], dtype=tf.float32)
+ tf.TensorSpec(shape=[None], dtype=tf.float32)
])
def f(x):
return tf.cosh(x)
- f(1.0)
+ result = f(tf.constant([0.0]))
- > CompatibilityWarning: op 'tf.Cosh' requires "Select TF Ops" for model
- conversion for TensorFlow Lite.
+ > COMPATIBILITY WARNING: op 'tf.Cosh' require(s) "Select TF Ops" for model
+ > conversion for TensorFlow Lite.
+ > Op: tf.Cosh
+ > - tensorflow/python/framework/op_def_library.py:xxx
+ > - tensorflow/python/ops/gen_math_ops.py:xxx
+ > - simple_authoring.py:xxx
"""
import functools
-import sys
-# pylint: disable=g-direct-tensorflow-import
+
+# pylint: disable=g-import-not-at-top
from tensorflow.lite.python import convert
from tensorflow.lite.python import lite
from tensorflow.lite.python.metrics_wrapper import converter_error_data_pb2
@@ -54,15 +58,13 @@
class _Compatible:
- """A decorator to check TFLite compatibility."""
+ """A decorator class to check TFLite compatibility created by `lite.experimental.authoring.compatible`."""
def __init__(self,
target,
- print_logs=True,
- raise_exception=False,
converter_target_spec=None,
converter_allow_custom_ops=None,
- debug=False):
+ raise_exception=False):
"""Initialize the decorator object.
Here is the description of the object variables.
@@ -73,24 +75,20 @@
Args:
target: decorated function.
- print_logs: to print warning / error messages to stdout.
- raise_exception : to raise an exception on compatibility issues.
- User need to use get_compatibility_log() to check details.
converter_target_spec : target_spec of TFLite converter parameter.
converter_allow_custom_ops : allow_custom_ops of TFLite converter
parameter.
- debug: to dump execution details of decorated function.
+ raise_exception : to raise an exception on compatibility issues.
+ User need to use get_compatibility_log() to check details.
"""
functools.update_wrapper(self, target)
self._func = target
self._obj_func = None
self._verified = False
self._log_messages = []
- self._print_logs = print_logs
self._raise_exception = raise_exception
self._converter_target_spec = converter_target_spec
self._converter_allow_custom_ops = converter_allow_custom_ops
- self._debug = debug
def __get__(self, instance, cls):
"""A Python descriptor interface."""
@@ -115,13 +113,6 @@
Returns:
A execution result of the decorated function.
"""
- if self._debug:
- args_repr = [repr(a) for a in args]
- kwargs_repr = [f"{k}={v!r}" for k, v in kwargs.items()]
- signature = ", ".join(args_repr + kwargs_repr)
- print(
- f"DEBUG: Calling {self._get_func().__name__}({signature})",
- file=sys.stderr)
if not self._verified:
model = self._get_func()
@@ -231,8 +222,7 @@
def _log(self, message):
"""Log and print authoring warning / error message."""
self._log_messages.append(message)
- if self._print_logs:
- print(message)
+ print(message)
def get_compatibility_log(self):
"""Returns list of compatibility log messages.
@@ -249,12 +239,21 @@
return self._log_messages
-def compatible(target=None, **kwargs):
- """Wraps _Compatible to allow for deferred calling."""
+def compatible(target=None, converter_target_spec=None, **kwargs):
+ """Wraps `tf.function` into a callable function with TFLite compatibility checking.
+
+ Args:
+ target: A `tf.function` to decorate.
+ converter_target_spec : target_spec of TFLite converter parameter.
+ **kwargs: The keyword arguments of the decorator class _Compatible.
+
+ Returns:
+ A callable object of `tf.lite.experimental.authoring._Compatible`.
+ """
if target is None:
def wrapper(target):
- return _Compatible(target, **kwargs)
+ return _Compatible(target, converter_target_spec, **kwargs)
return wrapper
else:
- return _Compatible(target, **kwargs)
+ return _Compatible(target, converter_target_spec, **kwargs)