| # Copyright 2015 The TensorFlow Authors. All Rights Reserved. |
| # |
| # Licensed under the Apache License, Version 2.0 (the "License"); |
| # you may not use this file except in compliance with the License. |
| # You may obtain a copy of the License at |
| # |
| # http://www.apache.org/licenses/LICENSE-2.0 |
| # |
| # Unless required by applicable law or agreed to in writing, software |
| # distributed under the License is distributed on an "AS IS" BASIS, |
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
| # See the License for the specific language governing permissions and |
| # limitations under the License. |
| # ============================================================================== |
| """Tests for tensorflow.ops.reverse_sequence_op.""" |
| |
| from __future__ import absolute_import |
| from __future__ import division |
| from __future__ import print_function |
| |
| import numpy as np |
| |
| from tensorflow.compiler.tests import xla_test |
| from tensorflow.python.framework import dtypes |
| from tensorflow.python.ops import array_ops |
| from tensorflow.python.platform import test |
| |
| |
| class ReverseSequenceTest(xla_test.XLATestCase): |
| |
| def _testReverseSequence(self, |
| x, |
| batch_axis, |
| seq_axis, |
| seq_lengths, |
| truth, |
| expected_err_re=None): |
| with self.cached_session(): |
| p = array_ops.placeholder(dtypes.as_dtype(x.dtype)) |
| lengths = array_ops.placeholder(dtypes.as_dtype(seq_lengths.dtype)) |
| with self.test_scope(): |
| ans = array_ops.reverse_sequence( |
| p, batch_axis=batch_axis, seq_axis=seq_axis, seq_lengths=lengths) |
| if expected_err_re is None: |
| tf_ans = ans.eval(feed_dict={p: x, lengths: seq_lengths}) |
| self.assertAllClose(tf_ans, truth, atol=1e-10) |
| else: |
| with self.assertRaisesOpError(expected_err_re): |
| ans.eval(feed_dict={p: x, lengths: seq_lengths}) |
| |
| def testSimple(self): |
| x = np.array([[1, 2, 3], [4, 5, 6], [7, 8, 9]], dtype=np.int32) |
| expected = np.array([[1, 2, 3], [6, 5, 4], [8, 7, 9]], dtype=np.int32) |
| self._testReverseSequence( |
| x, |
| batch_axis=0, |
| seq_axis=1, |
| seq_lengths=np.array([1, 3, 2], np.int32), |
| truth=expected) |
| |
| def _testBasic(self, dtype, len_dtype): |
| x = np.asarray( |
| [[[1, 2, 3, 4], [5, 6, 7, 8]], [[9, 10, 11, 12], [13, 14, 15, 16]], |
| [[17, 18, 19, 20], [21, 22, 23, 24]]], |
| dtype=dtype) |
| x = x.reshape(3, 2, 4, 1, 1) |
| x = x.transpose([2, 1, 0, 3, 4]) # permute axes 0 <=> 2 |
| |
| # reverse dim 2 up to (0:3, none, 0:4) along dim=0 |
| seq_lengths = np.asarray([3, 0, 4], dtype=len_dtype) |
| |
| truth_orig = np.asarray( |
| [ |
| [[3, 2, 1, 4], [7, 6, 5, 8]], # reverse 0:3 |
| [[9, 10, 11, 12], [13, 14, 15, 16]], # reverse none |
| [[20, 19, 18, 17], [24, 23, 22, 21]] |
| ], # reverse 0:4 (all) |
| dtype=dtype) |
| truth_orig = truth_orig.reshape(3, 2, 4, 1, 1) |
| truth = truth_orig.transpose([2, 1, 0, 3, 4]) # permute axes 0 <=> 2 |
| |
| seq_axis = 0 # permute seq_axis and batch_axis (originally 2 and 0, resp.) |
| batch_axis = 2 |
| self._testReverseSequence(x, batch_axis, seq_axis, seq_lengths, truth) |
| |
| def testSeqLength(self): |
| for dtype in self.all_types: |
| for seq_dtype in self.int_types: |
| self._testBasic(dtype, seq_dtype) |
| |
| |
| if __name__ == "__main__": |
| test.main() |