blob: 67e75d1a73294b90724d17db0fb6b094945f5cc2 [file] [log] [blame]
import collections
import contextlib
import cProfile
import functools
import itertools
import logging
import os.path
import pstats
import shutil
import subprocess
from typing import Any, List
from functorch.compile import draw_graph, get_graph_being_compiled
import torch
from torch import fx as fx
from torch.fx.graph_module import GraphModule
from torch.fx.passes.shape_prop import TensorMetadata
from torch.fx.passes.tools_common import legalize_graph
from . import config, ir
from .scheduler import (
BaseSchedulerNode,
ExternKernelSchedulerNode,
FusedSchedulerNode,
NopKernelSchedulerNode,
OutputNode,
SchedulerNode,
TemplateSchedulerNode,
)
from .utils import dynamo_config, dynamo_debug_utils, dynamo_utils
from .virtualized import V
log = logging.getLogger(__name__)
@functools.lru_cache(None)
def has_dot():
try:
subprocess.check_output(["which", "dot"], stderr=subprocess.PIPE)
return True
except subprocess.SubprocessError:
return False
def draw_buffers(nodes, print_graph=False, fname=None):
"""
Draw a graph in fname.svg.
nodes is a list of SchedulerNode objects.
"""
if not has_dot():
log.warning("draw_buffers() requires `graphviz` package")
return
if fname is None:
fname = get_graph_being_compiled()
graph = create_fx_from_snodes(nodes)
for node in graph.nodes:
if "fusion_meta" not in node.meta:
continue
group = node.meta["fusion_meta"].group
if isinstance(group, tuple):
group = group[1]
# gather meta data
dtype = None
if isinstance(node, ir.ComputedBuffer):
dtype = node.data.dtype
metadata = TensorMetadata(group, dtype, None, None, None, None, None)
node.meta["tensor_meta"] = metadata
if print_graph:
print(graph)
gm = GraphModule({}, graph)
legalize_graph(gm)
gm.graph.lint()
draw_graph(gm, fname, clear_meta=False)
def create_fx_from_snodes(snodes: List[BaseSchedulerNode]) -> fx.Graph:
"""
Creates a FX Graph from a list of SchedulerNode objects.
"""
def get_fake_func(name):
def func1(*args):
return 0
func1.__name__ = name
return func1
FusionMeta = collections.namedtuple("FusionMeta", ["group", "snodes", "type"])
func_dict = {s: get_fake_func(s) for s in ["extern", "nop", "compute", "fused"]}
buf_to_fx_node = {}
graph = torch.fx.Graph()
first_node = None
outputs = []
group: Any = None
# create call_function node for each Buffer and Kernel
for snode in snodes:
if isinstance(snode, ExternKernelSchedulerNode):
node_type = "extern"
group = node_type
elif isinstance(snode, TemplateSchedulerNode):
node_type = "template"
group = node_type
elif isinstance(snode, NopKernelSchedulerNode):
node_type = "nop"
group = node_type
elif isinstance(snode, SchedulerNode):
node_type = "compute"
group = snode.group
elif isinstance(snode, FusedSchedulerNode):
node_type = "fused"
group = snode.group
else:
raise RuntimeError("Unknown node type")
node_func = func_dict[node_type]
fx_node = graph.call_function(node_func, args=(), kwargs=None)
def in_output(snode):
if isinstance(snode, FusedSchedulerNode):
return any([in_output(x) for x in snode.snodes])
return any([isinstance(user.node, OutputNode) for user in snode.users])
if in_output(snode):
outputs.append(fx_node)
name = snode.get_name()
fx_node.name = name
fx_node.meta["fusion_meta"] = FusionMeta(group, [snode], node_type)
if isinstance(snode, FusedSchedulerNode):
for x in snode.snodes:
buf_to_fx_node[x.get_name()] = fx_node
buf_to_fx_node[name] = fx_node
if first_node is None:
first_node = fx_node
# create edges between nodes
for snode in snodes:
name = snode.get_name()
deps = snode.read_writes.reads
fx_node = buf_to_fx_node[name]
new_args = []
for dep in deps:
if dep.name in buf_to_fx_node:
dep_node = buf_to_fx_node[dep.name]
else:
with graph.inserting_before(first_node):
dep_node = graph.placeholder(dep.name)
buf_to_fx_node[dep.name] = dep_node
new_args.append(dep_node)
fx_node.args = tuple(new_args)
graph.output(outputs[0] if len(outputs) == 1 else tuple(outputs))
return graph
class DebugContext:
_counter = itertools.count()
@staticmethod
def wrap(fn):
@functools.wraps(fn)
def inner(*args, **kwargs):
with DebugContext():
return fn(*args, **kwargs)
return dynamo_debug_utils.wrap_compiler_debug(inner, compiler_name="inductor")
@staticmethod
def create_debug_dir():
for n in DebugContext._counter:
dirname = os.path.join(
dynamo_utils.get_debug_dir(),
"torchinductor",
f"debug.{os.getpid()}.{n}",
)
if not os.path.exists(dirname):
os.makedirs(dirname)
return dirname
def __init__(self):
self._prof = None
self._path = None
self._stack = contextlib.ExitStack()
def rename(self, new_path: str):
if not self._path:
return
assert new_path.endswith(".debug"), new_path
if os.path.exists(new_path):
shutil.rmtree(new_path)
try:
os.rename(self._path, new_path)
self._path = new_path
except OSError:
# other OS might have troubling renaming dir with open files
pass
def fopen(self, filename):
assert self._path
return open(os.path.join(self._path, filename), "w")
def filename(self, suffix):
return os.path.join(self._path, suffix)
def upload_tar(self):
if config.trace.upload_tar is not None:
import tarfile
assert self._path
tar_file = os.path.join(
self._path, f"{os.path.basename(self._path)}.tar.gz"
)
with tarfile.open(tar_file, "w:gz") as tar:
tar.add(self._path, arcname=os.path.basename(self._path))
config.trace.upload_tar(tar_file)
def __enter__(self):
log = logging.getLogger(config.inductor_import)
if not log.handlers:
dynamo_utils.init_logging()
if config.debug:
dynamo_config.log_level = logging.DEBUG
self._stack.enter_context(V.set_debug_handler(self))
if not config.trace.enabled:
return
self._path = self.create_debug_dir()
if config.trace.debug_log:
self._setup_log_capture("debug.log", logging.DEBUG)
if config.trace.info_log:
self._setup_log_capture("info.log", logging.INFO)
if config.trace.compile_profile:
self._prof = cProfile.Profile()
self._prof.enable()
def _setup_log_capture(self, filename, level):
log = logging.getLogger(config.inductor_import)
fd = self._stack.enter_context(self.fopen(filename))
ch = logging.StreamHandler(fd)
ch.setLevel(level)
ch.setFormatter(
logging.Formatter("[%(filename)s:%(lineno)d %(levelname)s] %(message)s")
)
log.addHandler(ch)
log.setLevel(min(log.level, level))
self._stack.callback(log.removeHandler, ch)
def __exit__(self, exc_type, exc_val, exc_tb):
if self._prof:
self._prof.disable()
self._save_profile_data()
if self._path:
self.upload_tar()
log.warning("%s debug trace: %s", get_graph_being_compiled(), self._path)
self._stack.close()
def _save_profile_data(self):
self._prof.dump_stats(self.filename("compile.prof"))
with self.fopen("compile.stats") as fd:
stats = pstats.Stats(self._prof, stream=fd)
stats.strip_dirs()
stats.sort_stats("cumtime")
stats.print_stats(100)
stats.sort_stats("tottime")
stats.print_stats(100)
def __getattr__(self, name):
if config.trace.enabled and getattr(config.trace, name):
try:
return getattr(DebugFormatter(self), name)
except Exception:
log.warning("Ignoring exception in debug code", exc_info=True)
else:
def ignored(*args, **kwargs):
pass
return ignored
SchedulerNodeList = List[Any]
class DebugFormatter:
def __init__(self, handler):
self.fopen = handler.fopen
self.filename = handler.filename
self.handler = handler
def fx_graph(self, gm: torch.fx.GraphModule, inputs: List[torch.Tensor]):
with self.fopen("fx_graph_runnable.py") as fd:
dynamo_debug_utils.save_graph_repro(fd, gm, inputs, "inductor")
with self.fopen("fx_graph_readable.py") as fd:
fd.write(gm.print_readable(print_output=False))
def ir_pre_fusion(self, nodes: SchedulerNodeList):
self._write_ir("ir_pre_fusion.txt", nodes)
def ir_post_fusion(self, nodes: SchedulerNodeList):
self._write_ir("ir_post_fusion.txt", nodes)
def _write_ir(self, filename: str, nodes: SchedulerNodeList):
with self.fopen(filename) as fd:
for node in nodes:
fd.write(node.debug_str())
fd.write("\n\n\n")
def graph_diagram(self, nodes: SchedulerNodeList):
draw_buffers(nodes, fname=self.filename("graph_diagram.svg"))
def output_code(self, filename):
shutil.copy(filename, self.filename("output_code.py"))