Make JIT tracing a thread-local property (#9414)
Summary:
As in the title. Lets us simplify a lot of code.
Depends on #9363, so please review only the last commit.
zdevito
Pull Request resolved: https://github.com/pytorch/pytorch/pull/9414
Reviewed By: zdevito
Differential Revision: D8836496
Pulled By: apaszke
fbshipit-source-id: 9b3c3d1f001a9dc522f8478abc005b6b86cfa3e3
diff --git a/test/expect/TestScript.test_call_python_fn_from_traced_module.expect b/test/expect/TestScript.test_call_python_fn_from_traced_module.expect
index 2a87a36..4c9e2e2 100644
--- a/test/expect/TestScript.test_call_python_fn_from_traced_module.expect
+++ b/test/expect/TestScript.test_call_python_fn_from_traced_module.expect
@@ -1,6 +1,6 @@
graph(%0 : Double(3, 4)
%1 : Double(4, 3)) {
- %2 : Double(3, 4) = aten::neg(%0)
- %4 : Double(3, 3) = aten::mm(%2, %1)
+ %2 : Double(3, 4) = aten::neg(%0), scope: TracedModule
+ %4 : Double(3, 3) = aten::mm(%2, %1), scope: TracedModule
return (%4);
}
diff --git a/test/expect/TestScript.test_call_python_mod_from_traced_module.expect b/test/expect/TestScript.test_call_python_mod_from_traced_module.expect
index 925bbf1..d39acaf 100644
--- a/test/expect/TestScript.test_call_python_mod_from_traced_module.expect
+++ b/test/expect/TestScript.test_call_python_mod_from_traced_module.expect
@@ -1,8 +1,8 @@
graph(%0 : Double(3, 4)
%1 : Double(4, 5)
%2 : Double(5, 7)) {
- %4 : Double(3, 5) = aten::mm(%0, %1)
- %6 : Double(3, 7) = aten::mm(%4, %2)
- %7 : Double(3, 7) = aten::add[other={1}, alpha={1}](%6)
+ %4 : Double(3, 5) = aten::mm(%0, %1), scope: TracedModule
+ %6 : Double(3, 7) = aten::mm(%4, %2), scope: TracedModule/PythonModule[mod]
+ %7 : Double(3, 7) = aten::add[other={1}, alpha={1}](%6), scope: TracedModule
return (%7);
}
diff --git a/test/expect/TestScript.test_call_python_mod_from_tracing_fn.expect b/test/expect/TestScript.test_call_python_mod_from_tracing_fn.expect
index 4de15a5..ea847d6 100644
--- a/test/expect/TestScript.test_call_python_mod_from_tracing_fn.expect
+++ b/test/expect/TestScript.test_call_python_mod_from_tracing_fn.expect
@@ -1,6 +1,6 @@
graph(%0 : Double(3, 4)) {
- %1 : Double(4, 3) = prim::Constant[value=<Tensor>]()
- %3 : Double(3, 3) = aten::mm(%0, %1)
+ %1 : Double(4, 3) = prim::Constant[value=<Tensor>](), scope: PythonMod
+ %3 : Double(3, 3) = aten::mm(%0, %1), scope: PythonMod
%4 : Double(3, 3) = aten::add[other={1}, alpha={1}](%3)
return (%4);
}
diff --git a/test/expect/TestScript.test_call_script_fn_from_traced_module.expect b/test/expect/TestScript.test_call_script_fn_from_traced_module.expect
index adaab38..6bf57b8 100644
--- a/test/expect/TestScript.test_call_script_fn_from_traced_module.expect
+++ b/test/expect/TestScript.test_call_script_fn_from_traced_module.expect
@@ -1,6 +1,6 @@
graph(%0 : Double(3, 4)
%1 : Double(4, 5)) {
- %3 : Double(3, 5) = aten::mm(%0, %1)
- %5 : Double(3, 5) = aten::neg(%3)
+ %3 : Double(3, 5) = aten::mm(%0, %1), scope: TracedModule
+ %5 : Double(3, 5) = aten::neg(%3), scope: TracedModule/ScriptModule
return (%5);
}
diff --git a/test/expect/TestScript.test_call_script_fn_from_tracing_fn.expect b/test/expect/TestScript.test_call_script_fn_from_tracing_fn.expect
index cffec80..dc8b494 100644
--- a/test/expect/TestScript.test_call_script_fn_from_tracing_fn.expect
+++ b/test/expect/TestScript.test_call_script_fn_from_tracing_fn.expect
@@ -1,5 +1,5 @@
graph(%0 : Double(3, 4)) {
- %2 : Double(3, 4) = aten::neg(%0)
+ %2 : Double(3, 4) = aten::neg(%0), scope: ScriptModule
%3 : Double(3, 4) = aten::add[other={1}, alpha={1}](%2)
return (%3);
}
diff --git a/test/expect/TestScript.test_call_script_mod_from_tracing_fn.expect b/test/expect/TestScript.test_call_script_mod_from_tracing_fn.expect
index d446882..fc7039b 100644
--- a/test/expect/TestScript.test_call_script_mod_from_tracing_fn.expect
+++ b/test/expect/TestScript.test_call_script_mod_from_tracing_fn.expect
@@ -1,6 +1,6 @@
graph(%0 : Double(3, 4)) {
- %1 : Double(4, 3) = prim::Constant[value=<Tensor>]()
- %4 : Double(3, 3) = aten::mm(%0, %1)
+ %1 : Double(4, 3) = prim::Constant[value=<Tensor>](), scope: ScriptMod
+ %4 : Double(3, 3) = aten::mm(%0, %1), scope: ScriptMod
%5 : Double(3, 3) = aten::add[other={1}, alpha={1}](%4)
return (%5);
}
diff --git a/test/expect/TestScript.test_call_script_module_from_traced_module.expect b/test/expect/TestScript.test_call_script_module_from_traced_module.expect
index c249ddc..21b14a2 100644
--- a/test/expect/TestScript.test_call_script_module_from_traced_module.expect
+++ b/test/expect/TestScript.test_call_script_module_from_traced_module.expect
@@ -1,8 +1,8 @@
graph(%0 : Double(3, 4)
%1 : Double(4, 5)
%2 : Double(5, 7)) {
- %4 : Double(3, 5) = aten::mm(%0, %1)
- %7 : Double(3, 7) = aten::mm(%4, %2)
- %8 : Double(3, 7) = aten::add[other={1}, alpha={1}](%7)
+ %4 : Double(3, 5) = aten::mm(%0, %1), scope: TracedModule
+ %7 : Double(3, 7) = aten::mm(%4, %2), scope: TracedModule/ScriptMod[mod]
+ %8 : Double(3, 7) = aten::add[other={1}, alpha={1}](%7), scope: TracedModule
return (%8);
}
diff --git a/test/expect/TestScript.test_call_traced_fn_from_traced_module.expect b/test/expect/TestScript.test_call_traced_fn_from_traced_module.expect
index 4e25a85..f45c3f1 100644
--- a/test/expect/TestScript.test_call_traced_fn_from_traced_module.expect
+++ b/test/expect/TestScript.test_call_traced_fn_from_traced_module.expect
@@ -1,6 +1,6 @@
graph(%0 : Double(3, 4)
%1 : Double(4, 5)) {
- %3 : Double(3, 5) = aten::mm(%0, %1)
- %5 : Double(3, 4) = aten::neg(%3)
+ %3 : Double(3, 5) = aten::mm(%0, %1), scope: TracedModule
+ %5 : Double(3, 4) = aten::neg(%3), scope: TracedModule/traced_fn
return (%5);
}
diff --git a/test/expect/TestScript.test_call_traced_fn_from_tracing_fn.expect b/test/expect/TestScript.test_call_traced_fn_from_tracing_fn.expect
index cffec80..ed737f4 100644
--- a/test/expect/TestScript.test_call_traced_fn_from_tracing_fn.expect
+++ b/test/expect/TestScript.test_call_traced_fn_from_tracing_fn.expect
@@ -1,5 +1,5 @@
graph(%0 : Double(3, 4)) {
- %2 : Double(3, 4) = aten::neg(%0)
+ %2 : Double(3, 4) = aten::neg(%0), scope: traced_fn1
%3 : Double(3, 4) = aten::add[other={1}, alpha={1}](%2)
return (%3);
}
diff --git a/test/expect/TestScript.test_call_traced_mod_from_tracing_fn.expect b/test/expect/TestScript.test_call_traced_mod_from_tracing_fn.expect
index d446882..3fac45f 100644
--- a/test/expect/TestScript.test_call_traced_mod_from_tracing_fn.expect
+++ b/test/expect/TestScript.test_call_traced_mod_from_tracing_fn.expect
@@ -1,6 +1,6 @@
graph(%0 : Double(3, 4)) {
- %1 : Double(4, 3) = prim::Constant[value=<Tensor>]()
- %4 : Double(3, 3) = aten::mm(%0, %1)
+ %1 : Double(4, 3) = prim::Constant[value=<Tensor>](), scope: TracedModule[TracedModule]
+ %4 : Double(3, 3) = aten::mm(%0, %1), scope: TracedModule[TracedModule]
%5 : Double(3, 3) = aten::add[other={1}, alpha={1}](%4)
return (%5);
}
diff --git a/test/expect/TestScript.test_call_traced_module_from_traced_module.expect b/test/expect/TestScript.test_call_traced_module_from_traced_module.expect
index c249ddc..471f9f1 100644
--- a/test/expect/TestScript.test_call_traced_module_from_traced_module.expect
+++ b/test/expect/TestScript.test_call_traced_module_from_traced_module.expect
@@ -1,8 +1,8 @@
graph(%0 : Double(3, 4)
%1 : Double(4, 5)
%2 : Double(5, 7)) {
- %4 : Double(3, 5) = aten::mm(%0, %1)
- %7 : Double(3, 7) = aten::mm(%4, %2)
- %8 : Double(3, 7) = aten::add[other={1}, alpha={1}](%7)
+ %4 : Double(3, 5) = aten::mm(%0, %1), scope: TracedModule
+ %7 : Double(3, 7) = aten::mm(%4, %2), scope: TracedModule/TracedModule[TracedModule1][mod]
+ %8 : Double(3, 7) = aten::add[other={1}, alpha={1}](%7), scope: TracedModule
return (%8);
}
diff --git a/test/onnx/test_operators.py b/test/onnx/test_operators.py
index cbde1fa..dd60188 100644
--- a/test/onnx/test_operators.py
+++ b/test/onnx/test_operators.py
@@ -432,7 +432,7 @@
return g.op('Sum', x, y[0], y[1]), (
g.op('Neg', x), g.op('Neg', y[0]))
- @torch.onnx.symbolic_override_first_arg_based(symb)
+ @torch.onnx.symbolic_override(symb)
def foo(x, y):
return x + y[0] + y[1], (-x, -y[0])
diff --git a/test/test_jit.py b/test/test_jit.py
index 9a1661e..9a12ca9 100644
--- a/test/test_jit.py
+++ b/test/test_jit.py
@@ -333,9 +333,9 @@
def f(x, y):
out = x + y
- with torch.jit.scope('Foo', out):
+ with torch.jit.scope('Foo'):
out = x * out
- with torch.jit.scope('Bar', out):
+ with torch.jit.scope('Bar'):
out = torch.tanh(out)
out = torch.sigmoid(out)
return out
diff --git a/tools/autograd/gen_variable_type.py b/tools/autograd/gen_variable_type.py
index 679e3c9..97130b1 100644
--- a/tools/autograd/gen_variable_type.py
+++ b/tools/autograd/gen_variable_type.py
@@ -54,6 +54,13 @@
's_native_mul': 'mul',
'th_addmm': 'addmm',
's_native_addmm': 'addmm',
+ 'zero': 'zeros_like',
+ 'fill': 'full_like',
+}
+
+# (declaration name, argument name) -> attribute name
+RENAME_ATTRIBUTES = {
+ ('fill_', 'value'): 'fill_value'
}
# These functions are not worth profiling because they are very cheap and may
@@ -126,7 +133,7 @@
PRE_RECORD_TRACE = CodeTemplate("""\
jit::tracer::PreTraceInfo trace_info;
-if (jit::tracer::isTracing( ${tensor_args} )) {
+if (jit::tracer::isTracing()) {
trace_info = jit::tracer::preRecordTrace( jit::aten::${trace_name}, ${trace_inputs} );
if (!jit::tracer::ArgumentStash::empty()) {
${record_positional_attributes}
@@ -138,13 +145,13 @@
""")
POST_RECORD_TRACE = CodeTemplate("""\
-if (trace_info.state != nullptr) {
+if (jit::tracer::isTracing()) {
jit::tracer::postRecordTrace( trace_info, ${trace_outputs} );
}
""")
RECORD_ATTRIBUTE = CodeTemplate("""\
-setattr(trace_info.n, jit::attr::${name}, ${name});""")
+setattr(trace_info.n, jit::attr::${attr_name}, ${name});""")
RECORD_POSITIONAL_ATTRIBUTE = CodeTemplate("""\
setposattr(trace_info.n, ${i}, "${name}", ${name});""")
@@ -417,7 +424,8 @@
for arg in declaration['arguments']:
if arg['simple_type'] in {'Tensor', 'TensorList'}:
continue
- local['record_attributes'].append(RECORD_ATTRIBUTE.substitute(name=arg['name']))
+ attr_name = RENAME_ATTRIBUTES.get((declaration['name'], arg['name']), arg['name'])
+ local['record_attributes'].append(RECORD_ATTRIBUTE.substitute(attr_name=attr_name, name=arg['name']))
local['record_positional_attributes'] = []
for i, arg in enumerate(declaration['arguments']):
diff --git a/tools/autograd/templates/VariableType.cpp b/tools/autograd/templates/VariableType.cpp
index 0ca293f..03e0f64 100644
--- a/tools/autograd/templates/VariableType.cpp
+++ b/tools/autograd/templates/VariableType.cpp
@@ -39,10 +39,14 @@
namespace torch { namespace autograd {
// Helper methods for working with Attributes (torch/csrc/jit/attributes.h)
+at::Tensor maybeUnwrapVar(const at::Tensor& t) {
+ return t.is_variable() ? Variable(t).data() : t;
+}
+
// The overloaded accessors are convenient for the generated code (since we
// don't want to make the codegen do the dispatch manually)
static void setattr(jit::Node* n, jit::Symbol name, int64_t v) { n->i_(name, v); }
-static void setattr(jit::Node* n, jit::Symbol name, const at::Scalar& v) { n->t_(name, v.toTensor()); }
+static void setattr(jit::Node* n, jit::Symbol name, const at::Scalar& v) { n->t_(name, maybeUnwrapVar(v.toTensor())); }
static void setattr(jit::Node* n, jit::Symbol name, SparseTensorRef s) { n->t_(name, s.tref); }
static void setattr(jit::Node* n, jit::Symbol name, const at::IntList& v) { n->is_(name, v); }
static void setattr(jit::Node* n, jit::Symbol name, bool v) { n->i_(name, v); }
diff --git a/tools/autograd/templates/python_variable_methods.cpp b/tools/autograd/templates/python_variable_methods.cpp
index 8e8c876..9e44054 100644
--- a/tools/autograd/templates/python_variable_methods.cpp
+++ b/tools/autograd/templates/python_variable_methods.cpp
@@ -140,7 +140,7 @@
ParsedArgs<3> parsed_args;
auto r = parser.parse(args, kwargs, parsed_args);
if (r.idx == 0) {
- if (jit::tracer::isTracing(self_)) {
+ if (jit::tracer::isTracing()) {
return wrap(jit::tracer::getSizeOf(self_, r.toInt64(0)));
} else {
return wrap(self_.size(r.toInt64(0)));
diff --git a/torch/CMakeLists.txt b/torch/CMakeLists.txt
index c6fe078..02fd042 100644
--- a/torch/CMakeLists.txt
+++ b/torch/CMakeLists.txt
@@ -159,10 +159,8 @@
${TORCH_SRC_DIR}/csrc/jit/script/lexer.cpp
${TORCH_SRC_DIR}/csrc/jit/script/module.cpp
${TORCH_SRC_DIR}/csrc/jit/test_jit.cpp
- ${TORCH_SRC_DIR}/csrc/jit/tracer_state.cpp
${TORCH_SRC_DIR}/csrc/jit/tracer.cpp
${TORCH_SRC_DIR}/csrc/jit/type.cpp
- ${TORCH_SRC_DIR}/csrc/jit/variable_flags.cpp
${TORCH_SRC_DIR}/csrc/onnx/onnx.cpp
${TORCH_SRC_DIR}/csrc/onnx/onnx.npb.cpp
${TORCH_SRC_DIR}/csrc/torch.cpp
diff --git a/torch/csrc/Size.cpp b/torch/csrc/Size.cpp
index 0708eba..95c7c64 100644
--- a/torch/csrc/Size.cpp
+++ b/torch/csrc/Size.cpp
@@ -14,7 +14,7 @@
PyObject * THPSize_New(const torch::autograd::Variable& var)
{
- if (!torch::jit::tracer::isTracing(var)) {
+ if (!torch::jit::tracer::isTracing()) {
auto sizes = var.sizes();
return THPSize_NewFromSizes(var.dim(), sizes.data());
}
@@ -38,10 +38,10 @@
return self.release();
}
-static bool isTracedVar(PyObject *item) {
+static bool isTracedZeroDimVar(PyObject *item) {
if (!THPVariable_Check(item)) return false;
auto & var = reinterpret_cast<THPVariable*>(item)->cdata;
- return torch::jit::tracer::isTracing(var);
+ return var.dim() == 0 && torch::jit::tracer::getValueTrace(var);
}
static PyObject * THPSize_pynew(PyTypeObject *type, PyObject *args, PyObject *kwargs)
@@ -50,10 +50,10 @@
if (self) {
for (Py_ssize_t i = 0; i < PyTuple_Size(self); ++i) {
PyObject *item = PyTuple_GET_ITEM(self.get(), i);
- if (isTracedVar(item)) {
+ if (THPUtils_checkLong(item)) {
continue;
}
- if (THPUtils_checkLong(item)) {
+ if (torch::jit::tracer::isTracing() && isTracedZeroDimVar(item)) {
continue;
}
// item.__index__() works with 0-dim tensors and tensors with one element
diff --git a/torch/csrc/autograd/function.h b/torch/csrc/autograd/function.h
index e4b6bef..b65a706 100644
--- a/torch/csrc/autograd/function.h
+++ b/torch/csrc/autograd/function.h
@@ -8,7 +8,6 @@
#include "torch/csrc/autograd/saved_variable.h"
#include "torch/csrc/autograd/type_and_shape.h"
#include "torch/csrc/autograd/variable.h"
-#include "torch/csrc/utils/auto_unique_ptr.h"
#include "torch/csrc/utils/python_stub.h"
#include "torch/csrc/utils/variadic.h"
diff --git a/torch/csrc/autograd/python_function.cpp b/torch/csrc/autograd/python_function.cpp
index 5494334..8c20c6c 100644
--- a/torch/csrc/autograd/python_function.cpp
+++ b/torch/csrc/autograd/python_function.cpp
@@ -550,7 +550,7 @@
}
static void _assert_not_tracing(const char* name, const variable_list& input_vars) {
- if (tracer::isTracingVar(input_vars)) {
+ if (tracer::isTracing()) {
std::ostringstream oss;
oss << "Attempted to trace " << name;
oss << ", but tracing of legacy functions is not supported";
@@ -562,7 +562,7 @@
PyObject* op_obj,
PyObject *input_objects,
const variable_list& input_vars) {
- if (!tracer::isTracingVar(input_vars)) {
+ if (!jit::tracer::isTracing()) {
return jit::tracer::PreTraceInfo();
}
@@ -598,7 +598,7 @@
const variable_list& input_vars,
PyObject *output_objects,
bool is_inplace) {
- if (!trace_info.state) {
+ if (!jit::tracer::isTracing()) {
return;
}
@@ -612,7 +612,6 @@
jit::tracer::postRecordTrace(trace_info, output_vars);
- auto state_lock = trace_info.state->lock();
trace_info.n->i_(attr::inplace, is_inplace);
}
@@ -640,11 +639,6 @@
bool is_inplace = static_cast<bool>(grad_fn->dirty_tensors);
_wrap_outputs(grad_fn, inputs, raw_output, outputs, is_executable);
- // NOTE: _trace_post_record has to run before _save_variables, because we need
- // to assign traces to outputs before we convert them to SavedVariables.
- // On the other hand, it needs to go after _mark_non_differentiable, because
- // it might be wraping backwards in Evals, and _mark_non_differentiable uses
- // grad_fn pointer equality for error checking.
_trace_post_record(trace_info, op_obj, unpacked.input_vars, outputs, is_inplace);
if (is_executable) {
_save_variables(grad_fn);
@@ -715,10 +709,6 @@
// Record input nodes if tracing
auto trace_info = _trace_pre_record(cls, inputs, unpacked_input.input_vars);
- if (trace_info.state) {
- // TODO: ezyang suggests this is unused and can be removed
- ctx->is_traced = true;
- }
// Initialize backward function (and ctx)
bool is_executable = input_info.is_executable;
@@ -1009,7 +999,6 @@
{"dirty_tensors", &getObject<&THPFunction::dirty_tensors>, &setObject<&THPFunction::dirty_tensors>, nullptr, nullptr},
{"needs_input_grad", &getObject<&THPFunction::needs_input_grad>, nullptr, nullptr, nullptr},
{"requires_grad", getRequiresGrad, nullptr, nullptr, nullptr},
- {"_is_tracing", &getMember<char, &THPFunction::is_traced, PyBool_FromLong>, nullptr, nullptr, nullptr},
{"metadata", (getter)THPFunction_metadata, nullptr, nullptr, nullptr},
{nullptr}
};
diff --git a/torch/csrc/autograd/python_function.h b/torch/csrc/autograd/python_function.h
index 7bc7548..bdbca10 100644
--- a/torch/csrc/autograd/python_function.h
+++ b/torch/csrc/autograd/python_function.h
@@ -90,7 +90,6 @@
// For each input, true if the input is a THPVariable
std::vector<bool> is_variable_input;
char has_freed_buffers;
- char is_traced;
// The C++ wrapper for this Python function.
// See a comment in THPFunction_asFunction for details about this field.
diff --git a/torch/csrc/autograd/python_variable.cpp b/torch/csrc/autograd/python_variable.cpp
index a93f4e6..607f3b7 100644
--- a/torch/csrc/autograd/python_variable.cpp
+++ b/torch/csrc/autograd/python_variable.cpp
@@ -16,7 +16,6 @@
#include "torch/csrc/autograd/generated/VariableType.h"
#include "torch/csrc/autograd/utils/python_error_messages.h"
#include "torch/csrc/autograd/utils/wrap_outputs.h"
-#include "torch/csrc/jit/tracer_state.h"
#include "torch/csrc/tensor/python_tensor.h"
#include "torch/csrc/utils/auto_gil.h"
#include "torch/csrc/utils/cuda_lazy_init.h"
diff --git a/torch/csrc/autograd/saved_variable.cpp b/torch/csrc/autograd/saved_variable.cpp
index 889f456..0f93c13 100644
--- a/torch/csrc/autograd/saved_variable.cpp
+++ b/torch/csrc/autograd/saved_variable.cpp
@@ -3,7 +3,6 @@
#include "torch/csrc/autograd/edge.h"
#include "torch/csrc/autograd/function.h"
#include "torch/csrc/autograd/variable.h"
-#include "torch/csrc/jit/tracer_state.h"
#include <ATen/Tensor.h>
@@ -29,10 +28,6 @@
}
version_counter_ = variable.version_counter();
saved_version_ = version_counter_.current_version();
- if (variable.has_tracing_state()) {
- tracing_state_.reset(
- new jit::tracer::ValueTracingState(variable.tracing_state()));
- }
}
}
@@ -78,9 +73,6 @@
if (requires_grad_ && !var.grad_fn() && grad_accumulator_.expired())
throw std::logic_error("No grad accumulator for a saved leaf!");
var.set_grad_accumulator(grad_accumulator_);
- if (tracing_state_) {
- var.set_tracing_state(new jit::tracer::ValueTracingState(*tracing_state_));
- }
return var;
}
diff --git a/torch/csrc/autograd/saved_variable.h b/torch/csrc/autograd/saved_variable.h
index 96dc8bd..ff5a36b 100644
--- a/torch/csrc/autograd/saved_variable.h
+++ b/torch/csrc/autograd/saved_variable.h
@@ -2,7 +2,6 @@
#include "torch/csrc/WindowsTorchApiMacro.h"
#include "torch/csrc/autograd/variable_version.h"
-#include "torch/csrc/jit/tracer_state.h"
#include <ATen/ATen.h>
@@ -44,7 +43,6 @@
// passed in to the unpack function when reconstructing the Variable.
std::shared_ptr<Function> grad_fn_;
std::weak_ptr<Function> grad_accumulator_;
- std::unique_ptr<jit::tracer::ValueTracingState> tracing_state_;
VariableVersion version_counter_;
uint32_t saved_version_;
diff --git a/torch/csrc/autograd/variable.cpp b/torch/csrc/autograd/variable.cpp
index 16e8105..7654c4e 100644
--- a/torch/csrc/autograd/variable.cpp
+++ b/torch/csrc/autograd/variable.cpp
@@ -9,8 +9,6 @@
#include "torch/csrc/autograd/generated/Functions.h"
#include "torch/csrc/autograd/generated/VariableType.h"
#include "torch/csrc/autograd/variable_version.h"
-#include "torch/csrc/jit/tracer_state.h"
-#include "torch/csrc/utils/auto_unique_ptr.h"
#include <ATen/ATen.h>
@@ -141,7 +139,6 @@
grad_.reset();
grad_fn_.reset();
hooks_.clear();
- tracing_state_.reset();
}
Variable::ViewImpl::ViewImpl(Variable base, at::Tensor data, Edge gradient_edge)
@@ -205,13 +202,4 @@
}
}
-void Variable::set_tracing_state(
- jit::tracer::ValueTracingState* new_tracing_state) {
- get()->tracing_state_.reset(new_tracing_state);
-}
-
-jit::tracer::ValueTracingState& Variable::tracing_state() const noexcept {
- return *get()->tracing_state_;
-}
-
}} // namespace torch::autograd
diff --git a/torch/csrc/autograd/variable.h b/torch/csrc/autograd/variable.h
index 260eb93..2def489 100644
--- a/torch/csrc/autograd/variable.h
+++ b/torch/csrc/autograd/variable.h
@@ -7,7 +7,6 @@
#include "torch/csrc/autograd/edge.h"
#include "torch/csrc/autograd/function_hook.h"
#include "torch/csrc/autograd/variable_version.h"
-#include "torch/csrc/utils/auto_unique_ptr.h"
#include <ATen/ATen.h>
#include <ATen/Error.h>
@@ -20,20 +19,10 @@
#include <utility>
#include <vector>
-namespace torch {
-namespace autograd {
-struct Function;
-} // namespace autograd
-namespace jit { namespace tracer {
-// Has to be forward declared because tracer_state.h has a dependency on
-// variable.h.
-struct ValueTracingStateElem;
-using ValueTracingState = std::list<ValueTracingStateElem>;
-}} // namespace jit::tracer
-} // namespace torch
-
namespace torch { namespace autograd {
+struct Function;
+
///~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
/// Variable
///~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
@@ -228,15 +217,6 @@
const std::vector<std::shared_ptr<FunctionPreHook>>& hooks() const noexcept;
void clear_hooks();
- // JIT Tracing
- //~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
-
- void set_tracing_state(jit::tracer::ValueTracingState* new_tracing_state);
- jit::tracer::ValueTracingState& tracing_state() const noexcept;
-
- /// Returns true if the `Variable`'s tracing state is not null.
- bool has_tracing_state() const noexcept;
-
// View Variables
//~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
@@ -379,9 +359,6 @@
// state are still thread-safe. Used by get_grad_fn and
// get_grad_accumulator.
std::mutex mutex_;
-
- // For use in torch::jit::tracer
- auto_unique_ptr<jit::tracer::ValueTracingState> tracing_state_;
};
//~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
@@ -564,13 +541,6 @@
get()->hooks_.clear();
}
-// JIT Tracing
-//~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
-
-inline bool Variable::has_tracing_state() const noexcept {
- return get()->tracing_state_ != nullptr;
-}
-
// View Variables
//~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
diff --git a/torch/csrc/jit/graph_executor.cpp b/torch/csrc/jit/graph_executor.cpp
index 5ef60d9..2b0e0b4 100644
--- a/torch/csrc/jit/graph_executor.cpp
+++ b/torch/csrc/jit/graph_executor.cpp
@@ -226,7 +226,7 @@
// there is no need to optimize, but we do need to splice the graph of
// this excutor into the trace. Otherwise we might unroll control-flow
// operations.
- if(isTracing(inputs)) {
+ if(tracer::isTracing()) {
return runTraced(std::move(inputs));
}
@@ -274,26 +274,11 @@
private:
friend struct GraphExecutor;
- // TODO: switching tracing to be part of the local thread state, instead of
- // a per-variable property will make this check significantly faster.
- // It is along the fast path, so this is important.
- static bool isTracing(const variable_tensor_list& inputs) {
- for(auto & i : inputs) {
- if(i.defined() && tracer::isTracingVar(autograd::as_variable_ref(i)))
- return true;
- }
- return false;
- }
variable_tensor_list runTraced(variable_tensor_list inputs) {
- // TODO: unnecessary copy to variable_list
- variable_list input_vars(inputs.begin(), inputs.end());
- auto state = tracer::getTracingState(input_vars);
- auto input_values = fmap(input_vars, [&](const Variable& v) {
- return tracer::getValueTrace(state, v);
- });
+ auto state = tracer::getTracingState();
+ auto input_values = fmap(inputs, tracer::getValueTrace);
ArgumentSpec spec(autograd::GradMode::is_enabled(), inputs);
- input_vars.clear(); // don't hold inputs during execution
auto outputs = runFallback(std::move(inputs));
auto all_dynamic = [](const at::ArrayRef<Value*> xs) {
@@ -316,7 +301,7 @@
auto output_values = script::inlineCallTo(*state->graph, *local_graph, input_values);
for(size_t i = 0; i < outputs.size(); ++i) {
- tracer::setValueTrace(state, outputs[i], output_values[i]);
+ tracer::setValueTrace(outputs[i], output_values[i]);
}
return outputs;
}
diff --git a/torch/csrc/jit/ir.h b/torch/csrc/jit/ir.h
index a42915b..2b55502 100644
--- a/torch/csrc/jit/ir.h
+++ b/torch/csrc/jit/ir.h
@@ -7,7 +7,6 @@
#include "torch/csrc/jit/resource_guard.h"
#include "torch/csrc/jit/source_location.h"
#include "torch/csrc/jit/type.h"
-#include "torch/csrc/jit/variable_flags.h"
#include "torch/csrc/utils/disallow_copy.h"
#include "torch/csrc/utils/functional.h"
diff --git a/torch/csrc/jit/passes/onnx.h b/torch/csrc/jit/passes/onnx.h
index bd6f6e4..a58d421 100644
--- a/torch/csrc/jit/passes/onnx.h
+++ b/torch/csrc/jit/passes/onnx.h
@@ -1,7 +1,6 @@
#pragma once
#include "torch/csrc/jit/ir.h"
-#include "torch/csrc/jit/tracer_state.h"
#include "torch/csrc/onnx/onnx.h"
namespace torch { namespace jit {
diff --git a/torch/csrc/jit/python_ir.cpp b/torch/csrc/jit/python_ir.cpp
index 534297a..f03edfe 100644
--- a/torch/csrc/jit/python_ir.cpp
+++ b/torch/csrc/jit/python_ir.cpp
@@ -444,7 +444,7 @@
return t.expect<TensorType>()->strides();
})
.def("contiguous",[](Type& t) {
- return t.expect<TensorType>()->contiguous();
+ return std::static_pointer_cast<Type>(t.expect<TensorType>()->contiguous());
})
.def("scalarType",[](Type& t) {
return at::toString(t.expect<TensorType>()->scalarType());
@@ -471,8 +471,5 @@
}
return std::make_tuple(graph, variables);
});
- m.def("_jit_is_tracing", [](const autograd::Variable& var) {
- return tracer::isTracing(var);
- });
}
}}
diff --git a/torch/csrc/jit/python_tracer.cpp b/torch/csrc/jit/python_tracer.cpp
index 2ad7a79..6397877 100644
--- a/torch/csrc/jit/python_tracer.cpp
+++ b/torch/csrc/jit/python_tracer.cpp
@@ -46,21 +46,26 @@
tracer::variable_list trace_inputs,
size_t num_func_inputs) {
auto enter_info = tracer::enter(std::move(trace_inputs));
- py::tuple py_inputs(num_func_inputs);
- for(size_t i = 0; i < num_func_inputs; ++i) {
- py_inputs[i] = py::cast(enter_info.second[i]);
+ try {
+ py::tuple py_inputs(num_func_inputs);
+ for(size_t i = 0; i < num_func_inputs; ++i) {
+ py_inputs[i] = py::cast(enter_info.second[i]);
+ }
+ auto out = func(*py_inputs);
+ std::vector<autograd::Variable> outputs;
+ if(PyTuple_Check(out.ptr())) {
+ outputs = py::cast<std::vector<autograd::Variable>>(out);
+ } else {
+ outputs.push_back(py::cast<autograd::Variable>(out));
+ }
+ tracer::exit(outputs);
+ auto graph = enter_info.first->graph;
+ EliminateDeadCode(graph);
+ return graph;
+ } catch (...) {
+ tracer::abandon();
+ throw;
}
- auto out = func(*py_inputs);
- std::vector<autograd::Variable> outputs;
- if(PyTuple_Check(out.ptr())) {
- outputs = py::cast<std::vector<autograd::Variable>>(out);
- } else {
- outputs.push_back(py::cast<autograd::Variable>(out));
- }
- tracer::exit(outputs);
- auto graph = enter_info.first->graph;
- EliminateDeadCode(graph);
- return graph;
}
PreTraceInfo preRecordPythonTrace(THPObjectPtr pyobj,
@@ -119,17 +124,17 @@
m.def("_tracer_exit", [](variable_list var_outputs) {
tracer::exit(var_outputs);
});
- m.def("_get_tracing_state", [](const variable_list& vars) {
- return getTracingState(vars);
+ m.def("_tracer_abandon", []() {
+ tracer::abandon();
});
- m.def("_get_value_trace", [](std::shared_ptr<TracingState>& state, const Variable& var) {
- return getValueTrace(state, var);
+ m.def("_get_tracing_state", []() {
+ return getTracingState();
});
- m.def("_set_value_trace", [](std::shared_ptr<TracingState>& state, const Variable& var, Value* value) {
- return setValueTrace(state, var, value);
+ m.def("_get_value_trace", [](const Variable& var) {
+ return getValueTrace(var);
});
- m.def("_is_tracing", [](const variable_list& vars) {
- return isTracingVar(vars);
+ m.def("_set_value_trace", [](const Variable& var, Value* value) {
+ return setValueTrace(var, value);
});
}
diff --git a/torch/csrc/jit/tracer.cpp b/torch/csrc/jit/tracer.cpp
index 0fda835..a86059f 100644
--- a/torch/csrc/jit/tracer.cpp
+++ b/torch/csrc/jit/tracer.cpp
@@ -13,6 +13,28 @@
namespace torch { namespace jit { namespace tracer {
+////////////////////////////////////////////////////////////////////////////////
+// Recording the traces
+////////////////////////////////////////////////////////////////////////////////
+namespace detail {
+
+thread_local std::shared_ptr<TracingState> tracing_state;
+
+} // namespace detail
+
+const std::shared_ptr<TracingState>& getTracingState() {
+ return detail::tracing_state;
+}
+
+void setTracingState(std::shared_ptr<TracingState> state) {
+ detail::tracing_state = std::move(state);
+}
+
+TracingState::TracingState()
+ : graph(new Graph()) {}
+
+TracingState::~TracingState() = default;
+
PreTraceInfo preRecordTrace(Symbol op,
at::ArrayRef<Variable> inputs) {
return makePreTraceInfo(inputs, [&op](const std::shared_ptr<TracingState>& state, Graph& graph) {
@@ -22,14 +44,10 @@
void postRecordTrace(const PreTraceInfo& info,
at::ArrayRef<Variable> outputs) {
- // TODO: Technically, we could reduce the scope of the lock, but since we
- // haven't actually specified what the locking contract is, be conservative.
- auto state_lock = info.state->lock();
-
auto assignOutput = [&info](const Variable & output, Value * value) {
if (output.defined()) {
value->inferTypeFrom(output.data());
- setValueTrace(info.state, output, value);
+ setValueTrace(output, value);
}
};
@@ -38,35 +56,39 @@
}
}
-thread_local ArgumentStash ArgumentStash::stash;
-
-void ArgumentStash::stashIntListElem(const std::string& arg_name, size_t size, size_t idx, const Variable& var) {
- // TODO: check type?
- if (!isTracing(var)) return;
- auto tracing_state = getTracingState({var});
- auto & list_trace = stash.intlists.emplace(arg_name, size).first->second;
- JIT_ASSERT(size == list_trace.size());
- JIT_ASSERT(idx < list_trace.size());
- JIT_ASSERT(list_trace[idx] == nullptr);
- list_trace[idx] = getValueTrace(tracing_state, var);
-}
-
autograd::Variable getSizeOf(const autograd::Variable& var, int64_t dim) {
- auto tracing_state = getTracingState({var});
+ auto & tracing_state = getTracingState();
auto & graph = tracing_state->graph;
auto size_var = autograd::make_variable(at::Scalar(var.size(dim)).toTensor());
- auto* value = getValueTrace(tracing_state, var);
+ auto* value = getValueTrace(var);
auto* node = graph->create(aten::size, {value})
->i_(attr::dim, dim);
node->output()->inferTypeFrom(size_var);
graph->appendNode(node);
- setValueTrace(tracing_state, size_var, node->output());
+ setValueTrace(size_var, node->output());
return size_var;
}
+////////////////////////////////////////////////////////////////////////////////
+// Argument stash
+////////////////////////////////////////////////////////////////////////////////
+thread_local ArgumentStash ArgumentStash::stash;
+void ArgumentStash::stashIntListElem(const std::string& arg_name, size_t size, size_t idx, const Variable& var) {
+ // TODO: check type?
+ if (!isTracing()) return;
+ auto & list_trace = stash.intlists.emplace(arg_name, size).first->second;
+ JIT_ASSERT(size == list_trace.size());
+ JIT_ASSERT(idx < list_trace.size());
+ JIT_ASSERT(list_trace[idx] == nullptr);
+ list_trace[idx] = getValueTrace(var);
+}
+
+////////////////////////////////////////////////////////////////////////////////
+// Stack trace recording
+////////////////////////////////////////////////////////////////////////////////
// no python present so we just do not record source information
void defaultRecordSourceLocation(Node* n) {}
std::atomic<decltype(&defaultRecordSourceLocation)> record_source_location(defaultRecordSourceLocation);
diff --git a/torch/csrc/jit/tracer.h b/torch/csrc/jit/tracer.h
index 5775091..2b8f32e 100644
--- a/torch/csrc/jit/tracer.h
+++ b/torch/csrc/jit/tracer.h
@@ -1,13 +1,12 @@
#pragma once
#include "torch/csrc/jit/ir.h"
-#include "torch/csrc/jit/tracer_state.h"
#include "torch/csrc/assertions.h"
#include "torch/csrc/utils/functional.h"
#include "torch/csrc/utils/variadic.h"
#include "torch/csrc/autograd/function_hook.h"
#include "torch/csrc/autograd/variable.h"
-#include "torch/csrc/utils/auto_unique_ptr.h"
+
#include <memory>
#include <mutex>
#include <vector>
@@ -20,38 +19,27 @@
using torch::autograd::Variable;
using variable_list = std::vector<Variable>;
-namespace detail {
+struct TracingState : public std::enable_shared_from_this<TracingState> {
+ TracingState();
+ ~TracingState();
-inline ValueTracingStateElem* getValueState(const std::shared_ptr<TracingState>& state, const Variable& var, bool alloc = true) {
- auto& tracing_state = var.tracing_state();
- for (auto it = tracing_state.begin(); it != tracing_state.end();) {
- auto ts = it->state.lock();
- // GC of invalidated tracing states
- if (!ts) {
- auto current_it = it++;
- tracing_state.erase(current_it);
- continue;
- } else if (ts == state) {
- return &(*it);
+ using WeakTensor = at::WeakTensor;
+
+ struct WeakTensorHasher {
+ size_t operator()(const WeakTensor& t) const {
+ return std::hash<void*>()(t.unsafeGetTensorImpl());
}
- ++it;
- }
- if (alloc) {
- tracing_state.emplace_front();
- auto & vts = tracing_state.front();
- vts.state = state;
- return &vts;
- } else {
- return nullptr;
- }
-}
+ };
-inline bool isElemActive(const ValueTracingStateElem& vts) {
- auto state = vts.state.lock();
- return state && state->active;
-}
+ struct WeakTensorEq {
+ bool operator()(const WeakTensor& t1, const WeakTensor& t2) const {
+ return t1.unsafeGetTensorImpl() == t2.unsafeGetTensorImpl();
+ }
+ };
-} // namespace detail
+ std::unordered_map<WeakTensor, Value*, WeakTensorHasher, WeakTensorEq> value_map;
+ std::shared_ptr<Graph> graph;
+};
// This is meant to be used as a thread local place, where we can store extra
@@ -91,76 +79,27 @@
std::unordered_map<std::string, IntListTrace> intlists;
};
-// Should a function which takes 'vars' as inputs be traced?
-// It suffices for ONE variable to be tracing: any "untraced" variables
-// are treated as constants.
-//
-// NB: This code lives in the hotpath; make sure it is fast
-//
-// NB: Variable overload is not variadic because we don't actually
-// need it (in most cases if we have a variable_list it is already
-// flattened).
-inline bool isTracingVar(const Variable& var) {
- if (!var.defined() || !var.has_tracing_state()) return false;
- return std::any_of(var.tracing_state().begin(), var.tracing_state().end(), detail::isElemActive);
+// Retrieve or set the current tracing state. Returns a nullptr if tracing is disabled.
+const std::shared_ptr<TracingState>& getTracingState();
+void setTracingState(std::shared_ptr<TracingState> state);
+
+inline bool isTracing() {
+ return static_cast<bool>(getTracingState());
}
-inline bool isTracingVar(at::ArrayRef<Variable> vars) {
- // Reference to avoid refcount bump
- for (const Variable& var : vars) {
- if (isTracingVar(var)) return true;
- }
- return false;
-}
-
-struct IsTracing : IterArgs<IsTracing> {
- bool out = false;
- using IterArgs<IsTracing>::operator();
- void operator()(const at::Tensor& var) {
- out = out || isTracingVar(var);
- }
- bool short_circuit() { return out; }
-};
-
-// To be called with Tensor arguments from generated code
-template<typename... Args>
-inline bool isTracing(Args&&... args) {
- return IsTracing().apply(std::forward<Args>(args)...).out;
-}
-
-// Retrieve the tracing state which a function applied with 'vars' should
-// be recorded to. Precondition: isTracing(vars) == true. At the moment,
-// we don't support mixing up variables from different traces; this code
-// will need to be revisited if that ever becomes supported.
-inline std::shared_ptr<TracingState> getTracingState(const variable_list& vars) {
- std::shared_ptr<TracingState> state;
- for (auto& var : vars) {
- if (!var.defined() || !var.has_tracing_state()) continue;
- for (auto & vts : var.tracing_state()) {
- auto var_state = vts.state.lock();
- if (!var_state || !var_state->active) continue;
- if (!state) state = var_state;
- JIT_ASSERT(var_state == state);
- }
- }
- JIT_ASSERT(state);
- return state;
-}
-
-// Having finished adding a new 'node' to the graph IR owned by TracingState 'state',
-// 'setValueTrace' associates this node with an output variable, so that further operations
-// involving this variable know which node in the IR to reference.
-inline void setValueTrace(const std::shared_ptr<TracingState>& state, const Variable& var, Value *value) {
+// Having finished adding a new 'node' to the graph IR 'setValueTrace' associates
+// this node with an output variable, so that further operations involving this
+// variable know which node in the IR to reference.
+inline void setValueTrace(const Variable& var, Value *value) {
JIT_ASSERT(var.defined());
- auto vts = detail::getValueState(state, var);
- vts->trace = value;
+ getTracingState()->value_map[var] = value;
}
// Given a variable 'var', return the 'node' which represents the instruction
-// which computes the value of this variable in the IR. When 'mustExist' is
-// false, we interpret untraced variables as constants that are just embedded
+// which computes the value of this variable in the IR.
+// Here, we interpret untraced variables as constants that are just embedded
// in the graph. This is useful to handle code which does things like this
-// (from torch.autograd.variable):
+// (from torch.autograd.variable, now moved to C++):
//
// def mm(self, matrix):
// output = Variable(self.data.new(self.data.size(0), matrix.data.size(1)))
@@ -170,19 +109,21 @@
// update on, but subsequently ignores it because the alpha scaling factor is zero.
// This is one of the cases where a Variable can be created inside of a trace, and
// if we treat it as a constant, everything will work out.
-inline Value* getValueTrace(const std::shared_ptr<TracingState>& state, const Variable& var) {
+inline Value* getValueTrace(const Variable& var) {
+ auto &state = getTracingState();
if (!var.defined()) {
Node *n = state->graph->createUndefined();
return state->graph->appendNode(n)->output();
}
- auto vts = detail::getValueState(state, var, true);
- if (vts->trace) return vts->trace;
-
- Value *constant = state->graph->appendNode(state->graph->createConstant(var.data()))->output();
- constant->inferTypeFrom(var.data());
- setValueTrace(state, var, constant);
- return constant;
+ auto & value_map = getTracingState()->value_map;
+ auto it = value_map.find(var);
+ if (it == value_map.end()) {
+ Value *constant = state->graph->appendNode(state->graph->createConstant(var.data()))->output();
+ constant->inferTypeFrom(var.data());
+ it = value_map.emplace_hint(it, var, constant);
+ }
+ return it->second;
}
inline Value* getOutputTrace(const std::shared_ptr<TracingState>& state, const Variable& var, size_t output_no) {
@@ -191,36 +132,37 @@
return state->graph->appendNode(n)->output();
}
- auto vts = detail::getValueState(state, var, false);
- if (!vts) {
+ auto & value_map = getTracingState()->value_map;
+ auto it = value_map.find(var);
+ if (it == value_map.end()) {
std::ostringstream os;
os << "output " << output_no << " of traced region did not have observable "
<< "data dependence with trace inputs; this probably indicates your program "
<< "cannot be understood by the tracer.";
throw std::runtime_error(os.str());
}
- return vts->trace;
+ return it->second;
}
// Start tracing, treating 'inputs' as inputs to the trace, which can be
// varied on subsequent invocations of the trace. Any other variables
// will be treated as constants.
-//
-// NB: Why does this take an rvalue reference? We need to get a non-const
-// reference to at::Tensor buffer to call unsafeGetTH, but you can't get this
-// out of a const vector (silly std::vector...)
inline std::pair<std::shared_ptr<TracingState>, variable_list> enter(
variable_list inputs) {
+ if (isTracing()) {
+ AT_ERROR("Tracing can't be nested");
+ }
auto state = std::make_shared<TracingState>();
+ setTracingState(state);
for (auto& input : inputs) {
- auto * value_state = detail::getValueState(state, input, false);
+ auto * value_state = state->value_map[input];
if (value_state) {
// See Note [Repeated inputs] in tracer.cpp
input = input.view(input.sizes());
}
auto input_node = state->graph->addInput(input.name());
- setValueTrace(state, input, input_node);
input_node->inferTypeFrom(input.data());
+ state->value_map[input] = input_node;
}
return std::make_pair(state, inputs);
}
@@ -229,27 +171,29 @@
// are the variables whose values will be computed upon subsequent
// invocations of the trace.
inline void exit(const variable_list& outputs) {
- auto state = getTracingState(outputs);
+ auto & state = getTracingState();
size_t i = 0;
for (auto& output : outputs) {
state->graph->registerOutput(getOutputTrace(state, output, i));
i++;
}
- state->active = false;
+ setTracingState(nullptr);
+}
+
+// Abort tracing. Used to reset the state in case of errors.
+inline void abandon() {
+ setTracingState(nullptr);
}
// Pre-recorded information about the trace before we actually carry
// out the trace
struct PreTraceInfo {
- std::shared_ptr<TracingState> state;
Node *n;
};
PreTraceInfo preRecordTrace(Symbol op, at::ArrayRef<Variable> inputs);
void postRecordTrace(const PreTraceInfo& info, at::ArrayRef<Variable> outputs);
-autograd::Variable getSizeOf(const autograd::Variable& var, int64_t dim);
-
void recordSourceLocation(Node* n);
void setRecordSourceLocation(void (*v)(Node*));
@@ -259,15 +203,14 @@
template<typename F>
PreTraceInfo makePreTraceInfo(at::ArrayRef<Variable> inputs, F ctor) {
PreTraceInfo info;
- info.state = getTracingState(inputs);
- auto& graph = info.state->graph;
- auto state_lock = info.state->lock();
+ auto & state = getTracingState();
+ auto & graph = state->graph;
- Node *n = ctor(info.state, *graph);
+ Node *n = ctor(state, *graph);
recordSourceLocation(n);
- for (Variable input : inputs) {
- n->addInput(getValueTrace(info.state, input));
+ for (const Variable & input : inputs) {
+ n->addInput(getValueTrace(input));
}
// NB: Order matters. This must append after inputs but before outputs.
@@ -278,4 +221,6 @@
return info;
}
+autograd::Variable getSizeOf(const autograd::Variable& var, int64_t dim);
+
}}} // namespace torch::jit::tracer
diff --git a/torch/csrc/jit/tracer_state.cpp b/torch/csrc/jit/tracer_state.cpp
deleted file mode 100644
index 6f44562..0000000
--- a/torch/csrc/jit/tracer_state.cpp
+++ /dev/null
@@ -1,12 +0,0 @@
-#include "torch/csrc/jit/tracer_state.h"
-#include "torch/csrc/jit/ir.h"
-
-namespace torch { namespace jit { namespace tracer {
-
-TracingState::TracingState()
- : graph(new Graph())
- , active(true) {}
-
-TracingState::~TracingState() = default;
-
-}}} // namespace torch::jit::tracer
diff --git a/torch/csrc/jit/tracer_state.h b/torch/csrc/jit/tracer_state.h
deleted file mode 100644
index 887ad94..0000000
--- a/torch/csrc/jit/tracer_state.h
+++ /dev/null
@@ -1,59 +0,0 @@
-#pragma once
-
-#include "torch/csrc/autograd/edge.h"
-#include "torch/csrc/autograd/variable.h"
-
-#include <atomic>
-#include <cstdint>
-#include <list>
-#include <memory>
-#include <mutex>
-#include <string>
-#include <unordered_map>
-#include <utility>
-#include <vector>
-
-namespace torch { namespace jit {
-struct Graph;
-struct Value;
-}} // namespace torch::jit
-
-namespace torch { namespace jit { namespace tracer {
-
-// TracingState tracks the necessary state when we are tracing the execution of
-// autograd code; most importantly, it holds a reference to the actual IR
-// graph which we are recording the trace to.
-//
-// The liveness of a TracingState is expected to be a superset of the region
-// of code being traced; in particular, Variables do not keep a TracingState
-// live. Instead, they hold weak pointers to TracingState, to prevent leaks
-// from arising when a variable that participated in a trace outlives the
-// actual trace itself.
-
-struct TracingState : public std::enable_shared_from_this<TracingState> {
- TracingState();
- ~TracingState();
-
- std::shared_ptr<Graph> graph;
- std::mutex mutex;
- bool active;
-
- std::unique_lock<std::mutex> lock() {
- return std::unique_lock<std::mutex>(mutex);
- }
-};
-
-struct ValueTracingStateElem {
- std::weak_ptr<TracingState> state;
- // it's only valid to use this field if !state.exired()
- Value* trace = nullptr;
-
- void reset() {
- state.reset();
- trace = nullptr;
- }
-};
-
-using ValueTracingState = std::list<ValueTracingStateElem>;
-
-}}} // namespace torch::jit::tracer
diff --git a/torch/csrc/jit/variable_flags.cpp b/torch/csrc/jit/variable_flags.cpp
deleted file mode 100644
index 8ab565d..0000000
--- a/torch/csrc/jit/variable_flags.cpp
+++ /dev/null
@@ -1,19 +0,0 @@
-#include "torch/csrc/jit/variable_flags.h"
-
-#include "torch/csrc/autograd/variable.h"
-#include "torch/csrc/jit/tracer_state.h"
-
-using torch::autograd::Variable;
-
-namespace torch { namespace jit {
-
-// These definitions require Variable struct to be defined, so they can't be
-// in tracer_state.h
-VariableFlags VariableFlags::of(const Variable& var) {
- VariableFlags f;
- f.defined = var.defined();
- f.requires_grad = f.defined && var.requires_grad();
- return f;
-}
-
-}}
diff --git a/torch/csrc/jit/variable_flags.h b/torch/csrc/jit/variable_flags.h
deleted file mode 100644
index 43c3ef9..0000000
--- a/torch/csrc/jit/variable_flags.h
+++ /dev/null
@@ -1,22 +0,0 @@
-#pragma once
-#include <iostream>
-namespace torch { namespace autograd {
-struct Variable;
-}}
-
-namespace torch { namespace jit {
-
-struct VariableFlags {
- static VariableFlags of(const autograd::Variable& var);
-
- bool requires_grad;
- bool defined;
-};
-
-static inline std::ostream & operator<<(std::ostream & out, const VariableFlags& v) {
- return out
- << "(requires_grad=" << v.requires_grad
- << ", defined=" << v.defined << ")";
-}
-
-}}
diff --git a/torch/csrc/utils/auto_unique_ptr.h b/torch/csrc/utils/auto_unique_ptr.h
deleted file mode 100644
index d49a036..0000000
--- a/torch/csrc/utils/auto_unique_ptr.h
+++ /dev/null
@@ -1,21 +0,0 @@
-#pragma once
-
-#include <memory>
-
-namespace torch {
-
-// A unique_ptr that automatically constructs the object on first dereference.
-template<typename T>
-struct auto_unique_ptr : public std::unique_ptr<T> {
- T& operator*() {
- if (!this->get()) this->reset(new T());
- return *this->get();
- }
-
- T* operator->() {
- if (!this->get()) this->reset(new T());
- return this->get();
- }
-};
-
-} // namespace torch
diff --git a/torch/jit/__init__.py b/torch/jit/__init__.py
index fbf3fab..4b60541 100644
--- a/torch/jit/__init__.py
+++ b/torch/jit/__init__.py
@@ -24,21 +24,10 @@
_jit_script_compile = torch._C._jit_script_compile
BatchTensor = torch._C._jit.BatchTensor
-# This global variable is set when we are tracing a *forwards* computation.
-# It is intended to be a cheap way to test if tracing has occurred, before
-# doing the slower path using `get_tracing_state` (below.)
-_tracing = False
-
-
-def get_tracing_state(args):
- if not torch._C._is_tracing(args):
- return None
- return torch._C._get_tracing_state(args)
-
@contextlib.contextmanager
-def scope(scope_name, *vars):
- tracing_state = get_tracing_state(vars)
+def scope(scope_name):
+ tracing_state = torch._C._get_tracing_state()
if tracing_state:
tracing_state.push_scope(scope_name)
try:
@@ -98,18 +87,19 @@
self.inner = inner
def forward(self, *args):
- global _tracing
in_vars, in_desc = _flatten(args)
# NOTE: use full state, because we need it for BatchNorm export
# This differs from the compiler path, which doesn't support it at the moment.
module_state = list(_unique_state_dict(self, keep_vars=True).values())
trace, all_trace_inputs = torch._C._tracer_enter(in_vars + module_state)
- _tracing = True
- trace_inputs = _unflatten(all_trace_inputs[:len(in_vars)], in_desc)
- out = self.inner(*trace_inputs)
- out_vars, _ = _flatten(out)
- _tracing = False
- torch._C._tracer_exit(out_vars)
+ try:
+ trace_inputs = _unflatten(all_trace_inputs[:len(in_vars)], in_desc)
+ out = self.inner(*trace_inputs)
+ out_vars, _ = _flatten(out)
+ torch._C._tracer_exit(out_vars)
+ except Exception:
+ torch._C._tracer_abandon()
+ raise
return trace, out
@@ -289,13 +279,7 @@
if len(kwargs) != 0:
raise TypeError("got unexpected keyword arguments: {}".format(", ".join(kwargs.keys())))
- if isinstance(func, torch.nn.Module):
- orig = func
- else:
- # traced functions become a method on an Empty module
- orig = Module()
-
- module = TopLevelTracedModule(orig, **executor_options)
+ module = TopLevelTracedModule(func, **executor_options)
module._create_method_from_trace('forward', func, args)
return module
@@ -683,10 +667,17 @@
__frozen = False
def __init__(self, orig, id_set=None, optimize=True):
+ # XXX: orig can be a nn.Module or a function!
super(TracedModule, self).__init__(optimize=optimize)
if id_set is None:
id_set = set()
+ if not isinstance(orig, torch.nn.Module):
+ self._name = orig.__name__
+ orig = torch.nn.Module()
+ else:
+ self._name = 'TracedModule[' + type(orig).__name__ + ']'
+
def check_unique(param):
if param in id_set:
raise ValueError("TracedModules don't support parameter sharing between modules")
@@ -702,7 +693,6 @@
if buf is not None:
self._buffers[name] = buf
check_unique(buf)
- self._orig_class = type(orig)
if orig._backward_hooks or orig._forward_hooks or orig._forward_pre_hooks:
raise ValueError("Modules that have hooks assigned can't be compiled")
@@ -719,7 +709,7 @@
self.__frozen = True
def _get_name(self):
- return 'TracedModule[' + self._orig_class.__name__ + ']'
+ return self._name
def __setattr__(self, attr, value):
if not self.__frozen or hasattr(self, attr):
diff --git a/torch/nn/_functions/rnn.py b/torch/nn/_functions/rnn.py
index c7f5d10..1cccb77 100644
--- a/torch/nn/_functions/rnn.py
+++ b/torch/nn/_functions/rnn.py
@@ -310,7 +310,7 @@
# function gets reconstructed each and every time when RNN() is invoked
# and we don't want to pay the cost of decorator invocation
import torch
- if torch._C._jit_is_tracing(input):
+ if torch._C._get_tracing_state():
import torch.onnx.symbolic
sym = torch.onnx.symbolic.RNN_symbolic_builder(*args, **kwargs)
cell_type = args[0]
@@ -318,7 +318,7 @@
bound_symbolic = partial(torch.onnx.symbolic.rnn_trace_override_symbolic,
cell_type, func, sym)
- decorator = torch.onnx.symbolic_override_first_arg_based(bound_symbolic)
+ decorator = torch.onnx.symbolic_override(bound_symbolic)
func = decorator(func)
return func(input, *fargs, **fkwargs)
diff --git a/torch/nn/functional.py b/torch/nn/functional.py
index f1c4f75..17a7c09 100644
--- a/torch/nn/functional.py
+++ b/torch/nn/functional.py
@@ -1274,7 +1274,7 @@
import torch.onnx.symbolic
- @torch.onnx.symbolic_override_first_arg_based(torch.onnx.symbolic.instance_norm)
+ @torch.onnx.symbolic_override(torch.onnx.symbolic.instance_norm)
def _instance_norm(input, running_mean=None, running_var=None, weight=None,
bias=None, use_input_stats=None, momentum=None, eps=None):
# Repeat stored stats and affine transform params if necessary
diff --git a/torch/nn/modules/module.py b/torch/nn/modules/module.py
index 91bab5c..a00ff3d 100644
--- a/torch/nn/modules/module.py
+++ b/torch/nn/modules/module.py
@@ -450,16 +450,16 @@
def _slow_forward(self, *input, **kwargs):
input_vars = tuple(torch.autograd.function._iter_tensors(input))
- tracing_state = torch.jit.get_tracing_state(input_vars)
+ tracing_state = torch._C._get_tracing_state()
if not tracing_state:
return self.forward(*input, **kwargs)
if not hasattr(tracing_state, '_traced_module_stack'):
tracing_state._traced_module_stack = []
name = self._tracing_name(tracing_state)
if name:
- tracing_state.push_scope('%s[%s]' % (self.__class__.__name__, name))
+ tracing_state.push_scope('%s[%s]' % (self._get_name(), name))
else:
- tracing_state.push_scope(self.__class__.__name__)
+ tracing_state.push_scope(self._get_name())
tracing_state._traced_module_stack.append(self)
try:
result = self.forward(*input, **kwargs)
@@ -471,7 +471,7 @@
def __call__(self, *input, **kwargs):
for hook in self._forward_pre_hooks.values():
hook(self, input)
- if torch.jit._tracing:
+ if torch._C._get_tracing_state():
result = self._slow_forward(*input, **kwargs)
else:
result = self.forward(*input, **kwargs)
diff --git a/torch/nn/utils/rnn.py b/torch/nn/utils/rnn.py
index 4c1eee0..d91797f 100644
--- a/torch/nn/utils/rnn.py
+++ b/torch/nn/utils/rnn.py
@@ -168,8 +168,7 @@
return tuple(o for o in outputs)
-pack_padded_sequence = torch.onnx.symbolic_override_first_arg_based(
- _symbolic_pack_padded_sequence)(pack_padded_sequence)
+pack_padded_sequence = torch.onnx.symbolic_override(_symbolic_pack_padded_sequence)(pack_padded_sequence)
def pad_packed_sequence(sequence, batch_first=False, padding_value=0.0, total_length=None):
@@ -264,8 +263,7 @@
return data, lengths
-pad_packed_sequence = torch.onnx.symbolic_override_packed_sequence_based(
- _symbolic_pad_packed_sequence)(pad_packed_sequence)
+pad_packed_sequence = torch.onnx.symbolic_override(_symbolic_pad_packed_sequence)(pad_packed_sequence)
def pad_sequence(sequences, batch_first=False, padding_value=0):
diff --git a/torch/onnx/__init__.py b/torch/onnx/__init__.py
index 1807b71..0514343 100644
--- a/torch/onnx/__init__.py
+++ b/torch/onnx/__init__.py
@@ -56,48 +56,6 @@
return utils._run_symbolic_method(*args, **kwargs)
-def _symbolic_override_wrapper_maker(symbolic_fn, might_trace, fn):
-
- def wrapper(*args, **kwargs):
- import torch
- import torch.jit
- from torch.autograd import Function, function
-
- # fast pass
- if not might_trace(args):
- return fn(*args, **kwargs)
-
- flat_args = tuple(function._iter_tensors_permissive(args))
- flat_args_only_tensors = tuple(t for t in flat_args if isinstance(t, torch.Tensor))
- if not any(map(torch._C._jit_is_tracing, flat_args_only_tensors)):
- return fn(*args, **kwargs)
-
- tstate = torch._C._get_tracing_state(flat_args_only_tensors)
-
- arg_values = [torch._C._get_value_trace(tstate, x) if isinstance(x, torch.Tensor) else x for x in flat_args]
-
- # This must come after the calls to get_value_trace, lest we
- # lose information due to in-place operations.
- output_vars = fn(*args, **kwargs)
-
- symbolic_args = function._unflatten(arg_values, args)
- output_vals = symbolic_fn(tstate.graph(), *symbolic_args, **kwargs)
-
- for var, val in zip(
- function._iter_tensors(output_vars),
- function._iter_jit_values(output_vals)):
- val.inferTypeFrom(var.data)
- torch._C._set_value_trace(tstate, var, val)
-
- return output_vars
-
- # fn might be autograd.Function too, in this case wrapping doesn't work
- if isinstance(fn, types.FunctionType):
- wrapper = functools.wraps(fn)(wrapper)
-
- return wrapper
-
-
def symbolic_override(symbolic_fn):
r"""
Decorator to override ONNX export of the a function with specified subgraph.
@@ -123,47 +81,36 @@
return x + y[0] + y[1]
```
"""
-
- return functools.partial(_symbolic_override_wrapper_maker, symbolic_fn, lambda x: True)
-
-
-def symbolic_override_first_arg_based(symbolic_fn):
- r"""
- Decorator to override ONNX export of the a function with specified subgraph.
-
- Equivalent to :func:`symbolic_override` but checks only the first argument
- of the function to figure out whether the tracing is on. Thus the first arg
- needs to be a Tensor.
- """
-
- def might_trace(args):
+ def decorator(fn):
import torch
- first_arg = args[0]
- if not isinstance(first_arg, torch.Tensor):
- raise ValueError('First argument of {} is expected to be a tensor, '
- 'but got an object of type {}'
- .format(symbolic_fn.__name__, type(first_arg)))
- return torch._C._jit_is_tracing(first_arg)
+ from torch.autograd import function
- return functools.partial(_symbolic_override_wrapper_maker, symbolic_fn, might_trace)
+ def wrapper(*args, **kwargs):
+ tstate = torch._C._get_tracing_state()
+ if not tstate:
+ return fn(*args, **kwargs)
+ flat_args = tuple(function._iter_tensors_permissive(args))
+ arg_values = [torch._C._get_value_trace(x) if isinstance(x, torch.Tensor) else x for x in flat_args]
-def symbolic_override_packed_sequence_based(symbolic_fn):
- r"""
- Decorator to override ONNX export of the a function with specified subgraph.
+ # This must come after the calls to get_value_trace, lest we
+ # lose information due to in-place operations.
+ output_vars = fn(*args, **kwargs)
- Equivalent to :func:`symbolic_override` but checks only the first argument
- of the function to figure out whether the tracing is on. Thus the first arg
- needs to be a Tensor.
- """
+ symbolic_args = function._unflatten(arg_values, args)
+ output_vals = symbolic_fn(tstate.graph(), *symbolic_args, **kwargs)
- def might_trace(args):
- import torch
- first_arg = args[0]
- if not isinstance(first_arg, torch.nn.utils.rnn.PackedSequence):
- raise ValueError('pad_packed_sequence expects sequence to be a '
- 'PackedSequence, but got an object of type {}'
- .format(type(first_arg)))
- return torch._C._jit_is_tracing(first_arg[0])
+ for var, val in zip(
+ function._iter_tensors(output_vars),
+ function._iter_jit_values(output_vals)):
+ val.inferTypeFrom(var.data)
+ torch._C._set_value_trace(var, val)
- return functools.partial(_symbolic_override_wrapper_maker, symbolic_fn, might_trace)
+ return output_vars
+
+ # fn might be autograd.Function too, in this case wrapping doesn't work
+ if isinstance(fn, types.FunctionType):
+ wrapper = functools.wraps(fn)(wrapper)
+
+ return wrapper
+ return decorator
diff --git a/torch/onnx/symbolic.py b/torch/onnx/symbolic.py
index f547750..a88739c 100644
--- a/torch/onnx/symbolic.py
+++ b/torch/onnx/symbolic.py
@@ -745,6 +745,15 @@
globals()[name] = partial(_cast_func_template, v)
+def zeros_like(g, input):
+ return g.op("Sub", input, input).setType(input.type().contiguous())
+
+
+def full_like(g, input, fill_value):
+ # TODO: a more efficient implementation (ConstantFill?)
+ return add(g, zeros_like(g, input), fill_value, alpha=torch.tensor(1))
+
+
def slice(g, self, dim, start, end, step):
if step != 1:
_unimplemented("slice", "step!=1 is currently not supported")