blob: 60c2337743b44e9bad61c4d65280eb2b1a1ad9ea [file] [log] [blame]
# Copyright 2015 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Tests for tensorflow.ops.reverse_sequence_op."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import numpy as np
from tensorflow.compiler.tests import xla_test
from tensorflow.python.framework import dtypes
from tensorflow.python.ops import array_ops
from tensorflow.python.platform import test
class ReverseSequenceTest(xla_test.XLATestCase):
def _testReverseSequence(self,
x,
batch_axis,
seq_axis,
seq_lengths,
truth,
expected_err_re=None):
with self.cached_session():
p = array_ops.placeholder(dtypes.as_dtype(x.dtype))
lengths = array_ops.placeholder(dtypes.as_dtype(seq_lengths.dtype))
with self.test_scope():
ans = array_ops.reverse_sequence(
p, batch_axis=batch_axis, seq_axis=seq_axis, seq_lengths=lengths)
if expected_err_re is None:
tf_ans = ans.eval(feed_dict={p: x, lengths: seq_lengths})
self.assertAllClose(tf_ans, truth, atol=1e-10)
else:
with self.assertRaisesOpError(expected_err_re):
ans.eval(feed_dict={p: x, lengths: seq_lengths})
def testSimple(self):
x = np.array([[1, 2, 3], [4, 5, 6], [7, 8, 9]], dtype=np.int32)
expected = np.array([[1, 2, 3], [6, 5, 4], [8, 7, 9]], dtype=np.int32)
self._testReverseSequence(
x,
batch_axis=0,
seq_axis=1,
seq_lengths=np.array([1, 3, 2], np.int32),
truth=expected)
def _testBasic(self, dtype, len_dtype):
x = np.asarray(
[[[1, 2, 3, 4], [5, 6, 7, 8]], [[9, 10, 11, 12], [13, 14, 15, 16]],
[[17, 18, 19, 20], [21, 22, 23, 24]]],
dtype=dtype)
x = x.reshape(3, 2, 4, 1, 1)
x = x.transpose([2, 1, 0, 3, 4]) # permute axes 0 <=> 2
# reverse dim 2 up to (0:3, none, 0:4) along dim=0
seq_lengths = np.asarray([3, 0, 4], dtype=len_dtype)
truth_orig = np.asarray(
[
[[3, 2, 1, 4], [7, 6, 5, 8]], # reverse 0:3
[[9, 10, 11, 12], [13, 14, 15, 16]], # reverse none
[[20, 19, 18, 17], [24, 23, 22, 21]]
], # reverse 0:4 (all)
dtype=dtype)
truth_orig = truth_orig.reshape(3, 2, 4, 1, 1)
truth = truth_orig.transpose([2, 1, 0, 3, 4]) # permute axes 0 <=> 2
seq_axis = 0 # permute seq_axis and batch_axis (originally 2 and 0, resp.)
batch_axis = 2
self._testReverseSequence(x, batch_axis, seq_axis, seq_lengths, truth)
def testSeqLength(self):
for dtype in self.all_types:
for seq_dtype in self.int_types:
self._testBasic(dtype, seq_dtype)
if __name__ == "__main__":
test.main()