blob: fe2e9b20361bad3093879d7286653c6a4884741f [file] [log] [blame]
# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Tests for anf module."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import textwrap
import gast
from tensorflow.python.autograph.pyct import compiler
from tensorflow.python.autograph.pyct import parser
from tensorflow.python.autograph.pyct import transformer
from tensorflow.python.autograph.pyct.common_transformers import anf
from tensorflow.python.platform import test
class DummyGensym(object):
"""A dumb gensym that suffixes a stem by sequential numbers from 1000."""
def __init__(self, ctx):
del ctx
# A proper implementation needs to account for:
# * ctx.info.namespace
# * all the symbols defined in the AST
# * the symbols generated so far
self._idx = 0
def new_name(self, stem='tmp'):
self._idx += 1
return stem + '_' + str(1000 + self._idx)
# These two test functions have to be top-level, not nested, for compatibility
# with some unknown version of Python 2.7 preceding 2.7.15. Why? Because
# `exec` and nested function definitions _incomaptibly_ change the
# representation of local variables, such that `exec` inside a nested function
# definition is a syntax error in that version. The tuple form of `exec` fixes
# this problem, but apparently that was introduced in some unknown version of
# Python that's more recent than at least one version that we wish to be
# compatible with.
def exec_test_function():
# The point is to test A-normal form conversion of exec
# pylint: disable=exec-used
exec('computed' + 5 + 'stuff', globals(), locals())
def exec_expected_result():
# pylint: disable=exec-used
tmp_1001 = 'computed' + 5
tmp_1002 = tmp_1001 + 'stuff'
tmp_1003 = globals()
tmp_1004 = locals()
exec(tmp_1002, tmp_1003, tmp_1004)
class AnfTestBase(test.TestCase):
def _simple_context(self):
entity_info = transformer.EntityInfo(
source_code=None, source_file=None, future_features=(), namespace=None)
return transformer.Context(entity_info)
def assert_same_ast(self, expected_node, node, msg=None):
expected_source = compiler.ast_to_source(expected_node, indentation=' ')
expected_str = textwrap.dedent(expected_source).strip()
got_source = compiler.ast_to_source(node, indentation=' ')
got_str = textwrap.dedent(got_source).strip()
self.assertEqual(expected_str, got_str, msg=msg)
def assert_body_anfs_as_expected(self, expected_fn, test_fn, config=None):
# Testing the code bodies only. Wrapping them in functions so the
# syntax highlights nicely, but Python doesn't try to execute the
# statements.
exp_node, _ = parser.parse_entity(expected_fn, future_features=())
node, _ = parser.parse_entity(test_fn, future_features=())
node = anf.transform(
node, self._simple_context(),
config=config, gensym_source=DummyGensym)
exp_name = exp_node.name
# Ignoring the function names in the result because they can't be
# the same (because both functions have to exist in the same scope
# at the same time).
node.name = exp_name
self.assert_same_ast(exp_node, node)
# Check that ANF is idempotent
node_repeated = anf.transform(
node, self._simple_context(), gensym_source=DummyGensym)
self.assert_same_ast(node_repeated, node)
class AnfTransformerTest(AnfTestBase):
def test_basic(self):
def test_function():
a = 0
return a
node, _ = parser.parse_entity(test_function, future_features=())
node = anf.transform(node, self._simple_context())
result, _, _ = compiler.ast_to_object(node)
self.assertEqual(test_function(), result.test_function())
def test_binop_basic(self):
def test_function(x, y, z):
a = x + y + z
return a
def expected_result(x, y, z):
tmp_1001 = x + y
a = tmp_1001 + z
return a
self.assert_body_anfs_as_expected(expected_result, test_function)
def test_if_basic(self):
def test_function(a, b, c, e, f, g):
if a + b + c:
d = e + f + g
return d
def expected_result(a, b, c, e, f, g):
tmp_1001 = a + b
tmp_1002 = tmp_1001 + c
if tmp_1002:
tmp_1003 = e + f
d = tmp_1003 + g
return d
self.assert_body_anfs_as_expected(expected_result, test_function)
def test_nested_binop_and_return(self):
def test_function(b, c, d, e):
return (2 * b + c) + (d + e)
def expected_result(b, c, d, e):
tmp_1001 = 2 * b
tmp_1002 = tmp_1001 + c
tmp_1003 = d + e
tmp_1004 = tmp_1002 + tmp_1003
return tmp_1004
self.assert_body_anfs_as_expected(expected_result, test_function)
def test_function_call_and_expr(self):
def test_function(call_something, a, b, y, z, c, d, e, f, g, h, i):
call_something(a + b, y * z, kwarg=c + d, *(e + f), **(g + h + i))
def expected_result(call_something, a, b, y, z, c, d, e, f, g, h, i):
tmp_1001 = g + h
tmp_1002 = a + b
tmp_1003 = y * z
tmp_1004 = e + f
tmp_1005 = c + d
tmp_1006 = tmp_1001 + i
call_something(tmp_1002, tmp_1003, kwarg=tmp_1005, *tmp_1004, **tmp_1006)
self.assert_body_anfs_as_expected(expected_result, test_function)
def test_with_and_print(self):
def test_function(a, b, c):
with a + b + c as d:
print(2 * d + 1)
def expected_result(a, b, c):
tmp_1001 = a + b
tmp_1002 = tmp_1001 + c
with tmp_1002 as d:
tmp_1003 = 2 * d
tmp_1004 = tmp_1003 + 1
print(tmp_1004)
self.assert_body_anfs_as_expected(expected_result, test_function)
def test_nested_multi_value_assign(self):
def test_function(a, b, c):
x, y = a, a + b
(z, y), x = (c, y + b), x + a
return z, (y, x)
def expected_result(a, b, c):
tmp_1001 = a + b
x, y = a, tmp_1001
tmp_1002 = y + b
tmp_1003 = (c, tmp_1002)
tmp_1004 = x + a
(z, y), x = tmp_1003, tmp_1004
tmp_1005 = y, x
tmp_1006 = z, tmp_1005
return tmp_1006
self.assert_body_anfs_as_expected(expected_result, test_function)
def test_deeply_nested_multi_value_assign(self):
def test_function(a):
[([(b, c), [d, e]], (f, g)), [(h, i, j), k]] = a
return [([(b, c), [d, e]], (f, g)), [(h, i, j), k]]
def expected_result(a):
[([(b, c), [d, e]], (f, g)), [(h, i, j), k]] = a
tmp_1001 = b, c
tmp_1002 = [d, e]
tmp_1003 = [tmp_1001, tmp_1002]
tmp_1004 = f, g
tmp_1005 = h, i, j
tmp_1006 = tmp_1003, tmp_1004
tmp_1007 = [tmp_1005, k]
tmp_1008 = [tmp_1006, tmp_1007]
return tmp_1008
self.assert_body_anfs_as_expected(expected_result, test_function)
def test_local_definition_and_binary_compare(self):
def test_function():
def foo(a, b):
return 2 * a < b
return foo
def expected_result():
def foo(a, b):
tmp_1001 = 2 * a
tmp_1002 = tmp_1001 < b
return tmp_1002
return foo
self.assert_body_anfs_as_expected(expected_result, test_function)
def test_list_literal(self):
def test_function(a, b, c, d, e, f):
return [a + b, c + d, e + f]
def expected_result(a, b, c, d, e, f):
tmp_1001 = a + b
tmp_1002 = c + d
tmp_1003 = e + f
tmp_1004 = [tmp_1001, tmp_1002, tmp_1003]
return tmp_1004
self.assert_body_anfs_as_expected(expected_result, test_function)
def test_tuple_literal_and_unary(self):
def test_function(a, b, c, d, e, f):
return (a + b, -(c + d), e + f)
def expected_result(a, b, c, d, e, f):
tmp_1001 = c + d
tmp_1002 = a + b
tmp_1003 = -tmp_1001
tmp_1004 = e + f
tmp_1005 = (tmp_1002, tmp_1003, tmp_1004)
return tmp_1005
self.assert_body_anfs_as_expected(expected_result, test_function)
def test_set_literal(self):
def test_function(a, b, c, d, e, f):
return set(a + b, c + d, e + f)
def expected_result(a, b, c, d, e, f):
tmp_1001 = a + b
tmp_1002 = c + d
tmp_1003 = e + f
tmp_1004 = set(tmp_1001, tmp_1002, tmp_1003)
return tmp_1004
self.assert_body_anfs_as_expected(expected_result, test_function)
def test_dict_literal_and_repr(self):
def test_function(foo, bar, baz):
return repr({foo + bar + baz: 7 | 8})
def expected_result(foo, bar, baz):
tmp_1001 = foo + bar
tmp_1002 = tmp_1001 + baz
tmp_1003 = 7 | 8
tmp_1004 = {tmp_1002: tmp_1003}
tmp_1005 = repr(tmp_1004)
return tmp_1005
self.assert_body_anfs_as_expected(expected_result, test_function)
def test_field_read_and_write(self):
def test_function(a, d):
a.b.c = d.e.f + 3
def expected_result(a, d):
tmp_1001 = a.b
tmp_1002 = d.e
tmp_1003 = tmp_1002.f
tmp_1001.c = tmp_1003 + 3
self.assert_body_anfs_as_expected(expected_result, test_function)
def test_subscript_read_and_write(self):
def test_function(a, b, c, d, e, f):
a[b][c] = d[e][f] + 3
def expected_result(a, b, c, d, e, f):
tmp_1001 = a[b]
tmp_1002 = d[e]
tmp_1003 = tmp_1002[f]
tmp_1001[c] = tmp_1003 + 3
self.assert_body_anfs_as_expected(expected_result, test_function)
def test_augassign_and_delete(self):
def test_function(a, x, y, z):
a += x + y + z
del a
del z[y][x]
def expected_result(a, x, y, z):
tmp_1001 = x + y
a += tmp_1001 + z
del a
tmp_1002 = z[y]
del tmp_1002[x]
self.assert_body_anfs_as_expected(expected_result, test_function)
def test_raise_yield_and_raise(self):
def test_function(a, c, some_computed, exception):
yield a ** c
raise some_computed('complicated' + exception)
def expected_result(a, c, some_computed, exception):
tmp_1001 = a ** c
yield tmp_1001
tmp_1002 = 'complicated' + exception
tmp_1003 = some_computed(tmp_1002)
raise tmp_1003
self.assert_body_anfs_as_expected(expected_result, test_function)
def test_with_and_if_with_expressions(self):
def test_function(foo, bar, function, quux, quozzle, w, x, y, z):
with foo + bar:
function(x + y)
if quux + quozzle:
function(z / w)
def expected_result(foo, bar, function, quux, quozzle, w, x, y, z):
tmp_1001 = foo + bar
with tmp_1001:
tmp_1002 = x + y
function(tmp_1002)
tmp_1003 = quux + quozzle
if tmp_1003:
tmp_1004 = z / w
function(tmp_1004)
self.assert_body_anfs_as_expected(expected_result, test_function)
def test_exec(self):
self.assert_body_anfs_as_expected(exec_expected_result, exec_test_function)
def test_simple_while_and_assert(self):
def test_function(foo, quux):
while foo:
assert quux
foo = foo + 1 * 3
def expected_result(foo, quux):
while foo:
assert quux
tmp_1001 = 1 * 3
foo = foo + tmp_1001
self.assert_body_anfs_as_expected(expected_result, test_function)
def test_for(self):
def test_function(compute, something, complicated, foo):
for foo in compute(something + complicated):
bar = foo + 1 * 3
return bar
def expected_result(compute, something, complicated, foo):
tmp_1001 = something + complicated
tmp_1002 = compute(tmp_1001)
for foo in tmp_1002:
tmp_1003 = 1 * 3
bar = foo + tmp_1003
return bar
self.assert_body_anfs_as_expected(expected_result, test_function)
# This test collects several examples where the definition of A-normal form
# implemented by this transformer is questionable. Mostly it's here to spell
# out what the definition is in these cases.
def test_controversial(self):
def test_function(b, c, d, f):
a = c + d
a.b = c + d
a[b] = c + d
a += c + d
a, b = c
a, b = c, d
a = f(c)
a = f(c + d)
a[b + d] = f.e(c + d)
def expected_result(b, c, d, f):
a = c + d
a.b = c + d # Should be a.b = tmp? (Definitely not tmp = c + d)
a[b] = c + d # Should be a[b] = tmp? (Definitely not tmp = c + d)
a += c + d # Should be a += tmp? (Definitely not tmp = c + d)
a, b = c # Should be a = c[0], b = c[1]? Or not?
a, b = c, d # Should be a = c, b = d? Or not?
a = f(c)
tmp_1001 = c + d
a = f(tmp_1001)
tmp_1002 = b + d
tmp_1003 = f.e
tmp_1004 = c + d
a[tmp_1002] = tmp_1003(tmp_1004) # Or should be a[tmp1] = tmp2?
self.assert_body_anfs_as_expected(expected_result, test_function)
class AnfNonTransformationTest(AnfTransformerTest):
"""Test that specifying "no transformation" does nothing.
Reuses all the examples of AnfTransformerTest by overriding
`assert_body_anfs_as_expected_`.
"""
def assert_body_anfs_as_expected(self, expected_fn, test_fn):
# Testing the code bodies only. Wrapping them in functions so the
# syntax highlights nicely, but Python doesn't try to execute the
# statements.
node, _ = parser.parse_entity(test_fn, future_features=())
orig_source = compiler.ast_to_source(node, indentation=' ')
orig_str = textwrap.dedent(orig_source).strip()
config = [(anf.ANY, anf.LEAVE)] # Configuration to trasform nothing
node = anf.transform(
node, self._simple_context(),
config=config, gensym_source=DummyGensym)
new_source = compiler.ast_to_source(node, indentation=' ')
new_str = textwrap.dedent(new_source).strip()
self.assertEqual(orig_str, new_str)
class AnfConfiguredTest(AnfTestBase):
def test_constants_in_function_calls(self):
# An example specific configuration that differs from the default: Moving
# literals out of being directly passed to functions, but nothing else.
literals = (gast.Num, gast.Str, gast.Bytes, gast.NameConstant, gast.Name)
config = [(anf.ASTEdgePattern(gast.Call, anf.ANY, literals), anf.REPLACE)]
def test_function(x, frob):
return frob(x, x+1, 2)
def expected_result(x, frob):
tmp_1001 = 2
return frob(x, x+1, tmp_1001)
self.assert_body_anfs_as_expected(expected_result, test_function, config)
def test_anf_some_function_calls(self):
# Another example specific configuration that differs from the default:
# Moving all arguments out of some function calls but leaving others be.
whitelist = ['foo']
def transform(parent, field, child):
del field
del child
func_name = parent.func.id
return str(func_name) in whitelist
config = [(anf.ASTEdgePattern(gast.Call, anf.ANY, anf.ANY), transform)]
def test_function(x, foo, bar):
y = foo(x, x+1, 2)
return bar(y, y+1, 2)
def expected_result(x, foo, bar):
tmp_1001 = x+1
tmp_1002 = 2
y = foo(x, tmp_1001, tmp_1002)
return bar(y, y+1, 2)
self.assert_body_anfs_as_expected(expected_result, test_function, config)
def test_touching_name_constant(self):
# Checking that the nodes for `True`, `False`, and `None` can be manipulated
# by a configuration. This is non-trivial, because in Python 2 those are
# represented as `Name`, which is the same node type as variable references.
specials = (gast.Name, gast.NameConstant)
config = [(anf.ASTEdgePattern(gast.Call, anf.ANY, specials), anf.REPLACE)]
def test_function(f):
return f(True, False, None)
def expected_result(f):
tmp_1001 = True
tmp_1002 = False
tmp_1003 = None
return f(tmp_1001, tmp_1002, tmp_1003)
self.assert_body_anfs_as_expected(expected_result, test_function, config)
if __name__ == '__main__':
test.main()