Implement torch.util.bottleneck (#5216)
* Implement torch.util.bottleneck
This is a tool that is intended to be used as initial exploratory
debugging of bottlenecks in user scripts. Run it with
python -m torch.utils.bottleneck /path/to/source/script.py
* Refactor and address comments
* Fix tests
* Allow passing of args to the profiled script
* Replace Variable
diff --git a/docs/source/bottleneck.rst b/docs/source/bottleneck.rst
new file mode 100644
index 0000000..f806d27
--- /dev/null
+++ b/docs/source/bottleneck.rst
@@ -0,0 +1,31 @@
+torch.utils.bottleneck
+===============
+
+.. currentmodule:: torch.utils.bottleneck
+
+`torch.utils.bottleneck` is a tool that can be used as an initial step for
+debugging bottlenecks in your program. It summarizes runs of your script with
+the Python profiler and PyTorch's autograd profiler.
+
+Run it on the command line with
+
+::
+
+ python -m torch.utils.bottleneck -- /path/to/source/script.py [args]
+
+where [args] are any number of arguments to `script.py`, or run
+``python -m torch.utils.bottleneck -h`` for more usage instructions.
+
+.. warning::
+ Because your script will be profiled, please ensure that it exits in a
+ finite amount of time.
+
+.. warning::
+ Due to the asynchronous nature of CUDA kernels, when running against
+ CUDA code, the cProfile output and CPU-mode autograd profilers may
+ not show correct timings. In this case, the CUDA-mode autograd
+ profiler is better at assigning blame to the relevant operator(s).
+
+For more complicated uses of the profilers (like in a multi-GPU case),
+please see https://docs.python.org/3/library/profile.html
+or :func:`torch.autograd.profiler.profile()` for more information.
diff --git a/docs/source/index.rst b/docs/source/index.rst
index 83a6c66..6c8aaf0 100644
--- a/docs/source/index.rst
+++ b/docs/source/index.rst
@@ -39,6 +39,7 @@
data
model_zoo
onnx
+ bottleneck
.. toctree::
:glob:
diff --git a/test/bottleneck/test.py b/test/bottleneck/test.py
new file mode 100644
index 0000000..30e2307
--- /dev/null
+++ b/test/bottleneck/test.py
@@ -0,0 +1,4 @@
+import torch
+
+x = torch.ones((3, 3), requires_grad=True)
+(3 * x).sum().backward()
diff --git a/test/bottleneck/test_args.py b/test/bottleneck/test_args.py
new file mode 100644
index 0000000..cddb6a6
--- /dev/null
+++ b/test/bottleneck/test_args.py
@@ -0,0 +1,13 @@
+import argparse
+import torch
+
+if __name__ == '__main__':
+ parser = argparse.ArgumentParser()
+
+ # Required args. Raises error if they aren't passed.
+ parser.add_argument('--foo', help='foo', required=True)
+ parser.add_argument('--bar', help='bar', required=True)
+ _ = parser.parse_args()
+
+ x = torch.ones((3, 3), requires_grad=True)
+ (3 * x).sum().backward()
diff --git a/test/bottleneck/test_cuda.py b/test/bottleneck/test_cuda.py
new file mode 100644
index 0000000..60d2f4b
--- /dev/null
+++ b/test/bottleneck/test_cuda.py
@@ -0,0 +1,27 @@
+import torch
+import torch.nn as nn
+
+
+class Model(nn.Module):
+ def __init__(self):
+ super(Model, self).__init__()
+ self.linear = nn.Linear(20, 20)
+
+ def forward(self, input):
+ out = self.linear(input[:, 10:30])
+ return out.sum()
+
+
+def main():
+ data = torch.randn(10, 50).cuda()
+ model = Model().cuda()
+ optimizer = torch.optim.SGD(model.parameters(), lr=0.0001)
+ for i in range(10):
+ optimizer.zero_grad()
+ loss = model(data)
+ loss.backward()
+ optimizer.step()
+
+
+if __name__ == '__main__':
+ main()
diff --git a/test/test_utils.py b/test/test_utils.py
index d2888f8..3853595 100644
--- a/test/test_utils.py
+++ b/test/test_utils.py
@@ -1,6 +1,7 @@
from __future__ import print_function
import sys
import os
+import re
import math
import shutil
import random
@@ -385,6 +386,105 @@
return input, target.sub(1)
+class TestBottleneck(TestCase):
+ def _run(self, command):
+ """Returns (return-code, stdout, stderr)"""
+ import subprocess
+ from common import PY3
+
+ p = subprocess.Popen(command, stdout=subprocess.PIPE,
+ stderr=subprocess.PIPE, shell=True)
+ output, err = p.communicate()
+ rc = p.returncode
+ if PY3:
+ output = output.decode("ascii")
+ err = err.decode("ascii")
+ return (rc, output, err)
+
+ def _run_bottleneck(self, test_file, scriptargs=''):
+ import os
+ curdir = os.path.dirname(os.path.abspath(__file__))
+ filepath = '{}/{}'.format(curdir, test_file)
+ if scriptargs != '':
+ mark = '-- '
+ scriptargs = ' {}'.format(scriptargs)
+ else:
+ mark = ''
+ rc, out, err = self._run(
+ 'python -m torch.utils.bottleneck {}{}{}'.format(mark, filepath, scriptargs))
+ return rc, out, err
+
+ def _check_run_args(self):
+ # Check that this fails due to missing args
+ rc, out, err = self._run_bottleneck('bottleneck/test_args.py')
+ self.assertEqual(rc, 2, None, self._fail_msg('Missing args should error', out + err))
+
+ # This should succeed
+ rc, out, err = self._run_bottleneck('bottleneck/test_args.py', '--foo foo --bar bar')
+ self.assertEqual(rc, 0, None, self._fail_msg('Should pass args to script', out + err))
+
+ def _fail_msg(self, msg, output):
+ return '{}, output was:\n{}'.format(msg, output)
+
+ def _check_environment_summary(self, output):
+ results = re.search('Environment Summary', output)
+ self.assertIsNotNone(results, self._fail_msg('Should have Enviroment Summary', output))
+
+ # Up to five lines away from the heading, there should be the version number
+ results = re.search(r'Environment Summary.*(\n.*){,5}\nPyTorch \d+\.\d+', output)
+ self.assertIsNotNone(results, self._fail_msg('Should have PyTorch version', output))
+
+ def _check_cprof_summary(self, output):
+ results = re.search('cProfile output', output)
+ self.assertIsNotNone(results, self._fail_msg('Should have cProfile output', output))
+
+ # This assumes that after the cProfile output section we have
+ # the autograd profiler output
+ results = re.search(r'cProfile output.*(\n.*){6,50}\n.*autograd profiler output', output)
+ self.assertIsNotNone(results, self._fail_msg(
+ 'Distance between cProfile and autograd prof out not in [6, 50] lines', output))
+
+ def _check_autograd_summary(self, output):
+ results = re.search('autograd profiler output', output)
+ self.assertIsNotNone(results, self._fail_msg('Should have autograd profiler output', output))
+
+ # This assumes that after the autograd profiler output is the end of the
+ # output.
+ results = re.search(r'autograd profiler output.*(\n.*){6,100}', output)
+ self.assertIsNotNone(results, self._fail_msg(
+ 'Distance between autograd prof output and end of output not in [6, 100] lines', output))
+
+ def _check_cuda(self, output):
+ if torch.cuda.is_available():
+ results = re.search('CUDA mode', output)
+ self.assertIsNotNone(results, self._fail_msg('Should tell users CUDA', output))
+ else:
+ results = re.search('CUDA mode', output)
+ self.assertIsNone(results, self._fail_msg('Should not tell users about CUDA', output))
+
+ @unittest.skipIf(torch.cuda.is_available(), 'CPU-only test')
+ def test_cpu_only(self):
+ rc, out, err = self._run_bottleneck('bottleneck/test.py')
+ self.assertEqual(rc, 0, 'Run failed with\n{}'.format(err))
+
+ self._check_run_args()
+ self._check_environment_summary(out)
+ self._check_autograd_summary(out)
+ self._check_cprof_summary(out)
+ self._check_cuda(out)
+
+ @unittest.skipIf(not torch.cuda.is_available(), 'No CUDA')
+ def test_cuda(self):
+ rc, out, err = self._run_bottleneck('bottleneck/test_cuda.py')
+ self.assertEqual(rc, 0, 'Run failed with\n{}'.format(err))
+
+ self._check_run_args()
+ self._check_environment_summary(out)
+ self._check_autograd_summary(out)
+ self._check_cprof_summary(out)
+ self._check_cuda(out)
+
+
class TestONNXUtils(TestCase):
def test_prepare_onnx_paddings(self):
sizes = [2, 3, 4]
diff --git a/torch/utils/bottleneck/__init__.py b/torch/utils/bottleneck/__init__.py
new file mode 100644
index 0000000..e69de29
--- /dev/null
+++ b/torch/utils/bottleneck/__init__.py
diff --git a/torch/utils/bottleneck/__main__.py b/torch/utils/bottleneck/__main__.py
new file mode 100644
index 0000000..f821b72
--- /dev/null
+++ b/torch/utils/bottleneck/__main__.py
@@ -0,0 +1,280 @@
+import argparse
+import cProfile
+import pstats
+import subprocess
+import sys
+import os
+import re
+import contextlib
+
+import torch
+from torch.autograd import profiler
+
+PY3 = sys.version_info >= (3, 0)
+
+
+def run(command):
+ """Returns (return-code, stdout, stderr)"""
+ p = subprocess.Popen(command, stdout=subprocess.PIPE,
+ stderr=subprocess.PIPE, shell=True)
+ output, err = p.communicate()
+ rc = p.returncode
+ if PY3:
+ output = output.decode("ascii")
+ err = err.decode("ascii")
+ return (rc, output, err)
+
+
+def redirect_argv(new_argv):
+ sys.argv[:] = new_argv[:]
+
+
+def check_running_cuda_version():
+ (rc, out, err) = run('nvcc --version')
+ if rc is not 0:
+ return None
+ m = re.search(r'V(.*)$', out)
+ assert m is not None
+ return m.group(1)
+
+
+def check_pip_packages():
+ # People generally have `pip` as `pip` or `pip3`
+ def run_with_pip(pip):
+ rc, out, _ = run(pip + ' list --format=legacy | grep torch')
+ if rc is 0:
+ return out
+ return None
+
+ if not PY3:
+ return 'pip', run_with_pip('pip')
+
+ # Try to figure out if the user is running pip or pip3.
+ out2 = run_with_pip('pip')
+ out3 = run_with_pip('pip3')
+
+ num_pips = len([x for x in [out2, out3] if x is not None])
+ if num_pips is 0:
+ return 'pip', out2
+
+ if num_pips == 1:
+ if out2 is not None:
+ return 'pip', out2
+ return 'pip3', out3
+
+ # num_pips is 2. Return pip3 by default b/c that most likely
+ # is the one associated with Python 3
+ return 'pip3', out3
+
+
+def compiled_with_cuda():
+ if torch.version.cuda:
+ return 'compiled w/ CUDA {}'.format(torch.version.cuda)
+ return 'not compiled w/ CUDA'
+
+
+env_summary = """
+--------------------------------------------------------------------------------
+ Environment Summary
+--------------------------------------------------------------------------------
+PyTorch {pytorch_version}{debug_str} {cuda_compiled}
+Running with Python {py_version} and {cuda_runtime}
+
+`{pip_version} list` truncated output:
+{pip_list_output}
+""".strip()
+
+
+def run_env_analysis():
+ print('Running environment analysis...')
+ result = []
+
+ debug_str = ''
+ if torch.version.debug:
+ debug_str = ' DEBUG'
+
+ cuda_avail = ''
+ if torch.cuda.is_available():
+ cuda = check_running_cuda_version()
+ if cuda is not None:
+ cuda_avail = 'CUDA ' + cuda
+ else:
+ cuda = 'CUDA unavailable'
+
+ pip_version, pip_list_output = check_pip_packages()
+ if pip_list_output is None:
+ pip_list_output = 'Unable to fetch'
+
+ result = {
+ 'debug_str': debug_str,
+ 'pytorch_version': torch.__version__,
+ 'cuda_compiled': compiled_with_cuda(),
+ 'py_version': '{}.{}'.format(sys.version_info[0], sys.version_info[1]),
+ 'cuda_runtime': cuda_avail,
+ 'pip_version': pip_version,
+ 'pip_list_output': pip_list_output,
+ }
+
+ return env_summary.format(**result)
+
+
+def run_cprofile(code, globs, launch_blocking=False):
+ print('Running your script with cProfile')
+ prof = cProfile.Profile()
+ prof.enable()
+ exec(code, globs, None)
+ prof.disable()
+ return prof
+
+
+cprof_summary = """
+--------------------------------------------------------------------------------
+ cProfile output
+--------------------------------------------------------------------------------
+""".strip()
+
+
+def print_cprofile_summary(prof, sortby='tottime', topk=15):
+ result = {}
+
+ print(cprof_summary.format(**result))
+
+ cprofile_stats = pstats.Stats(prof).sort_stats(sortby)
+ cprofile_stats.print_stats(topk)
+
+
+def run_autograd_prof(code, globs):
+ def run_prof(use_cuda=False):
+ with profiler.profile(use_cuda=use_cuda) as prof:
+ exec(code, globs, None)
+ return prof
+
+ print('Running your script with the autograd profiler...')
+ result = [run_prof(use_cuda=False)]
+ if torch.cuda.is_available():
+ result.append(run_prof(use_cuda=True))
+ else:
+ result.append(None)
+
+ return result
+
+
+autograd_prof_summary = """
+--------------------------------------------------------------------------------
+ autograd profiler output ({mode} mode)
+--------------------------------------------------------------------------------
+ {description}
+{cuda_warning}
+{output}
+""".strip()
+
+
+def print_autograd_prof_summary(prof, mode, sortby='cpu_time', topk=15):
+ valid_sortby = ['cpu_time', 'cuda_time', 'cpu_time_total', 'cuda_time_total', 'count']
+ if sortby not in valid_sortby:
+ warn = ('WARNING: invalid sorting option for autograd profiler results: {}\n'
+ 'Expected `cpu_time`, `cpu_time_total`, or `count`. '
+ 'Defaulting to `cpu_time`.')
+ print(warn.format(autograd_prof_sortby))
+ sortby = 'cpu_time'
+
+ if mode is 'CUDA':
+ cuda_warning = ('\n\tBecause the autograd profiler uses the CUDA event API,\n'
+ '\tthe CUDA time column reports approximately max(cuda_time, cpu_time).\n'
+ '\tPlease ignore this output if your code does not use CUDA.\n')
+ else:
+ cuda_warning = ''
+
+ sorted_events = sorted(prof.function_events,
+ key=lambda x: getattr(x, sortby), reverse=True)
+ topk_events = sorted_events[:topk]
+
+ result = {
+ 'mode': mode,
+ 'description': 'top {} events sorted by {}'.format(topk, sortby),
+ 'output': torch.autograd.profiler.build_table(topk_events),
+ 'cuda_warning': cuda_warning
+ }
+
+ print(autograd_prof_summary.format(**result))
+
+
+descript = """
+`bottleneck` is a tool that can be used as an initial step for debugging
+bottlenecks in your program.
+
+It summarizes runs of your script with the Python profiler and PyTorch\'s
+autograd profiler. Because your script will be profiled, please ensure that it
+exits in a finite amount of time.
+
+For more complicated uses of the profilers, please see
+https://docs.python.org/3/library/profile.html and
+http://pytorch.org/docs/master/autograd.html#profiler for more information.
+""".strip()
+
+
+def parse_args():
+ parser = argparse.ArgumentParser(description=descript)
+ parser.add_argument('scriptfile', type=str,
+ help='Path to the script to be run. '
+ 'Usually run with `python path/to/script`.')
+ parser.add_argument('args', type=str, nargs='*',
+ help='Command-line arguments to be passed to the script.')
+ return parser.parse_args()
+
+
+def cpu_time_total(autograd_prof):
+ return sum([event.cpu_time_total for event in autograd_prof.function_events])
+
+
+def main():
+ args = parse_args()
+
+ # Customizable constants.
+ scriptfile = args.scriptfile
+ scriptargs = [] if args.args is None else args.args
+ scriptargs.insert(0, scriptfile)
+ cprofile_sortby = 'tottime'
+ cprofile_topk = 15
+ autograd_prof_sortby = 'cpu_time_total'
+ autograd_prof_topk = 15
+
+ redirect_argv(scriptargs)
+
+ sys.path.insert(0, os.path.dirname(scriptfile))
+ with open(scriptfile, 'rb') as stream:
+ code = compile(stream.read(), scriptfile, 'exec')
+ globs = {
+ '__file__': scriptfile,
+ '__name__': '__main__',
+ '__package__': None,
+ '__cached__': None,
+ }
+
+ print(descript)
+
+ env_summary = run_env_analysis()
+
+ if torch.cuda.is_available():
+ torch.cuda.init()
+ cprofile_prof = run_cprofile(code, globs)
+ autograd_prof_cpu, autograd_prof_cuda = run_autograd_prof(code, globs)
+
+ print(env_summary)
+ print_cprofile_summary(cprofile_prof, cprofile_sortby, cprofile_topk)
+
+ if not torch.cuda.is_available():
+ print_autograd_prof_summary(autograd_prof_cpu, 'CPU', autograd_prof_sortby, autograd_prof_topk)
+ return
+
+ # Print both the result of the CPU-mode and CUDA-mode autograd profilers
+ # if their execution times are very different.
+ cuda_prof_exec_time = cpu_time_total(autograd_prof_cuda)
+ cpu_prof_exec_time = cpu_time_total(autograd_prof_cpu)
+ pct_diff = cuda_prof_exec_time - cpu_prof_exec_time / cuda_prof_exec_time
+ if abs(pct_diff) > 0.05:
+ print_autograd_prof_summary(autograd_prof_cpu, 'CPU', autograd_prof_sortby, autograd_prof_topk)
+ print_autograd_prof_summary(autograd_prof_cuda, 'CUDA', autograd_prof_sortby, autograd_prof_topk)
+
+if __name__ == '__main__':
+ main()