blob: 345320dc931fdb39d45e604ea941cccede8abae8 [file] [log] [blame]
# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Code transformation exceptions."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import collections
import os
from tensorflow.python.autograph.pyct import origin_info
class FrameInfo(
collections.namedtuple(
'FrameInfo',
('filename', 'lineno', 'function_name', 'code', 'converted'))):
pass
def _stack_trace_inside_mapped_code(tb, source_map):
"""Summarizes inner traceback frames up to the call to a given function.
This functions locates the innermost (i.e. most recent) frame that corresponds
to code that can be mapped by source_map originated from, and returns a
translated stack trace ending at that frame. If no such frame is found, the
entire stack trace is summarized.
For example, the following code:
def f():
for i in tf.range(1):
z = y + i # z only defined here
Would generate this traceback:
<converted code>
ag__.for_stmt(...)
<for_stmt>
return _known_len_tf_for_stmt(iter_, extra_test, body, init_state)
<_known_len_tf_for_stmt>
_disallow_undefs_into_loop(*init_state)
<_disallow_undefs_into_loop>
raise ...
Which is then processed into:
<f>
for i in tf.range(1):
<for_stmt>
return _known_len_tf_for_stmt(iter_, extra_test, body, init_state)
<_known_len_tf_for_stmt>
_disallow_undefs_into_loop(*init_state)
<_disallow_undefs_into_loop>
raise ...
Args:
tb: List[Tuple], the traceback corresponding to an error; typically,
the output of traceback.extract_tb.
source_map: Dict[LineLocation, OriginInfo], a source map as created by
origin_info.create_source_map.
Returns:
List[FrameInfo]
"""
result_frames = []
for filename, line_number, function_name, text in reversed(tb):
loc = origin_info.LineLocation(filename=filename, lineno=line_number)
if loc in source_map:
origin = source_map[loc]
origin_frame_info = FrameInfo(
filename=origin.loc.filename,
lineno=origin.loc.lineno,
function_name=origin.function_name,
code=origin.source_code_line,
converted=True)
result_frames.append(origin_frame_info)
break
fi = FrameInfo(
filename=filename,
lineno=line_number,
function_name=function_name,
code=text,
converted=False)
result_frames.append(fi)
return tuple(result_frames)
KNOWN_STRING_CONSTRUCTOR_ERRORS = (
AssertionError,
AttributeError,
NameError,
NotImplementedError,
RuntimeError,
StopIteration,
TypeError,
ValueError,
)
# KeyError escapes newlines in strings. We create a special subclass
# that doesn't do that. Overriding the name for display purposes; hopefully
# that won't create too many surprises.
class MultilineMessageKeyError(KeyError):
def __init__(self, message, original_key):
super(MultilineMessageKeyError, self).__init__(original_key)
self.__message = message
def __str__(self):
return self.__message
MultilineMessageKeyError.__name__ = KeyError.__name__
class ErrorMetadataBase(object):
"""Container objects attached to exceptions in converted code.
This metadata allows re-raising exceptions that occur in generated code, with
a custom error message that includes a stack trace relative to user-readable
code from which the generated code originated.
"""
def __init__(self, callsite_tb, cause_metadata, cause_message, source_map):
translated_stack = _stack_trace_inside_mapped_code(callsite_tb, source_map)
if cause_metadata is None:
self.translated_stack = translated_stack
self.cause_message = cause_message
else:
# Daisy chain the translated stacks.
self.translated_stack = (
cause_metadata.translated_stack + (translated_stack[-1],))
self.cause_message = cause_metadata.cause_message
def get_message(self):
"""Returns the message for the underlying exception."""
all_paths = tuple(fi.filename for fi in self.translated_stack)
if len(all_paths) > 1:
common_path = os.path.dirname(os.path.commonprefix(all_paths))
if common_path == os.path.sep:
common_path = ''
if common_path:
path_idx = len(common_path) + 1
else:
path_idx = 0
else:
common_path = ''
path_idx = 0
lines = []
lines.append('in converted code:')
if common_path:
lines.append(' relative to {}:'.format(common_path))
lines.append('')
for frame_info in reversed(self.translated_stack):
lines.append(' {}:{} {}{}'.format(
frame_info.filename[path_idx:],
frame_info.lineno,
frame_info.function_name,
' *' if frame_info.converted else '',
))
if frame_info.code is None:
code_snippet = '<source unavailable>'
else:
code_snippet = frame_info.code.strip()
lines.append(' {}'.format(code_snippet))
lines.append('')
message_lines = self.cause_message.split('\n')
for i in range(len(message_lines)):
message_lines[i] = ' ' + message_lines[i]
lines.extend(message_lines)
lines.append('')
return '\n'.join(lines)
def create_exception(self, source_error):
preferred_type = type(source_error)
if preferred_type.__init__ is Exception.__init__:
return preferred_type(self.get_message())
if preferred_type in KNOWN_STRING_CONSTRUCTOR_ERRORS:
return preferred_type(self.get_message())
elif preferred_type is KeyError:
return MultilineMessageKeyError(self.get_message(), self.cause_message)
return None
def to_exception(self, source_error):
exc = self.create_exception(source_error)
exc.__suppress_context__ = True
exc.ag_error_metadata = self
return exc