Add support for Constant nodes in AutogradClosureFactory
diff --git a/test/test_jit.py b/test/test_jit.py
index ae828e2..2e42591 100644
--- a/test/test_jit.py
+++ b/test/test_jit.py
@@ -161,7 +161,7 @@
trace = torch._C._tracer_enter((x, y))
z = torch.sigmoid(x * (x + y))
- w = torch.abs(x * x * x + y)
+ w = torch.abs(x * x * x + y) + Variable(torch.ones(1))
torch._C._tracer_exit((z, w))
torch._C._jit_pass_lint(trace)
diff --git a/torch/csrc/autograd/function.h b/torch/csrc/autograd/function.h
index c9993bd..54aabba 100644
--- a/torch/csrc/autograd/function.h
+++ b/torch/csrc/autograd/function.h
@@ -127,6 +127,9 @@
// operation in some way that defines the graph structure AND the backward function
// is traceable. In particular, parametrization MUST NOT depend on the data
// of any Variable.
+ // TODO: it might be possible to handle cases where backward is non-traceable
+ // but state passing could be considered transparent. This will probably depend
+ // on saved_variable_list being mutable.
// NOTE: this value matters only if is_traceable() returns false.
virtual inline bool passes_state_transparently() { return false; };
diff --git a/torch/csrc/autograd/functions/jit_closure.cpp b/torch/csrc/autograd/functions/jit_closure.cpp
index 0431226..ce0bf75 100644
--- a/torch/csrc/autograd/functions/jit_closure.cpp
+++ b/torch/csrc/autograd/functions/jit_closure.cpp
@@ -40,7 +40,7 @@
// Used for inputs of previous previous stages
struct PrevStageInput : public Replicate {};
-// Used for inputs to the closure and
+// Used for inputs to the closure
struct InputPlaceholder : public Placeholder {};
// Used to mark places that will have to apply Evals from previous stages
struct EvalPlaceholder : public Placeholder {};
@@ -111,6 +111,7 @@
if (scalar_it == scalar_args.end())
throw std::runtime_error("expected too many scalar args");
obj = (scalar_it++)->get();
+ Py_INCREF(obj);
} else if (arg_type == 't') {
if (var_it == inputs.end())
throw std::runtime_error("expected too many inputs");
@@ -118,7 +119,6 @@
} else {
throw std::runtime_error("unexpected calling convention");
}
- Py_INCREF(obj);
PyTuple_SET_ITEM(py_inputs.get(), input_nr++, obj);
}
@@ -281,7 +281,8 @@
struct StageClosure {
using node_fn_map_type = std::unordered_map<Node*, std::shared_ptr<Function>>;
- StageClosure(Graph *graph, const CrossStageStateDesc& xstate, std::size_t stage) {
+ StageClosure(Graph *graph, const CrossStageStateDesc& xstate, std::size_t stage)
+ : const_factory(std::make_shared<ConstantFactory>()) {
node_fn_map_type node_map;
node_fn_map_type prev_stage_input_map;
@@ -349,6 +350,8 @@
captured_variables.emplace_back(fn.get(), 0, captured_node->unique());
}
}
+
+ roots.emplace_back(const_factory, 0);
}
// Returns a function implementing functionality of a given node,
@@ -391,10 +394,10 @@
fn->num_inputs = 1;
return fn;
IR_ELSEIF(Constant)
- throw std::runtime_error("constants not supported");
- //fn = std::make_shared<torch::autograd::WrapConstant>(value->t(kValue));
- //const_factory->next_functions.emplace_back(fn, 0);
- //fn->num_inputs = 1;
+ auto fn = std::make_shared<torch::autograd::WrapConstant>(value->t(kValue));
+ const_factory->next_functions.emplace_back(fn, 0);
+ fn->num_inputs = 1;
+ return fn;
IR_ELSEIF(Chunk)
return std::make_shared<Chunk>(value->i(kNumChunks), value->i(kDim));
IR_ELSE()
@@ -463,12 +466,13 @@
}
}
- // Roots for a call to the engine. The list begins with nodes corresponding to inputs
- // to apply, and PrevStageInput nodes afterwards
+ // Roots for a call to the engine. The list contains function in this order:
+ // [ apply input roots | prev stage input roots | constant factory ]
function_list roots;
// Output node
std::shared_ptr<Function> output;
+ std::shared_ptr<ConstantFactory> const_factory;
// These will be used by each instantiation of AutogradClosure to register hooks.
std::vector<int> prev_stage_variables; // unique
@@ -568,6 +572,7 @@
});
for (auto unique : desc->stages[stage].prev_stage_variables)
input_leaves.emplace_back(std::make_shared<Variable>(saved_vars.at(unique), true, false));
+ input_leaves.emplace_back(nullptr); // for ConstantFactory
auto& engine = python::PythonEngine::getDefaultEngine();
engine.execute(stage_closure.roots, input_leaves, true, pre_callbacks, post_callbacks);
diff --git a/torch/csrc/autograd/functions/special.h b/torch/csrc/autograd/functions/special.h
index d72acc9..f623c9a 100644
--- a/torch/csrc/autograd/functions/special.h
+++ b/torch/csrc/autograd/functions/special.h
@@ -15,7 +15,11 @@
EvalOutput(const edge_type& next_edge)
: next_edge(next_edge) {
num_inputs = 1;
- is_executable = next_edge.first->is_executable;
+ // It would be nice if we could inherit this from the function of next_edge,
+ // but we want to always run this node to capture the output. This might
+ // confuse some of the functions causing them to do unnecessary work.
+ // TODO: it should be possible to improve this once we get rid of NULL Variables
+ is_executable = true;
}
virtual variable_list apply(const variable_list& inputs) override {
diff --git a/torch/csrc/autograd/python_engine.cpp b/torch/csrc/autograd/python_engine.cpp
index 1268be4..bba6744 100644
--- a/torch/csrc/autograd/python_engine.cpp
+++ b/torch/csrc/autograd/python_engine.cpp
@@ -114,56 +114,6 @@
}
}
-PyObject *THPEngine_run_forward(THPEngine *self, PyObject *args, PyObject *kwargs)
-{
- HANDLE_TH_ERRORS
- //PyObject *pyclosure = NULL;
- //PyObject *inputs = NULL;
- //const char *accepted_kwargs[] = {"closure", "inputs", NULL};
- //if (!PyArg_ParseTupleAndKeywords(args, kwargs, "OO", (char**)accepted_kwargs,
- //&pyclosure, &inputs))
- //return NULL;
-
- //THPUtils_assert(THPWrapper_check(pyclosure), "closure should be a PtrWrapper object");
- //THPUtils_assert(PyTuple_Check(inputs), "inputs should be a tuple");
-
- //variable_list var_inputs;
- //auto num_inputs = PyTuple_GET_SIZE(inputs);
- //var_inputs.reserve(1 + num_inputs);
- //var_inputs.emplace_back(nullptr); // For ConstantFactory
- //for (int i = 0; i < num_inputs; ++i) {
- //PyObject *input = PyTuple_GET_ITEM(inputs, i);
- //THPUtils_assert(THPVariable_Check(input), "%d input is not a Variable", i);
- //var_inputs.emplace_back(((THPVariable*)input)->cdata);
- //}
-
- //AutogradClosure *closure = reinterpret_cast<AutogradClosure*>(THPWrapper_get(pyclosure));
-
- //variable_list outputs;
- //Engine::callback_map callbacks;
- //callbacks.emplace(closure->output.get(), [&outputs](Function* _unused, variable_list& inputs) -> bool {
- //outputs = inputs;
- //return false;
- //});
-
- //try {
- //AutoNoGIL no_gil;
- //engine.execute(closure->roots, var_inputs, true, callbacks);
- //} catch (python_error &e) {
- //e.restore();
- //return nullptr;
- //}
-
- //int num_outputs = outputs.size();
- //THPObjectPtr pyoutputs { PyTuple_New(num_outputs) };
- //for (int i = 0; i < num_outputs; ++i) {
- //PyTuple_SET_ITEM(pyoutputs.get(), i, THPVariable_Wrap(outputs[i]));
- //}
- //return pyoutputs.release();
- Py_RETURN_NONE;
- END_HANDLE_TH_ERRORS
-}
-
// Implementation of torch._C._EngineBase.run_backward
PyObject *THPEngine_run_backward(THPEngine *self, PyObject *args, PyObject *kwargs)
{
@@ -300,7 +250,6 @@
static struct PyMethodDef THPEngine_methods[] = {
{(char*)"run_backward", (PyCFunction)THPEngine_run_backward, METH_VARARGS | METH_KEYWORDS, NULL},
- {(char*)"run_forward", (PyCFunction)THPEngine_run_forward, METH_VARARGS | METH_KEYWORDS, NULL},
{(char*)"queue_callback", (PyCFunction)THPEngine_queue_callback, METH_O, NULL},
{NULL}
};
diff --git a/torch/csrc/jit/python_ir.cpp b/torch/csrc/jit/python_ir.cpp
index 61e1614..4f64588 100644
--- a/torch/csrc/jit/python_ir.cpp
+++ b/torch/csrc/jit/python_ir.cpp
@@ -1,37 +1,9 @@
-#include <pybind11/pybind11.h>
-// DO NOT REMOVE, this enables std containers to be recognized
-// with pybind11, removing the include disables the support
-#include <pybind11/stl.h>
-
-namespace py = pybind11;
+#include "torch/csrc/utils/pybind.h"
#include <iostream>
#include <sstream>
#include "torch/csrc/jit/ir.h"
-struct THPGenerator;
-#include "torch/csrc/Module.h"
-#include "torch/csrc/DynamicTypes.h"
#include "torch/csrc/jit/python_tracer.h"
-// handle Tensor <-> at::Tensor conversions
-namespace pybind11 { namespace detail {
- template <> struct type_caster<at::Tensor> {
- public:
- PYBIND11_TYPE_CASTER(at::Tensor, _("at::Tensor"));
-
- bool load(handle src, bool) {
- /* Extract PyObject from handle */
- PyObject *source = src.ptr();
- if(!THPModule_isTensor(source))
- return false;
- value = torch::createTensor(source);
- return true;
- }
- static handle cast(at::Tensor src, return_value_policy /* policy */, handle /* parent */) {
- return handle(torch::createPyObject(src));
- }
- };
-}} // namespace pybind11::detail
-
namespace torch { namespace jit {
void initPythonIRBindings(PyObject * module_) {
diff --git a/torch/csrc/jit/tracer.h b/torch/csrc/jit/tracer.h
index 983c4fc..7b412df 100644
--- a/torch/csrc/jit/tracer.h
+++ b/torch/csrc/jit/tracer.h
@@ -128,6 +128,7 @@
if (mustExist) throw std::runtime_error("untraced variable");
Node *constant = state->graph->appendNode(state->graph->createConstant(var->data));
+ constant->inferTypeFrom(var->data);
setValueTrace(state, var, constant);
return constant;
}
diff --git a/torch/jit.py b/torch/jit.py
index acfc34f..296d475 100644
--- a/torch/jit.py
+++ b/torch/jit.py
@@ -128,7 +128,6 @@
def run_trace(self, trace_inputs):
if self.saved_closure is None:
- print(self.saved_trace)
self.saved_closure = torch._C._jit_createAutogradClosure(
self.saved_trace)
with _time("run_trace", self.time):