blob: 7ce1b7c4c5ac34cdc69837e55e850d44f70e2236 [file] [log] [blame]
# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Base class for tests in this module."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import contextlib
import imp
import sys
import six
from tensorflow.python.autograph import operators
from tensorflow.python.autograph import utils
from tensorflow.python.autograph.core import config
from tensorflow.python.autograph.core import converter
from tensorflow.python.autograph.core import errors
from tensorflow.python.autograph.pyct import compiler
from tensorflow.python.autograph.pyct import parser
from tensorflow.python.autograph.pyct import pretty_printer
from tensorflow.python.autograph.pyct import transformer
from tensorflow.python.platform import test
def imported_decorator(f):
return lambda a: f(a) + 1
# TODO(mdan): We might be able to use the real namer here.
class FakeNamer(object):
"""A fake namer that uses a global counter to generate unique names."""
def __init__(self):
self.i = 0
def new_symbol(self, name_root, used):
while True:
self.i += 1
name = '%s%d' % (name_root, self.i)
if name not in used:
return name
def compiled_function_name(self,
original_fqn,
live_entity=None,
owner_type=None):
del live_entity
if owner_type is not None:
return None, False
return ('renamed_%s' % '_'.join(original_fqn)), True
class FakeNoRenameNamer(FakeNamer):
def compiled_function_name(self, original_fqn, **_):
return str(original_fqn), False
class TestCase(test.TestCase):
"""Base class for unit tests in this module. Contains relevant utilities."""
@contextlib.contextmanager
def assertPrints(self, expected_result):
try:
out_capturer = six.StringIO()
sys.stdout = out_capturer
yield
self.assertEqual(out_capturer.getvalue(), expected_result)
finally:
sys.stdout = sys.__stdout__
@contextlib.contextmanager
def compiled(self, node, namespace, *symbols):
source = None
self.dynamic_calls = []
def converted_call(*args):
"""Mock version of api.converted_call."""
self.dynamic_calls.append(args)
return 7
class ConversionOptions(object):
"""Mock version of api.ConversionOptions."""
def __init__(self, recursive):
self.recursive = recursive
@classmethod
def new(cls, recursive):
cls(recursive)
try:
result, source = compiler.ast_to_object(node, include_source_map=True)
result.tf = self.make_fake_mod('fake_tf', *symbols)
fake_ag = self.make_fake_mod('fake_ag', converted_call, ConversionOptions)
fake_ag.__dict__.update(operators.__dict__)
fake_ag.__dict__['utils'] = utils
fake_ag.__dict__['rewrite_graph_construction_error'] = (
errors.rewrite_graph_construction_error)
result.__dict__['ag__'] = fake_ag
for k, v in namespace.items():
result.__dict__[k] = v
yield result
except Exception: # pylint:disable=broad-except
if source is None:
print('Offending AST:\n%s' % pretty_printer.fmt(node, color=False))
else:
print('Offending compiled code:\n%s' % source)
raise
@contextlib.contextmanager
def converted(self, entity, converter_module, namespace, *tf_symbols):
node, ctx = self.prepare(entity, namespace)
node = converter_module.transform(node, ctx)
with self.compiled(node, namespace, *tf_symbols) as result:
yield result
def make_fake_mod(self, name, *symbols):
fake_mod = imp.new_module(name)
for s in symbols:
if hasattr(s, '__name__'):
setattr(fake_mod, s.__name__, s)
elif hasattr(s, 'name'):
# This is a bit of a hack, but works for things like tf.int32
setattr(fake_mod, s.name, s)
else:
raise ValueError('can not attach %s - what should be its name?' % s)
return fake_mod
def attach_namespace(self, module, **ns):
for k, v in ns.items():
setattr(module, k, v)
def prepare(self,
test_fn,
namespace,
namer=None,
arg_types=None,
owner_type=None,
recursive=True,
autograph_decorators=()):
node, source = parser.parse_entity(test_fn)
node = node.body[0]
if namer is None:
namer = FakeNamer()
program_ctx = converter.ProgramContext(
recursive=recursive,
autograph_decorators=autograph_decorators,
partial_types=None,
autograph_module=None,
uncompiled_modules=config.DEFAULT_UNCOMPILED_MODULES)
entity_info = transformer.EntityInfo(
source_code=source,
source_file='<fragment>',
namespace=namespace,
arg_values=None,
arg_types=arg_types,
owner_type=owner_type)
ctx = converter.EntityContext(namer, entity_info, program_ctx)
node = converter.standard_analysis(node, ctx, is_initial=True)
return node, ctx