blob: 3e43ab92d044a87149e49392cfa053b8130c96dd [file] [log] [blame]
# Copyright 2021 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""TensorFlow Authoring tool package for TFLite compatibility.
WARNING: The package is experimental and subject to change.
This package provides a way to check TFLite compatibility at model authoring
time.
Example:
@lite.authoring.compatible
@tf.function(input_signature=[
tf.TensorSpec(shape=[], dtype=tf.float32)
])
def f(x):
return tf.cosh(x)
f(1.0)
> CompatibilityWarning: op 'tf.Cosh' requires "Select TF Ops" for model
conversion for TensorFlow Lite.
"""
import functools
import sys
# pylint: disable=g-direct-tensorflow-import
from tensorflow.lite.python import convert
from tensorflow.lite.python import lite
from tensorflow.lite.python.metrics_wrapper import converter_error_data_pb2
from tensorflow.python.platform import tf_logging as logging
_CUSTOM_OPS_HDR = "Custom ops: "
_TF_OPS_HDR = "TF Select ops: "
class CompatibilityError(Exception):
"""Raised when an error occurs with TFLite compatibility."""
pass
class _Compatible:
"""A decorator to check TFLite compatibility."""
def __init__(self,
target,
raise_exception=False,
converter_target_spec=None,
converter_allow_custom_ops=None,
debug=False):
"""Initialize the decorator object.
Here is the description of the object variables.
- _func : decorated function.
- _obj_func : for class object, we need to use this object to provide `self`
instance as 1 first argument.
- _verified : whether the compatibility is checked or not.
Args:
target: decorated function.
raise_exception : to raise an exception on compatibility issues.
User need to use get_compatibility_log() to check details.
converter_target_spec : target_spec of TFLite converter parameter.
converter_allow_custom_ops : allow_custom_ops of TFLite converter
parameter.
debug: to dump execution details of decorated function.
"""
functools.update_wrapper(self, target)
self._func = target
self._obj_func = None
self._verified = False
self._log_messages = []
self._raise_exception = raise_exception
self._converter_target_spec = converter_target_spec
self._converter_allow_custom_ops = converter_allow_custom_ops
self._debug = debug
def __get__(self, instance, cls):
"""A Python descriptor interface."""
self._obj_func = self._func.__get__(instance, cls)
return self
def _get_func(self):
"""Returns decorated function object.
For a class method, use self._obj_func to provide `self` instance.
"""
if self._obj_func is not None:
return self._obj_func
else:
return self._func
def __call__(self, *args, **kwargs): # pylint: disable=g-doc-args
"""Calls decorated function object.
Also verifies if the function is compatible with TFLite.
Returns:
A execution result of the decorated function.
"""
if self._debug:
args_repr = [repr(a) for a in args]
kwargs_repr = [f"{k}={v!r}" for k, v in kwargs.items()]
signature = ", ".join(args_repr + kwargs_repr)
print(
f"DEBUG: Calling {self._get_func().__name__}({signature})",
file=sys.stderr)
if not self._verified:
concrete_func = self._get_func().get_concrete_function(*args, **kwargs)
converter = lite.TFLiteConverterV2.from_concrete_functions(
[concrete_func])
# Set provided converter parameters
if self._converter_target_spec is not None:
converter.target_spec = self._converter_target_spec
if self._converter_allow_custom_ops is not None:
converter.allow_custom_ops = self._converter_allow_custom_ops
try:
converter.convert()
except convert.ConverterError as err:
self._decode_error(err)
finally:
self._verified = True
return self._get_func()(*args, **kwargs)
def get_concrete_function(self, *args, **kwargs):
"""Returns a concrete function of the decorated function."""
return self._get_func().get_concrete_function(*args, **kwargs)
def _dump_error_details(self, ops, locations):
"""Dump the list of ops and locations."""
for i in range(0, len(ops)):
callstack = []
for single_call in locations[i].call:
if (locations[i].type ==
converter_error_data_pb2.ConverterErrorData.CALLSITELOC):
callstack.append(
f" - {single_call.source.filename}:{single_call.source.line}")
else:
callstack.append(str(single_call))
callstack_dump = "\n".join(callstack)
err_string = f"Op: {ops[i]}\n{callstack_dump}\n"
self._log_messages.append(err_string)
logging.warning(err_string)
def _decode_error_legacy(self, err):
"""Parses the given legacy ConverterError for OSS."""
for line in str(err).splitlines():
# Check custom op usage error.
if line.startswith(_CUSTOM_OPS_HDR):
custom_ops = line[len(_CUSTOM_OPS_HDR):]
err_string = (
f"CompatibilityError: op '{custom_ops}' is(are) not natively "
"supported by TensorFlow Lite. You need to provide a custom "
"operator. https://www.tensorflow.org/lite/guide/ops_custom")
self._log_messages.append(err_string)
logging.warning(err_string)
# Check TensorFlow op usage error.
elif line.startswith(_TF_OPS_HDR):
tf_ops = line[len(_TF_OPS_HDR):]
err_string = (
f"CompatibilityWarning: op '{tf_ops}' require(s) \"Select TF Ops\" "
"for model conversion for TensorFlow Lite. "
"https://www.tensorflow.org/lite/guide/ops_select")
self._log_messages.append(err_string)
logging.warning(err_string)
def _decode_converter_error(self, err):
"""Parses the given ConverterError which has detailed error information."""
custom_ops = []
custom_ops_location = []
tf_ops = []
tf_ops_location = []
for err in err.errors:
# Check custom op usage error.
if err.error_code == converter_error_data_pb2.ConverterErrorData.ERROR_NEEDS_CUSTOM_OPS:
custom_ops.append(err.operator.name)
custom_ops_location.append(err.location)
# Check TensorFlow op usage error.
elif err.error_code == converter_error_data_pb2.ConverterErrorData.ERROR_NEEDS_FLEX_OPS:
tf_ops.append(err.operator.name)
tf_ops_location.append(err.location)
if custom_ops:
custom_ops_str = ", ".join(sorted(custom_ops))
err_string = (
f"CompatibilityError: op '{custom_ops_str}' is(are) not natively "
"supported by TensorFlow Lite. You need to provide a custom "
"operator. https://www.tensorflow.org/lite/guide/ops_custom")
self._log_messages.append(err_string)
logging.warning(err_string)
self._dump_error_details(custom_ops, custom_ops_location)
if tf_ops:
tf_ops_str = ", ".join(sorted(tf_ops))
err_string = (
f"CompatibilityWarning: op '{tf_ops_str}' require(s) \"Select TF Ops"
"\" for model conversion for TensorFlow Lite. "
"https://www.tensorflow.org/lite/guide/ops_select")
self._log_messages.append(err_string)
logging.warning(err_string)
self._dump_error_details(tf_ops, tf_ops_location)
def _decode_error(self, err):
"""Parses the given ConverterError and generates compatibility warnings."""
if hasattr(err, "errors"):
self._decode_converter_error(err)
else:
self._decode_error_legacy(err)
if self._raise_exception and self._log_messages:
raise CompatibilityError(f"CompatibilityException at {repr(self._func)}")
def get_compatibility_log(self):
"""Returns list of compatibility log messages.
WARNING: This method should only be used for unit tests.
Returns:
The list of log messages by the recent compatibility check.
Raises:
RuntimeError: when the compatibility was NOT checked.
"""
if not self._verified:
raise RuntimeError("target compatibility isn't verified yet")
return self._log_messages
def compatible(target=None, **kwargs):
"""Wraps _Compatible to allow for deferred calling."""
if target is None:
def wrapper(target):
return _Compatible(target, **kwargs)
return wrapper
else:
return _Compatible(target, **kwargs)