blob: d4afed9f63e37400bd3f6de66dcf11621da36814 [file] [log] [blame]
import functools
import itertools
import logging
import os
import traceback
import types
import typing
import weakref
from typing import Callable
import torch
from torch.fx.graph_module import _forward_from_src as original_forward_from_src
from . import config, exc, logging as torchdynamo_logging
from .allowed_functions import is_allowed
from .bytecode_analysis import remove_dead_code, remove_pointless_jumps
from .bytecode_transformation import is_generator, transform_code_object
from .eval_frame import (
always_optimize_code_objects,
skip_code,
TorchPatcher,
WrapperBackend,
)
from .exc import (
BackendCompilerFailed,
InternalTorchDynamoError,
TorchRuntimeError,
unimplemented,
Unsupported,
)
from .guards import CheckFunctionManager, GuardedCode
from .replay_record import ExecutionRecord
from .symbolic_convert import InstructionTranslator
from .utils import (
CleanupManager,
counters,
dynamo_timed,
filter_stack,
format_bytecode,
gen_record_file_name,
guard_failures,
init_logging,
is_namedtuple,
istype,
orig_code_map,
troubleshooting_url,
write_record_to_file,
)
log = logging.getLogger(__name__)
class Tracker:
def __init__(self):
self.seen = []
self.seen_ids = set()
def add(self, strong_obj):
idx = id(strong_obj)
if idx not in self.seen_ids:
obj = weakref.ref(strong_obj, lambda _: self.seen_ids.remove(idx))
self.seen.append(obj)
self.seen_ids.add(idx)
def __contains__(self, item):
return id(item) in self.seen_ids
def clear(self):
self.seen.clear()
self.seen_ids.clear()
input_codes = Tracker()
output_codes = Tracker()
initial_grad_state = None
@functools.wraps(original_forward_from_src)
def fx_forward_from_src_skip_result(*args, **kwargs):
# we monkey patch FX to prevent infinite loop of trying to convert
# our generated code
result: types.FunctionType = original_forward_from_src(*args, **kwargs)
skip_code(result.__code__)
return result
def wrap_compiler_fn(compiler_fn):
"""WrapperBackend if config.verify_correctness is True"""
if config.verify_correctness:
# wrap backend if verify_correctness is True
wrapper_backend_compiler_fn = WrapperBackend(compiler_fn)
wrapper_backend_compiler_fn._torchdynamo_orig_callable = compiler_fn
return wrapper_backend_compiler_fn
return compiler_fn
def wrap_convert_context(fn):
"""
Context manager to:
1) Save/restore torch random state
2) Save/restore torch.is_grad_enabled() state
3) Monkey patch torch.fx.graph_module._forward_from_src
"""
@functools.wraps(fn)
def _fn(*args, **kwargs):
prior_grad_mode = torch.is_grad_enabled()
rng_state = torch.random.get_rng_state()
if torch.cuda.is_available():
cuda_rng_state = torch.cuda.get_rng_state()
prior_fwd_from_src = torch.fx.graph_module._forward_from_src
torch.fx.graph_module._forward_from_src = fx_forward_from_src_skip_result
try:
return fn(*args, **kwargs)
finally:
torch._C._set_grad_enabled(prior_grad_mode)
torch.random.set_rng_state(rng_state)
if torch.cuda.is_available():
torch.cuda.set_rng_state(cuda_rng_state)
torch.fx.graph_module._forward_from_src = prior_fwd_from_src
_fn._torchdynamo_orig_callable = fn
return _fn
@TorchPatcher.suppress_torch_distributed_warnings
def has_tensor_in_frame(frame):
"""Check if the frame has torch.* related bits"""
# Check if the function was decorated using torchdynamo.optimize
if frame.f_code in always_optimize_code_objects:
return True
# Check if there is global import of torch.*
for co_name in frame.f_code.co_names:
if co_name in frame.f_globals:
if is_allowed(frame.f_globals[co_name]):
return True
seen_ids = dict()
def has_tensor(obj):
"""Recursively check if the obj has a tensor"""
obj_id = id(obj)
if obj_id in seen_ids:
return seen_ids[obj_id]
seen_ids[obj_id] = False
if isinstance(obj, (torch.Tensor, torch.nn.Module)):
seen_ids[obj_id] = True
return seen_ids[obj_id]
elif istype(obj, (list, tuple)):
seen_ids[obj_id] = any([has_tensor(v) for v in obj])
return seen_ids[obj_id]
elif istype(obj, dict):
seen_ids[obj_id] = any([has_tensor(v) for v in obj.values()])
return seen_ids[obj_id]
elif istype(obj, (str, int, float, type(None), bool)):
seen_ids[obj_id] = False
return seen_ids[obj_id]
elif is_namedtuple(obj):
seen_ids[obj_id] = any([has_tensor(getattr(obj, v)) for v in obj._fields])
return seen_ids[obj_id]
elif not is_allowed(obj) and hasattr(obj, "__dict__") and len(obj.__dict__):
seen_ids[obj_id] = any([has_tensor(v) for v in obj.__dict__.values()])
return seen_ids[obj_id]
else:
# if config.debug:
# print(
# f"Assuming that object of type {type(obj)} does not have a tensor"
# )
return False
# Check if the passed arguments are of type Tensor
for value in frame.f_locals.values():
if has_tensor(value):
return True
log.debug(
f"skipping because no torch.* {frame.f_code.co_name} \
{frame.f_code.co_filename} {frame.f_code.co_firstlineno}"
)
return False
def format_error_msg(exc, code, record_filename=None, frame=None):
msg = os.linesep * 2
def replay_record_msg():
if (
config.replay_record_enabled
and hasattr(exc, "exec_record")
and record_filename is not None
):
return f"\nLast frame execution written to {record_filename}. To run only this frame while debugging, run\
{config.dynamo_import}.replay('{record_filename}').\n"
else:
return ""
if config.verbose:
msg = format_bytecode(
"WON'T CONVERT", code.co_name, code.co_filename, code.co_firstlineno, code
)
msg += "=" * 10 + " TorchDynamo Stack Trace " + "=" * 10 + "\n"
msg += traceback.format_exc()
if hasattr(exc, "real_stack"):
msg += (
"\n"
+ "=" * 10
+ " The above exception occurred while processing the following code "
+ "=" * 10
+ "\n\n"
)
stack_above_dynamo = []
if frame is not None:
stack_above_dynamo = filter_stack(traceback.extract_stack(frame))
msg += "".join(
traceback.format_list(
stack_above_dynamo + list(reversed(exc.real_stack))
)
)
msg += replay_record_msg()
else:
msg = f"WON'T CONVERT {code.co_name} {code.co_filename}\
line {code.co_firstlineno} \ndue to: \n{traceback.format_exc(limit=-1)}"
if hasattr(exc, "real_stack"):
msg += f"\nfrom user code:\n {''.join(traceback.format_list([exc.real_stack[-1]]))}"
msg += replay_record_msg()
msg += (
f"\nSet {config.dynamo_import}.config.verbose=True for more information\n"
)
msg += "=" * 10
return msg
def exception_handler(e, code, frame=None):
record_filename = None
if hasattr(e, "exec_record"):
record_filename = gen_record_file_name(e, code)
write_record_to_file(record_filename, e.exec_record)
log.error(format_error_msg(e, code, record_filename, frame))
def convert_frame_assert(
compiler_fn: Callable, guard_export_fn=None, one_graph=True, export=False
):
"""Fully convert a frame into an FX graph"""
init_logging()
compiler_fn = wrap_compiler_fn(compiler_fn)
@dynamo_timed
def _convert_frame_assert(frame: types.FrameType, cache_size: int):
code = frame.f_code
input_codes.add(code)
if code in output_codes:
return None
if (
os.environ.get("TORCHDYNAMO_DEBUG_FUNCTION")
and os.environ.get("TORCHDYNAMO_DEBUG_FUNCTION") != code.co_name
):
return None
if code.co_name == "<genexpr>" and code.co_filename.endswith(
("transformers/file_utils.py", "transformers/utils/generic.py")
):
# not needed, but cleans up torchbench error stats
return None
if code.co_name == "__setattr__":
# setattr could be tricky to handle generally,
# but also not likely useful to compile- skip the whole frame
return None
# Check if the frame is generated by an exec builtin call
# TODO - Running exec generated frame seems propagates f_globals to the
# next frames.
if code.co_name == "<module>" and code.co_filename == "<string>":
return None
if (
code.co_name == "<lambda>"
and code.co_filename == "<string>"
and not bool(frame.f_builtins)
):
# namedtuple subclass constructor. Empty builtins cause issue with
# len keyword in LIST_LEN guard.
return None
if is_generator(code):
unimplemented("generator")
if cache_size >= config.cache_size_limit:
def format_func_info(code):
return f"'{code.co_name}' ({code.co_filename}:{code.co_firstlineno})"
def format_guard_failures(code):
# For the common case, it's sufficient to see just the most recent failure.
# We could add a verbose mode if needed
return f"{str(guard_failures[code][-1])}"
assert code in guard_failures, "TODO(whc) any other recompile reasons?"
log.warning(
f"{config.dynamo_import} hit config.cache_size_limit ({config.cache_size_limit})\n"
+ f" function: {format_func_info(code)}\n"
+ f" reasons: {format_guard_failures(code)}\n"
+ f"to diagnose recompilation issues, see {troubleshooting_url}."
)
unimplemented("cache_size_limit reached")
if not has_tensor_in_frame(frame):
return None
global initial_grad_state
initial_grad_state = torch.is_grad_enabled()
return _compile(
frame.f_code,
frame.f_globals,
frame.f_locals,
frame.f_builtins,
compiler_fn,
one_graph,
export,
guard_export_fn,
frame,
)
_convert_frame_assert._torchdynamo_orig_callable = compiler_fn
return wrap_convert_context(_convert_frame_assert)
def _compile(
code,
globals,
locals,
builtins,
compiler_fn,
one_graph,
export,
guard_export_fn=None,
frame=None,
):
output = None
# from .utils import print_once; print_once(code.co_filename)
def transform(instructions, code_options):
nonlocal output
tracer = InstructionTranslator(
instructions,
code,
locals,
globals,
builtins,
code_options,
compiler_fn,
one_graph,
export,
)
tracer.run()
output = tracer.output
assert output.output_instructions
instructions[:] = output.output_instructions
code_options.update(output.code_options)
if config.dead_code_elimination:
instructions[:] = remove_pointless_jumps(remove_dead_code(instructions))
try:
for attempt in itertools.count():
try:
out_code = transform_code_object(code, transform)
orig_code_map[out_code] = code
break
except exc.RestartAnalysis:
log.debug("Restarting analysis ...")
if attempt > 100:
unimplemented("100+ RestartAnalysis() calls")
except exc.SkipFrame:
log.debug(
f"Skipping frame {code.co_name} \
{code.co_filename} {code.co_firstlineno}"
)
if one_graph:
log.debug("No graph captured with one_graph=True")
return None
output_codes.add(out_code)
log.log(
torchdynamo_logging.CODE,
format_bytecode(
"ORIGINAL BYTECODE",
code.co_name,
code.co_filename,
code.co_firstlineno,
code,
),
)
log.log(
torchdynamo_logging.CODE,
format_bytecode(
"MODIFIED BYTECODE",
code.co_name,
code.co_filename,
code.co_firstlineno,
out_code,
),
)
assert output.guards is not None
CleanupManager.instance[out_code] = output.cleanups
check_fn = CheckFunctionManager(output.guards, locals, globals)
guarded_code = GuardedCode(out_code, check_fn.check_fn)
guard_str = "GUARDS:\n"
guard_str += "\n".join([f" - {str(guard)}" for guard in sorted(output.guards)])
log.log(torchdynamo_logging.CODE, guard_str)
if guard_export_fn is not None:
guard_export_fn(output.guards)
return guarded_code
except (
Unsupported,
TorchRuntimeError,
BackendCompilerFailed,
AssertionError,
) as e:
exception_handler(e, code, frame)
raise
except Exception as e:
exception_handler(e, code, frame)
raise InternalTorchDynamoError()
def convert_frame(compiler_fn: typing.Callable, guard_export_fn=None):
"""Try to convert a frame into an FX graph, if error leave frame unmodified"""
inner_convert = convert_frame_assert(compiler_fn, guard_export_fn, one_graph=False)
def _convert_frame(frame: types.FrameType, cache_size: int):
counters["frames"]["total"] += 1
try:
result = inner_convert(frame, cache_size)
counters["frames"]["ok"] += 1
return result
except AssertionError:
if config.raise_on_assertion_error:
raise
except BackendCompilerFailed:
raise
except Exception:
pass
return None
_convert_frame._torchdynamo_orig_callable = compiler_fn
return _convert_frame
# TODO mlazos: add support for same args, or record them
def replay(filename):
from .optimizations.backends import eager
original_replay_val = config.replay_record_enabled
config.replay_record_enabled = False
init_logging()
with open(filename, "rb") as in_file:
record = ExecutionRecord.load(in_file)
record.globals = {
k: v for k, v in itertools.chain(record.globals.items(), globals().items())
}
try:
_compile(
record.code,
record.globals,
record.locals,
record.builtins,
eager,
False, # one_graph
None, # export_fn
None, # frame
False, # Export
)
except Exception:
pass
finally:
config.replay_record_enabled = original_replay_val