Do not create an extra GradientTape in custom_gradient/recompute_grad. The extra tape leads the tf.function gradient code to believe that the user intends to compute higher order derivatives. This requires generating a forward function with all possible side outputs which is expensive. This also doesn't work well with control flow and causes the added test to fail.
Instead, we create a VariableWatcher object that keeps track of variables that
have been accessed.
PiperOrigin-RevId: 307157027
Change-Id: Ifd628b421dc725ad2366af2f6f63cf52dd1511e9
diff --git a/tensorflow/python/eager/BUILD b/tensorflow/python/eager/BUILD
index d0a1b91..30cc424 100644
--- a/tensorflow/python/eager/BUILD
+++ b/tensorflow/python/eager/BUILD
@@ -331,6 +331,7 @@
"//tensorflow/python:embedding_ops",
"//tensorflow/python:layers",
"//tensorflow/python:math_ops",
+ "//tensorflow/python:memory_checker",
"//tensorflow/python:nn_grad",
"//tensorflow/python:nn_ops",
"//tensorflow/python:random_ops",
@@ -662,6 +663,7 @@
deps = [
":backprop",
":context",
+ ":tape",
":test",
"//tensorflow/python:array_ops",
"//tensorflow/python:constant_op",
diff --git a/tensorflow/python/eager/backprop_test.py b/tensorflow/python/eager/backprop_test.py
index 4259224..b28aaa3 100644
--- a/tensorflow/python/eager/backprop_test.py
+++ b/tensorflow/python/eager/backprop_test.py
@@ -35,6 +35,7 @@
from tensorflow.python.framework import tensor_shape
from tensorflow.python.framework import tensor_util
from tensorflow.python.framework import test_util
+from tensorflow.python.framework.memory_checker import MemoryChecker
from tensorflow.python.layers.pooling import max_pooling3d
from tensorflow.python.ops import array_ops
from tensorflow.python.ops import control_flow_ops
@@ -1532,6 +1533,39 @@
self.assertIn('gradient_tape/my_scope/', op.name)
self.assertEqual(num_sin_ops_found, 2)
+ @test_util.assert_no_new_pyobjects_executing_eagerly
+ def testRecomputeGradWithNestedFunctionAndWhileLoop(self):
+
+ @custom_gradient.recompute_grad
+ @def_function.function
+ def outer(x):
+
+ @def_function.function
+ def middle(y):
+
+ @def_function.function
+ def inner(z):
+ return z + 1
+
+ i = constant_op.constant(0.0)
+ c = lambda y, i: i < 10.
+ b = lambda y, i: (inner(y), i + 1.0)
+ y, i = control_flow_ops.while_loop(c, b, [y, i])
+
+ return y
+
+ return middle(x)
+
+ with MemoryChecker() as memory_checker:
+ for _ in range(5):
+ x = variables.Variable(1.0, name='x')
+ with backprop.GradientTape():
+ y = outer(x)
+ self.assertAllEqual(y, 11.0)
+
+ memory_checker.report()
+ memory_checker.assert_no_leak_if_all_possibly_except_one()
+
class JacobianTest(test.TestCase):
diff --git a/tensorflow/python/eager/pywrap_tfe.h b/tensorflow/python/eager/pywrap_tfe.h
index 398c8aa..b4a16f4 100755
--- a/tensorflow/python/eager/pywrap_tfe.h
+++ b/tensorflow/python/eager/pywrap_tfe.h
@@ -331,6 +331,22 @@
// appended to `tensors`.
PyObject* TFE_Py_PackJVPs(PyObject* tensors);
+// Variable Watcher methods.
+
+// Creates a new variable watcher and adds it to the set of active variable
+// watchers.
+PyObject* TFE_Py_VariableWatcherNew();
+
+// Removes the passed variable watcher from the set of active variable watchers.
+void TFE_Py_VariableWatcherRemove(PyObject* variable_watcher);
+
+// Notifies all variable watchers that a variable has been accessed.
+void TFE_Py_VariableWatcherVariableAccessed(PyObject* variable);
+
+// Returns all variables watched by the given variable_watcher in the order
+// those variables were created.
+PyObject* TFE_Py_VariableWatcherWatchedVariables(PyObject* variable_watcher);
+
// Returns an EagerTensor of dimension [len(`tensors`)] containing
// the `slice_dim`'th dimension of each tensor in `tensors`. In other words,
// TFE_Py_TensorShapeSlice takes a slice of dimensions of tensors in
diff --git a/tensorflow/python/eager/pywrap_tfe_src.cc b/tensorflow/python/eager/pywrap_tfe_src.cc
index 0091cf2..3100c15 100644
--- a/tensorflow/python/eager/pywrap_tfe_src.cc
+++ b/tensorflow/python/eager/pywrap_tfe_src.cc
@@ -1375,38 +1375,24 @@
return result;
}
-class GradientTape
- : public tensorflow::eager::GradientTape<PyObject, PyBackwardFunction,
- PyTapeTensor> {
+// Keeps track of all variables that have been accessed during execution.
+class VariableWatcher {
public:
- explicit GradientTape(bool persistent, bool watch_accessed_variables)
- : tensorflow::eager::GradientTape<PyObject, PyBackwardFunction,
- PyTapeTensor>(persistent),
- watch_accessed_variables_(watch_accessed_variables) {}
+ VariableWatcher() {}
- virtual ~GradientTape() {
+ ~VariableWatcher() {
for (const IdAndVariable& v : watched_variables_) {
Py_DECREF(v.variable);
}
}
- void VariableAccessed(PyObject* v) {
- if (watch_accessed_variables_) {
- WatchVariable(v);
- }
- }
-
- void WatchVariable(PyObject* v) {
+ tensorflow::int64 WatchVariable(PyObject* v) {
tensorflow::Safe_PyObjectPtr handle(PyObject_GetAttrString(v, "handle"));
if (handle == nullptr) {
- return;
+ return -1;
}
tensorflow::int64 id = FastTensorId(handle.get());
- if (!PyErr_Occurred()) {
- this->Watch(id);
- }
-
tensorflow::mutex_lock l(watched_variables_mu_);
auto insert_result = watched_variables_.emplace(id, v);
@@ -1415,6 +1401,8 @@
// variable.
Py_INCREF(v);
}
+
+ return id;
}
PyObject* GetVariablesAsPyTuple() {
@@ -1445,12 +1433,45 @@
}
};
- bool watch_accessed_variables_;
tensorflow::mutex watched_variables_mu_;
std::set<IdAndVariable, CompareById> watched_variables_
TF_GUARDED_BY(watched_variables_mu_);
};
+class GradientTape
+ : public tensorflow::eager::GradientTape<PyObject, PyBackwardFunction,
+ PyTapeTensor> {
+ public:
+ explicit GradientTape(bool persistent, bool watch_accessed_variables)
+ : tensorflow::eager::GradientTape<PyObject, PyBackwardFunction,
+ PyTapeTensor>(persistent),
+ watch_accessed_variables_(watch_accessed_variables) {}
+
+ virtual ~GradientTape() {}
+
+ void VariableAccessed(PyObject* v) {
+ if (watch_accessed_variables_) {
+ WatchVariable(v);
+ }
+ }
+
+ void WatchVariable(PyObject* v) {
+ tensorflow::int64 id = variable_watcher_.WatchVariable(v);
+
+ if (!PyErr_Occurred()) {
+ this->Watch(id);
+ }
+ }
+
+ PyObject* GetVariablesAsPyTuple() {
+ return variable_watcher_.GetVariablesAsPyTuple();
+ }
+
+ private:
+ bool watch_accessed_variables_;
+ VariableWatcher variable_watcher_;
+};
+
typedef tensorflow::eager::ForwardAccumulator<PyObject, PyBackwardFunction,
PyTapeTensor>
ForwardAccumulator;
@@ -1535,6 +1556,41 @@
"TFE_Py_ForwardAccumulator objects", /* tp_doc */
};
+typedef struct {
+ PyObject_HEAD
+ /* Type-specific fields go here. */
+ VariableWatcher* variable_watcher;
+} TFE_Py_VariableWatcher;
+
+static void TFE_Py_VariableWatcher_Delete(PyObject* variable_watcher) {
+ delete reinterpret_cast<TFE_Py_VariableWatcher*>(variable_watcher)
+ ->variable_watcher;
+ Py_TYPE(variable_watcher)->tp_free(variable_watcher);
+}
+
+static PyTypeObject TFE_Py_VariableWatcher_Type = {
+ PyVarObject_HEAD_INIT(nullptr, 0) "tfe.VariableWatcher", /* tp_name */
+ sizeof(TFE_Py_VariableWatcher), /* tp_basicsize */
+ 0, /* tp_itemsize */
+ &TFE_Py_VariableWatcher_Delete, /* tp_dealloc */
+ 0, /* tp_print */
+ nullptr, /* tp_getattr */
+ nullptr, /* tp_setattr */
+ nullptr, /* tp_reserved */
+ nullptr, /* tp_repr */
+ nullptr, /* tp_as_number */
+ nullptr, /* tp_as_sequence */
+ nullptr, /* tp_as_mapping */
+ nullptr, /* tp_hash */
+ nullptr, /* tp_call */
+ nullptr, /* tp_str */
+ nullptr, /* tp_getattro */
+ nullptr, /* tp_setattro */
+ nullptr, /* tp_as_buffer */
+ Py_TPFLAGS_DEFAULT, /* tp_flags */
+ "TFE_Py_VariableWatcher objects", /* tp_doc */
+};
+
// Note: in the current design no mutex is needed here because of the python
// GIL, which is always held when any TFE_Py_* methods are called. We should
// revisit this if/when decide to not hold the GIL while manipulating the tape
@@ -1548,6 +1604,18 @@
return tape_set.get();
}
+tensorflow::gtl::CompactPointerSet<TFE_Py_VariableWatcher*>*
+GetVariableWatcherSet() {
+ thread_local std::unique_ptr<
+ tensorflow::gtl::CompactPointerSet<TFE_Py_VariableWatcher*>>
+ variable_watcher_set = nullptr;
+ if (variable_watcher_set == nullptr) {
+ variable_watcher_set.reset(
+ new tensorflow::gtl::CompactPointerSet<TFE_Py_VariableWatcher*>);
+ }
+ return variable_watcher_set.get();
+}
+
// A linked hash set, where iteration is in insertion order.
//
// Nested accumulators rely on op recording happening in insertion order, so an
@@ -1670,6 +1738,16 @@
}
};
+class SafeVariableWatcherSet
+ : public SafeSetCopy<
+ tensorflow::gtl::CompactPointerSet<TFE_Py_VariableWatcher*>> {
+ public:
+ SafeVariableWatcherSet()
+ : SafeSetCopy<
+ tensorflow::gtl::CompactPointerSet<TFE_Py_VariableWatcher*>>(
+ *GetVariableWatcherSet()) {}
+};
+
bool* ThreadTapeIsStopped() {
thread_local bool thread_tape_is_stopped{false};
return &thread_tape_is_stopped;
@@ -2037,6 +2115,36 @@
return reinterpret_cast<TFE_Py_Tape*>(tape)->tape->GetVariablesAsPyTuple();
}
+PyObject* TFE_Py_VariableWatcherNew() {
+ TFE_Py_VariableWatcher_Type.tp_new = PyType_GenericNew;
+ if (PyType_Ready(&TFE_Py_VariableWatcher_Type) < 0) return nullptr;
+ TFE_Py_VariableWatcher* variable_watcher =
+ PyObject_NEW(TFE_Py_VariableWatcher, &TFE_Py_VariableWatcher_Type);
+ variable_watcher->variable_watcher = new VariableWatcher();
+ Py_INCREF(variable_watcher);
+ GetVariableWatcherSet()->insert(variable_watcher);
+ return reinterpret_cast<PyObject*>(variable_watcher);
+}
+
+void TFE_Py_VariableWatcherRemove(PyObject* variable_watcher) {
+ auto* stack = GetVariableWatcherSet();
+ stack->erase(reinterpret_cast<TFE_Py_VariableWatcher*>(variable_watcher));
+ // We kept a reference to the variable watcher in the set to ensure it
+ // wouldn't get deleted under us; cleaning it up here.
+ Py_DECREF(variable_watcher);
+}
+
+void TFE_Py_VariableWatcherVariableAccessed(PyObject* variable) {
+ for (TFE_Py_VariableWatcher* variable_watcher : SafeVariableWatcherSet()) {
+ variable_watcher->variable_watcher->WatchVariable(variable);
+ }
+}
+
+PyObject* TFE_Py_VariableWatcherWatchedVariables(PyObject* variable_watcher) {
+ return reinterpret_cast<TFE_Py_VariableWatcher*>(variable_watcher)
+ ->variable_watcher->GetVariablesAsPyTuple();
+}
+
namespace {
std::vector<tensorflow::DataType> MakeTensorDtypeList(PyObject* tensors) {
PyObject* seq = PySequence_Fast(tensors, "expected a sequence");
@@ -3086,6 +3194,7 @@
PyObject_GetAttrString(input, "_trainable"));
if (trainable.get() == Py_False) return;
TFE_Py_TapeVariableAccessed(input);
+ TFE_Py_VariableWatcherVariableAccessed(input);
}
bool ReadVariableOp(const FastPathOpExecInfo& parent_op_exec_info,
diff --git a/tensorflow/python/eager/tape.py b/tensorflow/python/eager/tape.py
index 2ecac8b..d1e8e52 100644
--- a/tensorflow/python/eager/tape.py
+++ b/tensorflow/python/eager/tape.py
@@ -58,6 +58,36 @@
pywrap_tfe.TFE_Py_TapeWatch(tape._tape, tensor) # pylint: disable=protected-access
+class VariableWatcher(object):
+ """A scope that tracks all trainable variable accesses within it.
+
+ This explicitly ignores variables that are not marked as trainable.
+
+ Sample usage:
+
+ var = tf.Variable(0.0)
+ with VariableWatcher() as variable_watcher:
+ var.assign_add(1.0)
+
+ assert variable_watcher.watched_variables == [var]
+ """
+
+ def __init__(self):
+ self._variable_watcher = None
+
+ def __enter__(self):
+ self._variable_watcher = pywrap_tfe.TFE_Py_VariableWatcherNew()
+ return self
+
+ def __exit__(self, typ, value, traceback):
+ pywrap_tfe.TFE_Py_VariableWatcherRemove(self._variable_watcher)
+
+ def watched_variables(self):
+ """Returns a tuple of variables accessed under this scope."""
+ return pywrap_tfe.TFE_Py_VariableWatcherWatchedVariables(
+ self._variable_watcher)
+
+
def watch_variable(tape, variable):
"""Marks this variable to be watched by the given tape."""
strategy, context = (
@@ -68,6 +98,7 @@
variables = strategy.experimental_local_results(variable)
for var in variables:
pywrap_tfe.TFE_Py_TapeWatchVariable(tape._tape, var) # pylint: disable=protected-access
+ pywrap_tfe.TFE_Py_VariableWatcherVariableAccessed(var)
def variable_accessed(variable):
@@ -84,6 +115,7 @@
variables = strategy.experimental_local_results(variable)
for var in variables:
pywrap_tfe.TFE_Py_TapeVariableAccessed(var)
+ pywrap_tfe.TFE_Py_VariableWatcherVariableAccessed(var)
def variables_accessed(variables):
@@ -107,6 +139,7 @@
for var in accessed:
pywrap_tfe.TFE_Py_TapeVariableAccessed(var)
+ pywrap_tfe.TFE_Py_VariableWatcherVariableAccessed(var)
def pop_tape(tape):
diff --git a/tensorflow/python/eager/tape_test.py b/tensorflow/python/eager/tape_test.py
index 48d3b8a..cf49aa2 100644
--- a/tensorflow/python/eager/tape_test.py
+++ b/tensorflow/python/eager/tape_test.py
@@ -21,6 +21,7 @@
from tensorflow.python.eager import backprop
from tensorflow.python.eager import context
+from tensorflow.python.eager import tape
from tensorflow.python.eager import test
from tensorflow.python.framework import constant_op
from tensorflow.python.framework import dtypes
@@ -31,6 +32,7 @@
# Importing nn_grad for the registration functions.
from tensorflow.python.ops import nn_grad # pylint: disable=unused-import
from tensorflow.python.ops import nn_ops
+from tensorflow.python.ops import variables
@custom_gradient.custom_gradient
@@ -166,5 +168,48 @@
self.assertAllEqual(g, 1.0)
+class VariableWatcherTest(test.TestCase):
+
+ def testBasic(self):
+ var1 = variables.Variable(0.0)
+ var2 = variables.Variable(1.0)
+ with tape.VariableWatcher() as variable_watcher:
+ var1.assign_add(1.0)
+ var2.assign_add(2.0)
+
+ self.assertAllEqual(variable_watcher.watched_variables(), (var1, var2))
+
+ def testNonTrainableVariables(self):
+ var1 = variables.Variable(0.0)
+ var2 = variables.Variable(1.0, trainable=False)
+ with tape.VariableWatcher() as variable_watcher:
+ var1.assign_add(1.0)
+ var2.assign_add(2.0)
+
+ self.assertAllEqual(variable_watcher.watched_variables(), (var1,))
+
+ def testMultipleScopes(self):
+ var1 = variables.Variable(0.0)
+ var2 = variables.Variable(1.0)
+ with tape.VariableWatcher() as variable_watcher1:
+ var1.assign_add(1.0)
+ with tape.VariableWatcher() as variable_watcher2:
+ var2.assign_add(2.0)
+
+ # variable_watcher1 should see both vars and variable_watcher2 only sees
+ # var2
+ self.assertAllEqual(variable_watcher1.watched_variables(), (var1, var2))
+ self.assertAllEqual(variable_watcher2.watched_variables(), (var2,))
+
+ def testCreateVariables(self):
+ with tape.VariableWatcher() as variable_watcher:
+ var1 = variables.Variable(0.0)
+ var2 = variables.Variable(1.0)
+ var1.assign_add(1.0)
+ var2.assign_add(2.0)
+
+ self.assertAllEqual(variable_watcher.watched_variables(), (var1, var2))
+
+
if __name__ == '__main__':
test.main()
diff --git a/tensorflow/python/ops/custom_gradient.py b/tensorflow/python/ops/custom_gradient.py
index 785813d..4040a4d 100644
--- a/tensorflow/python/ops/custom_gradient.py
+++ b/tensorflow/python/ops/custom_gradient.py
@@ -315,7 +315,7 @@
v.ref() for v in current_var_scope.global_variables() +
current_var_scope.local_variables()
])
- with backprop.GradientTape() as tape:
+ with tape_lib.VariableWatcher() as variable_watcher:
result, grad_fn = f(*args)
after_vars = set([
v.ref() for v in current_var_scope.global_variables() +
@@ -332,8 +332,9 @@
# The variables that grad_fn needs to return gradients for are the set of
# variables used that are *not* part of the inputs.
inputs = args
- variables_in_tape = frozenset([v.ref() for v in tape.watched_variables()
- ]) - frozenset(v.ref() for v in inputs)
+ variables_in_tape = frozenset([
+ v.ref() for v in variable_watcher.watched_variables()
+ ]) - frozenset(v.ref() for v in inputs)
variables_in_subgraph = frozenset([
v.ref()
for v in get_dependent_variables(input_ops=inputs, output_ops=result)
@@ -405,14 +406,14 @@
def _eager_mode_decorator(f, args, kwargs):
"""Implement custom gradient decorator for eager mode."""
- with backprop.GradientTape() as tape:
+ with tape_lib.VariableWatcher() as variable_watcher:
result, grad_fn = f(*args, **kwargs)
all_inputs = list(args) + list(kwargs.values())
# The variables that grad_fn needs to return gradients for are the set of
# variables used that are *not* part of the inputs.
variables = [
v.deref() # pylint: disable=g-complex-comprehension
- for v in set(v.ref() for v in tape.watched_variables())
+ for v in set(v.ref() for v in variable_watcher.watched_variables())
if all(v.deref() is not i for i in all_inputs)
]
grad_argspec = tf_inspect.getfullargspec(grad_fn)
diff --git a/tensorflow/python/tfe_wrapper.cc b/tensorflow/python/tfe_wrapper.cc
index 64af900..26683c3 100644
--- a/tensorflow/python/tfe_wrapper.cc
+++ b/tensorflow/python/tfe_wrapper.cc
@@ -665,6 +665,23 @@
return tensorflow::PyoOrThrow(TFE_Py_TapeWatchedVariables(tape.ptr()));
});
+ // TFE_Py_VariableWatcher logic.
+ m.def("TFE_Py_VariableWatcherNew",
+ []() { return tensorflow::PyoOrThrow(TFE_Py_VariableWatcherNew()); });
+ m.def("TFE_Py_VariableWatcherRemove", [](const py::handle& variable_watcher) {
+ TFE_Py_VariableWatcherRemove(variable_watcher.ptr());
+ });
+ m.def("TFE_Py_VariableWatcherVariableAccessed",
+ [](const py::handle& variable) {
+ TFE_Py_VariableWatcherVariableAccessed(variable.ptr());
+ });
+ m.def("TFE_Py_VariableWatcherWatchedVariables",
+ [](const py::handle& variable_watcher) {
+ return tensorflow::PyoOrThrow(
+ TFE_Py_VariableWatcherWatchedVariables(variable_watcher.ptr()));
+ });
+
+ // TFE_Py_ForwardAccumulator logic.
m.def("TFE_Py_ForwardAccumulatorNew", []() {
return tensorflow::PyoOrThrow(TFE_Py_ForwardAccumulatorNew());
});
diff --git a/tensorflow/tools/def_file_filter/symbols_pybind.txt b/tensorflow/tools/def_file_filter/symbols_pybind.txt
index d7da4b7..40ada68 100644
--- a/tensorflow/tools/def_file_filter/symbols_pybind.txt
+++ b/tensorflow/tools/def_file_filter/symbols_pybind.txt
@@ -173,6 +173,10 @@
TFE_Py_FastPathExecute_C
TFE_Py_RecordGradient
TFE_Py_TapeWatchedVariables
+TFE_Py_VariableWatcherNew
+TFE_Py_VariableWatcherRemove
+TFE_Py_VariableWatcherVariableAccessed
+TFE_Py_VariableWatcherWatchedVariables
TFE_Py_ForwardAccumulatorNew
TFE_Py_ForwardAccumulatorSetAdd
TFE_Py_ForwardAccumulatorSetRemove
diff --git a/tensorflow/tools/pip_package/BUILD b/tensorflow/tools/pip_package/BUILD
index c8df1e3..229cf48 100644
--- a/tensorflow/tools/pip_package/BUILD
+++ b/tensorflow/tools/pip_package/BUILD
@@ -111,6 +111,7 @@
"//tensorflow/python/distribute:distribute_test_lib_pip",
"//tensorflow/python:loss_scale",
"//tensorflow/python:loss_scale_optimizer",
+ "//tensorflow/python:memory_checker",
"//tensorflow/python:meta_graph_testdata",
"//tensorflow/python:util_example_parser_configuration",
"//tensorflow/python/data/benchmarks:benchmark_base",