| # Copyright 2015 The TensorFlow Authors. All Rights Reserved. |
| # |
| # Licensed under the Apache License, Version 2.0 (the "License"); |
| # you may not use this file except in compliance with the License. |
| # You may obtain a copy of the License at |
| # |
| # http://www.apache.org/licenses/LICENSE-2.0 |
| # |
| # Unless required by applicable law or agreed to in writing, software |
| # distributed under the License is distributed on an "AS IS" BASIS, |
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
| # See the License for the specific language governing permissions and |
| # limitations under the License. |
| # ============================================================================== |
| """Tests for RNN cells.""" |
| |
| from __future__ import absolute_import |
| from __future__ import division |
| from __future__ import print_function |
| |
| import itertools |
| import os |
| |
| from absl.testing import parameterized |
| import numpy as np |
| |
| from tensorflow.core.protobuf import config_pb2 |
| from tensorflow.python.eager import context |
| from tensorflow.python.framework import constant_op |
| from tensorflow.python.framework import dtypes |
| from tensorflow.python.framework import errors_impl |
| from tensorflow.python.framework import ops |
| from tensorflow.python.framework import random_seed |
| from tensorflow.python.framework import tensor_shape |
| from tensorflow.python.framework import test_util |
| from tensorflow.python.ops import array_ops |
| from tensorflow.python.ops import control_flow_ops |
| from tensorflow.python.ops import gradients_impl |
| from tensorflow.python.ops import init_ops |
| from tensorflow.python.ops import math_ops |
| from tensorflow.python.ops import rnn |
| from tensorflow.python.ops import rnn_cell |
| from tensorflow.python.ops import rnn_cell_impl |
| from tensorflow.python.ops import state_ops |
| from tensorflow.python.ops import tensor_array_ops |
| from tensorflow.python.ops import variable_scope |
| from tensorflow.python.ops import variables as variables_lib |
| from tensorflow.python.platform import test |
| from tensorflow.python.platform import tf_logging |
| from tensorflow.python.training.tracking import util as trackable_utils |
| from tensorflow.python.util import nest |
| |
| |
| class Plus1RNNCell(rnn_cell.RNNCell): |
| """RNN Cell generating (output, new_state) = (input + 1, state + 1).""" |
| |
| @property |
| def output_size(self): |
| return 5 |
| |
| @property |
| def state_size(self): |
| return 5 |
| |
| def __call__(self, input_, state, scope=None): |
| return (input_ + 1, state + 1) |
| |
| |
| class DummyMultiDimensionalLSTM(rnn_cell.RNNCell): |
| """LSTM Cell generating (output, new_state) = (input + 1, state + 1). |
| |
| The input to this cell may have an arbitrary number of dimensions that follow |
| the preceding 'Time' and 'Batch' dimensions. |
| """ |
| |
| def __init__(self, dims): |
| """Initialize the Multi-dimensional LSTM cell. |
| |
| Args: |
| dims: tuple that contains the dimensions of the output of the cell, |
| without including 'Time' or 'Batch' dimensions. |
| """ |
| if not isinstance(dims, tuple): |
| raise TypeError("The dimensions passed to DummyMultiDimensionalLSTM " |
| "should be a tuple of ints.") |
| self._dims = dims |
| self._output_size = tensor_shape.TensorShape(self._dims) |
| self._state_size = (tensor_shape.TensorShape(self._dims), |
| tensor_shape.TensorShape(self._dims)) |
| |
| @property |
| def output_size(self): |
| return self._output_size |
| |
| @property |
| def state_size(self): |
| return self._state_size |
| |
| def __call__(self, input_, state, scope=None): |
| h, c = state |
| return (input_ + 1, (h + 1, c + 1)) |
| |
| |
| class NestedRNNCell(rnn_cell.RNNCell): |
| """RNN Cell generating (output, new_state) = (input + 1, state + 1). |
| |
| The input, output and state of this cell is a tuple of two tensors. |
| """ |
| |
| @property |
| def output_size(self): |
| return (5, 5) |
| |
| @property |
| def state_size(self): |
| return (6, 6) |
| |
| def __call__(self, input_, state, scope=None): |
| h, c = state |
| x, y = input_ |
| return ((x + 1, y + 1), (h + 1, c + 1)) |
| |
| |
| class TestStateSaver(object): |
| |
| def __init__(self, batch_size, state_size): |
| self._batch_size = batch_size |
| self._state_size = state_size |
| self.saved_state = {} |
| |
| def state(self, name): |
| |
| if isinstance(self._state_size, dict): |
| state_size = self._state_size[name] |
| else: |
| state_size = self._state_size |
| if isinstance(state_size, int): |
| state_size = (state_size,) |
| elif isinstance(state_size, tuple): |
| pass |
| else: |
| raise TypeError("state_size should either be an int or a tuple") |
| |
| return array_ops.zeros((self._batch_size,) + state_size) |
| |
| def save_state(self, name, state): |
| self.saved_state[name] = state |
| return array_ops.identity(state) |
| |
| @property |
| def batch_size(self): |
| return self._batch_size |
| |
| @property |
| def state_size(self): |
| return self._state_size |
| |
| |
| class TestStateSaverWithCounters(TestStateSaver): |
| """Class wrapper around TestStateSaver. |
| |
| A dummy class used for testing of static_state_saving_rnn. It helps test if |
| save_state and state functions got called same number of time when we |
| evaluate output of rnn cell and state or either of them separately. It |
| inherits from the TestStateSaver and adds the counters for calls of functions. |
| """ |
| |
| @test_util.run_v1_only("b/124229375") |
| def __init__(self, batch_size, state_size): |
| super(TestStateSaverWithCounters, self).__init__(batch_size, state_size) |
| self._num_state_calls = variables_lib.VariableV1(0) |
| self._num_save_state_calls = variables_lib.VariableV1(0) |
| |
| def state(self, name): |
| with ops.control_dependencies( |
| [state_ops.assign_add(self._num_state_calls, 1)]): |
| return super(TestStateSaverWithCounters, self).state(name) |
| |
| def save_state(self, name, state): |
| with ops.control_dependencies([state_ops.assign_add( |
| self._num_save_state_calls, 1)]): |
| return super(TestStateSaverWithCounters, self).save_state(name, state) |
| |
| @property |
| def num_state_calls(self): |
| return self._num_state_calls |
| |
| @property |
| def num_save_state_calls(self): |
| return self._num_save_state_calls |
| |
| |
| class RNNTest(test.TestCase): |
| |
| def setUp(self): |
| self._seed = 23489 |
| np.random.seed(self._seed) |
| |
| @test_util.run_v1_only("b/124229375") |
| def testInvalidSequenceLengthShape(self): |
| cell = Plus1RNNCell() |
| inputs = [array_ops.placeholder(dtypes.float32, shape=(3, 4))] |
| with self.assertRaisesRegexp(ValueError, "must be a vector"): |
| rnn.static_rnn(cell, inputs, dtype=dtypes.float32, sequence_length=4) |
| |
| @test_util.run_v1_only("b/124229375") |
| def testRNN(self): |
| cell = Plus1RNNCell() |
| batch_size = 2 |
| input_size = 5 |
| max_length = 8 # unrolled up to this length |
| inputs = max_length * [ |
| array_ops.placeholder(dtypes.float32, shape=(batch_size, input_size)) |
| ] |
| outputs, state = rnn.static_rnn(cell, inputs, dtype=dtypes.float32) |
| self.assertEqual(len(outputs), len(inputs)) |
| for out, inp in zip(outputs, inputs): |
| self.assertEqual(out.get_shape(), inp.get_shape()) |
| self.assertEqual(out.dtype, inp.dtype) |
| |
| with self.session(use_gpu=True) as sess: |
| input_value = np.random.randn(batch_size, input_size) |
| values = sess.run(outputs + [state], feed_dict={inputs[0]: input_value}) |
| |
| # Outputs |
| for v in values[:-1]: |
| self.assertAllClose(v, input_value + 1.0) |
| |
| # Final state |
| self.assertAllClose(values[-1], |
| max_length * np.ones( |
| (batch_size, input_size), dtype=np.float32)) |
| |
| @test_util.run_v1_only("b/124229375") |
| def testDropout(self): |
| cell = Plus1RNNCell() |
| full_dropout_cell = rnn_cell.DropoutWrapper( |
| cell, input_keep_prob=1e-6, seed=0) |
| (name, dep), = full_dropout_cell._checkpoint_dependencies |
| self.assertIs(dep, cell) |
| self.assertEqual("cell", name) |
| batch_size = 2 |
| input_size = 5 |
| max_length = 8 |
| inputs = max_length * [ |
| array_ops.placeholder(dtypes.float32, shape=(batch_size, input_size)) |
| ] |
| with variable_scope.variable_scope("share_scope"): |
| outputs, state = rnn.static_rnn(cell, inputs, dtype=dtypes.float32) |
| with variable_scope.variable_scope("drop_scope"): |
| dropped_outputs, _ = rnn.static_rnn( |
| full_dropout_cell, inputs, dtype=dtypes.float32) |
| self.assertEqual(len(outputs), len(inputs)) |
| for out, inp in zip(outputs, inputs): |
| self.assertEqual(out.get_shape().as_list(), inp.get_shape().as_list()) |
| self.assertEqual(out.dtype, inp.dtype) |
| |
| with self.session(use_gpu=True) as sess: |
| input_value = np.random.randn(batch_size, input_size) |
| values = sess.run(outputs + [state], feed_dict={inputs[0]: input_value}) |
| full_dropout_values = sess.run( |
| dropped_outputs, feed_dict={ |
| inputs[0]: input_value |
| }) |
| |
| for v in values[:-1]: |
| self.assertAllClose(v, input_value + 1.0) |
| for d_v in full_dropout_values[:-1]: # Add 1.0 to dropped_out (all zeros) |
| self.assertAllClose(d_v, np.ones_like(input_value)) |
| |
| @test_util.run_v1_only("b/124229375") |
| def testDynamicCalculation(self): |
| cell = Plus1RNNCell() |
| sequence_length = array_ops.placeholder(dtypes.int64) |
| batch_size = 2 |
| input_size = 5 |
| max_length = 8 |
| inputs = max_length * [ |
| array_ops.placeholder(dtypes.float32, shape=(batch_size, input_size)) |
| ] |
| with variable_scope.variable_scope("drop_scope"): |
| dynamic_outputs, dynamic_state = rnn.static_rnn( |
| cell, inputs, sequence_length=sequence_length, dtype=dtypes.float32) |
| self.assertEqual(len(dynamic_outputs), len(inputs)) |
| |
| with self.session(use_gpu=True) as sess: |
| input_value = np.random.randn(batch_size, input_size) |
| dynamic_values = sess.run( |
| dynamic_outputs, |
| feed_dict={ |
| inputs[0]: input_value, |
| sequence_length: [2, 3] |
| }) |
| dynamic_state_value = sess.run( |
| [dynamic_state], |
| feed_dict={ |
| inputs[0]: input_value, |
| sequence_length: [2, 3] |
| }) |
| |
| # outputs are fully calculated for t = 0, 1 |
| for v in dynamic_values[:2]: |
| self.assertAllClose(v, input_value + 1.0) |
| |
| # outputs at t = 2 are zero for entry 0, calculated for entry 1 |
| self.assertAllClose(dynamic_values[2], |
| np.vstack((np.zeros((input_size)), |
| 1.0 + input_value[1, :]))) |
| |
| # outputs at t = 3+ are zero |
| for v in dynamic_values[3:]: |
| self.assertAllEqual(v, np.zeros_like(input_value)) |
| |
| # the final states are: |
| # entry 0: the values from the calculation at t=1 |
| # entry 1: the values from the calculation at t=2 |
| self.assertAllEqual(dynamic_state_value[0], |
| np.vstack((1.0 * (1 + 1) * np.ones((input_size)), |
| 1.0 * (2 + 1) * np.ones((input_size))))) |
| |
| def _testScope(self, factory, prefix="prefix", use_outer_scope=True): |
| with self.session(use_gpu=True, graph=ops.Graph()): |
| if use_outer_scope: |
| with variable_scope.variable_scope(prefix) as scope: |
| factory(scope) |
| else: |
| factory(prefix) |
| |
| # check that all the variables names starts |
| # with the proper scope. |
| variables_lib.global_variables_initializer() |
| all_vars = variables_lib.global_variables() |
| prefix = prefix or "rnn" |
| scope_vars = [v for v in all_vars if v.name.startswith(prefix + "/")] |
| tf_logging.info("RNN with scope: %s (%s)" % |
| (prefix, "scope" if use_outer_scope else "str")) |
| for v in scope_vars: |
| tf_logging.info(v.name) |
| self.assertEqual(len(scope_vars), len(all_vars)) |
| |
| @test_util.run_v1_only("b/124229375") |
| def testScope(self): |
| |
| def factory(scope): |
| cell = Plus1RNNCell() |
| batch_size = 2 |
| input_size = 5 |
| max_length = 8 # unrolled up to this length |
| inputs = max_length * [ |
| array_ops.placeholder(dtypes.float32, shape=(batch_size, input_size)) |
| ] |
| return rnn.static_rnn(cell, inputs, dtype=dtypes.float32, scope=scope) |
| |
| self._testScope(factory, use_outer_scope=True) |
| self._testScope(factory, use_outer_scope=False) |
| self._testScope(factory, prefix=None, use_outer_scope=False) |
| |
| |
| class LSTMTest(test.TestCase): |
| |
| def setUp(self): |
| self._seed = 23489 |
| np.random.seed(self._seed) |
| |
| def testDType(self): |
| # Test case for GitHub issue 16228 |
| # Not passing dtype in constructor results in default float32 |
| lstm = rnn_cell.LSTMCell(10) |
| input_tensor = array_ops.ones([10, 50]) |
| lstm.build(input_tensor.get_shape()) |
| self.assertEqual(lstm._bias.dtype.base_dtype, dtypes.float32) |
| |
| # Explicitly pass dtype in constructor |
| for dtype in [dtypes.float16, dtypes.float32, dtypes.float64]: |
| lstm = rnn_cell.LSTMCell(10, dtype=dtype) |
| input_tensor = array_ops.ones([10, 50]) |
| lstm.build(input_tensor.get_shape()) |
| self.assertEqual(lstm._bias.dtype.base_dtype, dtype) |
| |
| @test_util.run_v1_only("b/124229375") |
| def testNoProjNoSharding(self): |
| num_units = 3 |
| input_size = 5 |
| batch_size = 2 |
| max_length = 8 |
| with self.session(use_gpu=True, graph=ops.Graph()) as sess: |
| initializer = init_ops.random_uniform_initializer( |
| -0.01, 0.01, seed=self._seed) |
| cell = rnn_cell.LSTMCell( |
| num_units, initializer=initializer, state_is_tuple=False) |
| inputs = max_length * [ |
| array_ops.placeholder(dtypes.float32, shape=(batch_size, input_size)) |
| ] |
| outputs, _ = rnn.static_rnn(cell, inputs, dtype=dtypes.float32) |
| self.assertEqual(len(outputs), len(inputs)) |
| for out in outputs: |
| self.assertEqual(out.get_shape().as_list(), [batch_size, num_units]) |
| |
| variables_lib.global_variables_initializer().run() |
| input_value = np.random.randn(batch_size, input_size) |
| sess.run(outputs, feed_dict={inputs[0]: input_value}) |
| |
| @test_util.run_v1_only("b/124229375") |
| def testCellClipping(self): |
| num_units = 3 |
| input_size = 5 |
| batch_size = 2 |
| max_length = 8 |
| with self.session(use_gpu=True, graph=ops.Graph()) as sess: |
| initializer = init_ops.random_uniform_initializer( |
| -0.01, 0.01, seed=self._seed) |
| cell = rnn_cell.LSTMCell( |
| num_units, |
| use_peepholes=True, |
| cell_clip=0.0, |
| initializer=initializer, |
| state_is_tuple=False) |
| inputs = max_length * [ |
| array_ops.placeholder(dtypes.float32, shape=(batch_size, input_size)) |
| ] |
| outputs, _ = rnn.static_rnn(cell, inputs, dtype=dtypes.float32) |
| self.assertEqual(len(outputs), len(inputs)) |
| for out in outputs: |
| self.assertEqual(out.get_shape().as_list(), [batch_size, num_units]) |
| |
| variables_lib.global_variables_initializer().run() |
| input_value = np.random.randn(batch_size, input_size) |
| values = sess.run(outputs, feed_dict={inputs[0]: input_value}) |
| |
| for value in values: |
| # if cell c is clipped to 0, tanh(c) = 0 => m==0 |
| self.assertAllEqual(value, np.zeros((batch_size, num_units))) |
| |
| @test_util.run_v1_only("b/124229375") |
| def testNoProjNoShardingSimpleStateSaver(self): |
| num_units = 3 |
| input_size = 5 |
| batch_size = 2 |
| max_length = 8 |
| with self.session(use_gpu=True, graph=ops.Graph()) as sess: |
| initializer = init_ops.random_uniform_initializer( |
| -0.01, 0.01, seed=self._seed) |
| state_saver = TestStateSaver(batch_size, 2 * num_units) |
| cell = rnn_cell.LSTMCell( |
| num_units, |
| use_peepholes=False, |
| initializer=initializer, |
| state_is_tuple=False) |
| inputs = max_length * [ |
| array_ops.placeholder(dtypes.float32, shape=(batch_size, input_size)) |
| ] |
| with variable_scope.variable_scope("share_scope"): |
| outputs, state = rnn.static_state_saving_rnn( |
| cell, inputs, state_saver=state_saver, state_name="save_lstm") |
| self.assertEqual(len(outputs), len(inputs)) |
| for out in outputs: |
| self.assertEqual(out.get_shape().as_list(), [batch_size, num_units]) |
| |
| variables_lib.global_variables_initializer().run() |
| input_value = np.random.randn(batch_size, input_size) |
| (last_state_value, saved_state_value) = sess.run( |
| [state, state_saver.saved_state["save_lstm"]], |
| feed_dict={ |
| inputs[0]: input_value |
| }) |
| self.assertAllEqual(last_state_value, saved_state_value) |
| |
| @test_util.run_v1_only("b/124229375") |
| def testNoProjNoShardingTupleStateSaver(self): |
| num_units = 3 |
| input_size = 5 |
| batch_size = 2 |
| max_length = 8 |
| with self.session(graph=ops.Graph()) as sess: |
| initializer = init_ops.random_uniform_initializer( |
| -0.01, 0.01, seed=self._seed) |
| state_saver = TestStateSaver(batch_size, num_units) |
| cell = rnn_cell.LSTMCell( |
| num_units, |
| use_peepholes=False, |
| initializer=initializer, |
| state_is_tuple=True) |
| inputs = max_length * [ |
| array_ops.placeholder(dtypes.float32, shape=(batch_size, input_size)) |
| ] |
| with variable_scope.variable_scope("share_scope"): |
| outputs, state = rnn.static_state_saving_rnn( |
| cell, inputs, state_saver=state_saver, state_name=("c", "m")) |
| self.assertEqual(len(outputs), len(inputs)) |
| for out in outputs: |
| self.assertEqual(out.get_shape().as_list(), [batch_size, num_units]) |
| |
| variables_lib.global_variables_initializer().run() |
| input_value = np.random.randn(batch_size, input_size) |
| last_and_saved_states = sess.run( |
| state + (state_saver.saved_state["c"], state_saver.saved_state["m"]), |
| feed_dict={ |
| inputs[0]: input_value |
| }) |
| self.assertEqual(4, len(last_and_saved_states)) |
| self.assertAllEqual(last_and_saved_states[:2], last_and_saved_states[2:]) |
| |
| @test_util.run_v1_only("b/124229375") |
| def testNoProjNoShardingNestedTupleStateSaver(self): |
| num_units = 3 |
| input_size = 5 |
| batch_size = 2 |
| max_length = 8 |
| with self.session(graph=ops.Graph()) as sess: |
| initializer = init_ops.random_uniform_initializer( |
| -0.01, 0.01, seed=self._seed) |
| state_saver = TestStateSaver( |
| batch_size, { |
| "c0": num_units, |
| "m0": num_units, |
| "c1": num_units + 1, |
| "m1": num_units + 1, |
| "c2": num_units + 2, |
| "m2": num_units + 2, |
| "c3": num_units + 3, |
| "m3": num_units + 3 |
| }) |
| |
| def _cell(i): |
| return rnn_cell.LSTMCell( |
| num_units + i, |
| use_peepholes=False, |
| initializer=initializer, |
| state_is_tuple=True) |
| |
| # This creates a state tuple which has 4 sub-tuples of length 2 each. |
| cell = rnn_cell.MultiRNNCell( |
| [_cell(i) for i in range(4)], state_is_tuple=True) |
| |
| self.assertEqual(len(cell.state_size), 4) |
| for i in range(4): |
| self.assertEqual(len(cell.state_size[i]), 2) |
| |
| inputs = max_length * [ |
| array_ops.placeholder(dtypes.float32, shape=(batch_size, input_size)) |
| ] |
| |
| state_names = (("c0", "m0"), ("c1", "m1"), ("c2", "m2"), ("c3", "m3")) |
| with variable_scope.variable_scope("share_scope"): |
| outputs, state = rnn.static_state_saving_rnn( |
| cell, inputs, state_saver=state_saver, state_name=state_names) |
| self.assertEqual(len(outputs), len(inputs)) |
| |
| # Final output comes from _cell(3) which has state size num_units + 3 |
| for out in outputs: |
| self.assertEqual(out.get_shape().as_list(), [batch_size, num_units + 3]) |
| |
| variables_lib.global_variables_initializer().run() |
| input_value = np.random.randn(batch_size, input_size) |
| last_states = sess.run( |
| list(nest.flatten(state)), feed_dict={ |
| inputs[0]: input_value |
| }) |
| saved_states = sess.run( |
| list(state_saver.saved_state.values()), |
| feed_dict={ |
| inputs[0]: input_value |
| }) |
| self.assertEqual(8, len(last_states)) |
| self.assertEqual(8, len(saved_states)) |
| flat_state_names = nest.flatten(state_names) |
| named_saved_states = dict( |
| zip(state_saver.saved_state.keys(), saved_states)) |
| |
| for i in range(8): |
| self.assertAllEqual(last_states[i], |
| named_saved_states[flat_state_names[i]]) |
| |
| @test_util.run_v1_only("b/124229375") |
| def testProjNoSharding(self): |
| num_units = 3 |
| input_size = 5 |
| batch_size = 2 |
| num_proj = 4 |
| max_length = 8 |
| with self.session(use_gpu=True, graph=ops.Graph()) as sess: |
| initializer = init_ops.random_uniform_initializer( |
| -0.01, 0.01, seed=self._seed) |
| inputs = max_length * [ |
| array_ops.placeholder(dtypes.float32, shape=(None, input_size)) |
| ] |
| cell = rnn_cell.LSTMCell( |
| num_units, |
| use_peepholes=True, |
| num_proj=num_proj, |
| initializer=initializer, |
| state_is_tuple=False) |
| outputs, _ = rnn.static_rnn(cell, inputs, dtype=dtypes.float32) |
| self.assertEqual(len(outputs), len(inputs)) |
| |
| variables_lib.global_variables_initializer().run() |
| input_value = np.random.randn(batch_size, input_size) |
| sess.run(outputs, feed_dict={inputs[0]: input_value}) |
| |
| def _testStateTupleWithProjAndSequenceLength(self): |
| num_units = 3 |
| input_size = 5 |
| batch_size = 2 |
| num_proj = 4 |
| max_length = 8 |
| sequence_length = [4, 6] |
| with self.session(graph=ops.Graph()) as sess: |
| initializer = init_ops.random_uniform_initializer( |
| -0.01, 0.01, seed=self._seed) |
| inputs = max_length * [ |
| array_ops.placeholder(dtypes.float32, shape=(None, input_size)) |
| ] |
| cell_notuple = rnn_cell.LSTMCell( |
| num_units, |
| use_peepholes=True, |
| num_proj=num_proj, |
| initializer=initializer, |
| state_is_tuple=False) |
| cell_tuple = rnn_cell.LSTMCell( |
| num_units, |
| use_peepholes=True, |
| num_proj=num_proj, |
| initializer=initializer, |
| state_is_tuple=True) |
| with variable_scope.variable_scope("root") as scope: |
| outputs_notuple, state_notuple = rnn.static_rnn( |
| cell_notuple, |
| inputs, |
| dtype=dtypes.float32, |
| sequence_length=sequence_length, |
| scope=scope) |
| scope.reuse_variables() |
| # TODO(ebrevdo): For this test, we ensure values are identical and |
| # therefore the weights here are tied. In the future, we may consider |
| # making the state_is_tuple property mutable so we can avoid |
| # having to do this - especially if users ever need to reuse |
| # the parameters from different RNNCell instances. Right now, |
| # this seems an unrealistic use case except for testing. |
| cell_tuple._scope = cell_notuple._scope # pylint: disable=protected-access |
| outputs_tuple, state_tuple = rnn.static_rnn( |
| cell_tuple, |
| inputs, |
| dtype=dtypes.float32, |
| sequence_length=sequence_length, |
| scope=scope) |
| self.assertEqual(len(outputs_notuple), len(inputs)) |
| self.assertEqual(len(outputs_tuple), len(inputs)) |
| self.assertTrue(isinstance(state_tuple, tuple)) |
| self.assertTrue(isinstance(state_notuple, ops.Tensor)) |
| |
| variables_lib.global_variables_initializer().run() |
| input_value = np.random.randn(batch_size, input_size) |
| outputs_notuple_v = sess.run( |
| outputs_notuple, feed_dict={ |
| inputs[0]: input_value |
| }) |
| outputs_tuple_v = sess.run( |
| outputs_tuple, feed_dict={ |
| inputs[0]: input_value |
| }) |
| self.assertAllEqual(outputs_notuple_v, outputs_tuple_v) |
| |
| (state_notuple_v,) = sess.run( |
| (state_notuple,), feed_dict={ |
| inputs[0]: input_value |
| }) |
| state_tuple_v = sess.run(state_tuple, feed_dict={inputs[0]: input_value}) |
| self.assertAllEqual(state_notuple_v, np.hstack(state_tuple_v)) |
| |
| @test_util.run_v1_only("b/124229375") |
| def testProjSharding(self): |
| num_units = 3 |
| input_size = 5 |
| batch_size = 2 |
| num_proj = 4 |
| num_proj_shards = 3 |
| num_unit_shards = 2 |
| max_length = 8 |
| with self.session(use_gpu=True, graph=ops.Graph()) as sess: |
| initializer = init_ops.random_uniform_initializer( |
| -0.01, 0.01, seed=self._seed) |
| |
| inputs = max_length * [ |
| array_ops.placeholder(dtypes.float32, shape=(None, input_size)) |
| ] |
| |
| cell = rnn_cell.LSTMCell( |
| num_units, |
| use_peepholes=True, |
| num_proj=num_proj, |
| num_unit_shards=num_unit_shards, |
| num_proj_shards=num_proj_shards, |
| initializer=initializer, |
| state_is_tuple=False) |
| |
| outputs, _ = rnn.static_rnn(cell, inputs, dtype=dtypes.float32) |
| |
| self.assertEqual(len(outputs), len(inputs)) |
| |
| variables_lib.global_variables_initializer().run() |
| input_value = np.random.randn(batch_size, input_size) |
| sess.run(outputs, feed_dict={inputs[0]: input_value}) |
| |
| @test_util.run_v1_only("b/124229375") |
| def testDoubleInput(self): |
| num_units = 3 |
| input_size = 5 |
| batch_size = 2 |
| num_proj = 4 |
| num_proj_shards = 3 |
| num_unit_shards = 2 |
| max_length = 8 |
| with self.session(use_gpu=True, graph=ops.Graph()) as sess: |
| initializer = init_ops.random_uniform_initializer(-1, 1, seed=self._seed) |
| inputs = max_length * [ |
| array_ops.placeholder(dtypes.float64, shape=(None, input_size)) |
| ] |
| |
| cell = rnn_cell.LSTMCell( |
| num_units, |
| use_peepholes=True, |
| num_proj=num_proj, |
| num_unit_shards=num_unit_shards, |
| num_proj_shards=num_proj_shards, |
| initializer=initializer, |
| state_is_tuple=False) |
| |
| outputs, _ = rnn.static_rnn( |
| cell, |
| inputs, |
| initial_state=cell.zero_state(batch_size, dtypes.float64)) |
| |
| self.assertEqual(len(outputs), len(inputs)) |
| |
| variables_lib.global_variables_initializer().run() |
| input_value = np.asarray( |
| np.random.randn(batch_size, input_size), dtype=np.float64) |
| values = sess.run(outputs, feed_dict={inputs[0]: input_value}) |
| self.assertEqual(values[0].dtype, input_value.dtype) |
| |
| @test_util.run_v1_only("b/124229375") |
| def testShardNoShardEquivalentOutput(self): |
| num_units = 3 |
| input_size = 5 |
| batch_size = 2 |
| num_proj = 4 |
| num_proj_shards = 3 |
| num_unit_shards = 2 |
| max_length = 8 |
| with self.session(use_gpu=True, graph=ops.Graph()) as sess: |
| inputs = max_length * [ |
| array_ops.placeholder(dtypes.float32, shape=(None, input_size)) |
| ] |
| initializer = init_ops.constant_initializer(0.001) |
| |
| cell_noshard = rnn_cell.LSTMCell( |
| num_units, |
| num_proj=num_proj, |
| use_peepholes=True, |
| initializer=initializer, |
| num_unit_shards=num_unit_shards, |
| num_proj_shards=num_proj_shards, |
| state_is_tuple=False) |
| |
| cell_shard = rnn_cell.LSTMCell( |
| num_units, |
| use_peepholes=True, |
| initializer=initializer, |
| num_proj=num_proj, |
| state_is_tuple=False) |
| |
| with variable_scope.variable_scope("noshard_scope"): |
| outputs_noshard, state_noshard = rnn.static_rnn( |
| cell_noshard, inputs, dtype=dtypes.float32) |
| with variable_scope.variable_scope("shard_scope"): |
| outputs_shard, state_shard = rnn.static_rnn( |
| cell_shard, inputs, dtype=dtypes.float32) |
| |
| self.assertEqual(len(outputs_noshard), len(inputs)) |
| self.assertEqual(len(outputs_noshard), len(outputs_shard)) |
| |
| variables_lib.global_variables_initializer().run() |
| input_value = np.random.randn(batch_size, input_size) |
| feeds = dict((x, input_value) for x in inputs) |
| values_noshard = sess.run(outputs_noshard, feed_dict=feeds) |
| values_shard = sess.run(outputs_shard, feed_dict=feeds) |
| state_values_noshard = sess.run([state_noshard], feed_dict=feeds) |
| state_values_shard = sess.run([state_shard], feed_dict=feeds) |
| self.assertEqual(len(values_noshard), len(values_shard)) |
| self.assertEqual(len(state_values_noshard), len(state_values_shard)) |
| for (v_noshard, v_shard) in zip(values_noshard, values_shard): |
| self.assertAllClose(v_noshard, v_shard, atol=1e-3) |
| for (s_noshard, s_shard) in zip(state_values_noshard, state_values_shard): |
| self.assertAllClose(s_noshard, s_shard, atol=1e-3) |
| |
| @test_util.run_v1_only("b/124229375") |
| def testDoubleInputWithDropoutAndDynamicCalculation(self): |
| """Smoke test for using LSTM with doubles, dropout, dynamic calculation.""" |
| |
| num_units = 3 |
| input_size = 5 |
| batch_size = 2 |
| num_proj = 4 |
| num_proj_shards = 3 |
| num_unit_shards = 2 |
| max_length = 8 |
| with self.session(use_gpu=True, graph=ops.Graph()) as sess: |
| sequence_length = array_ops.placeholder(dtypes.int64) |
| initializer = init_ops.random_uniform_initializer( |
| -0.01, 0.01, seed=self._seed) |
| inputs = max_length * [ |
| array_ops.placeholder(dtypes.float64, shape=(None, input_size)) |
| ] |
| |
| cell = rnn_cell.LSTMCell( |
| num_units, |
| use_peepholes=True, |
| num_proj=num_proj, |
| num_unit_shards=num_unit_shards, |
| num_proj_shards=num_proj_shards, |
| initializer=initializer, |
| state_is_tuple=False) |
| dropout_cell = rnn_cell.DropoutWrapper(cell, 0.5, seed=0) |
| |
| outputs, state = rnn.static_rnn( |
| dropout_cell, |
| inputs, |
| sequence_length=sequence_length, |
| initial_state=cell.zero_state(batch_size, dtypes.float64)) |
| |
| self.assertEqual(len(outputs), len(inputs)) |
| |
| variables_lib.global_variables_initializer().run(feed_dict={ |
| sequence_length: [2, 3] |
| }) |
| input_value = np.asarray( |
| np.random.randn(batch_size, input_size), dtype=np.float64) |
| values = sess.run( |
| outputs, feed_dict={ |
| inputs[0]: input_value, |
| sequence_length: [2, 3] |
| }) |
| state_value = sess.run( |
| [state], feed_dict={ |
| inputs[0]: input_value, |
| sequence_length: [2, 3] |
| }) |
| self.assertEqual(values[0].dtype, input_value.dtype) |
| self.assertEqual(state_value[0].dtype, input_value.dtype) |
| |
| @test_util.run_v1_only("b/124229375") |
| def testSharingWeightsWithReuse(self): |
| num_units = 3 |
| input_size = 5 |
| batch_size = 2 |
| num_proj = 4 |
| max_length = 8 |
| with self.session(graph=ops.Graph()) as sess: |
| initializer = init_ops.random_uniform_initializer(-1, 1, seed=self._seed) |
| initializer_d = init_ops.random_uniform_initializer( |
| -1, 1, seed=self._seed + 1) |
| inputs = max_length * [ |
| array_ops.placeholder(dtypes.float32, shape=(None, input_size)) |
| ] |
| cell = rnn_cell.LSTMCell( |
| num_units, |
| use_peepholes=True, |
| num_proj=num_proj, |
| initializer=initializer, |
| state_is_tuple=False) |
| cell_d = rnn_cell.LSTMCell( |
| num_units, |
| use_peepholes=True, |
| num_proj=num_proj, |
| initializer=initializer_d, |
| state_is_tuple=False) |
| |
| with variable_scope.variable_scope("share_scope"): |
| outputs0, _ = rnn.static_rnn(cell, inputs, dtype=dtypes.float32) |
| with variable_scope.variable_scope("share_scope", reuse=True): |
| outputs1, _ = rnn.static_rnn(cell, inputs, dtype=dtypes.float32) |
| with variable_scope.variable_scope("diff_scope"): |
| outputs2, _ = rnn.static_rnn(cell_d, inputs, dtype=dtypes.float32) |
| |
| variables_lib.global_variables_initializer().run() |
| input_value = np.random.randn(batch_size, input_size) |
| output_values = sess.run( |
| outputs0 + outputs1 + outputs2, feed_dict={ |
| inputs[0]: input_value |
| }) |
| outputs0_values = output_values[:max_length] |
| outputs1_values = output_values[max_length:2 * max_length] |
| outputs2_values = output_values[2 * max_length:] |
| self.assertEqual(len(outputs0_values), len(outputs1_values)) |
| self.assertEqual(len(outputs0_values), len(outputs2_values)) |
| for o1, o2, o3 in zip(outputs0_values, outputs1_values, outputs2_values): |
| # Same weights used by both RNNs so outputs should be the same. |
| self.assertAllEqual(o1, o2) |
| # Different weights used so outputs should be different. |
| self.assertTrue(np.linalg.norm(o1 - o3) > 1e-6) |
| |
| @test_util.run_v1_only("b/124229375") |
| def testSharingWeightsWithDifferentNamescope(self): |
| num_units = 3 |
| input_size = 5 |
| batch_size = 2 |
| num_proj = 4 |
| max_length = 8 |
| with self.session(graph=ops.Graph()) as sess: |
| initializer = init_ops.random_uniform_initializer(-1, 1, seed=self._seed) |
| inputs = max_length * [ |
| array_ops.placeholder(dtypes.float32, shape=(None, input_size)) |
| ] |
| cell = rnn_cell.LSTMCell( |
| num_units, |
| use_peepholes=True, |
| num_proj=num_proj, |
| initializer=initializer, |
| state_is_tuple=False) |
| |
| with ops.name_scope("scope0"): |
| with variable_scope.variable_scope("share_scope"): |
| outputs0, _ = rnn.static_rnn(cell, inputs, dtype=dtypes.float32) |
| with ops.name_scope("scope1"): |
| with variable_scope.variable_scope("share_scope", reuse=True): |
| outputs1, _ = rnn.static_rnn(cell, inputs, dtype=dtypes.float32) |
| |
| variables_lib.global_variables_initializer().run() |
| input_value = np.random.randn(batch_size, input_size) |
| output_values = sess.run( |
| outputs0 + outputs1, feed_dict={ |
| inputs[0]: input_value |
| }) |
| outputs0_values = output_values[:max_length] |
| outputs1_values = output_values[max_length:] |
| self.assertEqual(len(outputs0_values), len(outputs1_values)) |
| for out0, out1 in zip(outputs0_values, outputs1_values): |
| self.assertAllEqual(out0, out1) |
| |
| @test_util.run_v1_only("b/124229375") |
| def testDynamicRNNAllowsUnknownTimeDimension(self): |
| inputs = array_ops.placeholder(dtypes.float32, shape=[1, None, 20]) |
| cell = rnn_cell.GRUCell(30) |
| # Smoke test, this should not raise an error |
| rnn.dynamic_rnn(cell, inputs, dtype=dtypes.float32) |
| |
| @test_util.run_in_graph_and_eager_modes |
| def testDynamicRNNWithTupleStates(self): |
| num_units = 3 |
| input_size = 5 |
| batch_size = 2 |
| num_proj = 4 |
| max_length = 8 |
| sequence_length = [4, 6] |
| in_graph_mode = not context.executing_eagerly() |
| with self.session(graph=ops.Graph()) as sess: |
| initializer = init_ops.random_uniform_initializer( |
| -0.01, 0.01, seed=self._seed) |
| if in_graph_mode: |
| inputs = max_length * [ |
| array_ops.placeholder(dtypes.float32, shape=(None, input_size)) |
| ] |
| else: |
| inputs = max_length * [ |
| constant_op.constant( |
| np.random.randn(batch_size, input_size).astype(np.float32)) |
| ] |
| inputs_c = array_ops.stack(inputs) |
| cell = rnn_cell.LSTMCell( |
| num_units, |
| use_peepholes=True, |
| num_proj=num_proj, |
| initializer=initializer, |
| state_is_tuple=True) |
| with variable_scope.variable_scope("root") as scope: |
| outputs_static, state_static = rnn.static_rnn( |
| cell, |
| inputs, |
| dtype=dtypes.float32, |
| sequence_length=sequence_length, |
| scope=scope) |
| scope.reuse_variables() |
| outputs_dynamic, state_dynamic = rnn.dynamic_rnn( |
| cell, |
| inputs_c, |
| dtype=dtypes.float32, |
| time_major=True, |
| sequence_length=sequence_length, |
| scope=scope) |
| self.assertTrue(isinstance(state_static, rnn_cell.LSTMStateTuple)) |
| self.assertTrue(isinstance(state_dynamic, rnn_cell.LSTMStateTuple)) |
| self.assertIs(state_static[0], state_static.c) |
| self.assertIs(state_static[1], state_static.h) |
| self.assertIs(state_dynamic[0], state_dynamic.c) |
| self.assertIs(state_dynamic[1], state_dynamic.h) |
| |
| if in_graph_mode: |
| variables_lib.global_variables_initializer().run() |
| input_value = np.random.randn(batch_size, input_size) |
| outputs_static = sess.run( |
| outputs_static, feed_dict={ |
| inputs[0]: input_value |
| }) |
| outputs_dynamic = sess.run( |
| outputs_dynamic, feed_dict={ |
| inputs[0]: input_value |
| }) |
| state_static = sess.run( |
| state_static, feed_dict={ |
| inputs[0]: input_value |
| }) |
| state_dynamic = sess.run( |
| state_dynamic, feed_dict={ |
| inputs[0]: input_value |
| }) |
| |
| if in_graph_mode: |
| self.assertAllEqual(outputs_static, outputs_dynamic) |
| else: |
| self.assertAllEqual(array_ops.stack(outputs_static), outputs_dynamic) |
| self.assertAllEqual(np.hstack(state_static), np.hstack(state_dynamic)) |
| |
| @test_util.run_in_graph_and_eager_modes |
| def testDynamicRNNWithNestedTupleStates(self): |
| num_units = 3 |
| input_size = 5 |
| batch_size = 2 |
| num_proj = 4 |
| max_length = 8 |
| sequence_length = [4, 6] |
| in_graph_mode = not context.executing_eagerly() |
| with self.session(graph=ops.Graph()) as sess: |
| initializer = init_ops.random_uniform_initializer( |
| -0.01, 0.01, seed=self._seed) |
| if in_graph_mode: |
| inputs = max_length * [ |
| array_ops.placeholder(dtypes.float32, shape=(None, input_size)) |
| ] |
| else: |
| inputs = max_length * [ |
| constant_op.constant( |
| np.random.randn(batch_size, input_size).astype(np.float32)) |
| ] |
| inputs_c = array_ops.stack(inputs) |
| |
| def _cell(i): |
| return rnn_cell.LSTMCell( |
| num_units + i, |
| use_peepholes=True, |
| num_proj=num_proj + i, |
| initializer=initializer, |
| state_is_tuple=True) |
| |
| # This creates a state tuple which has 4 sub-tuples of length 2 each. |
| cell = rnn_cell.MultiRNNCell( |
| [_cell(i) for i in range(4)], state_is_tuple=True) |
| |
| self.assertEqual(len(cell.state_size), 4) |
| for i in range(4): |
| self.assertEqual(len(cell.state_size[i]), 2) |
| |
| test_zero = cell.zero_state(1, dtypes.float32) |
| self.assertEqual(len(test_zero), 4) |
| for i in range(4): |
| self.assertEqual(test_zero[i][0].get_shape()[1], cell.state_size[i][0]) |
| self.assertEqual(test_zero[i][1].get_shape()[1], cell.state_size[i][1]) |
| |
| with variable_scope.variable_scope("root") as scope: |
| outputs_static, state_static = rnn.static_rnn( |
| cell, |
| inputs, |
| dtype=dtypes.float32, |
| sequence_length=sequence_length, |
| scope=scope) |
| scope.reuse_variables() |
| outputs_dynamic, state_dynamic = rnn.dynamic_rnn( |
| cell, |
| inputs_c, |
| dtype=dtypes.float32, |
| time_major=True, |
| sequence_length=sequence_length, |
| scope=scope) |
| |
| if in_graph_mode: |
| input_value = np.random.randn(batch_size, input_size) |
| variables_lib.global_variables_initializer().run() |
| outputs_static = sess.run( |
| outputs_static, feed_dict={ |
| inputs[0]: input_value |
| }) |
| outputs_dynamic = sess.run( |
| outputs_dynamic, feed_dict={ |
| inputs[0]: input_value |
| }) |
| state_static = sess.run( |
| nest.flatten(state_static), feed_dict={ |
| inputs[0]: input_value |
| }) |
| state_dynamic = sess.run( |
| nest.flatten(state_dynamic), feed_dict={ |
| inputs[0]: input_value |
| }) |
| |
| if in_graph_mode: |
| self.assertAllEqual(outputs_static, outputs_dynamic) |
| else: |
| self.assertAllEqual(array_ops.stack(outputs_static), outputs_dynamic) |
| state_static = nest.flatten(state_static) |
| state_dynamic = nest.flatten(state_dynamic) |
| self.assertAllEqual(np.hstack(state_static), np.hstack(state_dynamic)) |
| |
| def _testDynamicEquivalentToStaticRNN(self, use_sequence_length): |
| time_steps = 8 |
| num_units = 3 |
| num_proj = 4 |
| input_size = 5 |
| batch_size = 2 |
| |
| input_values = np.random.randn(time_steps, batch_size, input_size).astype( |
| np.float32) |
| |
| if use_sequence_length: |
| sequence_length = np.random.randint(0, time_steps, size=batch_size) |
| else: |
| sequence_length = None |
| |
| in_graph_mode = not context.executing_eagerly() |
| |
| # TODO(b/68017812): Eager ignores operation seeds, so we need to create a |
| # single cell and reuse it across the static and dynamic RNNs. Remove this |
| # special case once is fixed. |
| if not in_graph_mode: |
| initializer = init_ops.random_uniform_initializer( |
| -0.01, 0.01, seed=self._seed) |
| cell = rnn_cell.LSTMCell( |
| num_units, |
| use_peepholes=True, |
| initializer=initializer, |
| num_proj=num_proj, |
| state_is_tuple=False) |
| |
| ########### Step 1: Run static graph and generate readouts |
| with self.session(use_gpu=True, graph=ops.Graph()) as sess: |
| if in_graph_mode: |
| concat_inputs = array_ops.placeholder( |
| dtypes.float32, shape=(time_steps, batch_size, input_size)) |
| else: |
| concat_inputs = constant_op.constant(input_values) |
| inputs = array_ops.unstack(concat_inputs) |
| initializer = init_ops.random_uniform_initializer( |
| -0.01, 0.01, seed=self._seed) |
| |
| # TODO(akshayka): Remove special case once b/68017812 is fixed. |
| if in_graph_mode: |
| cell = rnn_cell.LSTMCell( |
| num_units, |
| use_peepholes=True, |
| initializer=initializer, |
| num_proj=num_proj, |
| state_is_tuple=False) |
| |
| with variable_scope.variable_scope("dynamic_scope"): |
| outputs_static, state_static = rnn.static_rnn( |
| cell, inputs, sequence_length=sequence_length, dtype=dtypes.float32) |
| |
| if in_graph_mode: |
| # Generate gradients and run sessions to obtain outputs |
| feeds = {concat_inputs: input_values} |
| # Initialize |
| variables_lib.global_variables_initializer().run(feed_dict=feeds) |
| # Generate gradients of sum of outputs w.r.t. inputs |
| static_gradients = gradients_impl.gradients( |
| outputs_static + [state_static], [concat_inputs]) |
| # Generate gradients of individual outputs w.r.t. inputs |
| static_individual_gradients = nest.flatten([ |
| gradients_impl.gradients(y, [concat_inputs]) |
| for y in [outputs_static[0], outputs_static[-1], state_static] |
| ]) |
| # Generate gradients of individual variables w.r.t. inputs |
| trainable_variables = ops.get_collection( |
| ops.GraphKeys.TRAINABLE_VARIABLES) |
| assert len(trainable_variables) > 1, ( |
| "Count of trainable variables: %d" % len(trainable_variables)) |
| # pylint: disable=bad-builtin |
| static_individual_variable_gradients = nest.flatten([ |
| gradients_impl.gradients(y, trainable_variables) |
| for y in [outputs_static[0], outputs_static[-1], state_static] |
| ]) |
| # Test forward pass |
| values_static = sess.run(outputs_static, feed_dict=feeds) |
| (state_value_static,) = sess.run((state_static,), feed_dict=feeds) |
| |
| # Test gradients to inputs and variables w.r.t. outputs & final state |
| static_grad_values = sess.run(static_gradients, feed_dict=feeds) |
| |
| static_individual_grad_values = sess.run( |
| static_individual_gradients, feed_dict=feeds) |
| |
| static_individual_var_grad_values = sess.run( |
| static_individual_variable_gradients, feed_dict=feeds) |
| |
| ########## Step 2: Run dynamic graph and generate readouts |
| with self.session(use_gpu=True, graph=ops.Graph()) as sess: |
| if in_graph_mode: |
| concat_inputs = array_ops.placeholder( |
| dtypes.float32, shape=(time_steps, batch_size, input_size)) |
| else: |
| concat_inputs = constant_op.constant(input_values) |
| initializer = init_ops.random_uniform_initializer( |
| -0.01, 0.01, seed=self._seed) |
| |
| # TODO(akshayka): Remove this special case once b/68017812 is |
| # fixed. |
| if in_graph_mode: |
| cell = rnn_cell.LSTMCell( |
| num_units, |
| use_peepholes=True, |
| initializer=initializer, |
| num_proj=num_proj, |
| state_is_tuple=False) |
| |
| with variable_scope.variable_scope("dynamic_scope"): |
| outputs_dynamic, state_dynamic = rnn.dynamic_rnn( |
| cell, |
| inputs=concat_inputs, |
| sequence_length=sequence_length, |
| time_major=True, |
| dtype=dtypes.float32) |
| split_outputs_dynamic = array_ops.unstack(outputs_dynamic, time_steps) |
| |
| if in_graph_mode: |
| feeds = {concat_inputs: input_values} |
| |
| # Initialize |
| variables_lib.global_variables_initializer().run(feed_dict=feeds) |
| |
| # Generate gradients of sum of outputs w.r.t. inputs |
| dynamic_gradients = gradients_impl.gradients( |
| split_outputs_dynamic + [state_dynamic], [concat_inputs]) |
| |
| # Generate gradients of several individual outputs w.r.t. inputs |
| dynamic_individual_gradients = nest.flatten([ |
| gradients_impl.gradients(y, [concat_inputs]) |
| for y in [ |
| split_outputs_dynamic[0], split_outputs_dynamic[-1], |
| state_dynamic |
| ] |
| ]) |
| |
| # Generate gradients of individual variables w.r.t. inputs |
| trainable_variables = ops.get_collection( |
| ops.GraphKeys.TRAINABLE_VARIABLES) |
| assert len(trainable_variables) > 1, ( |
| "Count of trainable variables: %d" % len(trainable_variables)) |
| dynamic_individual_variable_gradients = nest.flatten([ |
| gradients_impl.gradients(y, trainable_variables) |
| for y in [ |
| split_outputs_dynamic[0], split_outputs_dynamic[-1], |
| state_dynamic |
| ] |
| ]) |
| |
| # Test forward pass |
| values_dynamic = sess.run(split_outputs_dynamic, feed_dict=feeds) |
| (state_value_dynamic,) = sess.run((state_dynamic,), feed_dict=feeds) |
| |
| # Test gradients to inputs and variables w.r.t. outputs & final state |
| dynamic_grad_values = sess.run(dynamic_gradients, feed_dict=feeds) |
| |
| dynamic_individual_grad_values = sess.run( |
| dynamic_individual_gradients, feed_dict=feeds) |
| |
| dynamic_individual_var_grad_values = sess.run( |
| dynamic_individual_variable_gradients, feed_dict=feeds) |
| |
| ######### Step 3: Comparisons |
| if not in_graph_mode: |
| values_static = outputs_static |
| values_dynamic = split_outputs_dynamic |
| state_value_static = state_static |
| state_value_dynamic = state_dynamic |
| |
| self.assertEqual(len(values_static), len(values_dynamic)) |
| for (value_static, value_dynamic) in zip(values_static, values_dynamic): |
| self.assertAllClose(value_static, value_dynamic) |
| self.assertAllClose(state_value_static, state_value_dynamic) |
| |
| if in_graph_mode: |
| |
| self.assertAllClose(static_grad_values, dynamic_grad_values) |
| |
| self.assertEqual( |
| len(static_individual_grad_values), |
| len(dynamic_individual_grad_values)) |
| self.assertEqual( |
| len(static_individual_var_grad_values), |
| len(dynamic_individual_var_grad_values)) |
| |
| for i, (a, b) in enumerate( |
| zip(static_individual_grad_values, dynamic_individual_grad_values)): |
| tf_logging.info("Comparing individual gradients iteration %d" % i) |
| self.assertAllClose(a, b) |
| |
| for i, (a, b) in enumerate( |
| zip(static_individual_var_grad_values, |
| dynamic_individual_var_grad_values)): |
| tf_logging.info( |
| "Comparing individual variable gradients iteration %d" % i) |
| self.assertAllClose(a, b) |
| |
| @test_util.run_in_graph_and_eager_modes |
| def testDynamicEquivalentToStaticRNN(self): |
| self._testDynamicEquivalentToStaticRNN(use_sequence_length=False) |
| |
| @test_util.run_in_graph_and_eager_modes |
| def testDynamicEquivalentToStaticRNNWithSequenceLength(self): |
| self._testDynamicEquivalentToStaticRNN(use_sequence_length=True) |
| |
| |
| class BidirectionalRNNTest(test.TestCase): |
| |
| def setUp(self): |
| self._seed = 23489 |
| np.random.seed(self._seed) |
| |
| def _createBidirectionalRNN(self, use_shape, use_sequence_length, scope=None): |
| num_units = 3 |
| input_size = 5 |
| batch_size = 2 |
| max_length = 8 |
| |
| initializer = init_ops.random_uniform_initializer( |
| -0.01, 0.01, seed=self._seed) |
| sequence_length = array_ops.placeholder( |
| dtypes.int64) if use_sequence_length else None |
| cell_fw = rnn_cell.LSTMCell( |
| num_units, input_size, initializer=initializer, state_is_tuple=False) |
| cell_bw = rnn_cell.LSTMCell( |
| num_units, input_size, initializer=initializer, state_is_tuple=False) |
| inputs = max_length * [ |
| array_ops.placeholder( |
| dtypes.float32, |
| shape=(batch_size, input_size) if use_shape else (None, input_size)) |
| ] |
| outputs, state_fw, state_bw = rnn.static_bidirectional_rnn( |
| cell_fw, |
| cell_bw, |
| inputs, |
| dtype=dtypes.float32, |
| sequence_length=sequence_length, |
| scope=scope) |
| self.assertEqual(len(outputs), len(inputs)) |
| for out in outputs: |
| self.assertEqual(out.get_shape().as_list(), |
| [batch_size if use_shape else None, 2 * num_units]) |
| |
| input_value = np.random.randn(batch_size, input_size) |
| outputs = array_ops.stack(outputs) |
| |
| return input_value, inputs, outputs, state_fw, state_bw, sequence_length |
| |
| def _testBidirectionalRNN(self, use_shape): |
| with self.session(use_gpu=True, graph=ops.Graph()) as sess: |
| input_value, inputs, outputs, state_fw, state_bw, sequence_length = ( |
| self._createBidirectionalRNN(use_shape, True)) |
| variables_lib.global_variables_initializer().run() |
| # Run with pre-specified sequence length of 2, 3 |
| out, s_fw, s_bw = sess.run( |
| [outputs, state_fw, state_bw], |
| feed_dict={ |
| inputs[0]: input_value, |
| sequence_length: [2, 3] |
| }) |
| |
| # Since the forward and backward LSTM cells were initialized with the |
| # same parameters, the forward and backward output has to be the same, |
| # but reversed in time. The format is output[time][batch][depth], and |
| # due to depth concatenation (as num_units=3 for both RNNs): |
| # - forward output: out[][][depth] for 0 <= depth < 3 |
| # - backward output: out[][][depth] for 4 <= depth < 6 |
| # |
| # First sequence in batch is length=2 |
| # Check that the time=0 forward output is equal to time=1 backward output |
| self.assertAllClose(out[0][0][0], out[1][0][3]) |
| self.assertAllClose(out[0][0][1], out[1][0][4]) |
| self.assertAllClose(out[0][0][2], out[1][0][5]) |
| # Check that the time=1 forward output is equal to time=0 backward output |
| self.assertAllClose(out[1][0][0], out[0][0][3]) |
| self.assertAllClose(out[1][0][1], out[0][0][4]) |
| self.assertAllClose(out[1][0][2], out[0][0][5]) |
| |
| # Second sequence in batch is length=3 |
| # Check that the time=0 forward output is equal to time=2 backward output |
| self.assertAllClose(out[0][1][0], out[2][1][3]) |
| self.assertAllClose(out[0][1][1], out[2][1][4]) |
| self.assertAllClose(out[0][1][2], out[2][1][5]) |
| # Check that the time=1 forward output is equal to time=1 backward output |
| self.assertAllClose(out[1][1][0], out[1][1][3]) |
| self.assertAllClose(out[1][1][1], out[1][1][4]) |
| self.assertAllClose(out[1][1][2], out[1][1][5]) |
| # Check that the time=2 forward output is equal to time=0 backward output |
| self.assertAllClose(out[2][1][0], out[0][1][3]) |
| self.assertAllClose(out[2][1][1], out[0][1][4]) |
| self.assertAllClose(out[2][1][2], out[0][1][5]) |
| # Via the reasoning above, the forward and backward final state should be |
| # exactly the same |
| self.assertAllClose(s_fw, s_bw) |
| |
| def _testBidirectionalRNNWithoutSequenceLength(self, use_shape): |
| with self.session(use_gpu=True, graph=ops.Graph()) as sess: |
| input_value, inputs, outputs, state_fw, state_bw, _ = ( |
| self._createBidirectionalRNN(use_shape, False)) |
| variables_lib.global_variables_initializer().run() |
| out, s_fw, s_bw = sess.run( |
| [outputs, state_fw, state_bw], feed_dict={ |
| inputs[0]: input_value |
| }) |
| |
| # Since the forward and backward LSTM cells were initialized with the |
| # same parameters, the forward and backward output has to be the same, |
| # but reversed in time. The format is output[time][batch][depth], and |
| # due to depth concatenation (as num_units=3 for both RNNs): |
| # - forward output: out[][][depth] for 0 <= depth < 3 |
| # - backward output: out[][][depth] for 4 <= depth < 6 |
| # |
| # Both sequences in batch are length=8. Check that the time=i |
| # forward output is equal to time=8-1-i backward output |
| for i in range(8): |
| self.assertAllClose(out[i][0][0:3], out[8 - 1 - i][0][3:6]) |
| self.assertAllClose(out[i][1][0:3], out[8 - 1 - i][1][3:6]) |
| # Via the reasoning above, the forward and backward final state should be |
| # exactly the same |
| self.assertAllClose(s_fw, s_bw) |
| |
| @test_util.run_v1_only("b/124229375") |
| def testBidirectionalRNN(self): |
| self._testBidirectionalRNN(use_shape=False) |
| self._testBidirectionalRNN(use_shape=True) |
| |
| @test_util.run_v1_only("b/124229375") |
| def testBidirectionalRNNWithoutSequenceLength(self): |
| self._testBidirectionalRNNWithoutSequenceLength(use_shape=False) |
| self._testBidirectionalRNNWithoutSequenceLength(use_shape=True) |
| |
| def _createBidirectionalDynamicRNN(self, |
| use_shape, |
| use_state_tuple, |
| use_time_major, |
| use_sequence_length, |
| scope=None): |
| num_units = 3 |
| input_size = 5 |
| batch_size = 2 |
| max_length = 8 |
| |
| initializer = init_ops.random_uniform_initializer( |
| -0.01, 0.01, seed=self._seed) |
| sequence_length = ( |
| array_ops.placeholder(dtypes.int64) if use_sequence_length else None) |
| cell_fw = rnn_cell.LSTMCell( |
| num_units, initializer=initializer, state_is_tuple=use_state_tuple) |
| cell_bw = rnn_cell.LSTMCell( |
| num_units, initializer=initializer, state_is_tuple=use_state_tuple) |
| inputs = max_length * [ |
| array_ops.placeholder( |
| dtypes.float32, |
| shape=(batch_size if use_shape else None, input_size)) |
| ] |
| inputs_c = array_ops.stack(inputs) |
| if not use_time_major: |
| inputs_c = array_ops.transpose(inputs_c, [1, 0, 2]) |
| outputs, states = rnn.bidirectional_dynamic_rnn( |
| cell_fw, |
| cell_bw, |
| inputs_c, |
| sequence_length, |
| dtype=dtypes.float32, |
| time_major=use_time_major, |
| scope=scope) |
| outputs = array_ops.concat(outputs, 2) |
| state_fw, state_bw = states |
| outputs_shape = [None, max_length, 2 * num_units] |
| if use_shape: |
| outputs_shape[0] = batch_size |
| if use_time_major: |
| outputs_shape[0], outputs_shape[1] = outputs_shape[1], outputs_shape[0] |
| self.assertEqual(outputs.get_shape().as_list(), outputs_shape) |
| |
| input_value = np.random.randn(batch_size, input_size) |
| |
| return input_value, inputs, outputs, state_fw, state_bw, sequence_length |
| |
| def _testBidirectionalDynamicRNN(self, use_shape, use_state_tuple, |
| use_time_major, use_sequence_length): |
| with self.session(use_gpu=True, graph=ops.Graph()) as sess: |
| input_value, inputs, outputs, state_fw, state_bw, sequence_length = ( |
| self._createBidirectionalDynamicRNN( |
| use_shape, use_state_tuple, use_time_major, use_sequence_length)) |
| variables_lib.global_variables_initializer().run() |
| # Run with pre-specified sequence length of 2, 3 |
| feed_dict = ({sequence_length: [2, 3]} if use_sequence_length else {}) |
| feed_dict.update({inputs[0]: input_value}) |
| if use_state_tuple: |
| out, c_fw, m_fw, c_bw, m_bw = sess.run( |
| [outputs, state_fw[0], state_fw[1], state_bw[0], state_bw[1]], |
| feed_dict=feed_dict) |
| s_fw = (c_fw, m_fw) |
| s_bw = (c_bw, m_bw) |
| else: |
| feed_dict.update({inputs[0]: input_value}) |
| out, s_fw, s_bw = sess.run( |
| [outputs, state_fw, state_bw], feed_dict=feed_dict) |
| |
| # Since the forward and backward LSTM cells were initialized with the |
| # same parameters, the forward and backward output has to be the same, |
| # but reversed in time. The format is output[time][batch][depth], and |
| # due to depth concatenation (as num_units=3 for both RNNs): |
| # - forward output: out[][][depth] for 0 <= depth < 3 |
| # - backward output: out[][][depth] for 4 <= depth < 6 |
| # |
| if not use_time_major: |
| out = np.swapaxes(out, 0, 1) |
| |
| if use_sequence_length: |
| # First sequence in batch is length=2 |
| # Check that the t=0 forward output is equal to t=1 backward output |
| self.assertEqual(out[0][0][0], out[1][0][3]) |
| self.assertEqual(out[0][0][1], out[1][0][4]) |
| self.assertEqual(out[0][0][2], out[1][0][5]) |
| # Check that the t=1 forward output is equal to t=0 backward output |
| self.assertEqual(out[1][0][0], out[0][0][3]) |
| self.assertEqual(out[1][0][1], out[0][0][4]) |
| self.assertEqual(out[1][0][2], out[0][0][5]) |
| |
| # Second sequence in batch is length=3 |
| # Check that the t=0 forward output is equal to t=2 backward output |
| self.assertEqual(out[0][1][0], out[2][1][3]) |
| self.assertEqual(out[0][1][1], out[2][1][4]) |
| self.assertEqual(out[0][1][2], out[2][1][5]) |
| # Check that the t=1 forward output is equal to t=1 backward output |
| self.assertEqual(out[1][1][0], out[1][1][3]) |
| self.assertEqual(out[1][1][1], out[1][1][4]) |
| self.assertEqual(out[1][1][2], out[1][1][5]) |
| # Check that the t=2 forward output is equal to t=0 backward output |
| self.assertEqual(out[2][1][0], out[0][1][3]) |
| self.assertEqual(out[2][1][1], out[0][1][4]) |
| self.assertEqual(out[2][1][2], out[0][1][5]) |
| # Via the reasoning above, the forward and backward final state should |
| # be exactly the same |
| self.assertAllClose(s_fw, s_bw) |
| else: # not use_sequence_length |
| max_length = 8 # from createBidirectionalDynamicRNN |
| for t in range(max_length): |
| self.assertAllEqual(out[t, :, 0:3], out[max_length - t - 1, :, 3:6]) |
| self.assertAllClose(s_fw, s_bw) |
| |
| @test_util.run_v1_only("b/124229375") |
| def testBidirectionalDynamicRNN(self): |
| # Generate 2^5 option values |
| # from [True, True, True, True, True] to [False, False, False, False, False] |
| options = itertools.product([True, False], repeat=4) |
| for option in options: |
| self._testBidirectionalDynamicRNN( |
| use_shape=option[0], |
| use_state_tuple=option[1], |
| use_time_major=option[2], |
| use_sequence_length=option[3]) |
| |
| def _testScope(self, factory, prefix="prefix", use_outer_scope=True): |
| # REMARKS: factory(scope) is a function accepting a scope |
| # as an argument, such scope can be None, a string |
| # or a VariableScope instance. |
| with self.session(use_gpu=True, graph=ops.Graph()): |
| if use_outer_scope: |
| with variable_scope.variable_scope(prefix) as scope: |
| factory(scope) |
| else: |
| factory(prefix) |
| |
| # check that all the variables names starts |
| # with the proper scope. |
| variables_lib.global_variables_initializer() |
| all_vars = variables_lib.global_variables() |
| prefix = prefix or "bidirectional_rnn" |
| scope_vars = [v for v in all_vars if v.name.startswith(prefix + "/")] |
| tf_logging.info("BiRNN with scope: %s (%s)" % |
| (prefix, "scope" if use_outer_scope else "str")) |
| for v in scope_vars: |
| tf_logging.info(v.name) |
| self.assertEqual(len(scope_vars), len(all_vars)) |
| |
| @test_util.run_v1_only("b/124229375") |
| def testBidirectionalRNNScope(self): |
| |
| def factory(scope): |
| return self._createBidirectionalRNN( |
| use_shape=True, use_sequence_length=True, scope=scope) |
| |
| self._testScope(factory, use_outer_scope=True) |
| self._testScope(factory, use_outer_scope=False) |
| self._testScope(factory, prefix=None, use_outer_scope=False) |
| |
| @test_util.run_v1_only("b/124229375") |
| def testBidirectionalDynamicRNNScope(self): |
| |
| def get_factory(use_time_major): |
| |
| def factory(scope): |
| return self._createBidirectionalDynamicRNN( |
| use_shape=True, |
| use_state_tuple=True, |
| use_sequence_length=True, |
| use_time_major=use_time_major, |
| scope=scope) |
| |
| return factory |
| |
| self._testScope(get_factory(True), use_outer_scope=True) |
| self._testScope(get_factory(True), use_outer_scope=False) |
| self._testScope(get_factory(True), prefix=None, use_outer_scope=False) |
| self._testScope(get_factory(False), use_outer_scope=True) |
| self._testScope(get_factory(False), use_outer_scope=False) |
| self._testScope(get_factory(False), prefix=None, use_outer_scope=False) |
| |
| |
| class MultiDimensionalLSTMTest(test.TestCase): |
| |
| def setUp(self): |
| self._seed = 23489 |
| np.random.seed(self._seed) |
| |
| @test_util.run_v1_only("b/124229375") |
| def testMultiDimensionalLSTMAllRNNContainers(self): |
| feature_dims = (3, 4, 5) |
| input_size = feature_dims |
| batch_size = 2 |
| max_length = 8 |
| sequence_length = [4, 6] |
| with self.session(graph=ops.Graph()) as sess: |
| inputs = max_length * [ |
| array_ops.placeholder(dtypes.float32, shape=(None,) + input_size) |
| ] |
| inputs_using_dim = max_length * [ |
| array_ops.placeholder( |
| dtypes.float32, shape=(batch_size,) + input_size) |
| ] |
| inputs_c = array_ops.stack(inputs) |
| # Create a cell for the whole test. This is fine because the cell has no |
| # variables. |
| cell = DummyMultiDimensionalLSTM(feature_dims) |
| state_saver = TestStateSaver(batch_size, input_size) |
| outputs_static, state_static = rnn.static_rnn( |
| cell, inputs, dtype=dtypes.float32, sequence_length=sequence_length) |
| outputs_dynamic, state_dynamic = rnn.dynamic_rnn( |
| cell, |
| inputs_c, |
| dtype=dtypes.float32, |
| time_major=True, |
| sequence_length=sequence_length) |
| outputs_bid, state_fw, state_bw = rnn.static_bidirectional_rnn( |
| cell, |
| cell, |
| inputs_using_dim, |
| dtype=dtypes.float32, |
| sequence_length=sequence_length) |
| outputs_sav, state_sav = rnn.static_state_saving_rnn( |
| cell, |
| inputs_using_dim, |
| sequence_length=sequence_length, |
| state_saver=state_saver, |
| state_name=("h", "c")) |
| |
| self.assertEqual(outputs_dynamic.get_shape().as_list(), |
| inputs_c.get_shape().as_list()) |
| for out, inp in zip(outputs_static, inputs): |
| self.assertEqual(out.get_shape().as_list(), inp.get_shape().as_list()) |
| for out, inp in zip(outputs_bid, inputs_using_dim): |
| input_shape_list = inp.get_shape().as_list() |
| # fwd and bwd activations are concatenated along the second dim. |
| input_shape_list[1] *= 2 |
| self.assertEqual(out.get_shape().as_list(), input_shape_list) |
| |
| variables_lib.global_variables_initializer().run() |
| |
| input_total_size = (batch_size,) + input_size |
| input_value = np.random.randn(*input_total_size) |
| outputs_static_v = sess.run( |
| outputs_static, feed_dict={ |
| inputs[0]: input_value |
| }) |
| outputs_dynamic_v = sess.run( |
| outputs_dynamic, feed_dict={ |
| inputs[0]: input_value |
| }) |
| outputs_bid_v = sess.run( |
| outputs_bid, feed_dict={ |
| inputs_using_dim[0]: input_value |
| }) |
| outputs_sav_v = sess.run( |
| outputs_sav, feed_dict={ |
| inputs_using_dim[0]: input_value |
| }) |
| |
| self.assertAllEqual(outputs_static_v, outputs_dynamic_v) |
| self.assertAllEqual(outputs_static_v, outputs_sav_v) |
| outputs_static_array = np.array(outputs_static_v) |
| outputs_static_array_double = np.concatenate( |
| (outputs_static_array, outputs_static_array), axis=2) |
| outputs_bid_array = np.array(outputs_bid_v) |
| self.assertAllEqual(outputs_static_array_double, outputs_bid_array) |
| |
| state_static_v = sess.run( |
| state_static, feed_dict={ |
| inputs[0]: input_value |
| }) |
| state_dynamic_v = sess.run( |
| state_dynamic, feed_dict={ |
| inputs[0]: input_value |
| }) |
| state_bid_fw_v = sess.run( |
| state_fw, feed_dict={ |
| inputs_using_dim[0]: input_value |
| }) |
| state_bid_bw_v = sess.run( |
| state_bw, feed_dict={ |
| inputs_using_dim[0]: input_value |
| }) |
| state_sav_v = sess.run( |
| state_sav, feed_dict={ |
| inputs_using_dim[0]: input_value |
| }) |
| self.assertAllEqual(np.hstack(state_static_v), np.hstack(state_dynamic_v)) |
| self.assertAllEqual(np.hstack(state_static_v), np.hstack(state_sav_v)) |
| self.assertAllEqual(np.hstack(state_static_v), np.hstack(state_bid_fw_v)) |
| self.assertAllEqual(np.hstack(state_static_v), np.hstack(state_bid_bw_v)) |
| |
| |
| class NestedLSTMTest(test.TestCase): |
| |
| def setUp(self): |
| self._seed = 23489 |
| np.random.seed(self._seed) |
| |
| @test_util.run_v1_only("b/124229375") |
| def testNestedIOLSTMAllRNNContainers(self): |
| input_size = 5 |
| batch_size = 2 |
| state_size = 6 |
| max_length = 8 |
| sequence_length = [4, 6] |
| with self.session(graph=ops.Graph()) as sess: |
| state_saver = TestStateSaver(batch_size, state_size) |
| single_input = (array_ops.placeholder( |
| dtypes.float32, shape=(None, input_size)), |
| array_ops.placeholder( |
| dtypes.float32, shape=(None, input_size))) |
| inputs = max_length * [single_input] |
| inputs_c = (array_ops.stack([input_[0] for input_ in inputs]), |
| array_ops.stack([input_[1] for input_ in inputs])) |
| single_input_using_dim = (array_ops.placeholder( |
| dtypes.float32, shape=(batch_size, input_size)), |
| array_ops.placeholder( |
| dtypes.float32, |
| shape=(batch_size, input_size))) |
| inputs_using_dim = max_length * [single_input_using_dim] |
| |
| # Create a cell for the whole test. This is fine because the cell has no |
| # variables. |
| cell = NestedRNNCell() |
| outputs_dynamic, state_dynamic = rnn.dynamic_rnn( |
| cell, |
| inputs_c, |
| dtype=dtypes.float32, |
| time_major=True, |
| sequence_length=sequence_length) |
| outputs_static, state_static = rnn.static_rnn( |
| cell, inputs, dtype=dtypes.float32, sequence_length=sequence_length) |
| outputs_bid, state_fw, state_bw = rnn.static_bidirectional_rnn( |
| cell, |
| cell, |
| inputs_using_dim, |
| dtype=dtypes.float32, |
| sequence_length=sequence_length) |
| outputs_sav, state_sav = rnn.static_state_saving_rnn( |
| cell, |
| inputs_using_dim, |
| sequence_length=sequence_length, |
| state_saver=state_saver, |
| state_name=("h", "c")) |
| |
| def _assert_same_shape(input1, input2, double=False): |
| flat_input1 = nest.flatten(input1) |
| flat_input2 = nest.flatten(input2) |
| for inp1, inp2 in zip(flat_input1, flat_input2): |
| input_shape = inp1.get_shape().as_list() |
| if double: |
| input_shape[1] *= 2 |
| self.assertEqual(input_shape, inp2.get_shape().as_list()) |
| |
| _assert_same_shape(inputs_c, outputs_dynamic) |
| _assert_same_shape(inputs, outputs_static) |
| _assert_same_shape(inputs_using_dim, outputs_sav) |
| _assert_same_shape(inputs_using_dim, outputs_bid, double=True) |
| |
| variables_lib.global_variables_initializer().run() |
| |
| input_total_size = (batch_size, input_size) |
| input_value = (np.random.randn(*input_total_size), |
| np.random.randn(*input_total_size)) |
| outputs_dynamic_v = sess.run( |
| outputs_dynamic, feed_dict={ |
| single_input: input_value |
| }) |
| outputs_static_v = sess.run( |
| outputs_static, feed_dict={ |
| single_input: input_value |
| }) |
| outputs_sav_v = sess.run( |
| outputs_sav, feed_dict={ |
| single_input_using_dim: input_value |
| }) |
| outputs_bid_v = sess.run( |
| outputs_bid, feed_dict={ |
| single_input_using_dim: input_value |
| }) |
| |
| self.assertAllEqual(outputs_static_v, |
| np.transpose(outputs_dynamic_v, (1, 0, 2, 3))) |
| self.assertAllEqual(outputs_static_v, outputs_sav_v) |
| outputs_static_array = np.array(outputs_static_v) |
| outputs_static_array_double = np.concatenate( |
| (outputs_static_array, outputs_static_array), axis=3) |
| outputs_bid_array = np.array(outputs_bid_v) |
| self.assertAllEqual(outputs_static_array_double, outputs_bid_array) |
| |
| state_dynamic_v = sess.run( |
| state_dynamic, feed_dict={ |
| single_input: input_value |
| }) |
| state_static_v = sess.run( |
| state_static, feed_dict={ |
| single_input: input_value |
| }) |
| state_bid_fw_v = sess.run( |
| state_fw, feed_dict={ |
| single_input_using_dim: input_value |
| }) |
| state_bid_bw_v = sess.run( |
| state_bw, feed_dict={ |
| single_input_using_dim: input_value |
| }) |
| state_sav_v = sess.run( |
| state_sav, feed_dict={ |
| single_input_using_dim: input_value |
| }) |
| self.assertAllEqual(np.hstack(state_static_v), np.hstack(state_dynamic_v)) |
| self.assertAllEqual(np.hstack(state_static_v), np.hstack(state_sav_v)) |
| self.assertAllEqual(np.hstack(state_static_v), np.hstack(state_bid_fw_v)) |
| self.assertAllEqual(np.hstack(state_static_v), np.hstack(state_bid_bw_v)) |
| |
| |
| class StateSaverRNNTest(test.TestCase): |
| |
| def setUp(self): |
| self._seed = 23489 |
| np.random.seed(self._seed) |
| |
| def _factory(self, scope, state_saver): |
| num_units = state_saver.state_size // 2 |
| batch_size = state_saver.batch_size |
| input_size = 5 |
| max_length = 8 |
| initializer = init_ops.random_uniform_initializer( |
| -0.01, 0.01, seed=self._seed) |
| cell = rnn_cell.LSTMCell( |
| num_units, |
| use_peepholes=False, |
| initializer=initializer, |
| state_is_tuple=False) |
| inputs = max_length * [ |
| array_ops.zeros(dtype=dtypes.float32, shape=(batch_size, input_size)) |
| ] |
| out, state = rnn.static_state_saving_rnn( |
| cell, |
| inputs, |
| state_saver=state_saver, |
| state_name="save_lstm", |
| scope=scope) |
| return out, state, state_saver |
| |
| def _testScope(self, prefix="prefix", use_outer_scope=True): |
| num_units = 3 |
| batch_size = 2 |
| state_saver = TestStateSaver(batch_size, 2 * num_units) |
| |
| with self.session(use_gpu=True, graph=ops.Graph()): |
| if use_outer_scope: |
| with variable_scope.variable_scope(prefix) as scope: |
| self._factory(scope=scope, state_saver=state_saver) |
| else: |
| self._factory(scope=prefix, state_saver=state_saver) |
| variables_lib.global_variables_initializer() |
| |
| # check that all the variables names starts |
| # with the proper scope. |
| all_vars = variables_lib.global_variables() |
| prefix = prefix or "rnn" |
| scope_vars = [v for v in all_vars if v.name.startswith(prefix + "/")] |
| tf_logging.info("RNN with scope: %s (%s)" % |
| (prefix, "scope" if use_outer_scope else "str")) |
| for v in scope_vars: |
| tf_logging.info(v.name) |
| self.assertEqual(len(scope_vars), len(all_vars)) |
| |
| def testStateSaverRNNScope(self): |
| self._testScope(use_outer_scope=True) |
| self._testScope(use_outer_scope=False) |
| self._testScope(prefix=None, use_outer_scope=False) |
| |
| def testStateSaverCallsSaveState(self): |
| """Test that number of calls to state and save_state is equal. |
| |
| Test if the order of actual evaluating or skipping evaluation of out, |
| state tensors, which are the output tensors from static_state_saving_rnn, |
| have influence on number of calls to save_state and state methods of |
| state_saver object (the number of calls should be same.) |
| """ |
| self.skipTest("b/124196246 Breakage for sess.run([out, ...]): 2 != 1") |
| |
| num_units = 3 |
| batch_size = 2 |
| state_saver = TestStateSaverWithCounters(batch_size, 2 * num_units) |
| out, state, state_saver = self._factory(scope=None, state_saver=state_saver) |
| |
| with self.cached_session() as sess: |
| sess.run(variables_lib.global_variables_initializer()) |
| sess.run(variables_lib.local_variables_initializer()) |
| |
| _, _, num_state_calls, num_save_state_calls = sess.run([ |
| out, |
| state, |
| state_saver.num_state_calls, |
| state_saver.num_save_state_calls]) |
| self.assertEqual(num_state_calls, num_save_state_calls) |
| |
| _, num_state_calls, num_save_state_calls = sess.run([ |
| out, |
| state_saver.num_state_calls, |
| state_saver.num_save_state_calls]) |
| self.assertEqual(num_state_calls, num_save_state_calls) |
| |
| _, num_state_calls, num_save_state_calls = sess.run([ |
| state, |
| state_saver.num_state_calls, |
| state_saver.num_save_state_calls]) |
| self.assertEqual(num_state_calls, num_save_state_calls) |
| |
| class GRUTest(test.TestCase): |
| |
| def setUp(self): |
| self._seed = 23489 |
| np.random.seed(self._seed) |
| |
| @test_util.run_v1_only("b/124229375") |
| def testDynamic(self): |
| time_steps = 8 |
| num_units = 3 |
| input_size = 5 |
| batch_size = 2 |
| |
| input_values = np.random.randn(time_steps, batch_size, input_size) |
| |
| sequence_length = np.random.randint(0, time_steps, size=batch_size) |
| |
| with self.session(use_gpu=True, graph=ops.Graph()) as sess: |
| concat_inputs = array_ops.placeholder( |
| dtypes.float32, shape=(time_steps, batch_size, input_size)) |
| |
| cell = rnn_cell.GRUCell(num_units=num_units) |
| |
| with variable_scope.variable_scope("dynamic_scope"): |
| outputs_dynamic, state_dynamic = rnn.dynamic_rnn( |
| cell, |
| inputs=concat_inputs, |
| sequence_length=sequence_length, |
| time_major=True, |
| dtype=dtypes.float32) |
| |
| feeds = {concat_inputs: input_values} |
| |
| # Initialize |
| variables_lib.global_variables_initializer().run(feed_dict=feeds) |
| |
| sess.run([outputs_dynamic, state_dynamic], feed_dict=feeds) |
| |
| def _testScope(self, factory, prefix="prefix", use_outer_scope=True): |
| with self.session(use_gpu=True, graph=ops.Graph()): |
| if use_outer_scope: |
| with variable_scope.variable_scope(prefix) as scope: |
| factory(scope) |
| else: |
| factory(prefix) |
| variables_lib.global_variables_initializer() |
| |
| # check that all the variables names starts |
| # with the proper scope. |
| all_vars = variables_lib.global_variables() |
| prefix = prefix or "rnn" |
| scope_vars = [v for v in all_vars if v.name.startswith(prefix + "/")] |
| tf_logging.info("RNN with scope: %s (%s)" % |
| (prefix, "scope" if use_outer_scope else "str")) |
| for v in scope_vars: |
| tf_logging.info(v.name) |
| self.assertEqual(len(scope_vars), len(all_vars)) |
| |
| @test_util.run_v1_only("b/124229375") |
| def testDynamicScope(self): |
| time_steps = 8 |
| num_units = 3 |
| input_size = 5 |
| batch_size = 2 |
| sequence_length = np.random.randint(0, time_steps, size=batch_size) |
| |
| def factory(scope): |
| concat_inputs = array_ops.placeholder( |
| dtypes.float32, shape=(time_steps, batch_size, input_size)) |
| cell = rnn_cell.GRUCell(num_units=num_units) |
| return rnn.dynamic_rnn( |
| cell, |
| inputs=concat_inputs, |
| sequence_length=sequence_length, |
| time_major=True, |
| dtype=dtypes.float32, |
| scope=scope) |
| |
| self._testScope(factory, use_outer_scope=True) |
| self._testScope(factory, use_outer_scope=False) |
| self._testScope(factory, prefix=None, use_outer_scope=False) |
| |
| |
| class RawRNNTest(test.TestCase): |
| |
| def setUp(self): |
| self._seed = 23489 |
| np.random.seed(self._seed) |
| |
| @test_util.run_v1_only("b/124229375") |
| def _testRawRNN(self, max_time): |
| with self.session(graph=ops.Graph()) as sess: |
| batch_size = 16 |
| input_depth = 4 |
| num_units = 3 |
| |
| inputs = array_ops.placeholder( |
| shape=(max_time, batch_size, input_depth), dtype=dtypes.float32) |
| sequence_length = array_ops.placeholder( |
| shape=(batch_size,), dtype=dtypes.int32) |
| inputs_ta = tensor_array_ops.TensorArray( |
| dtype=dtypes.float32, size=array_ops.shape(inputs)[0]) |
| inputs_ta = inputs_ta.unstack(inputs) |
| |
| cell = rnn_cell.LSTMCell(num_units, state_is_tuple=True) |
| |
| def loop_fn(time_, cell_output, cell_state, unused_loop_state): |
| emit_output = cell_output # == None for time == 0 |
| if cell_output is None: # time == 0 |
| next_state = cell.zero_state(batch_size, dtypes.float32) |
| else: |
| next_state = cell_state # copy state through |
| elements_finished = (time_ >= sequence_length) |
| finished = math_ops.reduce_all(elements_finished) |
| # For the very final iteration, we must emit a dummy input |
| next_input = control_flow_ops.cond( |
| finished, |
| lambda: array_ops.zeros([batch_size, input_depth], dtype=dtypes.float32), |
| lambda: inputs_ta.read(time_)) |
| return (elements_finished, next_input, next_state, emit_output, None) |
| |
| reuse_scope = variable_scope.get_variable_scope() |
| |
| outputs_ta, final_state, _ = rnn.raw_rnn(cell, loop_fn, scope=reuse_scope) |
| outputs = outputs_ta.stack() |
| |
| reuse_scope.reuse_variables() |
| outputs_dynamic_rnn, final_state_dynamic_rnn = rnn.dynamic_rnn( |
| cell, |
| inputs, |
| time_major=True, |
| dtype=dtypes.float32, |
| sequence_length=sequence_length, |
| scope=reuse_scope) |
| |
| variables = variables_lib.trainable_variables() |
| gradients = gradients_impl.gradients([outputs, final_state], |
| [inputs] + variables) |
| gradients_dynamic_rnn = gradients_impl.gradients( |
| [outputs_dynamic_rnn, final_state_dynamic_rnn], [inputs] + variables) |
| |
| variables_lib.global_variables_initializer().run() |
| |
| rand_input = np.random.randn(max_time, batch_size, input_depth) |
| if max_time == 0: |
| rand_seq_len = np.zeros(batch_size) |
| else: |
| rand_seq_len = np.random.randint(max_time, size=batch_size) |
| |
| # To ensure same output lengths for dynamic_rnn and raw_rnn |
| rand_seq_len[0] = max_time |
| |
| (outputs_val, outputs_dynamic_rnn_val, final_state_val, |
| final_state_dynamic_rnn_val) = sess.run( |
| [outputs, outputs_dynamic_rnn, final_state, final_state_dynamic_rnn], |
| feed_dict={ |
| inputs: rand_input, |
| sequence_length: rand_seq_len |
| }) |
| |
| self.assertAllClose(outputs_dynamic_rnn_val, outputs_val) |
| self.assertAllClose(final_state_dynamic_rnn_val, final_state_val) |
| |
| # NOTE: Because with 0 time steps, raw_rnn does not have shape |
| # information about the input, it is impossible to perform |
| # gradients comparisons as the gradients eval will fail. So |
| # this case skips the gradients test. |
| if max_time > 0: |
| self.assertEqual(len(gradients), len(gradients_dynamic_rnn)) |
| gradients_val = sess.run( |
| gradients, |
| feed_dict={ |
| inputs: rand_input, |
| sequence_length: rand_seq_len |
| }) |
| gradients_dynamic_rnn_val = sess.run( |
| gradients_dynamic_rnn, |
| feed_dict={ |
| inputs: rand_input, |
| sequence_length: rand_seq_len |
| }) |
| self.assertEqual(len(gradients_val), len(gradients_dynamic_rnn_val)) |
| input_gradients_val = gradients_val[0] |
| input_gradients_dynamic_rnn_val = gradients_dynamic_rnn_val[0] |
| self.assertAllClose(input_gradients_val, |
| input_gradients_dynamic_rnn_val) |
| for i in range(1, len(gradients_val)): |
| self.assertAllClose(gradients_dynamic_rnn_val[i], gradients_val[i]) |
| |
| @test_util.run_v1_only("b/124229375") |
| def testRawRNNZeroLength(self): |
| # NOTE: Because with 0 time steps, raw_rnn does not have shape |
| # information about the input, it is impossible to perform |
| # gradients comparisons as the gradients eval will fail. So this |
| # case skips the gradients test. |
| self._testRawRNN(max_time=0) |
| |
| def testRawRNN(self): |
| self._testRawRNN(max_time=10) |
| |
| @test_util.run_v1_only("b/124229375") |
| def testLoopState(self): |
| with self.session(graph=ops.Graph()): |
| max_time = 10 |
| batch_size = 16 |
| input_depth = 4 |
| num_units = 3 |
| |
| inputs = np.random.randn(max_time, batch_size, input_depth) |
| inputs_ta = tensor_array_ops.TensorArray( |
| dtype=dtypes.float32, size=array_ops.shape(inputs)[0]) |
| inputs_ta = inputs_ta.unstack(inputs) |
| |
| cell = rnn_cell.LSTMCell(num_units, state_is_tuple=True) |
| |
| def loop_fn(time_, cell_output, cell_state, loop_state): |
| if cell_output is None: |
| loop_state = constant_op.constant([0]) |
| next_state = cell.zero_state(batch_size, dtypes.float32) |
| else: |
| loop_state = array_ops.stack([array_ops.squeeze(loop_state) + 1]) |
| next_state = cell_state |
| emit_output = cell_output # == None for time == 0 |
| elements_finished = array_ops.tile([time_ >= max_time], [batch_size]) |
| finished = math_ops.reduce_all(elements_finished) |
| # For the very final iteration, we must emit a dummy input |
| next_input = control_flow_ops.cond( |
| finished, |
| lambda: array_ops.zeros([batch_size, input_depth], dtype=dtypes.float32), |
| lambda: inputs_ta.read(time_)) |
| return (elements_finished, next_input, next_state, emit_output, |
| loop_state) |
| |
| r = rnn.raw_rnn(cell, loop_fn) |
| loop_state = r[-1] |
| self.assertEqual([10], loop_state.eval()) |
| |
| @test_util.run_v1_only("b/124229375") |
| def testLoopStateWithTensorArray(self): |
| with self.session(graph=ops.Graph()): |
| max_time = 4 |
| batch_size = 16 |
| input_depth = 4 |
| num_units = 3 |
| |
| inputs = np.random.randn(max_time, batch_size, input_depth) |
| inputs_ta = tensor_array_ops.TensorArray( |
| dtype=dtypes.float32, size=array_ops.shape(inputs)[0]) |
| inputs_ta = inputs_ta.unstack(inputs) |
| |
| cell = rnn_cell.LSTMCell(num_units, state_is_tuple=True) |
| |
| def loop_fn(time_, cell_output, cell_state, loop_state): |
| if cell_output is None: |
| loop_state = tensor_array_ops.TensorArray( |
| dynamic_size=True, |
| size=0, |
| dtype=dtypes.int32, |
| clear_after_read=False) |
| loop_state = loop_state.write(0, 1) |
| next_state = cell.zero_state(batch_size, dtypes.float32) |
| else: |
| loop_state = loop_state.write(time_, |
| loop_state.read(time_ - 1) + time_) |
| next_state = cell_state |
| emit_output = cell_output # == None for time == 0 |
| elements_finished = array_ops.tile([time_ >= max_time], [batch_size]) |
| finished = math_ops.reduce_all(elements_finished) |
| # For the very final iteration, we must emit a dummy input |
| next_input = control_flow_ops.cond( |
| finished, |
| lambda: array_ops.zeros([batch_size, input_depth], dtype=dtypes.float32), |
| lambda: inputs_ta.read(time_)) |
| return (elements_finished, next_input, next_state, emit_output, |
| loop_state) |
| |
| r = rnn.raw_rnn(cell, loop_fn) |
| loop_state = r[-1] |
| loop_state = loop_state.stack() |
| self.assertAllEqual([1, 2, 2 + 2, 4 + 3, 7 + 4], loop_state.eval()) |
| |
| @test_util.run_v1_only("b/124229375") |
| def testEmitDifferentStructureThanCellOutput(self): |
| with self.session(graph=ops.Graph()) as sess: |
| max_time = 10 |
| batch_size = 16 |
| input_depth = 4 |
| num_units = 3 |
| |
| inputs = np.random.randn(max_time, batch_size, input_depth) |
| inputs_ta = tensor_array_ops.TensorArray( |
| dtype=dtypes.float32, size=array_ops.shape(inputs)[0]) |
| inputs_ta = inputs_ta.unstack(inputs) |
| # Verify emit shapes may be unknown by feeding a placeholder that |
| # determines an emit shape. |
| unknown_dim = array_ops.placeholder(dtype=dtypes.int32) |
| |
| cell = rnn_cell.LSTMCell(num_units, state_is_tuple=True) |
| |
| def loop_fn(time_, cell_output, cell_state, _): |
| if cell_output is None: |
| emit_output = (array_ops.zeros([2, 3], dtype=dtypes.int32), |
| array_ops.zeros([unknown_dim], dtype=dtypes.int64)) |
| next_state = cell.zero_state(batch_size, dtypes.float32) |
| else: |
| emit_output = (array_ops.ones([batch_size, 2, 3], dtype=dtypes.int32), |
| array_ops.ones( |
| [batch_size, unknown_dim], dtype=dtypes.int64)) |
| next_state = cell_state |
| elements_finished = array_ops.tile([time_ >= max_time], [batch_size]) |
| finished = math_ops.reduce_all(elements_finished) |
| # For the very final iteration, we must emit a dummy input |
| next_input = control_flow_ops.cond( |
| finished, |
| lambda: array_ops.zeros([batch_size, input_depth], dtype=dtypes.float32), |
| lambda: inputs_ta.read(time_)) |
| return (elements_finished, next_input, next_state, emit_output, None) |
| |
| r = rnn.raw_rnn(cell, loop_fn) |
| output_ta = r[0] |
| self.assertEqual(2, len(output_ta)) |
| self.assertEqual([dtypes.int32, dtypes.int64], |
| [ta.dtype for ta in output_ta]) |
| output = [ta.stack() for ta in output_ta] |
| output_vals = sess.run(output, feed_dict={unknown_dim: 1}) |
| self.assertAllEqual( |
| np.ones((max_time, batch_size, 2, 3), np.int32), output_vals[0]) |
| self.assertAllEqual( |
| np.ones((max_time, batch_size, 1), np.int64), output_vals[1]) |
| |
| def _testScope(self, factory, prefix="prefix", use_outer_scope=True): |
| with self.session(use_gpu=True, graph=ops.Graph()): |
| if use_outer_scope: |
| with variable_scope.variable_scope(prefix) as scope: |
| factory(scope) |
| else: |
| factory(prefix) |
| variables_lib.global_variables_initializer() |
| |
| # check that all the variables names starts |
| # with the proper scope. |
| all_vars = variables_lib.global_variables() |
| prefix = prefix or "rnn" |
| scope_vars = [v for v in all_vars if v.name.startswith(prefix + "/")] |
| tf_logging.info("RNN with scope: %s (%s)" % |
| (prefix, "scope" if use_outer_scope else "str")) |
| for v in scope_vars: |
| tf_logging.info(v.name) |
| self.assertEqual(len(scope_vars), len(all_vars)) |
| |
| @test_util.run_v1_only("b/124229375") |
| def testRawRNNScope(self): |
| max_time = 10 |
| batch_size = 16 |
| input_depth = 4 |
| num_units = 3 |
| |
| def factory(scope): |
| inputs = array_ops.placeholder( |
| shape=(max_time, batch_size, input_depth), dtype=dtypes.float32) |
| sequence_length = array_ops.placeholder( |
| shape=(batch_size,), dtype=dtypes.int32) |
| inputs_ta = tensor_array_ops.TensorArray( |
| dtype=dtypes.float32, size=array_ops.shape(inputs)[0]) |
| inputs_ta = inputs_ta.unstack(inputs) |
| |
| cell = rnn_cell.LSTMCell(num_units, state_is_tuple=True) |
| |
| def loop_fn(time_, cell_output, cell_state, unused_loop_state): |
| emit_output = cell_output # == None for time == 0 |
| if cell_output is None: # time == 0 |
| next_state = cell.zero_state(batch_size, dtypes.float32) |
| else: |
| next_state = cell_state |
| |
| elements_finished = (time_ >= sequence_length) |
| finished = math_ops.reduce_all(elements_finished) |
| # For the very final iteration, we must emit a dummy input |
| next_input = control_flow_ops.cond( |
| finished, |
| lambda: array_ops.zeros([batch_size, input_depth], dtype=dtypes.float32), |
| lambda: inputs_ta.read(time_)) |
| return (elements_finished, next_input, next_state, emit_output, None) |
| |
| return rnn.raw_rnn(cell, loop_fn, scope=scope) |
| |
| self._testScope(factory, use_outer_scope=True) |
| self._testScope(factory, use_outer_scope=False) |
| self._testScope(factory, prefix=None, use_outer_scope=False) |
| |
| |
| class DeviceWrapperCell(rnn_cell.RNNCell): |
| """Class to ensure cell calculation happens on a specific device.""" |
| |
| def __init__(self, cell, device): |
| self._cell = cell |
| self._device = device |
| |
| @property |
| def output_size(self): |
| return self._cell.output_size |
| |
| @property |
| def state_size(self): |
| return self._cell.state_size |
| |
| def __call__(self, input_, state, scope=None): |
| if self._device is not None: |
| with ops.device(self._device): |
| return self._cell(input_, state, scope=scope) |
| else: |
| return self._cell(input_, state, scope=scope) |
| |
| |
| class TensorArrayOnCorrectDeviceTest(test.TestCase): |
| |
| def _execute_rnn_on(self, |
| rnn_device=None, |
| cell_device=None, |
| input_device=None): |
| batch_size = 3 |
| time_steps = 7 |
| input_size = 5 |
| num_units = 10 |
| |
| cell = rnn_cell.LSTMCell(num_units, use_peepholes=True) |
| gpu_cell = DeviceWrapperCell(cell, cell_device) |
| inputs = np.random.randn(batch_size, time_steps, input_size).astype( |
| np.float32) |
| sequence_length = np.random.randint(0, time_steps, size=batch_size) |
| |
| if input_device is not None: |
| with ops.device(input_device): |
| inputs = constant_op.constant(inputs) |
| |
| if rnn_device is not None: |
| with ops.device(rnn_device): |
| outputs, _ = rnn.dynamic_rnn( |
| gpu_cell, |
| inputs, |
| sequence_length=sequence_length, |
| dtype=dtypes.float32) |
| else: |
| outputs, _ = rnn.dynamic_rnn( |
| gpu_cell, |
| inputs, |
| sequence_length=sequence_length, |
| dtype=dtypes.float32) |
| |
| with self.session(use_gpu=True) as sess: |
| opts = config_pb2.RunOptions(trace_level=config_pb2.RunOptions.FULL_TRACE) |
| run_metadata = config_pb2.RunMetadata() |
| variables_lib.global_variables_initializer().run() |
| sess.run(outputs, options=opts, run_metadata=run_metadata) |
| |
| return run_metadata |
| |
| def _retrieve_cpu_gpu_stats(self, run_metadata): |
| cpu_stats = None |
| gpu_stats = None |
| step_stats = run_metadata.step_stats |
| for ds in step_stats.dev_stats: |
| if "cpu:0" in ds.device[-5:].lower(): |
| cpu_stats = ds.node_stats |
| if "gpu:0" == ds.device[-5:].lower(): |
| gpu_stats = ds.node_stats |
| return cpu_stats, gpu_stats |
| |
| @test_util.run_v1_only("b/124229375") |
| def testRNNOnCPUCellOnGPU(self): |
| if not test.is_gpu_available(): |
| return # Test requires access to a GPU |
| |
| gpu_dev = test.gpu_device_name() |
| run_metadata = self._execute_rnn_on( |
| rnn_device="/cpu:0", cell_device=gpu_dev) |
| cpu_stats, gpu_stats = self._retrieve_cpu_gpu_stats(run_metadata) |
| |
| def _assert_in(op_str, in_stats, out_stats): |
| self.assertTrue(any(op_str in s.node_name for s in in_stats)) |
| self.assertFalse(any(op_str in s.node_name for s in out_stats)) |
| |
| # Writes happen at output of RNN cell |
| _assert_in("TensorArrayWrite", gpu_stats, cpu_stats) |
| # Gather happens on final TensorArray |
| _assert_in("TensorArrayGather", gpu_stats, cpu_stats) |
| # Reads happen at input to RNN cell |
| _assert_in("TensorArrayRead", cpu_stats, gpu_stats) |
| # Scatters happen to get initial input into TensorArray |
| _assert_in("TensorArrayScatter", cpu_stats, gpu_stats) |
| |
| @test_util.run_v1_only("b/124229375") |
| def testRNNOnCPUCellOnCPU(self): |
| if not test.is_gpu_available(): |
| return # Test requires access to a GPU |
| |
| gpu_dev = test.gpu_device_name() |
| run_metadata = self._execute_rnn_on( |
| rnn_device="/cpu:0", cell_device="/cpu:0", input_device=gpu_dev) |
| cpu_stats, gpu_stats = self._retrieve_cpu_gpu_stats(run_metadata) |
| |
| def _assert_in(op_str, in_stats, out_stats): |
| self.assertTrue(any(op_str in s.node_name for s in in_stats)) |
| self.assertFalse(any(op_str in s.node_name for s in out_stats)) |
| |
| # All TensorArray operations happen on CPU |
| _assert_in("TensorArray", cpu_stats, gpu_stats) |
| |
| @test_util.run_v1_only("b/124229375") |
| def testInputOnGPUCellNotDeclared(self): |
| if not test.is_gpu_available(): |
| return # Test requires access to a GPU |
| |
| gpu_dev = test.gpu_device_name() |
| run_metadata = self._execute_rnn_on(input_device=gpu_dev) |
| cpu_stats, gpu_stats = self._retrieve_cpu_gpu_stats(run_metadata) |
| |
| def _assert_in(op_str, in_stats, out_stats): |
| self.assertTrue(any(op_str in s.node_name for s in in_stats)) |
| self.assertFalse(any(op_str in s.node_name for s in out_stats)) |
| |
| # Everything happens on GPU |
| _assert_in("TensorArray", gpu_stats, cpu_stats) |
| |
| |
| class RNNCellTest(test.TestCase, parameterized.TestCase): |
| |
| @test_util.run_v1_only("b/124229375") |
| def testBasicRNNCell(self): |
| with self.cached_session() as sess: |
| with variable_scope.variable_scope( |
| "root", initializer=init_ops.constant_initializer(0.5)): |
| x = array_ops.zeros([1, 2]) |
| m = array_ops.zeros([1, 2]) |
| cell = rnn_cell_impl.BasicRNNCell(2) |
| g, _ = cell(x, m) |
| self.assertEqual([ |
| "root/basic_rnn_cell/%s:0" % rnn_cell_impl._WEIGHTS_VARIABLE_NAME, |
| "root/basic_rnn_cell/%s:0" % rnn_cell_impl._BIAS_VARIABLE_NAME |
| ], [v.name for v in cell.trainable_variables]) |
| self.assertFalse(cell.non_trainable_variables) |
| sess.run([variables_lib.global_variables_initializer()]) |
| res = sess.run([g], { |
| x: np.array([[1., 1.]]), |
| m: np.array([[0.1, 0.1]]) |
| }) |
| self.assertEqual(res[0].shape, (1, 2)) |
| |
| @test_util.run_v1_only("b/124229375") |
| def testBasicRNNCellNotTrainable(self): |
| with self.cached_session() as sess: |
| |
| def not_trainable_getter(getter, *args, **kwargs): |
| kwargs["trainable"] = False |
| return getter(*args, **kwargs) |
| |
| with variable_scope.variable_scope( |
| "root", |
| initializer=init_ops.constant_initializer(0.5), |
| custom_getter=not_trainable_getter): |
| x = array_ops.zeros([1, 2]) |
| m = array_ops.zeros([1, 2]) |
| cell = rnn_cell_impl.BasicRNNCell(2) |
| g, _ = cell(x, m) |
| self.assertFalse(cell.trainable_variables) |
| self.assertEqual([ |
| "root/basic_rnn_cell/%s:0" % rnn_cell_impl._WEIGHTS_VARIABLE_NAME, |
| "root/basic_rnn_cell/%s:0" % rnn_cell_impl._BIAS_VARIABLE_NAME |
| ], [v.name for v in cell.non_trainable_variables]) |
| sess.run([variables_lib.global_variables_initializer()]) |
| res = sess.run([g], { |
| x: np.array([[1., 1.]]), |
| m: np.array([[0.1, 0.1]]) |
| }) |
| self.assertEqual(res[0].shape, (1, 2)) |
| |
| @test_util.run_v1_only("b/124229375") |
| def testGRUCell(self): |
| with self.cached_session() as sess: |
| with variable_scope.variable_scope( |
| "root", initializer=init_ops.constant_initializer(0.5)): |
| x = array_ops.zeros([1, 2]) |
| m = array_ops.zeros([1, 2]) |
| g, _ = rnn_cell_impl.GRUCell(2)(x, m) |
| sess.run([variables_lib.global_variables_initializer()]) |
| res = sess.run([g], { |
| x: np.array([[1., 1.]]), |
| m: np.array([[0.1, 0.1]]) |
| }) |
| # Smoke test |
| self.assertAllClose(res[0], [[0.175991, 0.175991]]) |
| with variable_scope.variable_scope( |
| "other", initializer=init_ops.constant_initializer(0.5)): |
| # Test GRUCell with input_size != num_units. |
| x = array_ops.zeros([1, 3]) |
| m = array_ops.zeros([1, 2]) |
| g, _ = rnn_cell_impl.GRUCell(2)(x, m) |
| sess.run([variables_lib.global_variables_initializer()]) |
| res = sess.run([g], { |
| x: np.array([[1., 1., 1.]]), |
| m: np.array([[0.1, 0.1]]) |
| }) |
| # Smoke test |
| self.assertAllClose(res[0], [[0.156736, 0.156736]]) |
| |
| @test_util.run_v1_only("b/124229375") |
| def testBasicLSTMCell(self): |
| for dtype in [dtypes.float16, dtypes.float32]: |
| np_dtype = dtype.as_numpy_dtype |
| with self.session(graph=ops.Graph()) as sess: |
| with variable_scope.variable_scope( |
| "root", initializer=init_ops.constant_initializer(0.5)): |
| x = array_ops.zeros([1, 2], dtype=dtype) |
| m = array_ops.zeros([1, 8], dtype=dtype) |
| cell = rnn_cell_impl.MultiRNNCell( |
| [ |
| rnn_cell_impl.BasicLSTMCell(2, state_is_tuple=False) |
| for _ in range(2) |
| ], |
| state_is_tuple=False) |
| self.assertEqual(cell.dtype, None) |
| self.assertEqual("cell-0", cell._checkpoint_dependencies[0].name) |
| self.assertEqual("cell-1", cell._checkpoint_dependencies[1].name) |
| cell.get_config() # Should not throw an error |
| g, out_m = cell(x, m) |
| # Layer infers the input type. |
| self.assertEqual(cell.dtype, dtype.name) |
| expected_variable_names = [ |
| "root/multi_rnn_cell/cell_0/basic_lstm_cell/%s:0" % |
| rnn_cell_impl._WEIGHTS_VARIABLE_NAME, |
| "root/multi_rnn_cell/cell_0/basic_lstm_cell/%s:0" % |
| rnn_cell_impl._BIAS_VARIABLE_NAME, |
| "root/multi_rnn_cell/cell_1/basic_lstm_cell/%s:0" % |
| rnn_cell_impl._WEIGHTS_VARIABLE_NAME, |
| "root/multi_rnn_cell/cell_1/basic_lstm_cell/%s:0" % |
| rnn_cell_impl._BIAS_VARIABLE_NAME |
| ] |
| self.assertEqual(expected_variable_names, |
| [v.name for v in cell.trainable_variables]) |
| self.assertFalse(cell.non_trainable_variables) |
| sess.run([variables_lib.global_variables_initializer()]) |
| res = sess.run([g, out_m], { |
| x: np.array([[1., 1.]]), |
| m: 0.1 * np.ones([1, 8]) |
| }) |
| self.assertEqual(len(res), 2) |
| variables = variables_lib.global_variables() |
| self.assertEqual(expected_variable_names, [v.name for v in variables]) |
| # The numbers in results were not calculated, this is just a |
| # smoke test. |
| self.assertAllClose(res[0], np.array( |
| [[0.240, 0.240]], dtype=np_dtype), 1e-2) |
| expected_mem = np.array( |
| [[0.689, 0.689, 0.448, 0.448, 0.398, 0.398, 0.240, 0.240]], |
| dtype=np_dtype) |
| self.assertAllClose(res[1], expected_mem, 1e-2) |
| with variable_scope.variable_scope( |
| "other", initializer=init_ops.constant_initializer(0.5)): |
| # Test BasicLSTMCell with input_size != num_units. |
| x = array_ops.zeros([1, 3], dtype=dtype) |
| m = array_ops.zeros([1, 4], dtype=dtype) |
| g, out_m = rnn_cell_impl.BasicLSTMCell(2, state_is_tuple=False)(x, m) |
| sess.run([variables_lib.global_variables_initializer()]) |
| res = sess.run( |
| [g, out_m], { |
| x: np.array([[1., 1., 1.]], dtype=np_dtype), |
| m: 0.1 * np.ones([1, 4], dtype=np_dtype) |
| }) |
| self.assertEqual(len(res), 2) |
| |
| @test_util.run_v1_only("b/124229375") |
| def testBasicLSTMCellDimension0Error(self): |
| """Tests that dimension 0 in both(x and m) shape must be equal.""" |
| with self.cached_session() as sess: |
| with variable_scope.variable_scope( |
| "root", initializer=init_ops.constant_initializer(0.5)): |
| num_units = 2 |
| state_size = num_units * 2 |
| batch_size = 3 |
| input_size = 4 |
| x = array_ops.zeros([batch_size, input_size]) |
| m = array_ops.zeros([batch_size - 1, state_size]) |
| with self.assertRaises(ValueError): |
| g, out_m = rnn_cell_impl.BasicLSTMCell( |
| num_units, state_is_tuple=False)(x, m) |
| sess.run([variables_lib.global_variables_initializer()]) |
| sess.run( |
| [g, out_m], { |
| x: 1 * np.ones([batch_size, input_size]), |
| m: 0.1 * np.ones([batch_size - 1, state_size]) |
| }) |
| |
| def testBasicLSTMCellStateSizeError(self): |
| """Tests that state_size must be num_units * 2.""" |
| with self.cached_session() as sess: |
| with variable_scope.variable_scope( |
| "root", initializer=init_ops.constant_initializer(0.5)): |
| num_units = 2 |
| state_size = num_units * 3 # state_size must be num_units * 2 |
| batch_size = 3 |
| input_size = 4 |
| x = array_ops.zeros([batch_size, input_size]) |
| m = array_ops.zeros([batch_size, state_size]) |
| with self.assertRaises((ValueError, errors_impl.InvalidArgumentError)): |
| g, out_m = rnn_cell_impl.BasicLSTMCell( |
| num_units, state_is_tuple=False)(x, m) |
| sess.run([variables_lib.global_variables_initializer()]) |
| sess.run( |
| [g, out_m], { |
| x: 1 * np.ones([batch_size, input_size]), |
| m: 0.1 * np.ones([batch_size, state_size]) |
| }) |
| |
| @test_util.run_v1_only("b/124229375") |
| def testBasicLSTMCellStateTupleType(self): |
| with self.cached_session(): |
| with variable_scope.variable_scope( |
| "root", initializer=init_ops.constant_initializer(0.5)): |
| x = array_ops.zeros([1, 2]) |
| m0 = (array_ops.zeros([1, 2]),) * 2 |
| m1 = (array_ops.zeros([1, 2]),) * 2 |
| cell = rnn_cell_impl.MultiRNNCell( |
| [rnn_cell_impl.BasicLSTMCell(2) for _ in range(2)], |
| state_is_tuple=True) |
| self.assertTrue(isinstance(cell.state_size, tuple)) |
| self.assertTrue( |
| isinstance(cell.state_size[0], rnn_cell_impl.LSTMStateTuple)) |
| self.assertTrue( |
| isinstance(cell.state_size[1], rnn_cell_impl.LSTMStateTuple)) |
| |
| # Pass in regular tuples |
| _, (out_m0, out_m1) = cell(x, (m0, m1)) |
| self.assertTrue(isinstance(out_m0, rnn_cell_impl.LSTMStateTuple)) |
| self.assertTrue(isinstance(out_m1, rnn_cell_impl.LSTMStateTuple)) |
| |
| # Pass in LSTMStateTuples |
| variable_scope.get_variable_scope().reuse_variables() |
| zero_state = cell.zero_state(1, dtypes.float32) |
| self.assertTrue(isinstance(zero_state, tuple)) |
| self.assertTrue(isinstance(zero_state[0], rnn_cell_impl.LSTMStateTuple)) |
| self.assertTrue(isinstance(zero_state[1], rnn_cell_impl.LSTMStateTuple)) |
| _, (out_m0, out_m1) = cell(x, zero_state) |
| self.assertTrue(isinstance(out_m0, rnn_cell_impl.LSTMStateTuple)) |
| self.assertTrue(isinstance(out_m1, rnn_cell_impl.LSTMStateTuple)) |
| |
| @test_util.run_v1_only("b/124229375") |
| def testBasicLSTMCellWithStateTuple(self): |
| with self.cached_session() as sess: |
| with variable_scope.variable_scope( |
| "root", initializer=init_ops.constant_initializer(0.5)): |
| x = array_ops.zeros([1, 2]) |
| m0 = array_ops.zeros([1, 4]) |
| m1 = array_ops.zeros([1, 4]) |
| cell = rnn_cell_impl.MultiRNNCell( |
| [ |
| rnn_cell_impl.BasicLSTMCell(2, state_is_tuple=False) |
| for _ in range(2) |
| ], |
| state_is_tuple=True) |
| g, (out_m0, out_m1) = cell(x, (m0, m1)) |
| sess.run([variables_lib.global_variables_initializer()]) |
| res = sess.run( |
| [g, out_m0, out_m1], { |
| x: np.array([[1., 1.]]), |
| m0: 0.1 * np.ones([1, 4]), |
| m1: 0.1 * np.ones([1, 4]) |
| }) |
| self.assertEqual(len(res), 3) |
| # The numbers in results were not calculated, this is just a smoke test. |
| # Note, however, these values should match the original |
| # version having state_is_tuple=False. |
| self.assertAllClose(res[0], [[0.24024698, 0.24024698]]) |
| expected_mem0 = np.array( |
| [[0.68967271, 0.68967271, 0.44848421, 0.44848421]]) |
| expected_mem1 = np.array( |
| [[0.39897051, 0.39897051, 0.24024698, 0.24024698]]) |
| self.assertAllClose(res[1], expected_mem0) |
| self.assertAllClose(res[2], expected_mem1) |
| |
| @test_util.run_v1_only("b/124229375") |
| def testLSTMCell(self): |
| with self.cached_session() as sess: |
| num_units = 8 |
| num_proj = 6 |
| state_size = num_units + num_proj |
| batch_size = 3 |
| input_size = 2 |
| with variable_scope.variable_scope( |
| "root", initializer=init_ops.constant_initializer(0.5)): |
| x = array_ops.zeros([batch_size, input_size]) |
| m = array_ops.zeros([batch_size, state_size]) |
| cell = rnn_cell_impl.LSTMCell( |
| num_units=num_units, |
| num_proj=num_proj, |
| forget_bias=1.0, |
| state_is_tuple=False) |
| output, state = cell(x, m) |
| sess.run([variables_lib.global_variables_initializer()]) |
| res = sess.run( |
| [output, state], { |
| x: np.array([[1., 1.], [2., 2.], [3., 3.]]), |
| m: 0.1 * np.ones((batch_size, state_size)) |
| }) |
| self.assertEqual(len(res), 2) |
| # The numbers in results were not calculated, this is mostly just a |
| # smoke test. |
| self.assertEqual(res[0].shape, (batch_size, num_proj)) |
| self.assertEqual(res[1].shape, (batch_size, state_size)) |
| # Different inputs so different outputs and states |
| for i in range(1, batch_size): |
| self.assertTrue( |
| float(np.linalg.norm((res[0][0, :] - res[0][i, :]))) > 1e-6) |
| self.assertTrue( |
| float(np.linalg.norm((res[1][0, :] - res[1][i, :]))) > 1e-6) |
| |
| @test_util.run_v1_only("b/124229375") |
| def testLSTMCellVariables(self): |
| with self.cached_session(): |
| num_units = 8 |
| num_proj = 6 |
| state_size = num_units + num_proj |
| batch_size = 3 |
| input_size = 2 |
| with variable_scope.variable_scope( |
| "root", initializer=init_ops.constant_initializer(0.5)): |
| x = array_ops.zeros([batch_size, input_size]) |
| m = array_ops.zeros([batch_size, state_size]) |
| cell = rnn_cell_impl.LSTMCell( |
| num_units=num_units, |
| num_proj=num_proj, |
| forget_bias=1.0, |
| state_is_tuple=False) |
| cell(x, m) # Execute to create variables |
| variables = variables_lib.global_variables() |
| self.assertEquals(variables[0].op.name, "root/lstm_cell/kernel") |
| self.assertEquals(variables[1].op.name, "root/lstm_cell/bias") |
| self.assertEquals(variables[2].op.name, |
| "root/lstm_cell/projection/kernel") |
| |
| @test_util.run_in_graph_and_eager_modes |
| def testWrapperCheckpointing(self): |
| for wrapper_type in [ |
| rnn_cell_impl.DropoutWrapper, |
| rnn_cell_impl.ResidualWrapper, |
| lambda cell: rnn_cell_impl.MultiRNNCell([cell])]: |
| cell = rnn_cell_impl.BasicRNNCell(1) |
| wrapper = wrapper_type(cell) |
| wrapper(array_ops.ones([1, 1]), |
| state=wrapper.zero_state(batch_size=1, dtype=dtypes.float32)) |
| self.evaluate([v.initializer for v in cell.variables]) |
| checkpoint = trackable_utils.Checkpoint(wrapper=wrapper) |
| prefix = os.path.join(self.get_temp_dir(), "ckpt") |
| self.evaluate(cell._bias.assign([40.])) |
| save_path = checkpoint.save(prefix) |
| self.evaluate(cell._bias.assign([0.])) |
| checkpoint.restore(save_path).assert_consumed().run_restore_ops() |
| self.assertAllEqual([40.], self.evaluate(cell._bias)) |
| |
| @test_util.run_in_graph_and_eager_modes |
| def testResidualWrapper(self): |
| wrapper_type = rnn_cell_impl.ResidualWrapper |
| x = ops.convert_to_tensor(np.array([[1., 1., 1.]])) |
| m = ops.convert_to_tensor(np.array([[0.1, 0.1, 0.1]])) |
| base_cell = rnn_cell_impl.GRUCell( |
| 3, kernel_initializer=init_ops.constant_initializer(0.5), |
| bias_initializer=init_ops.constant_initializer(0.5)) |
| g, m_new = base_cell(x, m) |
| wrapper_object = wrapper_type(base_cell) |
| (name, dep), = wrapper_object._checkpoint_dependencies |
| wrapper_object.get_config() # Should not throw an error |
| self.assertIs(dep, base_cell) |
| self.assertEqual("cell", name) |
| |
| g_res, m_new_res = wrapper_object(x, m) |
| self.evaluate([variables_lib.global_variables_initializer()]) |
| res = self.evaluate([g, g_res, m_new, m_new_res]) |
| # Residual connections |
| self.assertAllClose(res[1], res[0] + [1., 1., 1.]) |
| # States are left untouched |
| self.assertAllClose(res[2], res[3]) |
| |
| @test_util.run_in_graph_and_eager_modes |
| def testResidualWrapperWithSlice(self): |
| wrapper_type = rnn_cell_impl.ResidualWrapper |
| x = ops.convert_to_tensor(np.array([[1., 1., 1., 1., 1.]])) |
| m = ops.convert_to_tensor(np.array([[0.1, 0.1, 0.1]])) |
| base_cell = rnn_cell_impl.GRUCell( |
| 3, kernel_initializer=init_ops.constant_initializer(0.5), |
| bias_initializer=init_ops.constant_initializer(0.5)) |
| g, m_new = base_cell(x, m) |
| |
| def residual_with_slice_fn(inp, out): |
| inp_sliced = array_ops.slice(inp, [0, 0], [-1, 3]) |
| return inp_sliced + out |
| |
| g_res, m_new_res = wrapper_type( |
| base_cell, residual_with_slice_fn)(x, m) |
| self.evaluate([variables_lib.global_variables_initializer()]) |
| res_g, res_g_res, res_m_new, res_m_new_res = self.evaluate( |
| [g, g_res, m_new, m_new_res]) |
| # Residual connections |
| self.assertAllClose(res_g_res, res_g + [1., 1., 1.]) |
| # States are left untouched |
| self.assertAllClose(res_m_new, res_m_new_res) |
| |
| def testDeviceWrapper(self): |
| wrapper_type = rnn_cell_impl.DeviceWrapper |
| x = array_ops.zeros([1, 3]) |
| m = array_ops.zeros([1, 3]) |
| cell = rnn_cell_impl.GRUCell(3) |
| wrapped_cell = wrapper_type(cell, "/cpu:0") |
| (name, dep), = wrapped_cell._checkpoint_dependencies |
| wrapped_cell.get_config() # Should not throw an error |
| self.assertIs(dep, cell) |
| self.assertEqual("cell", name) |
| |
| outputs, _ = wrapped_cell(x, m) |
| self.assertIn("cpu:0", outputs.device.lower()) |
| |
| def _retrieve_cpu_gpu_stats(self, run_metadata): |
| cpu_stats = None |
| gpu_stats = None |
| step_stats = run_metadata.step_stats |
| for ds in step_stats.dev_stats: |
| if "cpu:0" in ds.device[-5:].lower(): |
| cpu_stats = ds.node_stats |
| if "gpu:0" == ds.device[-5:].lower(): |
| gpu_stats = ds.node_stats |
| return cpu_stats, gpu_stats |
| |
| @test_util.run_v1_only("b/124229375") |
| def testDeviceWrapperDynamicExecutionNodesAreAllProperlyLocated(self): |
| if not test.is_gpu_available(): |
| # Can't perform this test w/o a GPU |
| return |
| |
| gpu_dev = test.gpu_device_name() |
| with self.session(use_gpu=True) as sess: |
| with variable_scope.variable_scope( |
| "root", initializer=init_ops.constant_initializer(0.5)): |
| x = array_ops.zeros([1, 1, 3]) |
| cell = rnn_cell_impl.DeviceWrapper(rnn_cell_impl.GRUCell(3), gpu_dev) |
| with ops.device("/cpu:0"): |
| outputs, _ = rnn.dynamic_rnn( |
| cell=cell, inputs=x, dtype=dtypes.float32) |
| run_metadata = config_pb2.RunMetadata() |
| opts = config_pb2.RunOptions( |
| trace_level=config_pb2.RunOptions.FULL_TRACE) |
| |
| sess.run([variables_lib.global_variables_initializer()]) |
| _ = sess.run(outputs, options=opts, run_metadata=run_metadata) |
| |
| cpu_stats, gpu_stats = self._retrieve_cpu_gpu_stats(run_metadata) |
| self.assertFalse([s for s in cpu_stats if "gru_cell" in s.node_name]) |
| self.assertTrue([s for s in gpu_stats if "gru_cell" in s.node_name]) |
| |
| @test_util.run_v1_only("b/124229375") |
| def testMultiRNNCell(self): |
| with self.cached_session() as sess: |
| with variable_scope.variable_scope( |
| "root", initializer=init_ops.constant_initializer(0.5)): |
| x = array_ops.zeros([1, 2]) |
| m = array_ops.zeros([1, 4]) |
| multi_rnn_cell = rnn_cell_impl.MultiRNNCell( |
| [rnn_cell_impl.GRUCell(2) for _ in range(2)], |
| state_is_tuple=False) |
| _, ml = multi_rnn_cell(x, m) |
| sess.run([variables_lib.global_variables_initializer()]) |
| res = sess.run(ml, { |
| x: np.array([[1., 1.]]), |
| m: np.array([[0.1, 0.1, 0.1, 0.1]]) |
| }) |
| # The numbers in results were not calculated, this is just a smoke test. |
| self.assertAllClose(res, [[0.175991, 0.175991, 0.13248, 0.13248]]) |
| self.assertEqual(len(multi_rnn_cell.weights), 2 * 4) |
| self.assertTrue( |
| [x.dtype == dtypes.float32 for x in multi_rnn_cell.weights]) |
| |
| @test_util.run_v1_only("b/124229375") |
| def testMultiRNNCellWithStateTuple(self): |
| with self.cached_session() as sess: |
| with variable_scope.variable_scope( |
| "root", initializer=init_ops.constant_initializer(0.5)): |
| x = array_ops.zeros([1, 2]) |
| m_bad = array_ops.zeros([1, 4]) |
| m_good = (array_ops.zeros([1, 2]), array_ops.zeros([1, 2])) |
| |
| # Test incorrectness of state |
| with self.assertRaisesRegexp(ValueError, "Expected state .* a tuple"): |
| rnn_cell_impl.MultiRNNCell( |
| [rnn_cell_impl.GRUCell(2) for _ in range(2)], |
| state_is_tuple=True)(x, m_bad) |
| |
| _, ml = rnn_cell_impl.MultiRNNCell( |
| [rnn_cell_impl.GRUCell(2) for _ in range(2)], |
| state_is_tuple=True)(x, m_good) |
| |
| sess.run([variables_lib.global_variables_initializer()]) |
| res = sess.run( |
| ml, { |
| x: np.array([[1., 1.]]), |
| m_good[0]: np.array([[0.1, 0.1]]), |
| m_good[1]: np.array([[0.1, 0.1]]) |
| }) |
| |
| # The numbers in results were not calculated, this is just a |
| # smoke test. However, these numbers should match those of |
| # the test testMultiRNNCell. |
| self.assertAllClose(res[0], [[0.175991, 0.175991]]) |
| self.assertAllClose(res[1], [[0.13248, 0.13248]]) |
| |
| def testDeviceWrapperSerialization(self): |
| wrapper_cls = rnn_cell_impl.DeviceWrapper |
| cell = rnn_cell_impl.LSTMCell(10) |
| wrapper = wrapper_cls(cell, "/cpu:0") |
| config = wrapper.get_config() |
| |
| # Replace the cell in the config with real cell instance to work around the |
| # reverse keras dependency issue. |
| config_copy = config.copy() |
| config_copy["cell"] = rnn_cell_impl.LSTMCell.from_config( |
| config_copy["cell"]["config"]) |
| reconstructed_wrapper = wrapper_cls.from_config(config_copy) |
| self.assertDictEqual(config, reconstructed_wrapper.get_config()) |
| self.assertIsInstance(reconstructed_wrapper, wrapper_cls) |
| |
| def testResidualWrapperSerialization(self): |
| wrapper_cls = rnn_cell_impl.ResidualWrapper |
| cell = rnn_cell_impl.LSTMCell(10) |
| wrapper = wrapper_cls(cell) |
| config = wrapper.get_config() |
| |
| # Replace the cell in the config with real cell instance to work around the |
| # reverse keras dependency issue. |
| config_copy = config.copy() |
| config_copy["cell"] = rnn_cell_impl.LSTMCell.from_config( |
| config_copy["cell"]["config"]) |
| reconstructed_wrapper = wrapper_cls.from_config(config_copy) |
| self.assertDictEqual(config, reconstructed_wrapper.get_config()) |
| self.assertIsInstance(reconstructed_wrapper, wrapper_cls) |
| |
| wrapper = wrapper_cls(cell, residual_fn=lambda i, o: i + i + o) |
| config = wrapper.get_config() |
| |
| config_copy = config.copy() |
| config_copy["cell"] = rnn_cell_impl.LSTMCell.from_config( |
| config_copy["cell"]["config"]) |
| reconstructed_wrapper = wrapper_cls.from_config(config_copy) |
| # Assert the reconstructed function will perform the math correctly. |
| self.assertEqual(reconstructed_wrapper._residual_fn(1, 2), 4) |
| |
| def residual_fn(inputs, outputs): |
| return inputs * 3 + outputs |
| |
| wrapper = wrapper_cls(cell, residual_fn=residual_fn) |
| config = wrapper.get_config() |
| |
| config_copy = config.copy() |
| config_copy["cell"] = rnn_cell_impl.LSTMCell.from_config( |
| config_copy["cell"]["config"]) |
| reconstructed_wrapper = wrapper_cls.from_config(config_copy) |
| # Assert the reconstructed function will perform the math correctly. |
| self.assertEqual(reconstructed_wrapper._residual_fn(1, 2), 5) |
| |
| def testDropoutWrapperSerialization(self): |
| wrapper_cls = rnn_cell_impl.DropoutWrapper |
| cell = rnn_cell_impl.LSTMCell(10) |
| wrapper = wrapper_cls(cell) |
| config = wrapper.get_config() |
| |
| config_copy = config.copy() |
| config_copy["cell"] = rnn_cell_impl.LSTMCell.from_config( |
| config_copy["cell"]["config"]) |
| reconstructed_wrapper = wrapper_cls.from_config(config_copy) |
| self.assertDictEqual(config, reconstructed_wrapper.get_config()) |
| self.assertIsInstance(reconstructed_wrapper, wrapper_cls) |
| |
| wrapper = wrapper_cls(cell, dropout_state_filter_visitor=lambda s: True) |
| config = wrapper.get_config() |
| |
| config_copy = config.copy() |
| config_copy["cell"] = rnn_cell_impl.LSTMCell.from_config( |
| config_copy["cell"]["config"]) |
| reconstructed_wrapper = wrapper_cls.from_config(config_copy) |
| self.assertTrue(reconstructed_wrapper._dropout_state_filter(None)) |
| |
| def dropout_state_filter_visitor(unused_state): |
| return False |
| |
| wrapper = wrapper_cls( |
| cell, dropout_state_filter_visitor=dropout_state_filter_visitor) |
| config = wrapper.get_config() |
| |
| config_copy = config.copy() |
| config_copy["cell"] = rnn_cell_impl.LSTMCell.from_config( |
| config_copy["cell"]["config"]) |
| reconstructed_wrapper = wrapper_cls.from_config(config_copy) |
| self.assertFalse(reconstructed_wrapper._dropout_state_filter(None)) |
| |
| |
| @test_util.run_all_in_graph_and_eager_modes |
| class DropoutWrapperTest(test.TestCase, parameterized.TestCase): |
| |
| def _testDropoutWrapper(self, |
| batch_size=None, |
| time_steps=None, |
| parallel_iterations=None, |
| wrapper_type=None, |
| scope="root", |
| **kwargs): |
| if batch_size is None and time_steps is None: |
| # 2 time steps, batch size 1, depth 3 |
| batch_size = 1 |
| time_steps = 2 |
| x = constant_op.constant( |
| [[[2., 2., 2.]], [[1., 1., 1.]]], dtype=dtypes.float32) |
| m = rnn_cell_impl.LSTMStateTuple( |
| *[constant_op.constant([[0.1, 0.1, 0.1]], dtype=dtypes.float32)] * 2) |
| else: |
| x = constant_op.constant( |
| np.random.randn(time_steps, batch_size, 3).astype(np.float32)) |
| m = rnn_cell_impl.LSTMStateTuple(*[ |
| constant_op. |
| constant([[0.1, 0.1, 0.1]] * batch_size, dtype=dtypes.float32)] * 2) |
| outputs, final_state = rnn.dynamic_rnn( |
| cell=wrapper_type( |
| rnn_cell_impl.LSTMCell( |
| 3, initializer=init_ops.constant_initializer(0.5)), |
| dtype=x.dtype, **kwargs), |
| time_major=True, |
| parallel_iterations=parallel_iterations, |
| inputs=x, |
| initial_state=m, |
| scope=scope) |
| self.evaluate([variables_lib.global_variables_initializer()]) |
| res = self.evaluate([outputs, final_state]) |
| self.assertEqual(res[0].shape, (time_steps, batch_size, 3)) |
| self.assertEqual(res[1].c.shape, (batch_size, 3)) |
| self.assertEqual(res[1].h.shape, (batch_size, 3)) |
| return res |
| |
| def testDropoutWrapperProperties(self): |
| wrapper_type = rnn_cell_impl.DropoutWrapper |
| cell = rnn_cell_impl.BasicRNNCell(10) |
| wrapper = wrapper_type(cell) |
| # Github issue 15810 |
| self.assertEqual(wrapper.wrapped_cell, cell) |
| self.assertEqual(wrapper.state_size, 10) |
| self.assertEqual(wrapper.output_size, 10) |
| |
| def testDropoutWrapperZeroState(self): |
| wrapper_type = rnn_cell_impl.DropoutWrapper |
| |
| class _Cell(rnn_cell_impl.BasicRNNCell): |
| |
| def zero_state(self, batch_size=None, dtype=None): |
| return "wrapped_cell_zero_state" |
| wrapper = wrapper_type(_Cell(10)) |
| self.assertEqual(wrapper.zero_state(10, dtypes.float32), |
| "wrapped_cell_zero_state") |
| |
| def testDropoutWrapperKeepAllConstantInput(self): |
| wrapper_type = rnn_cell_impl.DropoutWrapper |
| keep = array_ops.ones([]) |
| res = self._testDropoutWrapper( |
| input_keep_prob=keep, output_keep_prob=keep, state_keep_prob=keep, |
| wrapper_type=wrapper_type) |
| true_full_output = np.array( |
| [[[0.751109, 0.751109, 0.751109]], [[0.895509, 0.895509, 0.895509]]], |
| dtype=np.float32) |
| true_full_final_c = np.array( |
| [[1.949385, 1.949385, 1.949385]], dtype=np.float32) |
| self.assertAllClose(true_full_output, res[0]) |
| self.assertAllClose(true_full_output[1], res[1].h) |
| self.assertAllClose(true_full_final_c, res[1].c) |
| |
| def testDropoutWrapperKeepAll(self): |
| wrapper_type = rnn_cell_impl.DropoutWrapper |
| keep = variable_scope.get_variable("all", initializer=1.0) |
| res = self._testDropoutWrapper( |
| input_keep_prob=keep, output_keep_prob=keep, state_keep_prob=keep, |
| wrapper_type=wrapper_type) |
| true_full_output = np.array( |
| [[[0.751109, 0.751109, 0.751109]], [[0.895509, 0.895509, 0.895509]]], |
| dtype=np.float32) |
| true_full_final_c = np.array( |
| [[1.949385, 1.949385, 1.949385]], dtype=np.float32) |
| self.assertAllClose(true_full_output, res[0]) |
| self.assertAllClose(true_full_output[1], res[1].h) |
| self.assertAllClose(true_full_final_c, res[1].c) |
| |
| def testDropoutWrapperWithSeed(self): |
| wrapper_type = rnn_cell_impl.DropoutWrapper |
| keep_some = 0.5 |
| random_seed.set_random_seed(2) |
| ## Use parallel_iterations = 1 in both calls to |
| ## _testDropoutWrapper to ensure the (per-time step) dropout is |
| ## consistent across both calls. Otherwise the seed may not end |
| ## up being munged consistently across both graphs. |
| res_standard_1 = self._testDropoutWrapper( |
| input_keep_prob=keep_some, |
| output_keep_prob=keep_some, |
| state_keep_prob=keep_some, |
| seed=10, |
| parallel_iterations=1, |
| wrapper_type=wrapper_type, |
| scope="root_1") |
| random_seed.set_random_seed(2) |
| res_standard_2 = self._testDropoutWrapper( |
| input_keep_prob=keep_some, |
| output_keep_prob=keep_some, |
| state_keep_prob=keep_some, |
| seed=10, |
| parallel_iterations=1, |
| wrapper_type=wrapper_type, |
| scope="root_2") |
| self.assertAllClose(res_standard_1[0], res_standard_2[0]) |
| self.assertAllClose(res_standard_1[1].c, res_standard_2[1].c) |
| self.assertAllClose(res_standard_1[1].h, res_standard_2[1].h) |
| |
| def testDropoutWrapperKeepNoOutput(self): |
| wrapper_type = rnn_cell_impl.DropoutWrapper |
| keep_all = variable_scope.get_variable("all", initializer=1.0) |
| keep_none = variable_scope.get_variable("none", initializer=1e-6) |
| res = self._testDropoutWrapper( |
| input_keep_prob=keep_all, |
| output_keep_prob=keep_none, |
| state_keep_prob=keep_all, |
| wrapper_type=wrapper_type) |
| true_full_output = np.array( |
| [[[0.751109, 0.751109, 0.751109]], [[0.895509, 0.895509, 0.895509]]], |
| dtype=np.float32) |
| true_full_final_c = np.array( |
| [[1.949385, 1.949385, 1.949385]], dtype=np.float32) |
| self.assertAllClose(np.zeros(res[0].shape), res[0]) |
| self.assertAllClose(true_full_output[1], res[1].h) |
| self.assertAllClose(true_full_final_c, res[1].c) |
| |
| def testDropoutWrapperKeepNoStateExceptLSTMCellMemory(self): |
| wrapper_type = rnn_cell_impl.DropoutWrapper |
| keep_all = variable_scope.get_variable("all", initializer=1.0) |
| keep_none = variable_scope.get_variable("none", initializer=1e-6) |
| # Even though we dropout state, by default DropoutWrapper never |
| # drops out the memory ("c") term of an LSTMStateTuple. |
| res = self._testDropoutWrapper( |
| input_keep_prob=keep_all, |
| output_keep_prob=keep_all, |
| state_keep_prob=keep_none, |
| wrapper_type=wrapper_type) |
| true_c_state = np.array([[1.713925, 1.713925, 1.713925]], dtype=np.float32) |
| true_full_output = np.array( |
| [[[0.751109, 0.751109, 0.751109]], [[0.895509, 0.895509, 0.895509]]], |
| dtype=np.float32) |
| self.assertAllClose(true_full_output[0], res[0][0]) |
| # Second output is modified by zero input state |
| self.assertGreater(np.linalg.norm(true_full_output[1] - res[0][1]), 1e-4) |
| # h state has been set to zero |
| self.assertAllClose(np.zeros(res[1].h.shape), res[1].h) |
| # c state of an LSTMStateTuple is NEVER modified. |
| self.assertAllClose(true_c_state, res[1].c) |
| |
| def testDropoutWrapperKeepNoInput(self): |
| wrapper_type = rnn_cell_impl.DropoutWrapper |
| keep_all = variable_scope.get_variable("all", initializer=1.0) |
| keep_none = variable_scope.get_variable("none", initializer=1e-6) |
| true_full_output = np.array( |
| [[[0.751109, 0.751109, 0.751109]], [[0.895509, 0.895509, 0.895509]]], |
| dtype=np.float32) |
| true_full_final_c = np.array( |
| [[1.949385, 1.949385, 1.949385]], dtype=np.float32) |
| # All outputs are different because inputs are zeroed out |
| res = self._testDropoutWrapper( |
| input_keep_prob=keep_none, |
| output_keep_prob=keep_all, |
| state_keep_prob=keep_all, |
| wrapper_type=wrapper_type) |
| self.assertGreater(np.linalg.norm(res[0] - true_full_output), 1e-4) |
| self.assertGreater(np.linalg.norm(res[1].h - true_full_output[1]), 1e-4) |
| self.assertGreater(np.linalg.norm(res[1].c - true_full_final_c), 1e-4) |
| |
| def testDropoutWrapperRecurrentOutput(self): |
| wrapper_type = rnn_cell_impl.DropoutWrapper |
| keep_some = 0.8 |
| keep_all = variable_scope.get_variable("all", initializer=1.0) |
| res = self._testDropoutWrapper( |
| input_keep_prob=keep_all, |
| output_keep_prob=keep_some, |
| state_keep_prob=keep_all, |
| variational_recurrent=True, |
| wrapper_type=wrapper_type, |
| input_size=3, |
| batch_size=5, |
| time_steps=7) |
| # Ensure the same dropout pattern for all time steps |
| output_mask = np.abs(res[0]) > 1e-6 |
| for m in output_mask[1:]: |
| self.assertAllClose(output_mask[0], m) |
| |
| def testDropoutWrapperRecurrentStateInputAndOutput(self): |
| wrapper_type = rnn_cell_impl.DropoutWrapper |
| keep_some = 0.9 |
| res = self._testDropoutWrapper( |
| input_keep_prob=keep_some, |
| output_keep_prob=keep_some, |
| state_keep_prob=keep_some, |
| variational_recurrent=True, |
| wrapper_type=wrapper_type, |
| input_size=3, |
| batch_size=5, |
| time_steps=7) |
| |
| # Smoke test for the state/input masks. |
| output_mask = np.abs(res[0]) > 1e-6 |
| for time_step in output_mask: |
| # Ensure the same dropout output pattern for all time steps |
| self.assertAllClose(output_mask[0], time_step) |
| for batch_entry in time_step: |
| # Assert all batch entries get the same mask |
| self.assertAllClose(batch_entry, time_step[0]) |
| |
| # For state, ensure all batch entries have the same mask |
| state_c_mask = np.abs(res[1].c) > 1e-6 |
| state_h_mask = np.abs(res[1].h) > 1e-6 |
| for batch_entry in state_c_mask: |
| self.assertAllClose(batch_entry, state_c_mask[0]) |
| for batch_entry in state_h_mask: |
| self.assertAllClose(batch_entry, state_h_mask[0]) |
| |
| def testDropoutWrapperRecurrentStateInputAndOutputWithSeed(self): |
| wrapper_type = rnn_cell_impl.DropoutWrapper |
| keep_some = 0.9 |
| random_seed.set_random_seed(2347) |
| np.random.seed(23487) |
| res0 = self._testDropoutWrapper( |
| input_keep_prob=keep_some, |
| output_keep_prob=keep_some, |
| state_keep_prob=keep_some, |
| variational_recurrent=True, |
| wrapper_type=wrapper_type, |
| input_size=3, |
| batch_size=5, |
| time_steps=7, |
| seed=-234987, |
| scope="root_0") |
| random_seed.set_random_seed(2347) |
| np.random.seed(23487) |
| res1 = self._testDropoutWrapper( |
| input_keep_prob=keep_some, |
| output_keep_prob=keep_some, |
| state_keep_prob=keep_some, |
| variational_recurrent=True, |
| wrapper_type=wrapper_type, |
| input_size=3, |
| batch_size=5, |
| time_steps=7, |
| seed=-234987, |
| scope="root_1") |
| |
| output_mask = np.abs(res0[0]) > 1e-6 |
| for time_step in output_mask: |
| # Ensure the same dropout output pattern for all time steps |
| self.assertAllClose(output_mask[0], time_step) |
| for batch_entry in time_step: |
| # Assert all batch entries get the same mask |
| self.assertAllClose(batch_entry, time_step[0]) |
| |
| # For state, ensure all batch entries have the same mask |
| state_c_mask = np.abs(res0[1].c) > 1e-6 |
| state_h_mask = np.abs(res0[1].h) > 1e-6 |
| for batch_entry in state_c_mask: |
| self.assertAllClose(batch_entry, state_c_mask[0]) |
| for batch_entry in state_h_mask: |
| self.assertAllClose(batch_entry, state_h_mask[0]) |
| |
| # Ensure seeded calculation is identical. |
| self.assertAllClose(res0[0], res1[0]) |
| self.assertAllClose(res0[1].c, res1[1].c) |
| self.assertAllClose(res0[1].h, res1[1].h) |
| |
| |
| if __name__ == "__main__": |
| test.main() |