Time how long compilation takes.
Also, still give time even if we throw an error midway.
Signed-off-by: Edward Z. Yang <ezyang@fb.com>
diff --git a/torch/jit.py b/torch/jit.py
index 562a22f..402a999 100644
--- a/torch/jit.py
+++ b/torch/jit.py
@@ -314,6 +314,7 @@
#
# NB: Class variables are also accessible via self!
cls.__params = ParameterList(list(params))
+ kwargs["time"] = time # also want to pass onto ktrace
cls.__ktrace_kwargs = kwargs
cls.__enabled = enabled
cls.__time = time
@@ -410,7 +411,7 @@
# - Whenever we want to run this trace, we call 'maybe_closure'. This
# returns None if we don't have a complete trace yet, or the
# autograd closure to actually run the trace if we do.
- def __init__(self, name, key, nderivs=1, optimize=True, volatile=False):
+ def __init__(self, name, key, nderivs=1, optimize=True, volatile=False, time=False):
self.name = name
self.key = key
# TODO: Not convinced about this volatile special case...
@@ -419,6 +420,7 @@
self.traces = []
self.closure = None
self.out_struct = None # initialized when we call trace, checked thereafter
+ self.time = time
# The signature here is a little goofy; it's a perf optimization.
# Additionally, f is passed in as an argument (even though it is fixed as
@@ -461,18 +463,19 @@
_dump_trace(self.name, pass_name, self.key, trace)
torch._C._jit_pass_lint(trace)
- _dump_trace(self.name, "init", self.key, complete_trace)
+ with _time(self.name, "compiling", self.time):
+ _dump_trace(self.name, "init", self.key, complete_trace)
- # It's important to always run DCE, because backward can create a lot of unnecessary nodes
- _run_pass(torch._C._jit_pass_dce, complete_trace)
- if self.optimize:
- _run_pass(torch._C._jit_pass_onnx, complete_trace)
- _run_pass(torch._C._jit_pass_fuse, complete_trace)
+ # It's important to always run DCE, because backward can create a lot of unnecessary nodes
+ _run_pass(torch._C._jit_pass_dce, complete_trace)
+ if self.optimize:
+ _run_pass(torch._C._jit_pass_onnx, complete_trace)
+ _run_pass(torch._C._jit_pass_fuse, complete_trace)
- _dump_trace(self.name, "final", self.key, complete_trace)
+ _dump_trace(self.name, "final", self.key, complete_trace)
- self.closure = torch._C._jit_createAutogradClosure(complete_trace)
- return self.closure
+ self.closure = torch._C._jit_createAutogradClosure(complete_trace)
+ return self.closure
def vars_key(in_vars):
@@ -575,10 +578,12 @@
start = torch.cuda.Event(enable_timing=True)
end = torch.cuda.Event(enable_timing=True)
stream.record_event(start)
- yield
- stream.record_event(end)
- end.synchronize()
- print("{} {} time: {} ms".format(trace_name, name, start.elapsed_time(end)))
+ try:
+ yield
+ finally:
+ stream.record_event(end)
+ end.synchronize()
+ print("{} {} time: {} ms".format(trace_name, name, start.elapsed_time(end)))
if not torch._C._jit_init():