Make @tf.function-wrapped functions pickleable.
diff --git a/tensorflow/python/eager/def_function.py b/tensorflow/python/eager/def_function.py
index 6f50325..1ed603c 100644
--- a/tensorflow/python/eager/def_function.py
+++ b/tensorflow/python/eager/def_function.py
@@ -422,6 +422,19 @@
self._input_signature = input_signature
self._call_counter = _CallCounter(FREQUENT_TRACING_WARNING_MAX_CALL_HISTORY)
+ def __getstate__(self):
+ """Custom pickling, to omit unpickleable objects."""
+ result = self.__dict__.copy()
+ del result["_lock"]
+ del result["_descriptor_cache"]
+ return result
+
+ def __setstate__(self, state):
+ """Restore from pickled state."""
+ self.__dict__ = state
+ self._lock = threading.Lock()
+ self._descriptor_cache = weakref.WeakKeyDictionary()
+
def _defun_with_scope(self, scope):
"""Creates a defun wrapped inside a variable creator scope."""
diff --git a/tensorflow/python/eager/def_function_test.py b/tensorflow/python/eager/def_function_test.py
index 998350f..6a85821 100644
--- a/tensorflow/python/eager/def_function_test.py
+++ b/tensorflow/python/eager/def_function_test.py
@@ -19,6 +19,7 @@
import functools
import itertools
+import pickle
import re
import weakref
@@ -67,6 +68,8 @@
self.optimizer.apply_gradients(zip(gradients, trainable_variables))
return {'loss': loss}
+def undecorated_function(x):
+ return x * 3.
class _HasDecoratedMethod(object):
@@ -747,6 +750,41 @@
# If the graph is deleted, then an exception is raised on reading `captures`
self.assertEmpty(graph.captures)
+ @parameterized.parameters(*itertools.product(
+ (None, (tensor_spec.TensorSpec([]),)), # input_signature
+ (True, False), # autograph
+ (None, converter.Feature.ALL), # autograph_options
+ (None, 'foo.bar'), # implements
+ (None, True, False), # relax_shapes
+ ))
+ def test_pickle(self, input_signature, autograph, autograph_options, implements,
+ relax_shapes):
+ """@function objects can be pickled and unpickled."""
+ # Can't pickle functions in __main__:
+ from tensorflow.python.eager.def_function_test import undecorated_function
+ original_py_function = undecorated_function
+
+ func = def_function.function(
+ func=original_py_function,
+ input_signature=input_signature,
+ autograph=autograph,
+ experimental_implements=implements,
+ experimental_autograph_options=autograph_options,
+ experimental_relax_shapes=relax_shapes,
+ )
+
+ cloned = pickle.loads(pickle.dumps(func))
+
+ self.assertEqual(func._name, cloned._name)
+ self.assertEqual(input_signature, cloned._input_signature)
+ self.assertEqual(autograph, cloned._autograph)
+ self.assertEqual(implements, cloned._implements)
+ self.assertEqual(autograph_options, cloned._experimental_autograph_options)
+ self.assertEqual(relax_shapes, cloned._experimental_relax_shapes)
+
+ x = array_ops.ones([])
+ self.assertEqual(self.evaluate(cloned(x)),
+ self.evaluate(func(x)))
if __name__ == '__main__':
ops.enable_eager_execution()