Rolling forward RNN performance fix.

The caching_device should only be used with tf.session, which is not true for tf 2.0, otherwise it will result into stale value read.

This change flip the default value to False for this value, and let RNN estimator to specify it.

PiperOrigin-RevId: 273413220
diff --git a/tensorflow/python/keras/engine/base_layer.py b/tensorflow/python/keras/engine/base_layer.py
index 702b828..90be4cc 100644
--- a/tensorflow/python/keras/engine/base_layer.py
+++ b/tensorflow/python/keras/engine/base_layer.py
@@ -344,8 +344,8 @@
       aggregation: Indicates how a distributed variable will be aggregated.
         Accepted values are constants defined in the class
         `tf.VariableAggregation`.
-      **kwargs: Additional keyword arguments. Accepted values are `getter` and
-        `collections`.
+      **kwargs: Additional keyword arguments. Accepted values are `getter`,
+        `collections`, `experimental_autocast` and `caching_device`.
 
     Returns:
       The created variable. Usually either a `Variable` or `ResourceVariable`
@@ -362,13 +362,16 @@
       shape = ()
     # Validate optional keyword arguments.
     for kwarg in kwargs:
-      if kwarg not in ['getter', 'collections', 'experimental_autocast']:
+      if kwarg not in ['getter', 'collections', 'experimental_autocast',
+                       'caching_device']:
         raise TypeError('Unknown keyword argument:', kwarg)
     getter = kwargs.pop('getter', base_layer_utils.make_variable)
     collections_arg = kwargs.pop('collections', None)
     # 'experimental_autocast' can be set to False by the caller to indicate an
     # AutoCastVariable should never be created.
     autocast = kwargs.pop('experimental_autocast', True)
+    # See the docstring for tf.Variable about the details for caching_device.
+    caching_device = kwargs.pop('caching_device', None)
 
     if dtype is None:
       dtype = self.dtype or backend.floatx()
@@ -414,6 +417,13 @@
       def getter(*args, **kwargs):  # pylint: disable=function-redefined
         variable = old_getter(*args, **kwargs)
         return autocast_variable.create_autocast_variable(variable)
+      # Also the caching_device does not work with the mixed precision API,
+      # disable it if it is specified.
+      # TODO(b/142020079): Reenable it once the bug is fixed.
+      if caching_device is not None:
+        tf_logging.warn('`caching_device` does not work with mixed precision '
+                        'API. Ignoring user specified `caching_device`.')
+        caching_device = None
 
     variable = self._add_variable_with_custom_getter(
         name=name,
@@ -431,7 +441,8 @@
         use_resource=use_resource,
         collections=collections_arg,
         synchronization=synchronization,
-        aggregation=aggregation)
+        aggregation=aggregation,
+        caching_device=caching_device)
     backend.track_variable(variable)
 
     if regularizer is not None:
diff --git a/tensorflow/python/keras/layers/recurrent.py b/tensorflow/python/keras/layers/recurrent.py
index 37ac80d..2813e98 100644
--- a/tensorflow/python/keras/layers/recurrent.py
+++ b/tensorflow/python/keras/layers/recurrent.py
@@ -24,6 +24,7 @@
 import numpy as np
 
 from tensorflow.python.eager import context
+from tensorflow.python.framework import ops
 from tensorflow.python.framework import tensor_shape
 from tensorflow.python.keras import activations
 from tensorflow.python.keras import backend as K
@@ -36,6 +37,7 @@
 from tensorflow.python.keras.utils import tf_utils
 from tensorflow.python.ops import array_ops
 from tensorflow.python.ops import control_flow_ops
+from tensorflow.python.ops import control_flow_util
 from tensorflow.python.ops import math_ops
 from tensorflow.python.ops import state_ops
 from tensorflow.python.ops.ragged import ragged_tensor
@@ -1283,6 +1285,7 @@
                dropout=0.,
                recurrent_dropout=0.,
                **kwargs):
+    self._enable_caching_device = kwargs.pop('enable_caching_device', False)
     super(SimpleRNNCell, self).__init__(**kwargs)
     self.units = units
     self.activation = activations.get(activation)
@@ -1307,25 +1310,29 @@
 
   @tf_utils.shape_type_conversion
   def build(self, input_shape):
+    default_caching_device = _caching_device(self)
     self.kernel = self.add_weight(
         shape=(input_shape[-1], self.units),
         name='kernel',
         initializer=self.kernel_initializer,
         regularizer=self.kernel_regularizer,
-        constraint=self.kernel_constraint)
+        constraint=self.kernel_constraint,
+        caching_device=default_caching_device)
     self.recurrent_kernel = self.add_weight(
         shape=(self.units, self.units),
         name='recurrent_kernel',
         initializer=self.recurrent_initializer,
         regularizer=self.recurrent_regularizer,
-        constraint=self.recurrent_constraint)
+        constraint=self.recurrent_constraint,
+        caching_device=default_caching_device)
     if self.use_bias:
       self.bias = self.add_weight(
           shape=(self.units,),
           name='bias',
           initializer=self.bias_initializer,
           regularizer=self.bias_regularizer,
-          constraint=self.bias_constraint)
+          constraint=self.bias_constraint,
+          caching_device=default_caching_device)
     else:
       self.bias = None
     self.built = True
@@ -1709,6 +1716,7 @@
                implementation=1,
                reset_after=False,
                **kwargs):
+    self._enable_caching_device = kwargs.pop('enable_caching_device', False)
     super(GRUCell, self).__init__(**kwargs)
     self.units = units
     self.activation = activations.get(activation)
@@ -1737,18 +1745,21 @@
   @tf_utils.shape_type_conversion
   def build(self, input_shape):
     input_dim = input_shape[-1]
+    default_caching_device = _caching_device(self)
     self.kernel = self.add_weight(
         shape=(input_dim, self.units * 3),
         name='kernel',
         initializer=self.kernel_initializer,
         regularizer=self.kernel_regularizer,
-        constraint=self.kernel_constraint)
+        constraint=self.kernel_constraint,
+        caching_device=default_caching_device)
     self.recurrent_kernel = self.add_weight(
         shape=(self.units, self.units * 3),
         name='recurrent_kernel',
         initializer=self.recurrent_initializer,
         regularizer=self.recurrent_regularizer,
-        constraint=self.recurrent_constraint)
+        constraint=self.recurrent_constraint,
+        caching_device=default_caching_device)
 
     if self.use_bias:
       if not self.reset_after:
@@ -1763,7 +1774,8 @@
                                   name='bias',
                                   initializer=self.bias_initializer,
                                   regularizer=self.bias_regularizer,
-                                  constraint=self.bias_constraint)
+                                  constraint=self.bias_constraint,
+                                  caching_device=default_caching_device)
     else:
       self.bias = None
     self.built = True
@@ -1841,9 +1853,7 @@
         # biases: bias_z_i, bias_r_i, bias_h_i
         matrix_x = K.bias_add(matrix_x, input_bias)
 
-      x_z = matrix_x[:, :self.units]
-      x_r = matrix_x[:, self.units: 2 * self.units]
-      x_h = matrix_x[:, 2 * self.units:]
+      x_z, x_r, x_h = array_ops.split(matrix_x, 3, axis=-1)
 
       if 0. < self.recurrent_dropout < 1.:
         h_tm1 = h_tm1 * rec_dp_mask[0]
@@ -1857,14 +1867,14 @@
         # hidden state projected separately for update/reset and new
         matrix_inner = K.dot(h_tm1, self.recurrent_kernel[:, :2 * self.units])
 
-      recurrent_z = matrix_inner[:, :self.units]
-      recurrent_r = matrix_inner[:, self.units:2 * self.units]
+      recurrent_z, recurrent_r, recurrent_h = array_ops.split(
+          matrix_inner, [self.units, self.units, -1], axis=-1)
 
       z = self.recurrent_activation(x_z + recurrent_z)
       r = self.recurrent_activation(x_r + recurrent_r)
 
       if self.reset_after:
-        recurrent_h = r * matrix_inner[:, 2 * self.units:]
+        recurrent_h = r * recurrent_h
       else:
         recurrent_h = K.dot(r * h_tm1,
                             self.recurrent_kernel[:, 2 * self.units:])
@@ -2259,6 +2269,7 @@
                recurrent_dropout=0.,
                implementation=1,
                **kwargs):
+    self._enable_caching_device = kwargs.pop('enable_caching_device', False)
     super(LSTMCell, self).__init__(**kwargs)
     self.units = units
     self.activation = activations.get(activation)
@@ -2292,19 +2303,22 @@
 
   @tf_utils.shape_type_conversion
   def build(self, input_shape):
+    default_caching_device = _caching_device(self)
     input_dim = input_shape[-1]
     self.kernel = self.add_weight(
         shape=(input_dim, self.units * 4),
         name='kernel',
         initializer=self.kernel_initializer,
         regularizer=self.kernel_regularizer,
-        constraint=self.kernel_constraint)
+        constraint=self.kernel_constraint,
+        caching_device=default_caching_device)
     self.recurrent_kernel = self.add_weight(
         shape=(self.units, self.units * 4),
         name='recurrent_kernel',
         initializer=self.recurrent_initializer,
         regularizer=self.recurrent_regularizer,
-        constraint=self.recurrent_constraint)
+        constraint=self.recurrent_constraint,
+        caching_device=default_caching_device)
 
     if self.use_bias:
       if self.unit_forget_bias:
@@ -2322,7 +2336,8 @@
           name='bias',
           initializer=bias_initializer,
           regularizer=self.bias_regularizer,
-          constraint=self.bias_constraint)
+          constraint=self.bias_constraint,
+          caching_device=default_caching_device)
     else:
       self.bias = None
     self.built = True
@@ -2911,3 +2926,50 @@
     return nest.map_structure(create_zeros, state_size)
   else:
     return create_zeros(state_size)
+
+
+def _caching_device(rnn_cell):
+  """Returns the caching device for the RNN variable.
+
+  This is useful for distributed training, when variable is not located as same
+  device as the training worker. By enabling the device cache, this allows
+  worker to read the variable once and cache locally, rather than read it every
+  time step from remote when it is needed.
+
+  Note that this is assuming the variable that cell needs for each time step is
+  having the same value in the forward path, and only gets updated in the
+  backprop. It is true for all the default cells (SimpleRNN, GRU, LSTM). If the
+  cell body relies on any variable that gets updated every time step, then
+  caching device will cause it to read the stall value.
+
+  Args:
+    rnn_cell: the rnn cell instance.
+  """
+  if context.executing_eagerly():
+    # caching_device is not supported in eager mode.
+    return None
+  if not getattr(rnn_cell, '_enable_caching_device', False):
+    return None
+  # Don't set a caching device when running in a loop, since it is possible that
+  # train steps could be wrapped in a tf.while_loop. In that scenario caching
+  # prevents forward computations in loop iterations from re-reading the
+  # updated weights.
+  if control_flow_util.IsInWhileLoop(ops.get_default_graph()):
+    logging.warn('Variable read device caching has been disabled because the '
+                 'RNN is in tf.while_loop loop context, which will cause '
+                 'reading stalled value in forward path. This could slow down '
+                 'the training due to duplicated variable reads. Please '
+                 'consider updating your code to remove tf.while_loop if '
+                 'possible.')
+    return None
+  if rnn_cell._dtype_policy.should_cast_variables:
+    logging.warn('Variable read device caching has been disabled since it '
+                 'doesn\'t work with the mixed precision API. This is '
+                 'likely to cause a slowdown for RNN training due to '
+                 'duplicated read of variable for each timestep, which '
+                 'will be significant in a multi remote worker setting. '
+                 'Please consider disabling mixed precision API if '
+                 'the performance has been affected.')
+    return None
+  # Cache the value on the device that access the variable.
+  return lambda op: op.device