blob: 31e4548bb28b258c077e7c41e9bbaa00ed06985f [file] [log] [blame]
# Copyright 2019 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Tests for RNN cell wrapper v2 implementation."""
from absl.testing import parameterized
import numpy as np
from tensorflow.python.framework import dtypes
from tensorflow.python.framework import ops
from tensorflow.python.keras import combinations
from tensorflow.python.keras import layers
from tensorflow.python.keras.layers import rnn_cell_wrapper_v2
from tensorflow.python.keras.layers.legacy_rnn import rnn_cell_impl
from tensorflow.python.keras.legacy_tf_layers import base as legacy_base_layer
from tensorflow.python.keras.utils import generic_utils
from tensorflow.python.ops import array_ops
from tensorflow.python.ops import init_ops
from tensorflow.python.ops import variables as variables_lib
from tensorflow.python.platform import test
@combinations.generate(combinations.combine(mode=["graph", "eager"]))
class RNNCellWrapperTest(test.TestCase, parameterized.TestCase):
def testResidualWrapper(self):
wrapper_type = rnn_cell_wrapper_v2.ResidualWrapper
x = ops.convert_to_tensor_v2_with_dispatch(
np.array([[1., 1., 1.]]), dtype="float32")
m = ops.convert_to_tensor_v2_with_dispatch(
np.array([[0.1, 0.1, 0.1]]), dtype="float32")
base_cell = rnn_cell_impl.GRUCell(
3, kernel_initializer=init_ops.constant_initializer(0.5),
bias_initializer=init_ops.constant_initializer(0.5))
g, m_new = base_cell(x, m)
wrapper_object = wrapper_type(base_cell)
children = wrapper_object._trackable_children()
wrapper_object.get_config() # Should not throw an error
self.assertIn("cell", children)
self.assertIs(children["cell"], base_cell)
g_res, m_new_res = wrapper_object(x, m)
self.evaluate([variables_lib.global_variables_initializer()])
res = self.evaluate([g, g_res, m_new, m_new_res])
# Residual connections
self.assertAllClose(res[1], res[0] + [1., 1., 1.])
# States are left untouched
self.assertAllClose(res[2], res[3])
def testResidualWrapperWithSlice(self):
wrapper_type = rnn_cell_wrapper_v2.ResidualWrapper
x = ops.convert_to_tensor_v2_with_dispatch(
np.array([[1., 1., 1., 1., 1.]]), dtype="float32")
m = ops.convert_to_tensor_v2_with_dispatch(
np.array([[0.1, 0.1, 0.1]]), dtype="float32")
base_cell = rnn_cell_impl.GRUCell(
3, kernel_initializer=init_ops.constant_initializer(0.5),
bias_initializer=init_ops.constant_initializer(0.5))
g, m_new = base_cell(x, m)
def residual_with_slice_fn(inp, out):
inp_sliced = array_ops.slice(inp, [0, 0], [-1, 3])
return inp_sliced + out
g_res, m_new_res = wrapper_type(
base_cell, residual_with_slice_fn)(x, m)
self.evaluate([variables_lib.global_variables_initializer()])
res_g, res_g_res, res_m_new, res_m_new_res = self.evaluate(
[g, g_res, m_new, m_new_res])
# Residual connections
self.assertAllClose(res_g_res, res_g + [1., 1., 1.])
# States are left untouched
self.assertAllClose(res_m_new, res_m_new_res)
def testDeviceWrapper(self):
wrapper_type = rnn_cell_wrapper_v2.DeviceWrapper
x = array_ops.zeros([1, 3])
m = array_ops.zeros([1, 3])
cell = rnn_cell_impl.GRUCell(3)
wrapped_cell = wrapper_type(cell, "/cpu:0")
children = wrapped_cell._trackable_children()
wrapped_cell.get_config() # Should not throw an error
self.assertIn("cell", children)
self.assertIs(children["cell"], cell)
outputs, _ = wrapped_cell(x, m)
self.assertIn("cpu:0", outputs.device.lower())
@parameterized.parameters(
[[rnn_cell_impl.DropoutWrapper, rnn_cell_wrapper_v2.DropoutWrapper],
[rnn_cell_impl.ResidualWrapper, rnn_cell_wrapper_v2.ResidualWrapper]])
def testWrapperKerasStyle(self, wrapper, wrapper_v2):
"""Tests if wrapper cell is instantiated in keras style scope."""
wrapped_cell_v2 = wrapper_v2(rnn_cell_impl.BasicRNNCell(1))
self.assertIsNone(getattr(wrapped_cell_v2, "_keras_style", None))
wrapped_cell = wrapper(rnn_cell_impl.BasicRNNCell(1))
self.assertFalse(wrapped_cell._keras_style)
@parameterized.parameters(
[rnn_cell_wrapper_v2.DropoutWrapper, rnn_cell_wrapper_v2.ResidualWrapper])
def testWrapperWeights(self, wrapper):
"""Tests that wrapper weights contain wrapped cells weights."""
base_cell = layers.SimpleRNNCell(1, name="basic_rnn_cell")
rnn_cell = wrapper(base_cell)
rnn_layer = layers.RNN(rnn_cell)
inputs = ops.convert_to_tensor_v2_with_dispatch([[[1]]],
dtype=dtypes.float32)
rnn_layer(inputs)
wrapper_name = generic_utils.to_snake_case(wrapper.__name__)
expected_weights = ["rnn/" + wrapper_name + "/" + var for var in
("kernel:0", "recurrent_kernel:0", "bias:0")]
self.assertLen(rnn_cell.weights, 3)
self.assertCountEqual([v.name for v in rnn_cell.weights], expected_weights)
self.assertCountEqual([v.name for v in rnn_cell.trainable_variables],
expected_weights)
self.assertCountEqual([v.name for v in rnn_cell.non_trainable_variables],
[])
self.assertCountEqual([v.name for v in rnn_cell.cell.weights],
expected_weights)
@parameterized.parameters(
[rnn_cell_wrapper_v2.DropoutWrapper, rnn_cell_wrapper_v2.ResidualWrapper])
def testWrapperV2Caller(self, wrapper):
"""Tests that wrapper V2 is using the LayerRNNCell's caller."""
with legacy_base_layer.keras_style_scope():
base_cell = rnn_cell_impl.MultiRNNCell(
[rnn_cell_impl.BasicRNNCell(1) for _ in range(2)])
rnn_cell = wrapper(base_cell)
inputs = ops.convert_to_tensor_v2_with_dispatch([[1]], dtype=dtypes.float32)
state = ops.convert_to_tensor_v2_with_dispatch([[1]], dtype=dtypes.float32)
_ = rnn_cell(inputs, [state, state])
weights = base_cell._cells[0].weights
self.assertLen(weights, expected_len=2)
self.assertTrue(all("_wrapper" in v.name for v in weights))
@parameterized.parameters(
[rnn_cell_wrapper_v2.DropoutWrapper, rnn_cell_wrapper_v2.ResidualWrapper])
def testWrapperV2Build(self, wrapper):
cell = rnn_cell_impl.LSTMCell(10)
wrapper = wrapper(cell)
wrapper.build((1,))
self.assertTrue(cell.built)
def testDeviceWrapperSerialization(self):
wrapper_cls = rnn_cell_wrapper_v2.DeviceWrapper
cell = layers.LSTMCell(10)
wrapper = wrapper_cls(cell, "/cpu:0")
config = wrapper.get_config()
reconstructed_wrapper = wrapper_cls.from_config(config)
self.assertDictEqual(config, reconstructed_wrapper.get_config())
self.assertIsInstance(reconstructed_wrapper, wrapper_cls)
def testResidualWrapperSerialization(self):
wrapper_cls = rnn_cell_wrapper_v2.ResidualWrapper
cell = layers.LSTMCell(10)
wrapper = wrapper_cls(cell)
config = wrapper.get_config()
reconstructed_wrapper = wrapper_cls.from_config(config)
self.assertDictEqual(config, reconstructed_wrapper.get_config())
self.assertIsInstance(reconstructed_wrapper, wrapper_cls)
wrapper = wrapper_cls(cell, residual_fn=lambda i, o: i + i + o)
config = wrapper.get_config()
reconstructed_wrapper = wrapper_cls.from_config(config)
# Assert the reconstructed function will perform the math correctly.
self.assertEqual(reconstructed_wrapper._residual_fn(1, 2), 4)
def residual_fn(inputs, outputs):
return inputs * 3 + outputs
wrapper = wrapper_cls(cell, residual_fn=residual_fn)
config = wrapper.get_config()
reconstructed_wrapper = wrapper_cls.from_config(config)
# Assert the reconstructed function will perform the math correctly.
self.assertEqual(reconstructed_wrapper._residual_fn(1, 2), 5)
def testDropoutWrapperSerialization(self):
wrapper_cls = rnn_cell_wrapper_v2.DropoutWrapper
cell = layers.GRUCell(10)
wrapper = wrapper_cls(cell)
config = wrapper.get_config()
reconstructed_wrapper = wrapper_cls.from_config(config)
self.assertDictEqual(config, reconstructed_wrapper.get_config())
self.assertIsInstance(reconstructed_wrapper, wrapper_cls)
wrapper = wrapper_cls(cell, dropout_state_filter_visitor=lambda s: True)
config = wrapper.get_config()
reconstructed_wrapper = wrapper_cls.from_config(config)
self.assertTrue(reconstructed_wrapper._dropout_state_filter(None))
def dropout_state_filter_visitor(unused_state):
return False
wrapper = wrapper_cls(
cell, dropout_state_filter_visitor=dropout_state_filter_visitor)
config = wrapper.get_config()
reconstructed_wrapper = wrapper_cls.from_config(config)
self.assertFalse(reconstructed_wrapper._dropout_state_filter(None))
def testDropoutWrapperWithKerasLSTMCell(self):
wrapper_cls = rnn_cell_wrapper_v2.DropoutWrapper
cell = layers.LSTMCell(10)
with self.assertRaisesRegex(ValueError, "does not work with "):
wrapper_cls(cell)
cell = layers.LSTMCellV2(10)
with self.assertRaisesRegex(ValueError, "does not work with "):
wrapper_cls(cell)
if __name__ == "__main__":
test.main()