Make JIT tracing a thread-local property (#9414)

Summary:
As in the title. Lets us simplify a lot of code.

Depends on #9363, so please review only the last commit.

zdevito
Pull Request resolved: https://github.com/pytorch/pytorch/pull/9414

Reviewed By: zdevito

Differential Revision: D8836496

Pulled By: apaszke

fbshipit-source-id: 9b3c3d1f001a9dc522f8478abc005b6b86cfa3e3
diff --git a/test/expect/TestScript.test_call_python_fn_from_traced_module.expect b/test/expect/TestScript.test_call_python_fn_from_traced_module.expect
index 2a87a36..4c9e2e2 100644
--- a/test/expect/TestScript.test_call_python_fn_from_traced_module.expect
+++ b/test/expect/TestScript.test_call_python_fn_from_traced_module.expect
@@ -1,6 +1,6 @@
 graph(%0 : Double(3, 4)
       %1 : Double(4, 3)) {
-  %2 : Double(3, 4) = aten::neg(%0)
-  %4 : Double(3, 3) = aten::mm(%2, %1)
+  %2 : Double(3, 4) = aten::neg(%0), scope: TracedModule
+  %4 : Double(3, 3) = aten::mm(%2, %1), scope: TracedModule
   return (%4);
 }
diff --git a/test/expect/TestScript.test_call_python_mod_from_traced_module.expect b/test/expect/TestScript.test_call_python_mod_from_traced_module.expect
index 925bbf1..d39acaf 100644
--- a/test/expect/TestScript.test_call_python_mod_from_traced_module.expect
+++ b/test/expect/TestScript.test_call_python_mod_from_traced_module.expect
@@ -1,8 +1,8 @@
 graph(%0 : Double(3, 4)
       %1 : Double(4, 5)
       %2 : Double(5, 7)) {
-  %4 : Double(3, 5) = aten::mm(%0, %1)
-  %6 : Double(3, 7) = aten::mm(%4, %2)
-  %7 : Double(3, 7) = aten::add[other={1}, alpha={1}](%6)
+  %4 : Double(3, 5) = aten::mm(%0, %1), scope: TracedModule
+  %6 : Double(3, 7) = aten::mm(%4, %2), scope: TracedModule/PythonModule[mod]
+  %7 : Double(3, 7) = aten::add[other={1}, alpha={1}](%6), scope: TracedModule
   return (%7);
 }
diff --git a/test/expect/TestScript.test_call_python_mod_from_tracing_fn.expect b/test/expect/TestScript.test_call_python_mod_from_tracing_fn.expect
index 4de15a5..ea847d6 100644
--- a/test/expect/TestScript.test_call_python_mod_from_tracing_fn.expect
+++ b/test/expect/TestScript.test_call_python_mod_from_tracing_fn.expect
@@ -1,6 +1,6 @@
 graph(%0 : Double(3, 4)) {
-  %1 : Double(4, 3) = prim::Constant[value=<Tensor>]()
-  %3 : Double(3, 3) = aten::mm(%0, %1)
+  %1 : Double(4, 3) = prim::Constant[value=<Tensor>](), scope: PythonMod
+  %3 : Double(3, 3) = aten::mm(%0, %1), scope: PythonMod
   %4 : Double(3, 3) = aten::add[other={1}, alpha={1}](%3)
   return (%4);
 }
diff --git a/test/expect/TestScript.test_call_script_fn_from_traced_module.expect b/test/expect/TestScript.test_call_script_fn_from_traced_module.expect
index adaab38..6bf57b8 100644
--- a/test/expect/TestScript.test_call_script_fn_from_traced_module.expect
+++ b/test/expect/TestScript.test_call_script_fn_from_traced_module.expect
@@ -1,6 +1,6 @@
 graph(%0 : Double(3, 4)
       %1 : Double(4, 5)) {
-  %3 : Double(3, 5) = aten::mm(%0, %1)
-  %5 : Double(3, 5) = aten::neg(%3)
+  %3 : Double(3, 5) = aten::mm(%0, %1), scope: TracedModule
+  %5 : Double(3, 5) = aten::neg(%3), scope: TracedModule/ScriptModule
   return (%5);
 }
diff --git a/test/expect/TestScript.test_call_script_fn_from_tracing_fn.expect b/test/expect/TestScript.test_call_script_fn_from_tracing_fn.expect
index cffec80..dc8b494 100644
--- a/test/expect/TestScript.test_call_script_fn_from_tracing_fn.expect
+++ b/test/expect/TestScript.test_call_script_fn_from_tracing_fn.expect
@@ -1,5 +1,5 @@
 graph(%0 : Double(3, 4)) {
-  %2 : Double(3, 4) = aten::neg(%0)
+  %2 : Double(3, 4) = aten::neg(%0), scope: ScriptModule
   %3 : Double(3, 4) = aten::add[other={1}, alpha={1}](%2)
   return (%3);
 }
diff --git a/test/expect/TestScript.test_call_script_mod_from_tracing_fn.expect b/test/expect/TestScript.test_call_script_mod_from_tracing_fn.expect
index d446882..fc7039b 100644
--- a/test/expect/TestScript.test_call_script_mod_from_tracing_fn.expect
+++ b/test/expect/TestScript.test_call_script_mod_from_tracing_fn.expect
@@ -1,6 +1,6 @@
 graph(%0 : Double(3, 4)) {
-  %1 : Double(4, 3) = prim::Constant[value=<Tensor>]()
-  %4 : Double(3, 3) = aten::mm(%0, %1)
+  %1 : Double(4, 3) = prim::Constant[value=<Tensor>](), scope: ScriptMod
+  %4 : Double(3, 3) = aten::mm(%0, %1), scope: ScriptMod
   %5 : Double(3, 3) = aten::add[other={1}, alpha={1}](%4)
   return (%5);
 }
diff --git a/test/expect/TestScript.test_call_script_module_from_traced_module.expect b/test/expect/TestScript.test_call_script_module_from_traced_module.expect
index c249ddc..21b14a2 100644
--- a/test/expect/TestScript.test_call_script_module_from_traced_module.expect
+++ b/test/expect/TestScript.test_call_script_module_from_traced_module.expect
@@ -1,8 +1,8 @@
 graph(%0 : Double(3, 4)
       %1 : Double(4, 5)
       %2 : Double(5, 7)) {
-  %4 : Double(3, 5) = aten::mm(%0, %1)
-  %7 : Double(3, 7) = aten::mm(%4, %2)
-  %8 : Double(3, 7) = aten::add[other={1}, alpha={1}](%7)
+  %4 : Double(3, 5) = aten::mm(%0, %1), scope: TracedModule
+  %7 : Double(3, 7) = aten::mm(%4, %2), scope: TracedModule/ScriptMod[mod]
+  %8 : Double(3, 7) = aten::add[other={1}, alpha={1}](%7), scope: TracedModule
   return (%8);
 }
diff --git a/test/expect/TestScript.test_call_traced_fn_from_traced_module.expect b/test/expect/TestScript.test_call_traced_fn_from_traced_module.expect
index 4e25a85..f45c3f1 100644
--- a/test/expect/TestScript.test_call_traced_fn_from_traced_module.expect
+++ b/test/expect/TestScript.test_call_traced_fn_from_traced_module.expect
@@ -1,6 +1,6 @@
 graph(%0 : Double(3, 4)
       %1 : Double(4, 5)) {
-  %3 : Double(3, 5) = aten::mm(%0, %1)
-  %5 : Double(3, 4) = aten::neg(%3)
+  %3 : Double(3, 5) = aten::mm(%0, %1), scope: TracedModule
+  %5 : Double(3, 4) = aten::neg(%3), scope: TracedModule/traced_fn
   return (%5);
 }
diff --git a/test/expect/TestScript.test_call_traced_fn_from_tracing_fn.expect b/test/expect/TestScript.test_call_traced_fn_from_tracing_fn.expect
index cffec80..ed737f4 100644
--- a/test/expect/TestScript.test_call_traced_fn_from_tracing_fn.expect
+++ b/test/expect/TestScript.test_call_traced_fn_from_tracing_fn.expect
@@ -1,5 +1,5 @@
 graph(%0 : Double(3, 4)) {
-  %2 : Double(3, 4) = aten::neg(%0)
+  %2 : Double(3, 4) = aten::neg(%0), scope: traced_fn1
   %3 : Double(3, 4) = aten::add[other={1}, alpha={1}](%2)
   return (%3);
 }
diff --git a/test/expect/TestScript.test_call_traced_mod_from_tracing_fn.expect b/test/expect/TestScript.test_call_traced_mod_from_tracing_fn.expect
index d446882..3fac45f 100644
--- a/test/expect/TestScript.test_call_traced_mod_from_tracing_fn.expect
+++ b/test/expect/TestScript.test_call_traced_mod_from_tracing_fn.expect
@@ -1,6 +1,6 @@
 graph(%0 : Double(3, 4)) {
-  %1 : Double(4, 3) = prim::Constant[value=<Tensor>]()
-  %4 : Double(3, 3) = aten::mm(%0, %1)
+  %1 : Double(4, 3) = prim::Constant[value=<Tensor>](), scope: TracedModule[TracedModule]
+  %4 : Double(3, 3) = aten::mm(%0, %1), scope: TracedModule[TracedModule]
   %5 : Double(3, 3) = aten::add[other={1}, alpha={1}](%4)
   return (%5);
 }
diff --git a/test/expect/TestScript.test_call_traced_module_from_traced_module.expect b/test/expect/TestScript.test_call_traced_module_from_traced_module.expect
index c249ddc..471f9f1 100644
--- a/test/expect/TestScript.test_call_traced_module_from_traced_module.expect
+++ b/test/expect/TestScript.test_call_traced_module_from_traced_module.expect
@@ -1,8 +1,8 @@
 graph(%0 : Double(3, 4)
       %1 : Double(4, 5)
       %2 : Double(5, 7)) {
-  %4 : Double(3, 5) = aten::mm(%0, %1)
-  %7 : Double(3, 7) = aten::mm(%4, %2)
-  %8 : Double(3, 7) = aten::add[other={1}, alpha={1}](%7)
+  %4 : Double(3, 5) = aten::mm(%0, %1), scope: TracedModule
+  %7 : Double(3, 7) = aten::mm(%4, %2), scope: TracedModule/TracedModule[TracedModule1][mod]
+  %8 : Double(3, 7) = aten::add[other={1}, alpha={1}](%7), scope: TracedModule
   return (%8);
 }
diff --git a/test/onnx/test_operators.py b/test/onnx/test_operators.py
index cbde1fa..dd60188 100644
--- a/test/onnx/test_operators.py
+++ b/test/onnx/test_operators.py
@@ -432,7 +432,7 @@
             return g.op('Sum', x, y[0], y[1]), (
                 g.op('Neg', x), g.op('Neg', y[0]))
 
-        @torch.onnx.symbolic_override_first_arg_based(symb)
+        @torch.onnx.symbolic_override(symb)
         def foo(x, y):
             return x + y[0] + y[1], (-x, -y[0])
 
diff --git a/test/test_jit.py b/test/test_jit.py
index 9a1661e..9a12ca9 100644
--- a/test/test_jit.py
+++ b/test/test_jit.py
@@ -333,9 +333,9 @@
 
         def f(x, y):
             out = x + y
-            with torch.jit.scope('Foo', out):
+            with torch.jit.scope('Foo'):
                 out = x * out
-                with torch.jit.scope('Bar', out):
+                with torch.jit.scope('Bar'):
                     out = torch.tanh(out)
                 out = torch.sigmoid(out)
             return out
diff --git a/tools/autograd/gen_variable_type.py b/tools/autograd/gen_variable_type.py
index 679e3c9..97130b1 100644
--- a/tools/autograd/gen_variable_type.py
+++ b/tools/autograd/gen_variable_type.py
@@ -54,6 +54,13 @@
     's_native_mul': 'mul',
     'th_addmm': 'addmm',
     's_native_addmm': 'addmm',
+    'zero': 'zeros_like',
+    'fill': 'full_like',
+}
+
+# (declaration name, argument name) -> attribute name
+RENAME_ATTRIBUTES = {
+    ('fill_', 'value'): 'fill_value'
 }
 
 # These functions are not worth profiling because they are very cheap and may
@@ -126,7 +133,7 @@
 
 PRE_RECORD_TRACE = CodeTemplate("""\
 jit::tracer::PreTraceInfo trace_info;
-if (jit::tracer::isTracing( ${tensor_args} )) {
+if (jit::tracer::isTracing()) {
   trace_info = jit::tracer::preRecordTrace( jit::aten::${trace_name}, ${trace_inputs} );
   if (!jit::tracer::ArgumentStash::empty()) {
     ${record_positional_attributes}
@@ -138,13 +145,13 @@
 """)
 
 POST_RECORD_TRACE = CodeTemplate("""\
-if (trace_info.state != nullptr) {
+if (jit::tracer::isTracing()) {
   jit::tracer::postRecordTrace( trace_info,  ${trace_outputs} );
 }
 """)
 
 RECORD_ATTRIBUTE = CodeTemplate("""\
-setattr(trace_info.n, jit::attr::${name}, ${name});""")
+setattr(trace_info.n, jit::attr::${attr_name}, ${name});""")
 
 RECORD_POSITIONAL_ATTRIBUTE = CodeTemplate("""\
 setposattr(trace_info.n, ${i}, "${name}", ${name});""")
@@ -417,7 +424,8 @@
         for arg in declaration['arguments']:
             if arg['simple_type'] in {'Tensor', 'TensorList'}:
                 continue
-            local['record_attributes'].append(RECORD_ATTRIBUTE.substitute(name=arg['name']))
+            attr_name = RENAME_ATTRIBUTES.get((declaration['name'], arg['name']), arg['name'])
+            local['record_attributes'].append(RECORD_ATTRIBUTE.substitute(attr_name=attr_name, name=arg['name']))
 
         local['record_positional_attributes'] = []
         for i, arg in enumerate(declaration['arguments']):
diff --git a/tools/autograd/templates/VariableType.cpp b/tools/autograd/templates/VariableType.cpp
index 0ca293f..03e0f64 100644
--- a/tools/autograd/templates/VariableType.cpp
+++ b/tools/autograd/templates/VariableType.cpp
@@ -39,10 +39,14 @@
 namespace torch { namespace autograd {
 // Helper methods for working with Attributes (torch/csrc/jit/attributes.h)
 
+at::Tensor maybeUnwrapVar(const at::Tensor& t) {
+  return t.is_variable() ? Variable(t).data() : t;
+}
+
 // The overloaded accessors are convenient for the generated code (since we
 // don't want to make the codegen do the dispatch manually)
 static void setattr(jit::Node* n, jit::Symbol name, int64_t v)             { n->i_(name, v); }
-static void setattr(jit::Node* n, jit::Symbol name, const at::Scalar& v)   { n->t_(name, v.toTensor()); }
+static void setattr(jit::Node* n, jit::Symbol name, const at::Scalar& v)   { n->t_(name, maybeUnwrapVar(v.toTensor())); }
 static void setattr(jit::Node* n, jit::Symbol name, SparseTensorRef s)     { n->t_(name, s.tref); }
 static void setattr(jit::Node* n, jit::Symbol name, const at::IntList& v)  { n->is_(name, v); }
 static void setattr(jit::Node* n, jit::Symbol name, bool v)                { n->i_(name, v); }
diff --git a/tools/autograd/templates/python_variable_methods.cpp b/tools/autograd/templates/python_variable_methods.cpp
index 8e8c876..9e44054 100644
--- a/tools/autograd/templates/python_variable_methods.cpp
+++ b/tools/autograd/templates/python_variable_methods.cpp
@@ -140,7 +140,7 @@
   ParsedArgs<3> parsed_args;
   auto r = parser.parse(args, kwargs, parsed_args);
   if (r.idx == 0) {
-    if (jit::tracer::isTracing(self_)) {
+    if (jit::tracer::isTracing()) {
       return wrap(jit::tracer::getSizeOf(self_, r.toInt64(0)));
     } else {
       return wrap(self_.size(r.toInt64(0)));
diff --git a/torch/CMakeLists.txt b/torch/CMakeLists.txt
index c6fe078..02fd042 100644
--- a/torch/CMakeLists.txt
+++ b/torch/CMakeLists.txt
@@ -159,10 +159,8 @@
   ${TORCH_SRC_DIR}/csrc/jit/script/lexer.cpp
   ${TORCH_SRC_DIR}/csrc/jit/script/module.cpp
   ${TORCH_SRC_DIR}/csrc/jit/test_jit.cpp
-  ${TORCH_SRC_DIR}/csrc/jit/tracer_state.cpp
   ${TORCH_SRC_DIR}/csrc/jit/tracer.cpp
   ${TORCH_SRC_DIR}/csrc/jit/type.cpp
-  ${TORCH_SRC_DIR}/csrc/jit/variable_flags.cpp
   ${TORCH_SRC_DIR}/csrc/onnx/onnx.cpp
   ${TORCH_SRC_DIR}/csrc/onnx/onnx.npb.cpp
   ${TORCH_SRC_DIR}/csrc/torch.cpp
diff --git a/torch/csrc/Size.cpp b/torch/csrc/Size.cpp
index 0708eba..95c7c64 100644
--- a/torch/csrc/Size.cpp
+++ b/torch/csrc/Size.cpp
@@ -14,7 +14,7 @@
 
 PyObject * THPSize_New(const torch::autograd::Variable& var)
 {
-  if (!torch::jit::tracer::isTracing(var)) {
+  if (!torch::jit::tracer::isTracing()) {
     auto sizes = var.sizes();
     return THPSize_NewFromSizes(var.dim(), sizes.data());
   }
@@ -38,10 +38,10 @@
   return self.release();
 }
 
-static bool isTracedVar(PyObject *item) {
+static bool isTracedZeroDimVar(PyObject *item) {
   if (!THPVariable_Check(item)) return false;
   auto & var = reinterpret_cast<THPVariable*>(item)->cdata;
-  return torch::jit::tracer::isTracing(var);
+  return var.dim() == 0 && torch::jit::tracer::getValueTrace(var);
 }
 
 static PyObject * THPSize_pynew(PyTypeObject *type, PyObject *args, PyObject *kwargs)
@@ -50,10 +50,10 @@
   if (self) {
     for (Py_ssize_t i = 0; i < PyTuple_Size(self); ++i) {
       PyObject *item = PyTuple_GET_ITEM(self.get(), i);
-      if (isTracedVar(item)) {
+      if (THPUtils_checkLong(item)) {
         continue;
       }
-      if (THPUtils_checkLong(item)) {
+      if (torch::jit::tracer::isTracing() && isTracedZeroDimVar(item)) {
         continue;
       }
       // item.__index__() works with 0-dim tensors and tensors with one element
diff --git a/torch/csrc/autograd/function.h b/torch/csrc/autograd/function.h
index e4b6bef..b65a706 100644
--- a/torch/csrc/autograd/function.h
+++ b/torch/csrc/autograd/function.h
@@ -8,7 +8,6 @@
 #include "torch/csrc/autograd/saved_variable.h"
 #include "torch/csrc/autograd/type_and_shape.h"
 #include "torch/csrc/autograd/variable.h"
-#include "torch/csrc/utils/auto_unique_ptr.h"
 #include "torch/csrc/utils/python_stub.h"
 #include "torch/csrc/utils/variadic.h"
 
diff --git a/torch/csrc/autograd/python_function.cpp b/torch/csrc/autograd/python_function.cpp
index 5494334..8c20c6c 100644
--- a/torch/csrc/autograd/python_function.cpp
+++ b/torch/csrc/autograd/python_function.cpp
@@ -550,7 +550,7 @@
 }
 
 static void _assert_not_tracing(const char* name, const variable_list& input_vars) {
-  if (tracer::isTracingVar(input_vars)) {
+  if (tracer::isTracing()) {
     std::ostringstream oss;
     oss << "Attempted to trace " << name;
     oss << ", but tracing of legacy functions is not supported";
@@ -562,7 +562,7 @@
     PyObject* op_obj,
     PyObject *input_objects,
     const variable_list& input_vars) {
-  if (!tracer::isTracingVar(input_vars)) {
+  if (!jit::tracer::isTracing()) {
     return jit::tracer::PreTraceInfo();
   }
 
@@ -598,7 +598,7 @@
     const variable_list& input_vars,
     PyObject *output_objects,
     bool is_inplace) {
-  if (!trace_info.state) {
+  if (!jit::tracer::isTracing()) {
     return;
   }
 
@@ -612,7 +612,6 @@
 
   jit::tracer::postRecordTrace(trace_info, output_vars);
 
-  auto state_lock = trace_info.state->lock();
   trace_info.n->i_(attr::inplace, is_inplace);
 
 }
@@ -640,11 +639,6 @@
 
   bool is_inplace = static_cast<bool>(grad_fn->dirty_tensors);
   _wrap_outputs(grad_fn, inputs, raw_output, outputs, is_executable);
-  // NOTE: _trace_post_record has to run before _save_variables, because we need
-  // to assign traces to outputs before we convert them to SavedVariables.
-  // On the other hand, it needs to go after _mark_non_differentiable, because
-  // it might be wraping backwards in Evals, and _mark_non_differentiable uses
-  // grad_fn pointer equality for error checking.
   _trace_post_record(trace_info, op_obj, unpacked.input_vars, outputs, is_inplace);
   if (is_executable) {
     _save_variables(grad_fn);
@@ -715,10 +709,6 @@
 
   // Record input nodes if tracing
   auto trace_info = _trace_pre_record(cls, inputs, unpacked_input.input_vars);
-  if (trace_info.state) {
-    // TODO: ezyang suggests this is unused and can be removed
-    ctx->is_traced = true;
-  }
 
   // Initialize backward function (and ctx)
   bool is_executable = input_info.is_executable;
@@ -1009,7 +999,6 @@
   {"dirty_tensors", &getObject<&THPFunction::dirty_tensors>, &setObject<&THPFunction::dirty_tensors>, nullptr, nullptr},
   {"needs_input_grad", &getObject<&THPFunction::needs_input_grad>, nullptr, nullptr, nullptr},
   {"requires_grad", getRequiresGrad, nullptr, nullptr, nullptr},
-  {"_is_tracing", &getMember<char, &THPFunction::is_traced, PyBool_FromLong>, nullptr, nullptr, nullptr},
   {"metadata", (getter)THPFunction_metadata, nullptr, nullptr, nullptr},
   {nullptr}
 };
diff --git a/torch/csrc/autograd/python_function.h b/torch/csrc/autograd/python_function.h
index 7bc7548..bdbca10 100644
--- a/torch/csrc/autograd/python_function.h
+++ b/torch/csrc/autograd/python_function.h
@@ -90,7 +90,6 @@
     // For each input, true if the input is a THPVariable
     std::vector<bool> is_variable_input;
     char has_freed_buffers;
-    char is_traced;
 
     // The C++ wrapper for this Python function.
     // See a comment in THPFunction_asFunction for details about this field.
diff --git a/torch/csrc/autograd/python_variable.cpp b/torch/csrc/autograd/python_variable.cpp
index a93f4e6..607f3b7 100644
--- a/torch/csrc/autograd/python_variable.cpp
+++ b/torch/csrc/autograd/python_variable.cpp
@@ -16,7 +16,6 @@
 #include "torch/csrc/autograd/generated/VariableType.h"
 #include "torch/csrc/autograd/utils/python_error_messages.h"
 #include "torch/csrc/autograd/utils/wrap_outputs.h"
-#include "torch/csrc/jit/tracer_state.h"
 #include "torch/csrc/tensor/python_tensor.h"
 #include "torch/csrc/utils/auto_gil.h"
 #include "torch/csrc/utils/cuda_lazy_init.h"
diff --git a/torch/csrc/autograd/saved_variable.cpp b/torch/csrc/autograd/saved_variable.cpp
index 889f456..0f93c13 100644
--- a/torch/csrc/autograd/saved_variable.cpp
+++ b/torch/csrc/autograd/saved_variable.cpp
@@ -3,7 +3,6 @@
 #include "torch/csrc/autograd/edge.h"
 #include "torch/csrc/autograd/function.h"
 #include "torch/csrc/autograd/variable.h"
-#include "torch/csrc/jit/tracer_state.h"
 
 #include <ATen/Tensor.h>
 
@@ -29,10 +28,6 @@
     }
     version_counter_ = variable.version_counter();
     saved_version_ = version_counter_.current_version();
-    if (variable.has_tracing_state()) {
-      tracing_state_.reset(
-          new jit::tracer::ValueTracingState(variable.tracing_state()));
-    }
   }
 }
 
@@ -78,9 +73,6 @@
   if (requires_grad_ && !var.grad_fn() && grad_accumulator_.expired())
     throw std::logic_error("No grad accumulator for a saved leaf!");
   var.set_grad_accumulator(grad_accumulator_);
-  if (tracing_state_) {
-    var.set_tracing_state(new jit::tracer::ValueTracingState(*tracing_state_));
-  }
 
   return var;
 }
diff --git a/torch/csrc/autograd/saved_variable.h b/torch/csrc/autograd/saved_variable.h
index 96dc8bd..ff5a36b 100644
--- a/torch/csrc/autograd/saved_variable.h
+++ b/torch/csrc/autograd/saved_variable.h
@@ -2,7 +2,6 @@
 
 #include "torch/csrc/WindowsTorchApiMacro.h"
 #include "torch/csrc/autograd/variable_version.h"
-#include "torch/csrc/jit/tracer_state.h"
 
 #include <ATen/ATen.h>
 
@@ -44,7 +43,6 @@
   // passed in to the unpack function when reconstructing the Variable.
   std::shared_ptr<Function> grad_fn_;
   std::weak_ptr<Function> grad_accumulator_;
-  std::unique_ptr<jit::tracer::ValueTracingState> tracing_state_;
   VariableVersion version_counter_;
 
   uint32_t saved_version_;
diff --git a/torch/csrc/autograd/variable.cpp b/torch/csrc/autograd/variable.cpp
index 16e8105..7654c4e 100644
--- a/torch/csrc/autograd/variable.cpp
+++ b/torch/csrc/autograd/variable.cpp
@@ -9,8 +9,6 @@
 #include "torch/csrc/autograd/generated/Functions.h"
 #include "torch/csrc/autograd/generated/VariableType.h"
 #include "torch/csrc/autograd/variable_version.h"
-#include "torch/csrc/jit/tracer_state.h"
-#include "torch/csrc/utils/auto_unique_ptr.h"
 
 #include <ATen/ATen.h>
 
@@ -141,7 +139,6 @@
   grad_.reset();
   grad_fn_.reset();
   hooks_.clear();
-  tracing_state_.reset();
 }
 
 Variable::ViewImpl::ViewImpl(Variable base, at::Tensor data, Edge gradient_edge)
@@ -205,13 +202,4 @@
   }
 }
 
-void Variable::set_tracing_state(
-    jit::tracer::ValueTracingState* new_tracing_state) {
-  get()->tracing_state_.reset(new_tracing_state);
-}
-
-jit::tracer::ValueTracingState& Variable::tracing_state() const noexcept {
-  return *get()->tracing_state_;
-}
-
 }} // namespace torch::autograd
diff --git a/torch/csrc/autograd/variable.h b/torch/csrc/autograd/variable.h
index 260eb93..2def489 100644
--- a/torch/csrc/autograd/variable.h
+++ b/torch/csrc/autograd/variable.h
@@ -7,7 +7,6 @@
 #include "torch/csrc/autograd/edge.h"
 #include "torch/csrc/autograd/function_hook.h"
 #include "torch/csrc/autograd/variable_version.h"
-#include "torch/csrc/utils/auto_unique_ptr.h"
 
 #include <ATen/ATen.h>
 #include <ATen/Error.h>
@@ -20,20 +19,10 @@
 #include <utility>
 #include <vector>
 
-namespace torch {
-namespace autograd {
-struct Function;
-} // namespace autograd
-namespace jit { namespace tracer {
-// Has to be forward declared because tracer_state.h has a dependency on
-// variable.h.
-struct ValueTracingStateElem;
-using ValueTracingState = std::list<ValueTracingStateElem>;
-}} // namespace jit::tracer
-} // namespace torch
-
 namespace torch { namespace autograd {
 
+struct Function;
+
 ///~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
 ///                                Variable
 ///~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
@@ -228,15 +217,6 @@
   const std::vector<std::shared_ptr<FunctionPreHook>>& hooks() const noexcept;
   void clear_hooks();
 
-  // JIT Tracing
-  //~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
-
-  void set_tracing_state(jit::tracer::ValueTracingState* new_tracing_state);
-  jit::tracer::ValueTracingState& tracing_state() const noexcept;
-
-  /// Returns true if the `Variable`'s tracing state is not null.
-  bool has_tracing_state() const noexcept;
-
   // View Variables
   //~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
 
@@ -379,9 +359,6 @@
   // state are still thread-safe. Used by get_grad_fn and
   // get_grad_accumulator.
   std::mutex mutex_;
-
-  // For use in torch::jit::tracer
-  auto_unique_ptr<jit::tracer::ValueTracingState> tracing_state_;
 };
 
 //~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
@@ -564,13 +541,6 @@
   get()->hooks_.clear();
 }
 
-// JIT Tracing
-//~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
-
-inline bool Variable::has_tracing_state() const noexcept {
-  return get()->tracing_state_ != nullptr;
-}
-
 // View Variables
 //~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
 
diff --git a/torch/csrc/jit/graph_executor.cpp b/torch/csrc/jit/graph_executor.cpp
index 5ef60d9..2b0e0b4 100644
--- a/torch/csrc/jit/graph_executor.cpp
+++ b/torch/csrc/jit/graph_executor.cpp
@@ -226,7 +226,7 @@
     // there is no need to optimize, but we do need to splice the graph of
     // this excutor into the trace. Otherwise we might unroll control-flow
     // operations.
-    if(isTracing(inputs)) {
+    if(tracer::isTracing()) {
       return runTraced(std::move(inputs));
     }
 
@@ -274,26 +274,11 @@
 private:
   friend struct GraphExecutor;
 
-  // TODO: switching tracing to be part of the local thread state, instead of
-  // a per-variable property will make this check significantly faster.
-  // It is along the fast path, so this is important.
-  static bool isTracing(const variable_tensor_list& inputs) {
-    for(auto & i : inputs) {
-      if(i.defined() && tracer::isTracingVar(autograd::as_variable_ref(i)))
-        return true;
-    }
-    return false;
-  }
   variable_tensor_list runTraced(variable_tensor_list inputs) {
-    // TODO: unnecessary copy to variable_list
-    variable_list input_vars(inputs.begin(), inputs.end());
-    auto state = tracer::getTracingState(input_vars);
-    auto input_values = fmap(input_vars, [&](const Variable& v) {
-      return tracer::getValueTrace(state, v);
-    });
+    auto state = tracer::getTracingState();
+    auto input_values = fmap(inputs, tracer::getValueTrace);
 
     ArgumentSpec spec(autograd::GradMode::is_enabled(), inputs);
-    input_vars.clear(); // don't hold inputs during execution
     auto outputs = runFallback(std::move(inputs));
 
     auto all_dynamic = [](const at::ArrayRef<Value*> xs) {
@@ -316,7 +301,7 @@
     auto output_values = script::inlineCallTo(*state->graph, *local_graph, input_values);
 
     for(size_t i = 0; i < outputs.size(); ++i) {
-      tracer::setValueTrace(state, outputs[i], output_values[i]);
+      tracer::setValueTrace(outputs[i], output_values[i]);
     }
     return outputs;
   }
diff --git a/torch/csrc/jit/ir.h b/torch/csrc/jit/ir.h
index a42915b..2b55502 100644
--- a/torch/csrc/jit/ir.h
+++ b/torch/csrc/jit/ir.h
@@ -7,7 +7,6 @@
 #include "torch/csrc/jit/resource_guard.h"
 #include "torch/csrc/jit/source_location.h"
 #include "torch/csrc/jit/type.h"
-#include "torch/csrc/jit/variable_flags.h"
 
 #include "torch/csrc/utils/disallow_copy.h"
 #include "torch/csrc/utils/functional.h"
diff --git a/torch/csrc/jit/passes/onnx.h b/torch/csrc/jit/passes/onnx.h
index bd6f6e4..a58d421 100644
--- a/torch/csrc/jit/passes/onnx.h
+++ b/torch/csrc/jit/passes/onnx.h
@@ -1,7 +1,6 @@
 #pragma once
 
 #include "torch/csrc/jit/ir.h"
-#include "torch/csrc/jit/tracer_state.h"
 #include "torch/csrc/onnx/onnx.h"
 
 namespace torch { namespace jit {
diff --git a/torch/csrc/jit/python_ir.cpp b/torch/csrc/jit/python_ir.cpp
index 534297a..f03edfe 100644
--- a/torch/csrc/jit/python_ir.cpp
+++ b/torch/csrc/jit/python_ir.cpp
@@ -444,7 +444,7 @@
       return t.expect<TensorType>()->strides();
     })
     .def("contiguous",[](Type& t) {
-      return t.expect<TensorType>()->contiguous();
+      return std::static_pointer_cast<Type>(t.expect<TensorType>()->contiguous());
     })
     .def("scalarType",[](Type& t) {
       return at::toString(t.expect<TensorType>()->scalarType());
@@ -471,8 +471,5 @@
     }
     return std::make_tuple(graph, variables);
   });
-  m.def("_jit_is_tracing", [](const autograd::Variable& var) {
-    return tracer::isTracing(var);
-  });
 }
 }}
diff --git a/torch/csrc/jit/python_tracer.cpp b/torch/csrc/jit/python_tracer.cpp
index 2ad7a79..6397877 100644
--- a/torch/csrc/jit/python_tracer.cpp
+++ b/torch/csrc/jit/python_tracer.cpp
@@ -46,21 +46,26 @@
         tracer::variable_list trace_inputs,
         size_t num_func_inputs) {
   auto enter_info = tracer::enter(std::move(trace_inputs));
-  py::tuple py_inputs(num_func_inputs);
-  for(size_t i = 0; i < num_func_inputs; ++i) {
-    py_inputs[i] = py::cast(enter_info.second[i]);
+  try {
+    py::tuple py_inputs(num_func_inputs);
+    for(size_t i = 0; i < num_func_inputs; ++i) {
+      py_inputs[i] = py::cast(enter_info.second[i]);
+    }
+    auto out = func(*py_inputs);
+    std::vector<autograd::Variable> outputs;
+    if(PyTuple_Check(out.ptr())) {
+      outputs = py::cast<std::vector<autograd::Variable>>(out);
+    } else {
+      outputs.push_back(py::cast<autograd::Variable>(out));
+    }
+    tracer::exit(outputs);
+    auto graph = enter_info.first->graph;
+    EliminateDeadCode(graph);
+    return graph;
+  } catch (...) {
+    tracer::abandon();
+    throw;
   }
-  auto out = func(*py_inputs);
-  std::vector<autograd::Variable> outputs;
-  if(PyTuple_Check(out.ptr())) {
-    outputs = py::cast<std::vector<autograd::Variable>>(out);
-  } else {
-    outputs.push_back(py::cast<autograd::Variable>(out));
-  }
-  tracer::exit(outputs);
-  auto graph = enter_info.first->graph;
-  EliminateDeadCode(graph);
-  return graph;
 }
 
 PreTraceInfo preRecordPythonTrace(THPObjectPtr pyobj,
@@ -119,17 +124,17 @@
   m.def("_tracer_exit", [](variable_list var_outputs) {
     tracer::exit(var_outputs);
   });
-  m.def("_get_tracing_state", [](const variable_list& vars) {
-    return getTracingState(vars);
+  m.def("_tracer_abandon", []() {
+    tracer::abandon();
   });
-  m.def("_get_value_trace", [](std::shared_ptr<TracingState>& state, const Variable& var) {
-    return getValueTrace(state, var);
+  m.def("_get_tracing_state", []() {
+    return getTracingState();
   });
-  m.def("_set_value_trace", [](std::shared_ptr<TracingState>& state, const Variable& var, Value* value) {
-    return setValueTrace(state, var, value);
+  m.def("_get_value_trace", [](const Variable& var) {
+    return getValueTrace(var);
   });
-  m.def("_is_tracing", [](const variable_list& vars) {
-    return isTracingVar(vars);
+  m.def("_set_value_trace", [](const Variable& var, Value* value) {
+    return setValueTrace(var, value);
   });
 }
 
diff --git a/torch/csrc/jit/tracer.cpp b/torch/csrc/jit/tracer.cpp
index 0fda835..a86059f 100644
--- a/torch/csrc/jit/tracer.cpp
+++ b/torch/csrc/jit/tracer.cpp
@@ -13,6 +13,28 @@
 
 namespace torch { namespace jit { namespace tracer {
 
+////////////////////////////////////////////////////////////////////////////////
+// Recording the traces
+////////////////////////////////////////////////////////////////////////////////
+namespace detail {
+
+thread_local std::shared_ptr<TracingState> tracing_state;
+
+} // namespace detail
+
+const std::shared_ptr<TracingState>& getTracingState() {
+  return detail::tracing_state;
+}
+
+void setTracingState(std::shared_ptr<TracingState> state) {
+  detail::tracing_state = std::move(state);
+}
+
+TracingState::TracingState()
+    : graph(new Graph()) {}
+
+TracingState::~TracingState() = default;
+
 PreTraceInfo preRecordTrace(Symbol op,
                             at::ArrayRef<Variable> inputs) {
   return makePreTraceInfo(inputs, [&op](const std::shared_ptr<TracingState>& state, Graph& graph) {
@@ -22,14 +44,10 @@
 
 void postRecordTrace(const PreTraceInfo& info,
                      at::ArrayRef<Variable> outputs) {
-  // TODO: Technically, we could reduce the scope of the lock, but since we
-  // haven't actually specified what the locking contract is, be conservative.
-  auto state_lock = info.state->lock();
-
   auto assignOutput = [&info](const Variable & output, Value * value) {
     if (output.defined()) {
       value->inferTypeFrom(output.data());
-      setValueTrace(info.state, output, value);
+      setValueTrace(output, value);
     }
   };
 
@@ -38,35 +56,39 @@
   }
 }
 
-thread_local ArgumentStash ArgumentStash::stash;
-
-void ArgumentStash::stashIntListElem(const std::string& arg_name, size_t size, size_t idx, const Variable& var) {
-  // TODO: check type?
-  if (!isTracing(var)) return;
-  auto tracing_state = getTracingState({var});
-  auto & list_trace = stash.intlists.emplace(arg_name, size).first->second;
-  JIT_ASSERT(size == list_trace.size());
-  JIT_ASSERT(idx < list_trace.size());
-  JIT_ASSERT(list_trace[idx] == nullptr);
-  list_trace[idx] = getValueTrace(tracing_state, var);
-}
-
 autograd::Variable getSizeOf(const autograd::Variable& var, int64_t dim) {
-  auto tracing_state = getTracingState({var});
+  auto & tracing_state = getTracingState();
   auto & graph = tracing_state->graph;
 
   auto size_var = autograd::make_variable(at::Scalar(var.size(dim)).toTensor());
-  auto* value = getValueTrace(tracing_state, var);
+  auto* value = getValueTrace(var);
   auto* node = graph->create(aten::size, {value})
                     ->i_(attr::dim, dim);
   node->output()->inferTypeFrom(size_var);
   graph->appendNode(node);
-  setValueTrace(tracing_state, size_var, node->output());
+  setValueTrace(size_var, node->output());
 
   return size_var;
 }
 
+////////////////////////////////////////////////////////////////////////////////
+// Argument stash
+////////////////////////////////////////////////////////////////////////////////
+thread_local ArgumentStash ArgumentStash::stash;
 
+void ArgumentStash::stashIntListElem(const std::string& arg_name, size_t size, size_t idx, const Variable& var) {
+  // TODO: check type?
+  if (!isTracing()) return;
+  auto & list_trace = stash.intlists.emplace(arg_name, size).first->second;
+  JIT_ASSERT(size == list_trace.size());
+  JIT_ASSERT(idx < list_trace.size());
+  JIT_ASSERT(list_trace[idx] == nullptr);
+  list_trace[idx] = getValueTrace(var);
+}
+
+////////////////////////////////////////////////////////////////////////////////
+// Stack trace recording
+////////////////////////////////////////////////////////////////////////////////
 // no python present so we just do not record source information
 void defaultRecordSourceLocation(Node* n) {}
 std::atomic<decltype(&defaultRecordSourceLocation)> record_source_location(defaultRecordSourceLocation);
diff --git a/torch/csrc/jit/tracer.h b/torch/csrc/jit/tracer.h
index 5775091..2b8f32e 100644
--- a/torch/csrc/jit/tracer.h
+++ b/torch/csrc/jit/tracer.h
@@ -1,13 +1,12 @@
 #pragma once
 
 #include "torch/csrc/jit/ir.h"
-#include "torch/csrc/jit/tracer_state.h"
 #include "torch/csrc/assertions.h"
 #include "torch/csrc/utils/functional.h"
 #include "torch/csrc/utils/variadic.h"
 #include "torch/csrc/autograd/function_hook.h"
 #include "torch/csrc/autograd/variable.h"
-#include "torch/csrc/utils/auto_unique_ptr.h"
+
 #include <memory>
 #include <mutex>
 #include <vector>
@@ -20,38 +19,27 @@
 using torch::autograd::Variable;
 using variable_list = std::vector<Variable>;
 
-namespace detail {
+struct TracingState : public std::enable_shared_from_this<TracingState> {
+  TracingState();
+  ~TracingState();
 
-inline ValueTracingStateElem* getValueState(const std::shared_ptr<TracingState>& state, const Variable& var, bool alloc = true) {
-  auto& tracing_state = var.tracing_state();
-  for (auto it = tracing_state.begin(); it != tracing_state.end();) {
-    auto ts = it->state.lock();
-    // GC of invalidated tracing states
-    if (!ts) {
-      auto current_it = it++;
-      tracing_state.erase(current_it);
-      continue;
-    } else if (ts == state) {
-      return &(*it);
+  using WeakTensor = at::WeakTensor;
+
+  struct WeakTensorHasher {
+    size_t operator()(const WeakTensor& t) const {
+      return std::hash<void*>()(t.unsafeGetTensorImpl());
     }
-    ++it;
-  }
-  if (alloc) {
-    tracing_state.emplace_front();
-    auto & vts = tracing_state.front();
-    vts.state = state;
-    return &vts;
-  } else {
-    return nullptr;
-  }
-}
+  };
 
-inline bool isElemActive(const ValueTracingStateElem& vts) {
-  auto state = vts.state.lock();
-  return state && state->active;
-}
+  struct WeakTensorEq {
+    bool operator()(const WeakTensor& t1, const WeakTensor& t2) const {
+      return t1.unsafeGetTensorImpl() == t2.unsafeGetTensorImpl();
+    }
+  };
 
-} // namespace detail
+  std::unordered_map<WeakTensor, Value*, WeakTensorHasher, WeakTensorEq> value_map;
+  std::shared_ptr<Graph> graph;
+};
 
 
 // This is meant to be used as a thread local place, where we can store extra
@@ -91,76 +79,27 @@
   std::unordered_map<std::string, IntListTrace> intlists;
 };
 
-// Should a function which takes 'vars' as inputs be traced?
-// It suffices for ONE variable to be tracing: any "untraced" variables
-// are treated as constants.
-//
-// NB: This code lives in the hotpath; make sure it is fast
-//
-// NB: Variable overload is not variadic because we don't actually
-// need it (in most cases if we have a variable_list it is already
-// flattened).
-inline bool isTracingVar(const Variable& var) {
-  if (!var.defined() || !var.has_tracing_state()) return false;
-  return std::any_of(var.tracing_state().begin(), var.tracing_state().end(), detail::isElemActive);
+// Retrieve or set the current tracing state. Returns a nullptr if tracing is disabled.
+const std::shared_ptr<TracingState>& getTracingState();
+void setTracingState(std::shared_ptr<TracingState> state);
+
+inline bool isTracing() {
+  return static_cast<bool>(getTracingState());
 }
 
-inline bool isTracingVar(at::ArrayRef<Variable> vars) {
-  // Reference to avoid refcount bump
-  for (const Variable& var : vars) {
-    if (isTracingVar(var)) return true;
-  }
-  return false;
-}
-
-struct IsTracing : IterArgs<IsTracing> {
-  bool out = false;
-  using IterArgs<IsTracing>::operator();
-  void operator()(const at::Tensor& var) {
-    out = out || isTracingVar(var);
-  }
-  bool short_circuit() { return out; }
-};
-
-// To be called with Tensor arguments from generated code
-template<typename... Args>
-inline bool isTracing(Args&&... args) {
-  return IsTracing().apply(std::forward<Args>(args)...).out;
-}
-
-// Retrieve the tracing state which a function applied with 'vars' should
-// be recorded to.  Precondition: isTracing(vars) == true.  At the moment,
-// we don't support mixing up variables from different traces; this code
-// will need to be revisited if that ever becomes supported.
-inline std::shared_ptr<TracingState> getTracingState(const variable_list& vars) {
-  std::shared_ptr<TracingState> state;
-  for (auto& var : vars) {
-    if (!var.defined() || !var.has_tracing_state()) continue;
-    for (auto & vts : var.tracing_state()) {
-      auto var_state = vts.state.lock();
-      if (!var_state || !var_state->active) continue;
-      if (!state) state = var_state;
-      JIT_ASSERT(var_state == state);
-    }
-  }
-  JIT_ASSERT(state);
-  return state;
-}
-
-// Having finished adding a new 'node' to the graph IR owned by TracingState 'state',
-// 'setValueTrace' associates this node with an output variable, so that further operations
-// involving this variable know which node in the IR to reference.
-inline void setValueTrace(const std::shared_ptr<TracingState>& state, const Variable& var, Value *value) {
+// Having finished adding a new 'node' to the graph IR 'setValueTrace' associates
+// this node with an output variable, so that further operations involving this
+// variable know which node in the IR to reference.
+inline void setValueTrace(const Variable& var, Value *value) {
   JIT_ASSERT(var.defined());
-  auto vts = detail::getValueState(state, var);
-  vts->trace = value;
+  getTracingState()->value_map[var] = value;
 }
 
 // Given a variable 'var', return the 'node' which represents the instruction
-// which computes the value of this variable in the IR.  When 'mustExist' is
-// false, we interpret untraced variables as constants that are just embedded
+// which computes the value of this variable in the IR.
+// Here, we interpret untraced variables as constants that are just embedded
 // in the graph.  This is useful to handle code which does things like this
-// (from torch.autograd.variable):
+// (from torch.autograd.variable, now moved to C++):
 //
 //    def mm(self, matrix):
 //      output = Variable(self.data.new(self.data.size(0), matrix.data.size(1)))
@@ -170,19 +109,21 @@
 // update on, but subsequently ignores it because the alpha scaling factor is zero.
 // This is one of the cases where a Variable can be created inside of a trace, and
 // if we treat it as a constant, everything will work out.
-inline Value* getValueTrace(const std::shared_ptr<TracingState>& state, const Variable& var) {
+inline Value* getValueTrace(const Variable& var) {
+  auto &state = getTracingState();
   if (!var.defined()) {
     Node *n = state->graph->createUndefined();
     return state->graph->appendNode(n)->output();
   }
 
-  auto vts = detail::getValueState(state, var, true);
-  if (vts->trace) return vts->trace;
-
-  Value *constant = state->graph->appendNode(state->graph->createConstant(var.data()))->output();
-  constant->inferTypeFrom(var.data());
-  setValueTrace(state, var, constant);
-  return constant;
+  auto & value_map = getTracingState()->value_map;
+  auto it = value_map.find(var);
+  if (it == value_map.end()) {
+    Value *constant = state->graph->appendNode(state->graph->createConstant(var.data()))->output();
+    constant->inferTypeFrom(var.data());
+    it = value_map.emplace_hint(it, var, constant);
+  }
+  return it->second;
 }
 
 inline Value* getOutputTrace(const std::shared_ptr<TracingState>& state, const Variable& var, size_t output_no) {
@@ -191,36 +132,37 @@
     return state->graph->appendNode(n)->output();
   }
 
-  auto vts = detail::getValueState(state, var, false);
-  if (!vts) {
+  auto & value_map = getTracingState()->value_map;
+  auto it = value_map.find(var);
+  if (it == value_map.end()) {
     std::ostringstream os;
     os << "output " << output_no << " of traced region did not have observable "
        << "data dependence with trace inputs; this probably indicates your program "
        << "cannot be understood by the tracer.";
     throw std::runtime_error(os.str());
   }
-  return vts->trace;
+  return it->second;
 }
 
 // Start tracing, treating 'inputs' as inputs to the trace, which can be
 // varied on subsequent invocations of the trace.  Any other variables
 // will be treated as constants.
-//
-// NB: Why does this take an rvalue reference?  We need to get a non-const
-// reference to at::Tensor buffer to call unsafeGetTH, but you can't get this
-// out of a const vector (silly std::vector...)
 inline std::pair<std::shared_ptr<TracingState>, variable_list> enter(
     variable_list inputs) {
+  if (isTracing()) {
+    AT_ERROR("Tracing can't be nested");
+  }
   auto state = std::make_shared<TracingState>();
+  setTracingState(state);
   for (auto& input : inputs) {
-    auto * value_state = detail::getValueState(state, input, false);
+    auto * value_state = state->value_map[input];
     if (value_state) {
       // See Note [Repeated inputs] in tracer.cpp
       input = input.view(input.sizes());
     }
     auto input_node = state->graph->addInput(input.name());
-    setValueTrace(state, input, input_node);
     input_node->inferTypeFrom(input.data());
+    state->value_map[input] = input_node;
   }
   return std::make_pair(state, inputs);
 }
@@ -229,27 +171,29 @@
 // are the variables whose values will be computed upon subsequent
 // invocations of the trace.
 inline void exit(const variable_list& outputs) {
-  auto state = getTracingState(outputs);
+  auto & state = getTracingState();
   size_t i = 0;
   for (auto& output : outputs) {
     state->graph->registerOutput(getOutputTrace(state, output, i));
     i++;
   }
-  state->active = false;
+  setTracingState(nullptr);
+}
+
+// Abort tracing. Used to reset the state in case of errors.
+inline void abandon() {
+  setTracingState(nullptr);
 }
 
 // Pre-recorded information about the trace before we actually carry
 // out the trace
 struct PreTraceInfo {
-  std::shared_ptr<TracingState> state;
   Node *n;
 };
 
 PreTraceInfo preRecordTrace(Symbol op, at::ArrayRef<Variable> inputs);
 void postRecordTrace(const PreTraceInfo& info, at::ArrayRef<Variable> outputs);
 
-autograd::Variable getSizeOf(const autograd::Variable& var, int64_t dim);
-
 void recordSourceLocation(Node* n);
 void setRecordSourceLocation(void (*v)(Node*));
 
@@ -259,15 +203,14 @@
 template<typename F>
 PreTraceInfo makePreTraceInfo(at::ArrayRef<Variable> inputs, F ctor) {
   PreTraceInfo info;
-  info.state = getTracingState(inputs);
-  auto& graph = info.state->graph;
-  auto state_lock = info.state->lock();
+  auto & state = getTracingState();
+  auto & graph = state->graph;
 
-  Node *n = ctor(info.state, *graph);
+  Node *n = ctor(state, *graph);
   recordSourceLocation(n);
 
-  for (Variable input : inputs) {
-    n->addInput(getValueTrace(info.state, input));
+  for (const Variable & input : inputs) {
+    n->addInput(getValueTrace(input));
   }
 
   // NB: Order matters. This must append after inputs but before outputs.
@@ -278,4 +221,6 @@
   return info;
 }
 
+autograd::Variable getSizeOf(const autograd::Variable& var, int64_t dim);
+
 }}} // namespace torch::jit::tracer
diff --git a/torch/csrc/jit/tracer_state.cpp b/torch/csrc/jit/tracer_state.cpp
deleted file mode 100644
index 6f44562..0000000
--- a/torch/csrc/jit/tracer_state.cpp
+++ /dev/null
@@ -1,12 +0,0 @@
-#include "torch/csrc/jit/tracer_state.h"
-#include "torch/csrc/jit/ir.h"
-
-namespace torch { namespace jit { namespace tracer {
-
-TracingState::TracingState()
-    : graph(new Graph())
-    , active(true) {}
-
-TracingState::~TracingState() = default;
-
-}}} // namespace torch::jit::tracer
diff --git a/torch/csrc/jit/tracer_state.h b/torch/csrc/jit/tracer_state.h
deleted file mode 100644
index 887ad94..0000000
--- a/torch/csrc/jit/tracer_state.h
+++ /dev/null
@@ -1,59 +0,0 @@
-#pragma once
-
-#include "torch/csrc/autograd/edge.h"
-#include "torch/csrc/autograd/variable.h"
-
-#include <atomic>
-#include <cstdint>
-#include <list>
-#include <memory>
-#include <mutex>
-#include <string>
-#include <unordered_map>
-#include <utility>
-#include <vector>
-
-namespace torch { namespace jit {
-struct Graph;
-struct Value;
-}} // namespace torch::jit
-
-namespace torch { namespace jit { namespace tracer {
-
-// TracingState tracks the necessary state when we are tracing the execution of
-// autograd code; most importantly, it holds a reference to the actual IR
-// graph which we are recording the trace to.
-//
-// The liveness of a TracingState is expected to be a superset of the region
-// of code being traced; in particular, Variables do not keep a TracingState
-// live.  Instead, they hold weak pointers to TracingState, to prevent leaks
-// from arising when a variable that participated in a trace outlives the
-// actual trace itself.
-
-struct TracingState : public std::enable_shared_from_this<TracingState> {
-  TracingState();
-  ~TracingState();
-
-  std::shared_ptr<Graph> graph;
-  std::mutex mutex;
-  bool active;
-
-  std::unique_lock<std::mutex> lock() {
-    return std::unique_lock<std::mutex>(mutex);
-  }
-};
-
-struct ValueTracingStateElem {
-  std::weak_ptr<TracingState> state;
-  // it's only valid to use this field if !state.exired()
-  Value* trace = nullptr;
-
-  void reset() {
-    state.reset();
-    trace = nullptr;
-  }
-};
-
-using ValueTracingState = std::list<ValueTracingStateElem>;
-
-}}} // namespace torch::jit::tracer
diff --git a/torch/csrc/jit/variable_flags.cpp b/torch/csrc/jit/variable_flags.cpp
deleted file mode 100644
index 8ab565d..0000000
--- a/torch/csrc/jit/variable_flags.cpp
+++ /dev/null
@@ -1,19 +0,0 @@
-#include "torch/csrc/jit/variable_flags.h"
-
-#include "torch/csrc/autograd/variable.h"
-#include "torch/csrc/jit/tracer_state.h"
-
-using torch::autograd::Variable;
-
-namespace torch { namespace jit {
-
-// These definitions require Variable struct to be defined, so they can't be
-// in tracer_state.h
-VariableFlags VariableFlags::of(const Variable& var) {
-  VariableFlags f;
-  f.defined = var.defined();
-  f.requires_grad = f.defined && var.requires_grad();
-  return f;
-}
-
-}}
diff --git a/torch/csrc/jit/variable_flags.h b/torch/csrc/jit/variable_flags.h
deleted file mode 100644
index 43c3ef9..0000000
--- a/torch/csrc/jit/variable_flags.h
+++ /dev/null
@@ -1,22 +0,0 @@
-#pragma once
-#include <iostream>
-namespace torch { namespace autograd {
-struct Variable;
-}}
-
-namespace torch { namespace jit {
-
-struct VariableFlags {
-  static VariableFlags of(const autograd::Variable& var);
-
-  bool requires_grad;
-  bool defined;
-};
-
-static inline std::ostream & operator<<(std::ostream & out, const VariableFlags& v) {
-  return out
-    << "(requires_grad=" << v.requires_grad
-    << ", defined=" << v.defined << ")";
-}
-
-}}
diff --git a/torch/csrc/utils/auto_unique_ptr.h b/torch/csrc/utils/auto_unique_ptr.h
deleted file mode 100644
index d49a036..0000000
--- a/torch/csrc/utils/auto_unique_ptr.h
+++ /dev/null
@@ -1,21 +0,0 @@
-#pragma once
-
-#include <memory>
-
-namespace torch {
-
-// A unique_ptr that automatically constructs the object on first dereference.
-template<typename T>
-struct auto_unique_ptr : public std::unique_ptr<T> {
-  T& operator*() {
-    if (!this->get()) this->reset(new T());
-    return *this->get();
-  }
-
-  T* operator->() {
-    if (!this->get()) this->reset(new T());
-    return this->get();
-  }
-};
-
-} // namespace torch
diff --git a/torch/jit/__init__.py b/torch/jit/__init__.py
index fbf3fab..4b60541 100644
--- a/torch/jit/__init__.py
+++ b/torch/jit/__init__.py
@@ -24,21 +24,10 @@
 _jit_script_compile = torch._C._jit_script_compile
 BatchTensor = torch._C._jit.BatchTensor
 
-# This global variable is set when we are tracing a *forwards* computation.
-# It is intended to be a cheap way to test if tracing has occurred, before
-# doing the slower path using `get_tracing_state` (below.)
-_tracing = False
-
-
-def get_tracing_state(args):
-    if not torch._C._is_tracing(args):
-        return None
-    return torch._C._get_tracing_state(args)
-
 
 @contextlib.contextmanager
-def scope(scope_name, *vars):
-    tracing_state = get_tracing_state(vars)
+def scope(scope_name):
+    tracing_state = torch._C._get_tracing_state()
     if tracing_state:
         tracing_state.push_scope(scope_name)
     try:
@@ -98,18 +87,19 @@
         self.inner = inner
 
     def forward(self, *args):
-        global _tracing
         in_vars, in_desc = _flatten(args)
         # NOTE: use full state, because we need it for BatchNorm export
         # This differs from the compiler path, which doesn't support it at the moment.
         module_state = list(_unique_state_dict(self, keep_vars=True).values())
         trace, all_trace_inputs = torch._C._tracer_enter(in_vars + module_state)
-        _tracing = True
-        trace_inputs = _unflatten(all_trace_inputs[:len(in_vars)], in_desc)
-        out = self.inner(*trace_inputs)
-        out_vars, _ = _flatten(out)
-        _tracing = False
-        torch._C._tracer_exit(out_vars)
+        try:
+            trace_inputs = _unflatten(all_trace_inputs[:len(in_vars)], in_desc)
+            out = self.inner(*trace_inputs)
+            out_vars, _ = _flatten(out)
+            torch._C._tracer_exit(out_vars)
+        except Exception:
+            torch._C._tracer_abandon()
+            raise
         return trace, out
 
 
@@ -289,13 +279,7 @@
         if len(kwargs) != 0:
             raise TypeError("got unexpected keyword arguments: {}".format(", ".join(kwargs.keys())))
 
-        if isinstance(func, torch.nn.Module):
-            orig = func
-        else:
-            # traced functions become a method on an Empty module
-            orig = Module()
-
-        module = TopLevelTracedModule(orig, **executor_options)
+        module = TopLevelTracedModule(func, **executor_options)
         module._create_method_from_trace('forward', func, args)
         return module
 
@@ -683,10 +667,17 @@
     __frozen = False
 
     def __init__(self, orig, id_set=None, optimize=True):
+        # XXX: orig can be a nn.Module or a function!
         super(TracedModule, self).__init__(optimize=optimize)
         if id_set is None:
             id_set = set()
 
+        if not isinstance(orig, torch.nn.Module):
+            self._name = orig.__name__
+            orig = torch.nn.Module()
+        else:
+            self._name = 'TracedModule[' + type(orig).__name__ + ']'
+
         def check_unique(param):
             if param in id_set:
                 raise ValueError("TracedModules don't support parameter sharing between modules")
@@ -702,7 +693,6 @@
             if buf is not None:
                 self._buffers[name] = buf
                 check_unique(buf)
-        self._orig_class = type(orig)
 
         if orig._backward_hooks or orig._forward_hooks or orig._forward_pre_hooks:
             raise ValueError("Modules that have hooks assigned can't be compiled")
@@ -719,7 +709,7 @@
         self.__frozen = True
 
     def _get_name(self):
-        return 'TracedModule[' + self._orig_class.__name__ + ']'
+        return self._name
 
     def __setattr__(self, attr, value):
         if not self.__frozen or hasattr(self, attr):
diff --git a/torch/nn/_functions/rnn.py b/torch/nn/_functions/rnn.py
index c7f5d10..1cccb77 100644
--- a/torch/nn/_functions/rnn.py
+++ b/torch/nn/_functions/rnn.py
@@ -310,7 +310,7 @@
         # function gets reconstructed each and every time when RNN() is invoked
         # and we don't want to pay the cost of decorator invocation
         import torch
-        if torch._C._jit_is_tracing(input):
+        if torch._C._get_tracing_state():
             import torch.onnx.symbolic
             sym = torch.onnx.symbolic.RNN_symbolic_builder(*args, **kwargs)
             cell_type = args[0]
@@ -318,7 +318,7 @@
             bound_symbolic = partial(torch.onnx.symbolic.rnn_trace_override_symbolic,
                                      cell_type, func, sym)
 
-            decorator = torch.onnx.symbolic_override_first_arg_based(bound_symbolic)
+            decorator = torch.onnx.symbolic_override(bound_symbolic)
             func = decorator(func)
 
         return func(input, *fargs, **fkwargs)
diff --git a/torch/nn/functional.py b/torch/nn/functional.py
index f1c4f75..17a7c09 100644
--- a/torch/nn/functional.py
+++ b/torch/nn/functional.py
@@ -1274,7 +1274,7 @@
 
     import torch.onnx.symbolic
 
-    @torch.onnx.symbolic_override_first_arg_based(torch.onnx.symbolic.instance_norm)
+    @torch.onnx.symbolic_override(torch.onnx.symbolic.instance_norm)
     def _instance_norm(input, running_mean=None, running_var=None, weight=None,
                        bias=None, use_input_stats=None, momentum=None, eps=None):
         # Repeat stored stats and affine transform params if necessary
diff --git a/torch/nn/modules/module.py b/torch/nn/modules/module.py
index 91bab5c..a00ff3d 100644
--- a/torch/nn/modules/module.py
+++ b/torch/nn/modules/module.py
@@ -450,16 +450,16 @@
 
     def _slow_forward(self, *input, **kwargs):
         input_vars = tuple(torch.autograd.function._iter_tensors(input))
-        tracing_state = torch.jit.get_tracing_state(input_vars)
+        tracing_state = torch._C._get_tracing_state()
         if not tracing_state:
             return self.forward(*input, **kwargs)
         if not hasattr(tracing_state, '_traced_module_stack'):
             tracing_state._traced_module_stack = []
         name = self._tracing_name(tracing_state)
         if name:
-            tracing_state.push_scope('%s[%s]' % (self.__class__.__name__, name))
+            tracing_state.push_scope('%s[%s]' % (self._get_name(), name))
         else:
-            tracing_state.push_scope(self.__class__.__name__)
+            tracing_state.push_scope(self._get_name())
         tracing_state._traced_module_stack.append(self)
         try:
             result = self.forward(*input, **kwargs)
@@ -471,7 +471,7 @@
     def __call__(self, *input, **kwargs):
         for hook in self._forward_pre_hooks.values():
             hook(self, input)
-        if torch.jit._tracing:
+        if torch._C._get_tracing_state():
             result = self._slow_forward(*input, **kwargs)
         else:
             result = self.forward(*input, **kwargs)
diff --git a/torch/nn/utils/rnn.py b/torch/nn/utils/rnn.py
index 4c1eee0..d91797f 100644
--- a/torch/nn/utils/rnn.py
+++ b/torch/nn/utils/rnn.py
@@ -168,8 +168,7 @@
     return tuple(o for o in outputs)
 
 
-pack_padded_sequence = torch.onnx.symbolic_override_first_arg_based(
-    _symbolic_pack_padded_sequence)(pack_padded_sequence)
+pack_padded_sequence = torch.onnx.symbolic_override(_symbolic_pack_padded_sequence)(pack_padded_sequence)
 
 
 def pad_packed_sequence(sequence, batch_first=False, padding_value=0.0, total_length=None):
@@ -264,8 +263,7 @@
     return data, lengths
 
 
-pad_packed_sequence = torch.onnx.symbolic_override_packed_sequence_based(
-    _symbolic_pad_packed_sequence)(pad_packed_sequence)
+pad_packed_sequence = torch.onnx.symbolic_override(_symbolic_pad_packed_sequence)(pad_packed_sequence)
 
 
 def pad_sequence(sequences, batch_first=False, padding_value=0):
diff --git a/torch/onnx/__init__.py b/torch/onnx/__init__.py
index 1807b71..0514343 100644
--- a/torch/onnx/__init__.py
+++ b/torch/onnx/__init__.py
@@ -56,48 +56,6 @@
     return utils._run_symbolic_method(*args, **kwargs)
 
 
-def _symbolic_override_wrapper_maker(symbolic_fn, might_trace, fn):
-
-    def wrapper(*args, **kwargs):
-        import torch
-        import torch.jit
-        from torch.autograd import Function, function
-
-        # fast pass
-        if not might_trace(args):
-            return fn(*args, **kwargs)
-
-        flat_args = tuple(function._iter_tensors_permissive(args))
-        flat_args_only_tensors = tuple(t for t in flat_args if isinstance(t, torch.Tensor))
-        if not any(map(torch._C._jit_is_tracing, flat_args_only_tensors)):
-            return fn(*args, **kwargs)
-
-        tstate = torch._C._get_tracing_state(flat_args_only_tensors)
-
-        arg_values = [torch._C._get_value_trace(tstate, x) if isinstance(x, torch.Tensor) else x for x in flat_args]
-
-        # This must come after the calls to get_value_trace, lest we
-        # lose information due to in-place operations.
-        output_vars = fn(*args, **kwargs)
-
-        symbolic_args = function._unflatten(arg_values, args)
-        output_vals = symbolic_fn(tstate.graph(), *symbolic_args, **kwargs)
-
-        for var, val in zip(
-                function._iter_tensors(output_vars),
-                function._iter_jit_values(output_vals)):
-            val.inferTypeFrom(var.data)
-            torch._C._set_value_trace(tstate, var, val)
-
-        return output_vars
-
-    # fn might be autograd.Function too, in this case wrapping doesn't work
-    if isinstance(fn, types.FunctionType):
-        wrapper = functools.wraps(fn)(wrapper)
-
-    return wrapper
-
-
 def symbolic_override(symbolic_fn):
     r"""
     Decorator to override ONNX export of the a function with specified subgraph.
@@ -123,47 +81,36 @@
         return x + y[0] + y[1]
     ```
     """
-
-    return functools.partial(_symbolic_override_wrapper_maker, symbolic_fn, lambda x: True)
-
-
-def symbolic_override_first_arg_based(symbolic_fn):
-    r"""
-    Decorator to override ONNX export of the a function with specified subgraph.
-
-    Equivalent to :func:`symbolic_override` but checks only the first argument
-    of the function to figure out whether the tracing is on. Thus the first arg
-    needs to be a Tensor.
-    """
-
-    def might_trace(args):
+    def decorator(fn):
         import torch
-        first_arg = args[0]
-        if not isinstance(first_arg, torch.Tensor):
-            raise ValueError('First argument of {} is expected to be a tensor, '
-                             'but got an object of type {}'
-                             .format(symbolic_fn.__name__, type(first_arg)))
-        return torch._C._jit_is_tracing(first_arg)
+        from torch.autograd import function
 
-    return functools.partial(_symbolic_override_wrapper_maker, symbolic_fn, might_trace)
+        def wrapper(*args, **kwargs):
+            tstate = torch._C._get_tracing_state()
+            if not tstate:
+                return fn(*args, **kwargs)
 
+            flat_args = tuple(function._iter_tensors_permissive(args))
+            arg_values = [torch._C._get_value_trace(x) if isinstance(x, torch.Tensor) else x for x in flat_args]
 
-def symbolic_override_packed_sequence_based(symbolic_fn):
-    r"""
-    Decorator to override ONNX export of the a function with specified subgraph.
+            # This must come after the calls to get_value_trace, lest we
+            # lose information due to in-place operations.
+            output_vars = fn(*args, **kwargs)
 
-    Equivalent to :func:`symbolic_override` but checks only the first argument
-    of the function to figure out whether the tracing is on. Thus the first arg
-    needs to be a Tensor.
-    """
+            symbolic_args = function._unflatten(arg_values, args)
+            output_vals = symbolic_fn(tstate.graph(), *symbolic_args, **kwargs)
 
-    def might_trace(args):
-        import torch
-        first_arg = args[0]
-        if not isinstance(first_arg, torch.nn.utils.rnn.PackedSequence):
-            raise ValueError('pad_packed_sequence expects sequence to be a '
-                             'PackedSequence, but got an object of type {}'
-                             .format(type(first_arg)))
-        return torch._C._jit_is_tracing(first_arg[0])
+            for var, val in zip(
+                    function._iter_tensors(output_vars),
+                    function._iter_jit_values(output_vals)):
+                val.inferTypeFrom(var.data)
+                torch._C._set_value_trace(var, val)
 
-    return functools.partial(_symbolic_override_wrapper_maker, symbolic_fn, might_trace)
+            return output_vars
+
+        # fn might be autograd.Function too, in this case wrapping doesn't work
+        if isinstance(fn, types.FunctionType):
+            wrapper = functools.wraps(fn)(wrapper)
+
+        return wrapper
+    return decorator
diff --git a/torch/onnx/symbolic.py b/torch/onnx/symbolic.py
index f547750..a88739c 100644
--- a/torch/onnx/symbolic.py
+++ b/torch/onnx/symbolic.py
@@ -745,6 +745,15 @@
     globals()[name] = partial(_cast_func_template, v)
 
 
+def zeros_like(g, input):
+    return g.op("Sub", input, input).setType(input.type().contiguous())
+
+
+def full_like(g, input, fill_value):
+    # TODO: a more efficient implementation (ConstantFill?)
+    return add(g, zeros_like(g, input), fill_value, alpha=torch.tensor(1))
+
+
 def slice(g, self, dim, start, end, step):
     if step != 1:
         _unimplemented("slice", "step!=1 is currently not supported")