Do not create an extra GradientTape in custom_gradient/recompute_grad. The extra tape leads the tf.function gradient code to believe that the user intends to compute higher order derivatives. This requires generating a forward function with all possible side outputs which is expensive. This also doesn't work well with control flow and causes the added test to fail.

Instead, we create a VariableWatcher object that keeps track of variables that
have been accessed.

PiperOrigin-RevId: 307157027
Change-Id: Ifd628b421dc725ad2366af2f6f63cf52dd1511e9
diff --git a/tensorflow/python/eager/BUILD b/tensorflow/python/eager/BUILD
index d0a1b91..30cc424 100644
--- a/tensorflow/python/eager/BUILD
+++ b/tensorflow/python/eager/BUILD
@@ -331,6 +331,7 @@
         "//tensorflow/python:embedding_ops",
         "//tensorflow/python:layers",
         "//tensorflow/python:math_ops",
+        "//tensorflow/python:memory_checker",
         "//tensorflow/python:nn_grad",
         "//tensorflow/python:nn_ops",
         "//tensorflow/python:random_ops",
@@ -662,6 +663,7 @@
     deps = [
         ":backprop",
         ":context",
+        ":tape",
         ":test",
         "//tensorflow/python:array_ops",
         "//tensorflow/python:constant_op",
diff --git a/tensorflow/python/eager/backprop_test.py b/tensorflow/python/eager/backprop_test.py
index 4259224..b28aaa3 100644
--- a/tensorflow/python/eager/backprop_test.py
+++ b/tensorflow/python/eager/backprop_test.py
@@ -35,6 +35,7 @@
 from tensorflow.python.framework import tensor_shape
 from tensorflow.python.framework import tensor_util
 from tensorflow.python.framework import test_util
+from tensorflow.python.framework.memory_checker import MemoryChecker
 from tensorflow.python.layers.pooling import max_pooling3d
 from tensorflow.python.ops import array_ops
 from tensorflow.python.ops import control_flow_ops
@@ -1532,6 +1533,39 @@
         self.assertIn('gradient_tape/my_scope/', op.name)
     self.assertEqual(num_sin_ops_found, 2)
 
+  @test_util.assert_no_new_pyobjects_executing_eagerly
+  def testRecomputeGradWithNestedFunctionAndWhileLoop(self):
+
+    @custom_gradient.recompute_grad
+    @def_function.function
+    def outer(x):
+
+      @def_function.function
+      def middle(y):
+
+        @def_function.function
+        def inner(z):
+          return z + 1
+
+        i = constant_op.constant(0.0)
+        c = lambda y, i: i < 10.
+        b = lambda y, i: (inner(y), i + 1.0)
+        y, i = control_flow_ops.while_loop(c, b, [y, i])
+
+        return y
+
+      return middle(x)
+
+    with MemoryChecker() as memory_checker:
+      for _ in range(5):
+        x = variables.Variable(1.0, name='x')
+        with backprop.GradientTape():
+          y = outer(x)
+          self.assertAllEqual(y, 11.0)
+
+    memory_checker.report()
+    memory_checker.assert_no_leak_if_all_possibly_except_one()
+
 
 class JacobianTest(test.TestCase):
 
diff --git a/tensorflow/python/eager/pywrap_tfe.h b/tensorflow/python/eager/pywrap_tfe.h
index 398c8aa..b4a16f4 100755
--- a/tensorflow/python/eager/pywrap_tfe.h
+++ b/tensorflow/python/eager/pywrap_tfe.h
@@ -331,6 +331,22 @@
 //       appended to `tensors`.
 PyObject* TFE_Py_PackJVPs(PyObject* tensors);
 
+// Variable Watcher methods.
+
+// Creates a new variable watcher and adds it to the set of active variable
+// watchers.
+PyObject* TFE_Py_VariableWatcherNew();
+
+// Removes the passed variable watcher from the set of active variable watchers.
+void TFE_Py_VariableWatcherRemove(PyObject* variable_watcher);
+
+// Notifies all variable watchers that a variable has been accessed.
+void TFE_Py_VariableWatcherVariableAccessed(PyObject* variable);
+
+// Returns all variables watched by the given variable_watcher in the order
+// those variables were created.
+PyObject* TFE_Py_VariableWatcherWatchedVariables(PyObject* variable_watcher);
+
 // Returns an EagerTensor of dimension [len(`tensors`)] containing
 // the `slice_dim`'th dimension of each tensor in `tensors`. In other words,
 // TFE_Py_TensorShapeSlice takes a slice of dimensions of tensors in
diff --git a/tensorflow/python/eager/pywrap_tfe_src.cc b/tensorflow/python/eager/pywrap_tfe_src.cc
index 0091cf2..3100c15 100644
--- a/tensorflow/python/eager/pywrap_tfe_src.cc
+++ b/tensorflow/python/eager/pywrap_tfe_src.cc
@@ -1375,38 +1375,24 @@
   return result;
 }
 
-class GradientTape
-    : public tensorflow::eager::GradientTape<PyObject, PyBackwardFunction,
-                                             PyTapeTensor> {
+// Keeps track of all variables that have been accessed during execution.
+class VariableWatcher {
  public:
-  explicit GradientTape(bool persistent, bool watch_accessed_variables)
-      : tensorflow::eager::GradientTape<PyObject, PyBackwardFunction,
-                                        PyTapeTensor>(persistent),
-        watch_accessed_variables_(watch_accessed_variables) {}
+  VariableWatcher() {}
 
-  virtual ~GradientTape() {
+  ~VariableWatcher() {
     for (const IdAndVariable& v : watched_variables_) {
       Py_DECREF(v.variable);
     }
   }
 
-  void VariableAccessed(PyObject* v) {
-    if (watch_accessed_variables_) {
-      WatchVariable(v);
-    }
-  }
-
-  void WatchVariable(PyObject* v) {
+  tensorflow::int64 WatchVariable(PyObject* v) {
     tensorflow::Safe_PyObjectPtr handle(PyObject_GetAttrString(v, "handle"));
     if (handle == nullptr) {
-      return;
+      return -1;
     }
     tensorflow::int64 id = FastTensorId(handle.get());
 
-    if (!PyErr_Occurred()) {
-      this->Watch(id);
-    }
-
     tensorflow::mutex_lock l(watched_variables_mu_);
     auto insert_result = watched_variables_.emplace(id, v);
 
@@ -1415,6 +1401,8 @@
       // variable.
       Py_INCREF(v);
     }
+
+    return id;
   }
 
   PyObject* GetVariablesAsPyTuple() {
@@ -1445,12 +1433,45 @@
     }
   };
 
-  bool watch_accessed_variables_;
   tensorflow::mutex watched_variables_mu_;
   std::set<IdAndVariable, CompareById> watched_variables_
       TF_GUARDED_BY(watched_variables_mu_);
 };
 
+class GradientTape
+    : public tensorflow::eager::GradientTape<PyObject, PyBackwardFunction,
+                                             PyTapeTensor> {
+ public:
+  explicit GradientTape(bool persistent, bool watch_accessed_variables)
+      : tensorflow::eager::GradientTape<PyObject, PyBackwardFunction,
+                                        PyTapeTensor>(persistent),
+        watch_accessed_variables_(watch_accessed_variables) {}
+
+  virtual ~GradientTape() {}
+
+  void VariableAccessed(PyObject* v) {
+    if (watch_accessed_variables_) {
+      WatchVariable(v);
+    }
+  }
+
+  void WatchVariable(PyObject* v) {
+    tensorflow::int64 id = variable_watcher_.WatchVariable(v);
+
+    if (!PyErr_Occurred()) {
+      this->Watch(id);
+    }
+  }
+
+  PyObject* GetVariablesAsPyTuple() {
+    return variable_watcher_.GetVariablesAsPyTuple();
+  }
+
+ private:
+  bool watch_accessed_variables_;
+  VariableWatcher variable_watcher_;
+};
+
 typedef tensorflow::eager::ForwardAccumulator<PyObject, PyBackwardFunction,
                                               PyTapeTensor>
     ForwardAccumulator;
@@ -1535,6 +1556,41 @@
     "TFE_Py_ForwardAccumulator objects",                    /* tp_doc */
 };
 
+typedef struct {
+  PyObject_HEAD
+      /* Type-specific fields go here. */
+      VariableWatcher* variable_watcher;
+} TFE_Py_VariableWatcher;
+
+static void TFE_Py_VariableWatcher_Delete(PyObject* variable_watcher) {
+  delete reinterpret_cast<TFE_Py_VariableWatcher*>(variable_watcher)
+      ->variable_watcher;
+  Py_TYPE(variable_watcher)->tp_free(variable_watcher);
+}
+
+static PyTypeObject TFE_Py_VariableWatcher_Type = {
+    PyVarObject_HEAD_INIT(nullptr, 0) "tfe.VariableWatcher", /* tp_name */
+    sizeof(TFE_Py_VariableWatcher),                          /* tp_basicsize */
+    0,                                                       /* tp_itemsize */
+    &TFE_Py_VariableWatcher_Delete,                          /* tp_dealloc */
+    0,                                                       /* tp_print */
+    nullptr,                                                 /* tp_getattr */
+    nullptr,                                                 /* tp_setattr */
+    nullptr,                                                 /* tp_reserved */
+    nullptr,                                                 /* tp_repr */
+    nullptr,                                                 /* tp_as_number */
+    nullptr,                          /* tp_as_sequence */
+    nullptr,                          /* tp_as_mapping */
+    nullptr,                          /* tp_hash  */
+    nullptr,                          /* tp_call */
+    nullptr,                          /* tp_str */
+    nullptr,                          /* tp_getattro */
+    nullptr,                          /* tp_setattro */
+    nullptr,                          /* tp_as_buffer */
+    Py_TPFLAGS_DEFAULT,               /* tp_flags */
+    "TFE_Py_VariableWatcher objects", /* tp_doc */
+};
+
 // Note: in the current design no mutex is needed here because of the python
 // GIL, which is always held when any TFE_Py_* methods are called. We should
 // revisit this if/when decide to not hold the GIL while manipulating the tape
@@ -1548,6 +1604,18 @@
   return tape_set.get();
 }
 
+tensorflow::gtl::CompactPointerSet<TFE_Py_VariableWatcher*>*
+GetVariableWatcherSet() {
+  thread_local std::unique_ptr<
+      tensorflow::gtl::CompactPointerSet<TFE_Py_VariableWatcher*>>
+      variable_watcher_set = nullptr;
+  if (variable_watcher_set == nullptr) {
+    variable_watcher_set.reset(
+        new tensorflow::gtl::CompactPointerSet<TFE_Py_VariableWatcher*>);
+  }
+  return variable_watcher_set.get();
+}
+
 // A linked hash set, where iteration is in insertion order.
 //
 // Nested accumulators rely on op recording happening in insertion order, so an
@@ -1670,6 +1738,16 @@
   }
 };
 
+class SafeVariableWatcherSet
+    : public SafeSetCopy<
+          tensorflow::gtl::CompactPointerSet<TFE_Py_VariableWatcher*>> {
+ public:
+  SafeVariableWatcherSet()
+      : SafeSetCopy<
+            tensorflow::gtl::CompactPointerSet<TFE_Py_VariableWatcher*>>(
+            *GetVariableWatcherSet()) {}
+};
+
 bool* ThreadTapeIsStopped() {
   thread_local bool thread_tape_is_stopped{false};
   return &thread_tape_is_stopped;
@@ -2037,6 +2115,36 @@
   return reinterpret_cast<TFE_Py_Tape*>(tape)->tape->GetVariablesAsPyTuple();
 }
 
+PyObject* TFE_Py_VariableWatcherNew() {
+  TFE_Py_VariableWatcher_Type.tp_new = PyType_GenericNew;
+  if (PyType_Ready(&TFE_Py_VariableWatcher_Type) < 0) return nullptr;
+  TFE_Py_VariableWatcher* variable_watcher =
+      PyObject_NEW(TFE_Py_VariableWatcher, &TFE_Py_VariableWatcher_Type);
+  variable_watcher->variable_watcher = new VariableWatcher();
+  Py_INCREF(variable_watcher);
+  GetVariableWatcherSet()->insert(variable_watcher);
+  return reinterpret_cast<PyObject*>(variable_watcher);
+}
+
+void TFE_Py_VariableWatcherRemove(PyObject* variable_watcher) {
+  auto* stack = GetVariableWatcherSet();
+  stack->erase(reinterpret_cast<TFE_Py_VariableWatcher*>(variable_watcher));
+  // We kept a reference to the variable watcher in the set to ensure it
+  // wouldn't get deleted under us; cleaning it up here.
+  Py_DECREF(variable_watcher);
+}
+
+void TFE_Py_VariableWatcherVariableAccessed(PyObject* variable) {
+  for (TFE_Py_VariableWatcher* variable_watcher : SafeVariableWatcherSet()) {
+    variable_watcher->variable_watcher->WatchVariable(variable);
+  }
+}
+
+PyObject* TFE_Py_VariableWatcherWatchedVariables(PyObject* variable_watcher) {
+  return reinterpret_cast<TFE_Py_VariableWatcher*>(variable_watcher)
+      ->variable_watcher->GetVariablesAsPyTuple();
+}
+
 namespace {
 std::vector<tensorflow::DataType> MakeTensorDtypeList(PyObject* tensors) {
   PyObject* seq = PySequence_Fast(tensors, "expected a sequence");
@@ -3086,6 +3194,7 @@
       PyObject_GetAttrString(input, "_trainable"));
   if (trainable.get() == Py_False) return;
   TFE_Py_TapeVariableAccessed(input);
+  TFE_Py_VariableWatcherVariableAccessed(input);
 }
 
 bool ReadVariableOp(const FastPathOpExecInfo& parent_op_exec_info,
diff --git a/tensorflow/python/eager/tape.py b/tensorflow/python/eager/tape.py
index 2ecac8b..d1e8e52 100644
--- a/tensorflow/python/eager/tape.py
+++ b/tensorflow/python/eager/tape.py
@@ -58,6 +58,36 @@
   pywrap_tfe.TFE_Py_TapeWatch(tape._tape, tensor)  # pylint: disable=protected-access
 
 
+class VariableWatcher(object):
+  """A scope that tracks all trainable variable accesses within it.
+
+  This explicitly ignores variables that are not marked as trainable.
+
+  Sample usage:
+
+  var = tf.Variable(0.0)
+  with VariableWatcher() as variable_watcher:
+    var.assign_add(1.0)
+
+  assert variable_watcher.watched_variables == [var]
+  """
+
+  def __init__(self):
+    self._variable_watcher = None
+
+  def __enter__(self):
+    self._variable_watcher = pywrap_tfe.TFE_Py_VariableWatcherNew()
+    return self
+
+  def __exit__(self, typ, value, traceback):
+    pywrap_tfe.TFE_Py_VariableWatcherRemove(self._variable_watcher)
+
+  def watched_variables(self):
+    """Returns a tuple of variables accessed under this scope."""
+    return pywrap_tfe.TFE_Py_VariableWatcherWatchedVariables(
+        self._variable_watcher)
+
+
 def watch_variable(tape, variable):
   """Marks this variable to be watched by the given tape."""
   strategy, context = (
@@ -68,6 +98,7 @@
     variables = strategy.experimental_local_results(variable)
   for var in variables:
     pywrap_tfe.TFE_Py_TapeWatchVariable(tape._tape, var)  # pylint: disable=protected-access
+    pywrap_tfe.TFE_Py_VariableWatcherVariableAccessed(var)
 
 
 def variable_accessed(variable):
@@ -84,6 +115,7 @@
     variables = strategy.experimental_local_results(variable)
   for var in variables:
     pywrap_tfe.TFE_Py_TapeVariableAccessed(var)
+    pywrap_tfe.TFE_Py_VariableWatcherVariableAccessed(var)
 
 
 def variables_accessed(variables):
@@ -107,6 +139,7 @@
 
   for var in accessed:
     pywrap_tfe.TFE_Py_TapeVariableAccessed(var)
+    pywrap_tfe.TFE_Py_VariableWatcherVariableAccessed(var)
 
 
 def pop_tape(tape):
diff --git a/tensorflow/python/eager/tape_test.py b/tensorflow/python/eager/tape_test.py
index 48d3b8a..cf49aa2 100644
--- a/tensorflow/python/eager/tape_test.py
+++ b/tensorflow/python/eager/tape_test.py
@@ -21,6 +21,7 @@
 
 from tensorflow.python.eager import backprop
 from tensorflow.python.eager import context
+from tensorflow.python.eager import tape
 from tensorflow.python.eager import test
 from tensorflow.python.framework import constant_op
 from tensorflow.python.framework import dtypes
@@ -31,6 +32,7 @@
 # Importing nn_grad for the registration functions.
 from tensorflow.python.ops import nn_grad  # pylint: disable=unused-import
 from tensorflow.python.ops import nn_ops
+from tensorflow.python.ops import variables
 
 
 @custom_gradient.custom_gradient
@@ -166,5 +168,48 @@
     self.assertAllEqual(g, 1.0)
 
 
+class VariableWatcherTest(test.TestCase):
+
+  def testBasic(self):
+    var1 = variables.Variable(0.0)
+    var2 = variables.Variable(1.0)
+    with tape.VariableWatcher() as variable_watcher:
+      var1.assign_add(1.0)
+      var2.assign_add(2.0)
+
+    self.assertAllEqual(variable_watcher.watched_variables(), (var1, var2))
+
+  def testNonTrainableVariables(self):
+    var1 = variables.Variable(0.0)
+    var2 = variables.Variable(1.0, trainable=False)
+    with tape.VariableWatcher() as variable_watcher:
+      var1.assign_add(1.0)
+      var2.assign_add(2.0)
+
+    self.assertAllEqual(variable_watcher.watched_variables(), (var1,))
+
+  def testMultipleScopes(self):
+    var1 = variables.Variable(0.0)
+    var2 = variables.Variable(1.0)
+    with tape.VariableWatcher() as variable_watcher1:
+      var1.assign_add(1.0)
+      with tape.VariableWatcher() as variable_watcher2:
+        var2.assign_add(2.0)
+
+    # variable_watcher1 should see both vars and variable_watcher2 only sees
+    # var2
+    self.assertAllEqual(variable_watcher1.watched_variables(), (var1, var2))
+    self.assertAllEqual(variable_watcher2.watched_variables(), (var2,))
+
+  def testCreateVariables(self):
+    with tape.VariableWatcher() as variable_watcher:
+      var1 = variables.Variable(0.0)
+      var2 = variables.Variable(1.0)
+      var1.assign_add(1.0)
+      var2.assign_add(2.0)
+
+    self.assertAllEqual(variable_watcher.watched_variables(), (var1, var2))
+
+
 if __name__ == '__main__':
   test.main()
diff --git a/tensorflow/python/ops/custom_gradient.py b/tensorflow/python/ops/custom_gradient.py
index 785813d..4040a4d 100644
--- a/tensorflow/python/ops/custom_gradient.py
+++ b/tensorflow/python/ops/custom_gradient.py
@@ -315,7 +315,7 @@
       v.ref() for v in current_var_scope.global_variables() +
       current_var_scope.local_variables()
   ])
-  with backprop.GradientTape() as tape:
+  with tape_lib.VariableWatcher() as variable_watcher:
     result, grad_fn = f(*args)
   after_vars = set([
       v.ref() for v in current_var_scope.global_variables() +
@@ -332,8 +332,9 @@
   # The variables that grad_fn needs to return gradients for are the set of
   # variables used that are *not* part of the inputs.
   inputs = args
-  variables_in_tape = frozenset([v.ref() for v in tape.watched_variables()
-                                ]) - frozenset(v.ref() for v in inputs)
+  variables_in_tape = frozenset([
+      v.ref() for v in variable_watcher.watched_variables()
+  ]) - frozenset(v.ref() for v in inputs)
   variables_in_subgraph = frozenset([
       v.ref()
       for v in get_dependent_variables(input_ops=inputs, output_ops=result)
@@ -405,14 +406,14 @@
 
 def _eager_mode_decorator(f, args, kwargs):
   """Implement custom gradient decorator for eager mode."""
-  with backprop.GradientTape() as tape:
+  with tape_lib.VariableWatcher() as variable_watcher:
     result, grad_fn = f(*args, **kwargs)
   all_inputs = list(args) + list(kwargs.values())
   # The variables that grad_fn needs to return gradients for are the set of
   # variables used that are *not* part of the inputs.
   variables = [
       v.deref()  # pylint: disable=g-complex-comprehension
-      for v in set(v.ref() for v in tape.watched_variables())
+      for v in set(v.ref() for v in variable_watcher.watched_variables())
       if all(v.deref() is not i for i in all_inputs)
   ]
   grad_argspec = tf_inspect.getfullargspec(grad_fn)
diff --git a/tensorflow/python/tfe_wrapper.cc b/tensorflow/python/tfe_wrapper.cc
index 64af900..26683c3 100644
--- a/tensorflow/python/tfe_wrapper.cc
+++ b/tensorflow/python/tfe_wrapper.cc
@@ -665,6 +665,23 @@
     return tensorflow::PyoOrThrow(TFE_Py_TapeWatchedVariables(tape.ptr()));
   });
 
+  // TFE_Py_VariableWatcher logic.
+  m.def("TFE_Py_VariableWatcherNew",
+        []() { return tensorflow::PyoOrThrow(TFE_Py_VariableWatcherNew()); });
+  m.def("TFE_Py_VariableWatcherRemove", [](const py::handle& variable_watcher) {
+    TFE_Py_VariableWatcherRemove(variable_watcher.ptr());
+  });
+  m.def("TFE_Py_VariableWatcherVariableAccessed",
+        [](const py::handle& variable) {
+          TFE_Py_VariableWatcherVariableAccessed(variable.ptr());
+        });
+  m.def("TFE_Py_VariableWatcherWatchedVariables",
+        [](const py::handle& variable_watcher) {
+          return tensorflow::PyoOrThrow(
+              TFE_Py_VariableWatcherWatchedVariables(variable_watcher.ptr()));
+        });
+
+  // TFE_Py_ForwardAccumulator logic.
   m.def("TFE_Py_ForwardAccumulatorNew", []() {
     return tensorflow::PyoOrThrow(TFE_Py_ForwardAccumulatorNew());
   });
diff --git a/tensorflow/tools/def_file_filter/symbols_pybind.txt b/tensorflow/tools/def_file_filter/symbols_pybind.txt
index d7da4b7..40ada68 100644
--- a/tensorflow/tools/def_file_filter/symbols_pybind.txt
+++ b/tensorflow/tools/def_file_filter/symbols_pybind.txt
@@ -173,6 +173,10 @@
 TFE_Py_FastPathExecute_C
 TFE_Py_RecordGradient
 TFE_Py_TapeWatchedVariables
+TFE_Py_VariableWatcherNew
+TFE_Py_VariableWatcherRemove
+TFE_Py_VariableWatcherVariableAccessed
+TFE_Py_VariableWatcherWatchedVariables
 TFE_Py_ForwardAccumulatorNew
 TFE_Py_ForwardAccumulatorSetAdd
 TFE_Py_ForwardAccumulatorSetRemove
diff --git a/tensorflow/tools/pip_package/BUILD b/tensorflow/tools/pip_package/BUILD
index c8df1e3..229cf48 100644
--- a/tensorflow/tools/pip_package/BUILD
+++ b/tensorflow/tools/pip_package/BUILD
@@ -111,6 +111,7 @@
     "//tensorflow/python/distribute:distribute_test_lib_pip",
     "//tensorflow/python:loss_scale",
     "//tensorflow/python:loss_scale_optimizer",
+    "//tensorflow/python:memory_checker",
     "//tensorflow/python:meta_graph_testdata",
     "//tensorflow/python:util_example_parser_configuration",
     "//tensorflow/python/data/benchmarks:benchmark_base",