Undo LSTM refactoring.
PiperOrigin-RevId: 321060640
Change-Id: Ibd2e5aa7481869ead6ed8d00de1f51c487fa760b
diff --git a/tensorflow/python/keras/layers/recurrent_v2.py b/tensorflow/python/keras/layers/recurrent_v2.py
index 58eb0bb..33babb5 100644
--- a/tensorflow/python/keras/layers/recurrent_v2.py
+++ b/tensorflow/python/keras/layers/recurrent_v2.py
@@ -385,17 +385,6 @@
else:
logging.warn(_CUDNN_NOT_AVAILABLE_MSG % self.name)
- # The first two attributes are added to support TFLite use case.
- supportive_attributes = {
- 'time_major': time_major,
- 'go_backwards': go_backwards,
- _FUNCTION_API_NAME_ATTRIBUTE: 'gru_' + str(uuid.uuid4())
- }
- self.defun_gru_with_backend_selection = function.defun_with_attributes(
- gru_with_backend_selection,
- attributes=supportive_attributes,
- autograph=False)
-
def build(self, input_shape):
super(GRU, self).build(input_shape)
@@ -478,7 +467,7 @@
if dropout_mask is not None:
inputs = inputs * dropout_mask[0]
- gru_kwargs = {
+ gpu_gru_kwargs = {
'inputs': inputs,
'init_h': _read_variable_value(initial_state[0]),
'kernel': _read_variable_value(self.cell.kernel),
@@ -487,11 +476,29 @@
'mask': mask,
'time_major': self.time_major,
'go_backwards': self.go_backwards,
- 'sequence_lengths': sequence_lengths,
- 'zero_output_for_mask': self.zero_output_for_mask
+ 'sequence_lengths': sequence_lengths
}
- (last_output, outputs, new_h,
- runtime) = self.defun_gru_with_backend_selection(**gru_kwargs)
+ normal_gru_kwargs = gpu_gru_kwargs.copy()
+ normal_gru_kwargs.update({
+ 'zero_output_for_mask': self.zero_output_for_mask,
+ })
+
+ if context.executing_eagerly():
+ device_type = _get_context_device_type()
+ can_use_gpu = (
+ # Either user specified GPU or unspecified but GPU is available.
+ (device_type == _GPU_DEVICE_NAME
+ or (device_type is None and context.num_gpus() > 0))
+ and
+ (mask is None or is_sequence_right_padded(mask, self.time_major)))
+ # Under eager context, check the device placement and prefer the
+ if can_use_gpu:
+ last_output, outputs, new_h, runtime = gpu_gru(**gpu_gru_kwargs)
+ else:
+ last_output, outputs, new_h, runtime = standard_gru(**normal_gru_kwargs)
+ else:
+ last_output, outputs, new_h, runtime = gru_with_backend_selection(
+ **normal_gru_kwargs)
states = [new_h]
return last_output, outputs, runtime, states
@@ -758,14 +765,24 @@
true_fn=input_right_padded,
false_fn=input_not_right_padded)
- # Chooses the implementation dynamicly based on the running device.
- (last_output, outputs, new_h,
- runtime) = control_flow_ops.execute_fn_for_device(
- {
- _CPU_DEVICE_NAME: lambda: standard_gru(**params),
- _GPU_DEVICE_NAME: lambda: gpu_gru_with_fallback(**params)
- }, lambda: standard_gru(**params))
+ # Each time a `tf.function` is called, we will give it a unique
+ # identifiable API name, so that Grappler won't get confused when it
+ # sees multiple GRU layers added into same graph, and it will be able
+ # to pair up the different implementations across them.
+ api_name = 'gru_' + str(uuid.uuid4())
+ supportive_attribute = {
+ 'time_major': time_major,
+ 'go_backwards': go_backwards,
+ }
+ defun_standard_gru = _generate_defun_backend(
+ api_name, _CPU_DEVICE_NAME, standard_gru, supportive_attribute)
+ defun_gpu_gru = _generate_defun_backend(
+ api_name, _GPU_DEVICE_NAME, gpu_gru_with_fallback, supportive_attribute)
+ # Call the normal GRU impl and register the CuDNN impl function. The
+ # grappler will kick in during session execution to optimize the graph.
+ last_output, outputs, new_h, runtime = defun_standard_gru(**params)
+ function.register(defun_gpu_gru, **params)
return last_output, outputs, new_h, runtime
@@ -1080,18 +1097,6 @@
else:
logging.warn(_CUDNN_NOT_AVAILABLE_MSG % self.name)
- # The first two attributes are added to support TFLite use case.
- supportive_attributes = {
- 'time_major': time_major,
- 'go_backwards': go_backwards,
- _FUNCTION_API_NAME_ATTRIBUTE: 'lstm_' + str(uuid.uuid4())
- }
-
- self.defun_lstm_with_backend_selection = function.defun_with_attributes(
- lstm_with_backend_selection,
- attributes=supportive_attributes,
- autograph=False)
-
def call(self, inputs, mask=None, training=None, initial_state=None):
# The input should be dense, padded with zeros. If a ragged input is fed
# into the layer, it is padded and the row lengths are used for masking.
@@ -1140,7 +1145,7 @@
dropout_mask = self.get_dropout_mask_for_cell(inputs, training, count=4)
if dropout_mask is not None:
inputs = inputs * dropout_mask[0]
- lstm_kwargs = {
+ gpu_lstm_kwargs = {
'inputs': inputs,
'init_h': _read_variable_value(initial_state[0]),
'init_c': _read_variable_value(initial_state[1]),
@@ -1150,11 +1155,32 @@
'mask': mask,
'time_major': self.time_major,
'go_backwards': self.go_backwards,
- 'sequence_lengths': row_lengths,
- 'zero_output_for_mask': self.zero_output_for_mask,
+ 'sequence_lengths': row_lengths
}
- (last_output, outputs, new_h, new_c,
- runtime) = self.defun_lstm_with_backend_selection(**lstm_kwargs)
+ normal_lstm_kwargs = gpu_lstm_kwargs.copy()
+ normal_lstm_kwargs.update({
+ 'zero_output_for_mask': self.zero_output_for_mask,
+ })
+
+ if context.executing_eagerly():
+ device_type = _get_context_device_type()
+ can_use_gpu = (
+ # Either user specified GPU or unspecified but GPU is available.
+ (device_type == _GPU_DEVICE_NAME
+ or (device_type is None and context.num_gpus() > 0))
+ and
+ (mask is None or is_sequence_right_padded(mask, self.time_major)))
+ # Under eager context, check the device placement and prefer the
+ # GPU implementation when GPU is available.
+ if can_use_gpu:
+ last_output, outputs, new_h, new_c, runtime = gpu_lstm(
+ **gpu_lstm_kwargs)
+ else:
+ last_output, outputs, new_h, new_c, runtime = standard_lstm(
+ **normal_lstm_kwargs)
+ else:
+ (last_output, outputs, new_h, new_c,
+ runtime) = lstm_with_backend_selection(**normal_lstm_kwargs)
states = [new_h, new_c]
@@ -1512,13 +1538,25 @@
true_fn=input_right_padded,
false_fn=input_not_right_padded)
- # Chooses the implementation dynamicly based on the running device.
- (last_output, outputs, new_h, new_c,
- runtime) = control_flow_ops.execute_fn_for_device(
- {
- _CPU_DEVICE_NAME: lambda: standard_lstm(**params),
- _GPU_DEVICE_NAME: lambda: gpu_lstm_with_fallback(**params)
- }, lambda: standard_lstm(**params))
+ # Each time a `tf.function` is called, we will give it a unique
+ # identifiable API name, so that Grappler won't get confused when it
+ # sees multiple LSTM layers added into same graph, and it will be able
+ # to pair up the different implementations across them.
+ api_name = 'lstm_' + str(uuid.uuid4())
+ supportive_attribute = {
+ 'time_major': time_major,
+ 'go_backwards': go_backwards,
+ }
+ defun_standard_lstm = _generate_defun_backend(
+ api_name, _CPU_DEVICE_NAME, standard_lstm, supportive_attribute)
+ defun_gpu_lstm = _generate_defun_backend(
+ api_name, _GPU_DEVICE_NAME, gpu_lstm_with_fallback, supportive_attribute)
+
+ # Call the normal LSTM impl and register the CuDNN impl function. The
+ # grappler will kick in during session execution to optimize the graph.
+ last_output, outputs, new_h, new_c, runtime = defun_standard_lstm(
+ **params)
+ function.register(defun_gpu_lstm, **params)
return last_output, outputs, new_h, new_c, runtime
@@ -1581,6 +1619,18 @@
axis=timestep_index)
+def _generate_defun_backend(unique_api_name, preferred_device, func,
+ supportive_attributes):
+ function_attributes = {
+ _FUNCTION_API_NAME_ATTRIBUTE: unique_api_name,
+ _FUNCTION_DEVICE_ATTRIBUTE: preferred_device,
+ }
+ function_attributes.update(supportive_attributes)
+ return function.defun_with_attributes(func=func,
+ attributes=function_attributes,
+ autograph=False)
+
+
def _get_context_device_type():
"""Parse the current context and return the device type, eg CPU/GPU."""
current_device = context.context().device_name