Undo LSTM refactoring.

PiperOrigin-RevId: 321060640
Change-Id: Ibd2e5aa7481869ead6ed8d00de1f51c487fa760b
diff --git a/tensorflow/python/keras/layers/recurrent_v2.py b/tensorflow/python/keras/layers/recurrent_v2.py
index 58eb0bb..33babb5 100644
--- a/tensorflow/python/keras/layers/recurrent_v2.py
+++ b/tensorflow/python/keras/layers/recurrent_v2.py
@@ -385,17 +385,6 @@
       else:
         logging.warn(_CUDNN_NOT_AVAILABLE_MSG % self.name)
 
-    # The first two attributes are added to support TFLite use case.
-    supportive_attributes = {
-        'time_major': time_major,
-        'go_backwards': go_backwards,
-        _FUNCTION_API_NAME_ATTRIBUTE: 'gru_' + str(uuid.uuid4())
-    }
-    self.defun_gru_with_backend_selection = function.defun_with_attributes(
-        gru_with_backend_selection,
-        attributes=supportive_attributes,
-        autograph=False)
-
   def build(self, input_shape):
     super(GRU, self).build(input_shape)
 
@@ -478,7 +467,7 @@
     if dropout_mask is not None:
       inputs = inputs * dropout_mask[0]
 
-    gru_kwargs = {
+    gpu_gru_kwargs = {
         'inputs': inputs,
         'init_h': _read_variable_value(initial_state[0]),
         'kernel': _read_variable_value(self.cell.kernel),
@@ -487,11 +476,29 @@
         'mask': mask,
         'time_major': self.time_major,
         'go_backwards': self.go_backwards,
-        'sequence_lengths': sequence_lengths,
-        'zero_output_for_mask': self.zero_output_for_mask
+        'sequence_lengths': sequence_lengths
     }
-    (last_output, outputs, new_h,
-     runtime) = self.defun_gru_with_backend_selection(**gru_kwargs)
+    normal_gru_kwargs = gpu_gru_kwargs.copy()
+    normal_gru_kwargs.update({
+        'zero_output_for_mask': self.zero_output_for_mask,
+    })
+
+    if context.executing_eagerly():
+      device_type = _get_context_device_type()
+      can_use_gpu = (
+          # Either user specified GPU or unspecified but GPU is available.
+          (device_type == _GPU_DEVICE_NAME
+           or (device_type is None and context.num_gpus() > 0))
+          and
+          (mask is None or is_sequence_right_padded(mask, self.time_major)))
+      # Under eager context, check the device placement and prefer the
+      if can_use_gpu:
+        last_output, outputs, new_h, runtime = gpu_gru(**gpu_gru_kwargs)
+      else:
+        last_output, outputs, new_h, runtime = standard_gru(**normal_gru_kwargs)
+    else:
+      last_output, outputs, new_h, runtime = gru_with_backend_selection(
+          **normal_gru_kwargs)
 
     states = [new_h]
     return last_output, outputs, runtime, states
@@ -758,14 +765,24 @@
         true_fn=input_right_padded,
         false_fn=input_not_right_padded)
 
-  # Chooses the implementation dynamicly based on the running device.
-  (last_output, outputs, new_h,
-   runtime) = control_flow_ops.execute_fn_for_device(
-       {
-           _CPU_DEVICE_NAME: lambda: standard_gru(**params),
-           _GPU_DEVICE_NAME: lambda: gpu_gru_with_fallback(**params)
-       }, lambda: standard_gru(**params))
+  # Each time a `tf.function` is called, we will give it a unique
+  # identifiable API name, so that Grappler won't get confused when it
+  # sees multiple GRU layers added into same graph, and it will be able
+  # to pair up the different implementations across them.
+  api_name = 'gru_' + str(uuid.uuid4())
+  supportive_attribute = {
+      'time_major': time_major,
+      'go_backwards': go_backwards,
+  }
+  defun_standard_gru = _generate_defun_backend(
+      api_name, _CPU_DEVICE_NAME, standard_gru, supportive_attribute)
+  defun_gpu_gru = _generate_defun_backend(
+      api_name, _GPU_DEVICE_NAME, gpu_gru_with_fallback, supportive_attribute)
 
+  # Call the normal GRU impl and register the CuDNN impl function. The
+  # grappler will kick in during session execution to optimize the graph.
+  last_output, outputs, new_h, runtime = defun_standard_gru(**params)
+  function.register(defun_gpu_gru, **params)
   return last_output, outputs, new_h, runtime
 
 
@@ -1080,18 +1097,6 @@
       else:
         logging.warn(_CUDNN_NOT_AVAILABLE_MSG % self.name)
 
-    # The first two attributes are added to support TFLite use case.
-    supportive_attributes = {
-        'time_major': time_major,
-        'go_backwards': go_backwards,
-        _FUNCTION_API_NAME_ATTRIBUTE: 'lstm_' + str(uuid.uuid4())
-    }
-
-    self.defun_lstm_with_backend_selection = function.defun_with_attributes(
-        lstm_with_backend_selection,
-        attributes=supportive_attributes,
-        autograph=False)
-
   def call(self, inputs, mask=None, training=None, initial_state=None):
     # The input should be dense, padded with zeros. If a ragged input is fed
     # into the layer, it is padded and the row lengths are used for masking.
@@ -1140,7 +1145,7 @@
       dropout_mask = self.get_dropout_mask_for_cell(inputs, training, count=4)
       if dropout_mask is not None:
         inputs = inputs * dropout_mask[0]
-      lstm_kwargs = {
+      gpu_lstm_kwargs = {
           'inputs': inputs,
           'init_h': _read_variable_value(initial_state[0]),
           'init_c': _read_variable_value(initial_state[1]),
@@ -1150,11 +1155,32 @@
           'mask': mask,
           'time_major': self.time_major,
           'go_backwards': self.go_backwards,
-          'sequence_lengths': row_lengths,
-          'zero_output_for_mask': self.zero_output_for_mask,
+          'sequence_lengths': row_lengths
       }
-      (last_output, outputs, new_h, new_c,
-       runtime) = self.defun_lstm_with_backend_selection(**lstm_kwargs)
+      normal_lstm_kwargs = gpu_lstm_kwargs.copy()
+      normal_lstm_kwargs.update({
+          'zero_output_for_mask': self.zero_output_for_mask,
+      })
+
+      if context.executing_eagerly():
+        device_type = _get_context_device_type()
+        can_use_gpu = (
+            # Either user specified GPU or unspecified but GPU is available.
+            (device_type == _GPU_DEVICE_NAME
+             or (device_type is None and context.num_gpus() > 0))
+            and
+            (mask is None or is_sequence_right_padded(mask, self.time_major)))
+        # Under eager context, check the device placement and prefer the
+        # GPU implementation when GPU is available.
+        if can_use_gpu:
+          last_output, outputs, new_h, new_c, runtime = gpu_lstm(
+              **gpu_lstm_kwargs)
+        else:
+          last_output, outputs, new_h, new_c, runtime = standard_lstm(
+              **normal_lstm_kwargs)
+      else:
+        (last_output, outputs, new_h, new_c,
+         runtime) = lstm_with_backend_selection(**normal_lstm_kwargs)
 
       states = [new_h, new_c]
 
@@ -1512,13 +1538,25 @@
         true_fn=input_right_padded,
         false_fn=input_not_right_padded)
 
-  # Chooses the implementation dynamicly based on the running device.
-  (last_output, outputs, new_h, new_c,
-   runtime) = control_flow_ops.execute_fn_for_device(
-       {
-           _CPU_DEVICE_NAME: lambda: standard_lstm(**params),
-           _GPU_DEVICE_NAME: lambda: gpu_lstm_with_fallback(**params)
-       }, lambda: standard_lstm(**params))
+  # Each time a `tf.function` is called, we will give it a unique
+  # identifiable API name, so that Grappler won't get confused when it
+  # sees multiple LSTM layers added into same graph, and it will be able
+  # to pair up the different implementations across them.
+  api_name = 'lstm_' + str(uuid.uuid4())
+  supportive_attribute = {
+      'time_major': time_major,
+      'go_backwards': go_backwards,
+  }
+  defun_standard_lstm = _generate_defun_backend(
+      api_name, _CPU_DEVICE_NAME, standard_lstm, supportive_attribute)
+  defun_gpu_lstm = _generate_defun_backend(
+      api_name, _GPU_DEVICE_NAME, gpu_lstm_with_fallback, supportive_attribute)
+
+  # Call the normal LSTM impl and register the CuDNN impl function. The
+  # grappler will kick in during session execution to optimize the graph.
+  last_output, outputs, new_h, new_c, runtime = defun_standard_lstm(
+      **params)
+  function.register(defun_gpu_lstm, **params)
 
   return last_output, outputs, new_h, new_c, runtime
 
@@ -1581,6 +1619,18 @@
                              axis=timestep_index)
 
 
+def _generate_defun_backend(unique_api_name, preferred_device, func,
+                            supportive_attributes):
+  function_attributes = {
+      _FUNCTION_API_NAME_ATTRIBUTE: unique_api_name,
+      _FUNCTION_DEVICE_ATTRIBUTE: preferred_device,
+  }
+  function_attributes.update(supportive_attributes)
+  return function.defun_with_attributes(func=func,
+                                        attributes=function_attributes,
+                                        autograph=False)
+
+
 def _get_context_device_type():
   """Parse the current context and return the device type, eg CPU/GPU."""
   current_device = context.context().device_name