blob: 43330f707f109f8a8ab2be1056c1c9ce614e6469 [file] [log] [blame]
# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Tests for api module."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import collections
import functools
import gc
import imp
import os
import re
import textwrap
import types
import numpy as np
from tensorflow.python.autograph import utils
from tensorflow.python.autograph.core import ag_ctx
from tensorflow.python.autograph.core import converter
from tensorflow.python.autograph.impl import api
from tensorflow.python.autograph.pyct import inspect_utils
from tensorflow.python.autograph.pyct import parser
from tensorflow.python.autograph.utils import py_func
from tensorflow.python.data.ops import dataset_ops
from tensorflow.python.eager import def_function
from tensorflow.python.eager import function
from tensorflow.python.framework import constant_op
from tensorflow.python.framework import test_util
from tensorflow.python.keras.engine import sequential
from tensorflow.python.keras.layers import core
from tensorflow.python.ops import gen_math_ops
from tensorflow.python.ops import variables
from tensorflow.python.platform import test
from tensorflow.python.util import function_utils
from tensorflow.python.util import tf_decorator
from tensorflow.python.util import tf_inspect
tf = utils.fake_tf()
global_n = 2
class TestResource(object):
def __init__(self):
self.x = 3
class ApiTest(test.TestCase):
@test_util.run_deprecated_v1
def test_decorator_recursive(self):
class TestClass(object):
def called_member(self, a):
if a < 0:
a = -a
return a
@api.convert(recursive=True)
def test_method(self, x, s, a):
while tf.reduce_sum(x) > s:
x //= self.called_member(a)
return x
tc = TestClass()
with self.cached_session() as sess:
x = tc.test_method(
constant_op.constant([2, 4]), constant_op.constant(1),
constant_op.constant(-2))
self.assertListEqual([0, 1], self.evaluate(x).tolist())
@test_util.run_deprecated_v1
def test_decorator_not_recursive(self):
class TestClass(object):
def called_member(self, a):
return tf.negative(a)
@api.convert(recursive=False)
def test_method(self, x, s, a):
while tf.reduce_sum(x) > s:
x //= self.called_member(a)
return x
tc = TestClass()
with self.cached_session() as sess:
x = tc.test_method(
constant_op.constant([2, 4]), constant_op.constant(1),
constant_op.constant(-2))
self.assertListEqual([0, 1], self.evaluate(x).tolist())
@test_util.run_deprecated_v1
def test_convert_then_do_not_convert_graph(self):
class TestClass(object):
@api.do_not_convert(run_as=api.RunMode.GRAPH)
def called_member(self, a):
return tf.negative(a)
@api.convert(recursive=True)
def test_method(self, x, s, a):
while tf.reduce_sum(x) > s:
x //= self.called_member(a)
return x
tc = TestClass()
x = tc.test_method(
constant_op.constant((2, 4)), constant_op.constant(1),
constant_op.constant(-2))
self.assertAllEqual((0, 1), self.evaluate(x))
@test_util.run_deprecated_v1
def test_convert_then_do_not_convert_py_func(self):
class TestClass(object):
@api.do_not_convert(
run_as=api.RunMode.PY_FUNC, return_dtypes=py_func.MatchDType(1))
def called_member(self, a):
return np.negative(a)
@api.convert(recursive=True)
def test_method(self, x, s, a):
while tf.reduce_sum(x) > s:
y = self.called_member(a)
# set_shape works around while_loop's limitations.
# TODO(mdan): Allow specifying shapes (or ShapeLike) instead.
y.set_shape(a.shape)
x //= y
return x
tc = TestClass()
x = tc.test_method(
constant_op.constant((2, 4)), constant_op.constant(1),
constant_op.constant(-2))
self.assertAllEqual((0, 1), self.evaluate(x))
@test_util.run_deprecated_v1
def test_decorator_calls_decorated(self):
class TestClass(object):
@api.convert()
def called_member(self, a):
if a < 0:
a = -a
return a
@api.convert(recursive=True)
def test_method(self, x, s, a):
while tf.reduce_sum(x) > s:
x //= self.called_member(a)
return x
tc = TestClass()
with self.cached_session() as sess:
x = tc.test_method(
constant_op.constant([2, 4]), constant_op.constant(1),
constant_op.constant(-2))
self.assertListEqual([0, 1], self.evaluate(x).tolist())
def test_decorator_preserves_argspec(self):
class TestClass(object):
def test_method(self, a):
if a < 0:
a = -a
return a
test_method_converted = api.convert()(test_method)
tc = TestClass()
self.assertListEqual(
list(tf_inspect.getfullargspec(tc.test_method)),
list(tf_inspect.getfullargspec(tc.test_method_converted)))
def test_do_not_convert_argspec(self):
class TestClass(object):
def test_method(self, x, y):
z = x + y
return z
test_method_whitelisted = api.do_not_convert(test_method)
tc = TestClass()
self.assertTrue(tf_inspect.ismethod(tc.test_method_whitelisted))
# Because the wrapped function is not generated, we can't preserve its
# arg spec.
self.assertEqual((),
tuple(function_utils.fn_args(tc.test_method_whitelisted)))
def test_do_not_convert_callable_object(self):
class TestClass(object):
def __call__(self):
return 1
tc = TestClass()
self.assertEqual(1, api.do_not_convert(tc)())
@test_util.run_deprecated_v1
def test_convert_call_site_decorator(self):
class TestClass(object):
def called_member(self, a):
if a < 0:
a = -a
return a
@api.convert(recursive=True)
def test_method(self, x, s, a):
while tf.reduce_sum(x) > s:
x //= api.converted_call(self.called_member,
converter.ConversionOptions(recursive=True),
(a,), {})
return x
tc = TestClass()
x = tc.test_method(
constant_op.constant([2, 4]), constant_op.constant(1),
constant_op.constant(-2))
self.assertListEqual([0, 1], self.evaluate(x).tolist())
def test_converted_call_builtin(self):
x = api.converted_call(range, converter.ConversionOptions(recursive=True),
(3,), {})
self.assertEqual((0, 1, 2), tuple(x))
x = api.converted_call(re.compile,
converter.ConversionOptions(recursive=True),
('mnas_v4_a.*\\/.*(weights|kernel):0$',), {})
self.assertIsNotNone(x.match('mnas_v4_a/weights:0'))
def test_converted_call_function(self):
def test_fn(x):
if x < 0:
return -x
return x
x = api.converted_call(test_fn, converter.ConversionOptions(recursive=True),
(constant_op.constant(-1),), {})
self.assertEqual(1, self.evaluate(x))
@test_util.run_v1_only('b/120545219')
def test_converted_call_functools_partial(self):
def test_fn(x, y, z):
if x < 0:
return -x, -y, -z
return x, y, z
x = api.converted_call(
functools.partial(test_fn, constant_op.constant(-1), z=-3),
converter.ConversionOptions(recursive=True),
(constant_op.constant(-2),), {})
self.assertEqual((1, 2, 3), self.evaluate(x))
x = api.converted_call(
functools.partial(
functools.partial(test_fn, constant_op.constant(-1)), z=-3),
converter.ConversionOptions(recursive=True),
(constant_op.constant(-2),), {})
self.assertEqual((1, 2, 3), self.evaluate(x))
def test_converted_call_method(self):
class TestClass(object):
def __init__(self, x):
self.x = x
def test_method(self):
if self.x < 0:
return -self.x
return self.x
tc = TestClass(constant_op.constant(-1))
x = api.converted_call(tc.test_method,
converter.ConversionOptions(recursive=True), (), {})
self.assertEqual(1, self.evaluate(x))
def test_converted_call_synthetic_method(self):
class TestClass(object):
def __init__(self, x):
self.x = x
def test_function(self):
if self.x < 0:
return -self.x
return self.x
tc = TestClass(constant_op.constant(-1))
test_method = types.MethodType(test_function, tc)
x = api.converted_call(test_method,
converter.ConversionOptions(recursive=True), (), {})
self.assertEqual(1, self.evaluate(x))
def test_converted_call_method_wrapper(self):
class TestClass(object):
def foo(self):
pass
tc = TestClass()
# `method.__get__()` returns a so-called method-wrapper.
wrapper = api.converted_call(tc.foo.__get__,
converter.ConversionOptions(recursive=True),
(tc,), {})
self.assertEqual(wrapper, tc.foo)
def test_converted_call_method_as_object_attribute(self):
class AnotherClass(object):
def __init__(self):
self.another_class_attr = constant_op.constant(1)
def method(self):
if self.another_class_attr > 0:
return self.another_class_attr + 1
return self.another_class_attr + 10
class TestClass(object):
def __init__(self, another_obj_method):
self.another_obj_method = another_obj_method
obj = AnotherClass()
tc = TestClass(obj.method)
x = api.converted_call(tc.another_obj_method,
converter.ConversionOptions(recursive=True), (), {})
self.assertEqual(self.evaluate(x), 2)
def test_converted_call_method_converts_recursively(self):
class TestClass(object):
def __init__(self, x):
self.x = x
def other_method(self):
if self.x < 0:
return -self.x
return self.x
def test_method(self):
return self.other_method()
tc = TestClass(constant_op.constant(-1))
x = api.converted_call(tc.test_method,
converter.ConversionOptions(recursive=True), (), {})
self.assertEqual(1, self.evaluate(x))
def test_converted_call_method_by_class(self):
class TestClass(object):
def __init__(self, x):
self.x = x
def test_method(self):
if self.x < 0:
return -self.x
return self.x
tc = TestClass(constant_op.constant(-1))
x = api.converted_call(TestClass.test_method,
converter.ConversionOptions(recursive=True), (tc,),
{})
self.assertEqual(1, self.evaluate(x))
def test_converted_call_callable_object(self):
class TestClass(object):
def __init__(self, x):
self.x = x
def __call__(self):
if self.x < 0:
return -self.x
return self.x
tc = TestClass(constant_op.constant(-1))
x = api.converted_call(tc, converter.ConversionOptions(recursive=True), (),
{})
self.assertEqual(1, self.evaluate(x))
def test_converted_call_callable_metaclass(self):
class TestMetaclass(type):
x = constant_op.constant(-1)
def __call__(cls):
if cls.x < 0:
cls.x = -cls.x
return cls
tc = TestMetaclass('TestClass', (), {})
# This functools.partial will hide the class form the constructor
# check. Not ideal. See b/120224672.
tc = functools.partial(tc)
converted_tc = api.converted_call(
tc, converter.ConversionOptions(recursive=True), (), {})
self.assertIsInstance(converted_tc, TestMetaclass)
self.assertEqual(1, self.evaluate(converted_tc.x))
@test_util.run_deprecated_v1
def test_converted_call_constructor(self):
class TestClass(object):
def __init__(self, x):
self.x = x
def test_method(self):
if self.x < 0:
return -self.x
return self.x
tc = api.converted_call(TestClass,
converter.ConversionOptions(recursive=True),
(constant_op.constant(-1),), {})
# tc is still a TestClass - constructors are whitelisted.
# TODO(b/124016764): Support this use case.
# The error below is specific to the `if` statement not being converted.
with self.assertRaisesRegex(TypeError,
'Using a `tf.Tensor` as a Python `bool`'):
tc.test_method()
def test_converted_call_mangled_properties(self):
class TestClass(object):
def __init__(self, x):
self.__private = x
def test_method(self):
if self.__private < 0:
return self.__private
return self.__private
tc = TestClass(constant_op.constant(-1))
# The error below is specific to the `if` statement not being converted.
with self.assertRaisesRegex(NotImplementedError, 'Mangled names'):
api.converted_call(tc.test_method,
converter.ConversionOptions(recursive=True), (), {})
tc.test_method()
def test_converted_call_already_converted(self):
def f(x):
return x == 0
x = api.converted_call(f, converter.ConversionOptions(recursive=True),
(constant_op.constant(0),), {})
self.assertTrue(self.evaluate(x))
converted_f = api.to_graph(
f, experimental_optional_features=converter.Feature.ALL)
x = api.converted_call(converted_f,
converter.ConversionOptions(recursive=True),
(constant_op.constant(0),), {})
self.assertTrue(self.evaluate(x))
def test_converted_call_then_already_converted_dynamic(self):
@api.convert()
def g(x):
if x > 0:
return x
else:
return -x
def f(g, x):
return g(x)
x = api.converted_call(f, converter.ConversionOptions(recursive=True),
(g, constant_op.constant(1)), {})
self.assertEqual(self.evaluate(x), 1)
def test_converted_call_forced_when_explicitly_whitelisted(self):
@api.do_not_convert()
def f(x):
return x + 1
x = api.converted_call(
f, converter.ConversionOptions(recursive=True, force_conversion=True),
(constant_op.constant(0),), {})
self.assertTrue(self.evaluate(x))
converted_f = api.to_graph(
f, experimental_optional_features=converter.Feature.ALL)
x = api.converted_call(converted_f,
converter.ConversionOptions(recursive=True), (0,),
{})
self.assertEqual(x, 1)
@test_util.run_deprecated_v1
def test_converted_call_no_user_code(self):
def f(x):
return len(x)
opts = converter.ConversionOptions(internal_convert_user_code=False)
# f should not be converted, causing len to error out.
with self.assertRaisesRegexp(Exception,
'object of type \'Tensor\' has no len()'):
api.converted_call(f, opts, (constant_op.constant([0]),), {})
# len on the other hand should work fine.
x = api.converted_call(len, opts, (constant_op.constant([0]),), {})
# The constant has static shape so the result is a primitive not a Tensor.
self.assertEqual(x, 1)
def test_converted_call_no_kwargs_allowed(self):
def f(*args):
# Note: np.broadcast rejects any **kwargs, even *{}
return np.broadcast(args[:1])
opts = converter.ConversionOptions(internal_convert_user_code=False)
self.assertIsNotNone(api.converted_call(f, opts, (1, 2, 3, 4), None))
def test_converted_call_whitelisted_method(self):
opts = converter.ConversionOptions(recursive=True)
model = sequential.Sequential([core.Dense(2)])
x = api.converted_call(model.call, opts, (constant_op.constant([[0.0]]),),
{'training': True})
self.evaluate(variables.global_variables_initializer())
self.assertAllEqual([[0.0, 0.0]], self.evaluate(x))
def test_converted_call_whitelisted_method_via_owner(self):
opts = converter.ConversionOptions(recursive=True)
model = sequential.Sequential([core.Dense(2)])
x = api.converted_call(model.call, opts, (constant_op.constant([[0.0]]),),
{'training': True})
self.evaluate(variables.global_variables_initializer())
self.assertAllEqual([[0.0, 0.0]], self.evaluate(x))
def test_converted_call_numpy(self):
opts = converter.ConversionOptions(recursive=True)
x = api.converted_call(np.arange, opts, (5,), {})
self.assertAllEqual(x, list(range(5)))
def test_converted_call_tf_op_forced(self):
# TODO(mdan): Add the missing level of support to LOGICAL_EXPRESSIONS.
opts = converter.ConversionOptions(
force_conversion=True, optional_features=None)
x = api.converted_call(gen_math_ops.add, opts, (1, 1), {})
self.assertAllEqual(self.evaluate(x), 2)
def test_converted_call_exec_generated_code(self):
temp_mod = imp.new_module('test_module')
dynamic_code = """
def foo(x):
return x + 1
"""
exec(textwrap.dedent(dynamic_code), temp_mod.__dict__) # pylint:disable=exec-used
opts = converter.ConversionOptions(optional_features=None)
x = api.converted_call(temp_mod.foo, opts, (1,), {})
self.assertAllEqual(x, 2)
def test_converted_call_namedtuple(self):
opts = converter.ConversionOptions(recursive=True)
x = api.converted_call(collections.namedtuple, opts,
('TestNamedtuple', ('a', 'b')), {})
self.assertTrue(inspect_utils.isnamedtuple(x))
def test_converted_call_namedtuple_via_collections(self):
opts = converter.ConversionOptions(recursive=True)
x = api.converted_call(collections.namedtuple, opts,
('TestNamedtuple', ('a', 'b')), {})
self.assertTrue(inspect_utils.isnamedtuple(x))
def test_converted_call_lambda(self):
opts = converter.ConversionOptions(recursive=True)
l = lambda x: x == 0
x = api.converted_call(l, opts, (constant_op.constant(0),), {})
self.evaluate(variables.global_variables_initializer())
self.assertAllEqual(True, self.evaluate(x))
def test_converted_call_defun_object_method(self):
opts = converter.ConversionOptions(recursive=True)
# pylint:disable=method-hidden
class TestClass(object):
def method(self):
return 1
def prepare(self):
self.method = function.defun(self.method)
# pylint:enable=method-hidden
tc = TestClass()
tc.prepare()
x = api.converted_call(tc.method, opts, (), {})
self.assertAllEqual(1, self.evaluate(x))
def test_converted_call_through_tf_dataset(self):
def other_fn(x):
if x > 0:
return x
return -x
def f():
return dataset_ops.Dataset.range(-3, 3).map(other_fn)
# Dataset iteration only works inside tf.function.
@def_function.function
def graph_fn():
opts = converter.ConversionOptions(recursive=True)
ds = api.converted_call(f, opts, (), {})
itr = iter(ds)
return next(itr), next(itr), next(itr)
self.assertAllEqual(self.evaluate(graph_fn()), (3, 2, 1))
def assertNoMemoryLeaks(self, f):
object_ids_before = {id(o) for o in gc.get_objects()}
f()
gc.collect()
objects_after = tuple(
o for o in gc.get_objects() if id(o) not in object_ids_before)
self.assertEmpty(
tuple(o for o in objects_after if isinstance(o, TestResource)))
def test_converted_call_no_leaks_via_closure(self):
def test_fn():
res = TestResource()
def f(y):
return res.x + y
opts = converter.ConversionOptions(recursive=True)
api.converted_call(f, opts, (1,), {})
self.assertNoMemoryLeaks(test_fn)
def test_converted_call_no_leaks_via_inner_function_closure(self):
def test_fn():
res = TestResource()
def f(y):
def inner_f():
return res.x + y
return inner_f
opts = converter.ConversionOptions(recursive=True)
api.converted_call(f, opts, (1,), {})()
self.assertNoMemoryLeaks(test_fn)
def test_context_tracking_direct_calls(self):
@api.do_not_convert()
def unconverted_fn():
self.assertEqual(ag_ctx.control_status_ctx().status,
ag_ctx.Status.DISABLED)
@api.convert()
def converted_fn():
self.assertEqual(ag_ctx.control_status_ctx().status,
ag_ctx.Status.ENABLED)
unconverted_fn()
self.assertEqual(ag_ctx.control_status_ctx().status,
ag_ctx.Status.ENABLED)
self.assertEqual(ag_ctx.control_status_ctx().status,
ag_ctx.Status.UNSPECIFIED)
converted_fn()
self.assertEqual(ag_ctx.control_status_ctx().status,
ag_ctx.Status.UNSPECIFIED)
@api.call_with_unspecified_conversion_status
def unspecified_fn():
self.assertEqual(ag_ctx.control_status_ctx().status,
ag_ctx.Status.UNSPECIFIED)
unspecified_fn()
def test_to_graph_basic(self):
def test_fn(x, s):
while tf.reduce_sum(x) > s:
x //= 2
return x
compiled_fn = api.to_graph(test_fn)
with tf.Graph().as_default():
x = compiled_fn(constant_op.constant((4, 8)), 4)
self.assertAllEqual(self.evaluate(x), (1, 2))
@test_util.run_deprecated_v1
def test_to_graph_with_defaults(self):
foo = 4
def test_fn(x, s=foo):
while tf.reduce_sum(x) > s:
x //= 2
return x
compiled_fn = api.to_graph(test_fn)
with self.cached_session() as sess:
x = compiled_fn(constant_op.constant([4, 8]))
self.assertListEqual([1, 2], self.evaluate(x).tolist())
def test_to_graph_with_globals(self):
def test_fn(x):
global global_n
global_n = x + global_n
return global_n
converted_fn = api.to_graph(test_fn)
prev_val = global_n
converted_fn(10)
self.assertGreater(global_n, prev_val)
def test_to_graph_with_kwargs_clashing_converted_call(self):
def called_fn(**kwargs):
return kwargs['f'] + kwargs['owner']
def test_fn():
# These arg names intentionally match converted_call's
return called_fn(f=1, owner=2)
compiled_fn = api.to_graph(test_fn)
self.assertEqual(compiled_fn(), 3)
def test_to_graph_with_kwargs_clashing_unconverted_call(self):
@api.do_not_convert
def called_fn(**kwargs):
return kwargs['f'] + kwargs['owner']
def test_fn():
# These arg names intentionally match _call_unconverted's
return called_fn(f=1, owner=2)
compiled_fn = api.to_graph(test_fn)
self.assertEqual(compiled_fn(), 3)
def test_to_graph_caching(self):
def test_fn(x):
if x > 0:
return x
else:
return -x
converted_functions = tuple(api.to_graph(test_fn) for _ in (-1, 0, 1))
# All outputs are from the same module. We can't use __module__ because
# that's reset when we instantiate the function (see conversion.py).
# TODO(mdan): Can and should we overwrite __module__ instead?
module_names = frozenset(f.ag_module for f in converted_functions)
self.assertEqual(len(module_names), 1)
self.assertNotIn('__main__', module_names)
self.assertEqual(len(frozenset(id(f) for f in converted_functions)), 3)
def test_to_graph_caching_different_options(self):
def called_fn():
pass
def test_fn():
return called_fn()
converted_recursive = api.to_graph(test_fn, recursive=True)
converted_non_recursive = api.to_graph(test_fn, recursive=False)
self.assertNotEqual(converted_recursive.ag_module,
converted_non_recursive.ag_module)
self.assertIn('ag__.STD', tf_inspect.getsource(converted_recursive))
self.assertNotIn('internal_convert_user_code=False',
tf_inspect.getsource(converted_recursive))
self.assertIn('internal_convert_user_code=False',
tf_inspect.getsource(converted_non_recursive))
self.assertNotIn('internal_convert_user_code=True',
tf_inspect.getsource(converted_non_recursive))
def test_to_graph_preserves_bindings(self):
y = 3
def test_fn():
return y
converted = api.to_graph(test_fn)
self.assertEqual(converted(), 3)
y = 7
self.assertEqual(converted(), 7)
def test_to_graph_source_map(self):
def test_fn(y):
return y**2
self.assertTrue(hasattr(api.to_graph(test_fn), 'ag_source_map'))
def test_to_code_basic(self):
def test_fn(x, s):
while tf.reduce_sum(x) > s:
x /= 2
return x
# Just check that the output is parseable Python code.
self.assertIsNotNone(parser.parse_str(api.to_code(test_fn)))
def test_tf_convert_direct(self):
def f():
if tf.reduce_sum([1, 2]) > 0:
return -1
return 1
# Note: the autograph setting of tf.function has nothing to do with the
# test case. We just disable it to avoid confusion.
@def_function.function(autograph=False)
def test_fn(ctx):
return api.tf_convert(f, ctx)()
self.assertEqual(
self.evaluate(
test_fn(ag_ctx.ControlStatusCtx(status=ag_ctx.Status.ENABLED))), -1)
with self.assertRaisesRegex(TypeError, 'tf.Tensor.*bool'):
# The code in `f` is only valid with AutoGraph.
test_fn(ag_ctx.ControlStatusCtx(status=ag_ctx.Status.DISABLED))
def test_tf_convert_unspecified_not_converted_by_default(self):
def f():
self.assertEqual(ag_ctx.control_status_ctx().status,
ag_ctx.Status.UNSPECIFIED)
if tf.reduce_sum([1, 2]) > 0:
return -1
return 1
@def_function.function
def test_fn(ctx):
return api.tf_convert(f, ctx, convert_by_default=False)()
with self.assertRaisesRegex(TypeError, 'tf.Tensor.*bool'):
# The code in `f` is only valid with AutoGraph.
test_fn(ag_ctx.ControlStatusCtx(status=ag_ctx.Status.UNSPECIFIED))
def test_tf_convert_whitelisted_method(self):
model = sequential.Sequential([core.Dense(2)])
converted_call = api.tf_convert(
model.call, ag_ctx.ControlStatusCtx(status=ag_ctx.Status.ENABLED))
_, converted_target = tf_decorator.unwrap(converted_call)
self.assertIs(converted_target.__func__, model.call.__func__)
def test_tf_convert_wrapped(self):
def f():
if tf.reduce_sum([1, 2]) > 0:
return -1
return 1
@functools.wraps(f)
def wrapper(*args, **kwargs):
return wrapper.__wrapped__(*args, **kwargs)
decorated_f = tf_decorator.make_decorator(f, wrapper)
# Note: the autograph setting of tf.function has nothing to do with the
# test case. We just disable it to avoid confusion.
@def_function.function(autograph=False)
def test_fn(ctx):
return api.tf_convert(decorated_f, ctx)()
self.assertEqual(
self.evaluate(
test_fn(ag_ctx.ControlStatusCtx(status=ag_ctx.Status.ENABLED))), -1)
# tf_convert mutates the decorator, so we need to create a new one for
# another test.
decorated_f = tf_decorator.make_decorator(f, wrapper)
with self.assertRaisesRegex(TypeError, 'tf.Tensor.*bool'):
# The code in `f` is only valid with AutoGraph.
test_fn(ag_ctx.ControlStatusCtx(status=ag_ctx.Status.DISABLED))
if __name__ == '__main__':
os.environ['AUTOGRAPH_STRICT_CONVERSION'] = '1'
test.main()