Time how long compilation takes.

Also, still give time even if we throw an error midway.

Signed-off-by: Edward Z. Yang <ezyang@fb.com>
diff --git a/torch/jit.py b/torch/jit.py
index 562a22f..402a999 100644
--- a/torch/jit.py
+++ b/torch/jit.py
@@ -314,6 +314,7 @@
         #
         # NB: Class variables are also accessible via self!
         cls.__params = ParameterList(list(params))
+        kwargs["time"] = time  # also want to pass onto ktrace
         cls.__ktrace_kwargs = kwargs
         cls.__enabled = enabled
         cls.__time = time
@@ -410,7 +411,7 @@
     #   - Whenever we want to run this trace, we call 'maybe_closure'.  This
     #     returns None if we don't have a complete trace yet, or the
     #     autograd closure to actually run the trace if we do.
-    def __init__(self, name, key, nderivs=1, optimize=True, volatile=False):
+    def __init__(self, name, key, nderivs=1, optimize=True, volatile=False, time=False):
         self.name = name
         self.key = key
         # TODO: Not convinced about this volatile special case...
@@ -419,6 +420,7 @@
         self.traces = []
         self.closure = None
         self.out_struct = None  # initialized when we call trace, checked thereafter
+        self.time = time
 
     # The signature here is a little goofy; it's a perf optimization.
     # Additionally, f is passed in as an argument (even though it is fixed as
@@ -461,18 +463,19 @@
             _dump_trace(self.name, pass_name, self.key, trace)
             torch._C._jit_pass_lint(trace)
 
-        _dump_trace(self.name, "init", self.key, complete_trace)
+        with _time(self.name, "compiling", self.time):
+            _dump_trace(self.name, "init", self.key, complete_trace)
 
-        # It's important to always run DCE, because backward can create a lot of unnecessary nodes
-        _run_pass(torch._C._jit_pass_dce, complete_trace)
-        if self.optimize:
-            _run_pass(torch._C._jit_pass_onnx, complete_trace)
-            _run_pass(torch._C._jit_pass_fuse, complete_trace)
+            # It's important to always run DCE, because backward can create a lot of unnecessary nodes
+            _run_pass(torch._C._jit_pass_dce, complete_trace)
+            if self.optimize:
+                _run_pass(torch._C._jit_pass_onnx, complete_trace)
+                _run_pass(torch._C._jit_pass_fuse, complete_trace)
 
-        _dump_trace(self.name, "final", self.key, complete_trace)
+            _dump_trace(self.name, "final", self.key, complete_trace)
 
-        self.closure = torch._C._jit_createAutogradClosure(complete_trace)
-        return self.closure
+            self.closure = torch._C._jit_createAutogradClosure(complete_trace)
+            return self.closure
 
 
 def vars_key(in_vars):
@@ -575,10 +578,12 @@
     start = torch.cuda.Event(enable_timing=True)
     end = torch.cuda.Event(enable_timing=True)
     stream.record_event(start)
-    yield
-    stream.record_event(end)
-    end.synchronize()
-    print("{} {} time: {} ms".format(trace_name, name, start.elapsed_time(end)))
+    try:
+        yield
+    finally:
+        stream.record_event(end)
+        end.synchronize()
+        print("{} {} time: {} ms".format(trace_name, name, start.elapsed_time(end)))
 
 
 if not torch._C._jit_init():