blob: e706a281ad7eb5a00af7c4b4da2cd70bd983c866 [file] [log] [blame]
# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Tests for py_builtins module."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import sys
import six
from tensorflow.python.autograph.core import converter
from tensorflow.python.autograph.core import function_wrappers
from tensorflow.python.autograph.operators import data_structures
from tensorflow.python.autograph.operators import py_builtins
from tensorflow.python.data.ops import dataset_ops
from tensorflow.python.framework import constant_op
from tensorflow.python.framework import dtypes
from tensorflow.python.framework import errors_impl
from tensorflow.python.framework import test_util
from tensorflow.python.ops import array_ops
from tensorflow.python.ops import tensor_array_ops
from tensorflow.python.platform import test
class TestBase(object):
def plus_twenty(self, x):
return x + 20
class PyBuiltinsTest(test.TestCase):
def test_abs(self):
self.assertEqual(py_builtins.abs_(-1), 1)
with self.cached_session() as sess:
t = py_builtins.abs_(constant_op.constant(-1))
self.assertEqual(self.evaluate(t), 1)
t = py_builtins.abs_(constant_op.constant([-1, 2, -3]))
self.assertAllEqual(self.evaluate(t), [1, 2, 3])
def test_float(self):
self.assertEqual(py_builtins.float_(10), 10.0)
self.assertEqual(py_builtins.float_('10.0'), 10.0)
with self.cached_session() as sess:
t = py_builtins.float_(constant_op.constant(1, dtype=dtypes.int64))
self.assertEqual(self.evaluate(t), 1.0)
st = py_builtins.float_(constant_op.constant('1.0'))
self.assertEqual(self.evaluate(st), 1.0)
def test_int(self):
self.assertEqual(py_builtins.int_(10.0), 10)
self.assertEqual(py_builtins.int_('11', 2), 3)
with self.cached_session() as sess:
t = py_builtins.int_(constant_op.constant(1, dtype=dtypes.float64))
self.assertEqual(self.evaluate(t), 1)
st = py_builtins.int_(constant_op.constant('1'))
self.assertEqual(self.evaluate(st), 1)
st = py_builtins.int_(constant_op.constant('1'), 10)
self.assertEqual(self.evaluate(st), 1)
def test_int_unsupported_base(self):
t = constant_op.constant(1, dtype=dtypes.float64)
with self.assertRaises(NotImplementedError):
py_builtins.int_(t, 2)
def test_len(self):
self.assertEqual(py_builtins.len_([1, 2, 3]), 3)
with self.cached_session() as sess:
t = py_builtins.len_(constant_op.constant([[1], [2], [3]]))
self.assertEqual(t, 3)
ta = py_builtins.len_(tensor_array_ops.TensorArray(dtypes.int32, size=5))
self.assertEqual(self.evaluate(ta), 5)
tl = py_builtins.len_(data_structures.tf_tensor_list_new([3, 4, 5]))
self.assertEqual(self.evaluate(tl), 3)
def test_len_scalar(self):
with self.assertRaises(ValueError):
py_builtins.len_(constant_op.constant(1))
@test_util.run_deprecated_v1
def test_len_dynamic_shape(self):
with self.cached_session() as sess:
p = array_ops.placeholder(dtype=dtypes.int32, shape=None)
t = py_builtins.len_(p)
self.assertEqual(sess.run(t, {p: [1, 2, 3]}), 3)
with self.assertRaises(errors_impl.InvalidArgumentError):
t = py_builtins.len_(p)
sess.run(t, {p: 1})
@test_util.run_deprecated_v1
def test_print_tensors(self):
try:
out_capturer = six.StringIO()
sys.stdout = out_capturer
with self.cached_session() as sess:
sess.run(py_builtins.print_(constant_op.constant('test message'), 1))
self.assertEqual(out_capturer.getvalue(), 'test message 1\n')
finally:
sys.stdout = sys.__stdout__
@test_util.run_deprecated_v1
def test_print_complex(self):
try:
out_capturer = six.StringIO()
sys.stdout = out_capturer
with self.cached_session() as sess:
sess.run(
py_builtins.print_(constant_op.constant('test message'), [1, 2]))
self.assertEqual(out_capturer.getvalue(), 'test message [1, 2]\n')
finally:
sys.stdout = sys.__stdout__
def test_range(self):
self.assertListEqual(list(py_builtins.range_(3)), [0, 1, 2])
self.assertListEqual(list(py_builtins.range_(1, 3)), [1, 2])
self.assertListEqual(list(py_builtins.range_(2, 0, -1)), [2, 1])
def test_range_tensor(self):
with self.cached_session() as sess:
r = py_builtins.range_(constant_op.constant(3))
self.assertAllEqual(self.evaluate(r), [0, 1, 2])
r = py_builtins.range_(1, constant_op.constant(3))
self.assertAllEqual(self.evaluate(r), [1, 2])
r = py_builtins.range_(2, 0, constant_op.constant(-1))
self.assertAllEqual(self.evaluate(r), [2, 1])
def test_range_tensor_empty_range(self):
with self.session() as sess:
r = py_builtins.range_(constant_op.constant(-3))
self.assertAllEqual(self.evaluate(r), [])
r = py_builtins.range_(5, constant_op.constant(2))
self.assertAllEqual(self.evaluate(r), [])
def test_enumerate(self):
self.assertListEqual(
list(py_builtins.enumerate_([3, 2, 1])), [(0, 3), (1, 2), (2, 1)])
self.assertListEqual(
list(py_builtins.enumerate_([3, 2, 1], 5)), [(5, 3), (6, 2), (7, 1)])
self.assertListEqual(list(py_builtins.enumerate_([-8], -3)), [(-3, -8)])
def test_enumerate_dataset(self):
dataset = dataset_ops.DatasetV2.from_tensor_slices(['a', 'c'])
start = constant_op.constant(20, dtype=dtypes.int64)
dataset = py_builtins.enumerate_(dataset, start)
iterator = dataset_ops.make_one_shot_iterator(dataset)
with self.cached_session() as sess:
self.assertAllEqual(self.evaluate(iterator.get_next()), (20, b'a'))
self.assertAllEqual(self.evaluate(iterator.get_next()), (21, b'c'))
def _basic_function_scope(self):
return function_wrappers.FunctionScope(
'test_function_name',
'test_scope', # Note: this must match the name in the `with` statement.
converter.ConversionOptions())
def test_eval_in_original_context(self):
def test_fn():
l = 1 # pylint:disable=unused-variable
with self._basic_function_scope() as test_scope:
return py_builtins.eval_in_original_context(eval, ('l',), test_scope)
self.assertEqual(test_fn(), 1)
def test_eval_in_original_context_inner_function(self):
def test_fn():
l = 1 # pylint:disable=unused-variable
with self._basic_function_scope() as test_scope:
def inner_fn():
# Note: a user function without a top-level function scope should
# never be found in user code; it's only possible in generated code.
l = 2 # pylint:disable=unused-variable
return py_builtins.eval_in_original_context(eval, ('l',), test_scope)
return inner_fn()
self.assertEqual(test_fn(), 2)
def test_super_in_original_context_unary_call(self):
test_case_self = self
class TestSubclass(TestBase):
def plus_twenty(self, x):
test_case_self.fail('This should never be called.')
def test_method(self):
with test_case_self._basic_function_scope() as test_scope:
test_base_unbound = py_builtins.super_in_original_context(
super, (TestSubclass,), test_scope)
test_base = test_base_unbound.__get__(self, TestSubclass)
return test_base.plus_twenty(1)
tc = TestSubclass()
self.assertEqual(tc.test_method(), 21)
def test_super_in_original_context_binary_call(self):
test_case_self = self
class TestSubclass(TestBase):
def plus_twenty(self, x):
test_case_self.fail('This should never be called.')
def test_method(self):
with test_case_self._basic_function_scope() as test_scope:
test_base = py_builtins.super_in_original_context(
super, (TestSubclass, self), test_scope)
return test_base.plus_twenty(1)
tc = TestSubclass()
self.assertEqual(tc.test_method(), 21)
if __name__ == '__main__':
test.main()