add @torch.jit.script, @torch.jit.compile, torch.jit.CompilationUnit(str) (#5367)
* torch.jit.trace annotation now creates a GraphExecutor
The other torch.jit.trace, which was used for testing purposes and for onnx to get the trace graph, is now called torch.jit. torch.jit.get_trace_graph.
* @script annotation, and compilation unit for strings
diff --git a/test/test_jit.py b/test/test_jit.py
index eceeb1e..257109f 100644
--- a/test/test_jit.py
+++ b/test/test_jit.py
@@ -122,7 +122,7 @@
def f(x, y):
return torch.sigmoid(torch.tanh(x * (x + y)))
- trace, z = torch.jit.trace(f, (x, y), nderivs=0)
+ trace, z = torch.jit.get_trace_graph(f, (x, y), nderivs=0)
self.assertExpectedTrace(trace)
# matmul is currently implemented as a native function, which
@@ -134,7 +134,7 @@
x = Variable(torch.Tensor([[0.4]]), requires_grad=True)
y = Variable(torch.Tensor([[0.7]]), requires_grad=True)
- trace, z = torch.jit.trace(lambda x, y: x.matmul(y), (x, y), nderivs=0)
+ trace, z = torch.jit.get_trace_graph(lambda x, y: x.matmul(y), (x, y), nderivs=0)
torch._C._jit_pass_lint(trace)
torch._C._jit_pass_dce(trace)
self.assertExpectedTrace(trace)
@@ -203,7 +203,7 @@
out = torch.sigmoid(out)
return out
- trace, z = torch.jit.trace(f, (x, y), nderivs=0)
+ trace, z = torch.jit.get_trace_graph(f, (x, y), nderivs=0)
self.assertExpectedTrace(trace)
def test_scopes_intermediate_node(self):
@@ -214,7 +214,7 @@
net = Net()
t = Variable(torch.ones(2), requires_grad=True)
- trace, _ = torch.jit.trace(net, (t, ))
+ trace, _ = torch.jit.get_trace_graph(net, (t, ))
torch.onnx._optimize_trace(trace, False)
self.assertExpectedTrace(trace)
@@ -240,7 +240,7 @@
t = Variable(torch.ones(1, 3, 227, 227), requires_grad=True)
with torch.onnx.set_training(model, False):
- trace, _ = torch.jit.trace(model, (t, ))
+ trace, _ = torch.jit.get_trace_graph(model, (t, ))
torch.onnx._optimize_trace(trace, False)
@@ -254,7 +254,7 @@
cx = Variable(torch.randn(3, 20).float().cuda())
module = nn.LSTMCell(10, 20).float().cuda() # Just to allocate weights with correct sizes
- trace, _ = torch.jit.trace(LSTMCell, (input, (hx, cx)) + tuple(module.parameters()))
+ trace, _ = torch.jit.get_trace_graph(LSTMCell, (input, (hx, cx)) + tuple(module.parameters()))
torch._C._jit_pass_lint(trace)
torch._C._jit_pass_dce(trace)
torch._C._jit_pass_lint(trace)
@@ -316,7 +316,7 @@
def Foo(hx, cx):
return torch.cat((hx + cx, hx * cx))
- trace, _ = torch.jit.trace(Foo, (hx, cx))
+ trace, _ = torch.jit.get_trace_graph(Foo, (hx, cx))
torch._C._jit_pass_lint(trace)
torch._C._jit_pass_fuse(trace)
self.assertExpectedTrace(trace)
@@ -329,7 +329,7 @@
return z1 * z2
x = Variable(torch.randn(4, 4).float().cuda())
y = Variable(torch.randn(4, 4).float().cuda())
- trace, _ = torch.jit.trace(f, (x, y), nderivs=0)
+ trace, _ = torch.jit.get_trace_graph(f, (x, y), nderivs=0)
torch._C._jit_pass_lint(trace)
torch._C._jit_pass_dce(trace)
self.assertExpectedTrace(trace, 'raw')
@@ -487,7 +487,7 @@
return a * grad_a
x = Variable(torch.randn(10, 10), requires_grad=True)
- trace, out = torch.jit.trace(MyFn.apply, x, nderivs=1)
+ trace, out = torch.jit.get_trace_graph(MyFn.apply, x, nderivs=1)
out.sum().backward()
torch._C._jit_pass_dce(trace)
self.assertExpectedTrace(trace)
@@ -901,7 +901,7 @@
def doit(x, y):
return torch.sigmoid(torch.tanh(x * (x + y)))
- traced, _ = torch.jit.trace(doit, (x, y))
+ traced, _ = torch.jit.get_trace_graph(doit, (x, y))
g = torch._C._jit_get_graph(traced)
g2 = torch._C.Graph()
g_to_g2 = {}
@@ -931,12 +931,12 @@
def test_batchnorm(self):
x = Variable(torch.randn(2, 2, 2, 2).fill_(1.0), requires_grad=True)
- trace, _ = torch.jit.trace(nn.BatchNorm2d(2), x)
+ trace, _ = torch.jit.get_trace_graph(nn.BatchNorm2d(2), x)
self.assertExpectedTrace(trace)
def test_dropout(self):
x = Variable(torch.randn(2, 2).fill_(1.0), requires_grad=True)
- trace, _ = torch.jit.trace(nn.Dropout(0.6), x)
+ trace, _ = torch.jit.get_trace_graph(nn.Dropout(0.6), x)
self.assertExpectedTrace(trace)
def test_batchnorm_run_twice(self):
@@ -957,7 +957,7 @@
def test_conv(self):
x = Variable(torch.randn(20, 16, 50, 40).fill_(1.0), requires_grad=True)
- trace, _ = torch.jit.trace(nn.Conv2d(16, 13, 3, bias=False), x)
+ trace, _ = torch.jit.get_trace_graph(nn.Conv2d(16, 13, 3, bias=False), x)
self.assertExpectedTrace(trace)
def test_reuse_function(self):
@@ -1145,7 +1145,7 @@
def test_alexnet(self):
return
x = Variable(torch.randn(10, 3, 224, 224).fill_(1.0), requires_grad=True)
- trace, _ = torch.jit.trace(torchvision.models.AlexNet(), x)
+ trace, _ = torch.jit.get_trace_graph(torchvision.models.AlexNet(), x)
self.assertExpectedTrace(trace)
# NB: Purposely NOT testing protobuf export here
@@ -1183,14 +1183,14 @@
out.copy_(x)
return out
- trace, z = torch.jit.trace(f, (x, ), nderivs=0)
+ trace, z = torch.jit.get_trace_graph(f, (x, ), nderivs=0)
torch._C._jit_pass_lint(trace)
torch._C._jit_pass_dce(trace)
self.assertExpectedTrace(trace)
def test_index_trace(self):
x = Variable(torch.randn(4, 4), requires_grad=True)
- trace, z = torch.jit.trace(lambda x: x[0], (x, ), nderivs=1)
+ trace, z = torch.jit.get_trace_graph(lambda x: x[0], (x, ), nderivs=1)
z.sum().backward()
torch._C._jit_pass_lint(trace)
torch._C._jit_pass_dce(trace)
@@ -1217,13 +1217,13 @@
return x * self.a + self.b
m = MyModule()
- trace, _ = torch.jit.trace(m, (Variable(torch.randn(2, 2)),), nderivs=0)
+ trace, _ = torch.jit.get_trace_graph(m, (Variable(torch.randn(2, 2)),), nderivs=0)
self.assertEqual(len(list(trace.graph().inputs())), 2)
self.assertExpected(str(trace))
def test_nested_inplace(self):
x = Variable(torch.randn(2, 2))
- trace, _ = torch.jit.trace(lambda x: F.threshold(x, 0, 0, inplace=True), (x,), nderivs=0)
+ trace, _ = torch.jit.get_trace_graph(lambda x: F.threshold(x, 0, 0, inplace=True), (x,), nderivs=0)
self.assertExpectedTrace(trace)
def checkGraphExecutor(self, func, reference_tensors, input_tensors=None, optimize=True, drop=None):
@@ -1238,12 +1238,6 @@
if input_tensors is None:
input_tensors = reference_tensors
- def wrapped(*inputs):
- res = func(*inputs)
- if isinstance(res, torch.Tensor):
- return (res,)
- return res
-
nograd_inputs = [Variable(t) for t in reference_tensors]
recording_inputs = [Variable(t, requires_grad=True)
for t in reference_tensors]
@@ -1252,13 +1246,13 @@
# test no gradients case
- outputs = wrapped(*nograd_inputs)
+ outputs = func(*nograd_inputs)
outputs_ge = ge(*nograd_inputs)
self.assertEqual(outputs, outputs_ge)
# test single grad case
- outputs = wrapped(*recording_inputs)
+ outputs = func(*recording_inputs)
grads = torch.autograd.grad(allSum(outputs), recording_inputs)
outputs_ge = ge(*recording_inputs)
@@ -1268,7 +1262,7 @@
# test the grad grad case
- outputs = wrapped(*recording_inputs)
+ outputs = func(*recording_inputs)
l1 = allSum(outputs)
grads = torch.autograd.grad(l1, recording_inputs, create_graph=True)
l2 = (allSum(grads) * l1)
@@ -1352,13 +1346,10 @@
def checkScript(self, script, inputs, outputs, optimize, name='func'):
if isinstance(script, str):
- cu = torch.jit._jit_script_compile(script)
+ cu = torch.jit.CompilationUnit(script, optimize)
+ ge = getattr(cu, name)
else:
- ast = torch.jit.frontend.get_jit_ast(script)
- cu = torch._C.CompilationUnit()
- cu.define_function(ast)
- graph = cu.get_graph(name)
- ge = torch._C.GraphExecutor(graph, optimize)
+ ge = torch.jit.script(script)
with capture_stdout() as captured:
outputs_ge = ge(*inputs)
self.assertEqual(outputs, outputs_ge)
@@ -1375,7 +1366,7 @@
a = Variable(torch.rand(1), requires_grad=True)
b = Variable(torch.rand(1), requires_grad=True)
outputs = a + b + a
- self.checkScript(script, [a, b], [outputs], False)
+ self.checkScript(script, [a, b], outputs, False)
def test_script_mul(self):
script = '''
@@ -1386,7 +1377,7 @@
a = Variable(torch.rand(1), requires_grad=True)
b = Variable(torch.rand(1), requires_grad=True)
outputs = a * b
- self.checkScript(script, [a, b], [outputs], False)
+ self.checkScript(script, [a, b], outputs, False)
def test_script_triple(self):
script = '''
@@ -1395,7 +1386,7 @@
'''
x = Variable(torch.rand(1).float(), requires_grad=True)
outputs = 3 * x
- self.checkScript(script, [x], [outputs], False)
+ self.checkScript(script, [x], outputs, False)
def test_script_slice(self):
script = '''
@@ -1404,7 +1395,7 @@
'''
x = Variable(torch.rand(10).float(), requires_grad=True)
outputs = x[:5]
- self.checkScript(script, [x], [outputs], False)
+ self.checkScript(script, [x], outputs, False)
def test_script_gather(self):
script = '''
@@ -1413,7 +1404,7 @@
'''
x = Variable(torch.rand(10).float(), requires_grad=True)
outputs = x[0]
- self.checkScript(script, [x], [outputs], False)
+ self.checkScript(script, [x], outputs, False)
def test_script_func_call(self):
script = '''
@@ -1431,7 +1422,7 @@
x = Variable(torch.rand(3).float(), requires_grad=True)
y = Variable(torch.rand(3).float(), requires_grad=True)
outputs = alpha * x + beta * y
- self.checkScript(script, [alpha, beta, x, y], [outputs], False)
+ self.checkScript(script, [alpha, beta, x, y], outputs, False)
@unittest.skip("RuntimeError: VariableType::ID() not implemented")
def test_script_cast(self):
@@ -1473,7 +1464,7 @@
'''
inputs = self._make_scalar_vars([1, 1, 10], np.int32)
outputs = self._make_scalar_vars([20], np.int32)
- self.checkScript(script, inputs, outputs, False, 'test_while')
+ self.checkScript(script, inputs, outputs[0], False, 'test_while')
def test_script_fibb(self):
script = '''
@@ -1515,7 +1506,7 @@
'''
inputs = self._make_scalar_vars([1, -1], np.int32)
outputs = self._make_scalar_vars([7], np.int32)
- self.checkScript(script, inputs, outputs, False, 'test_if')
+ self.checkScript(script, inputs, outputs[0], False, 'test_if')
def test_script_if_noelse(self):
script = '''
@@ -1526,11 +1517,11 @@
'''
inputs = self._make_scalar_vars([-1, 1], np.int32)
outputs = self._make_scalar_vars([0], np.int32)
- self.checkScript(script, inputs, outputs, False, 'test_if_noelse')
+ self.checkScript(script, inputs, outputs[0], False, 'test_if_noelse')
def test_script_while_nonexistant_value(self):
with self.assertRaisesRegex(RuntimeError, "undefined value x"):
- torch.jit._jit_script_compile('''
+ torch.jit.CompilationUnit('''
def test_while(a, b) -> (c):
while a < 10:
a = a + x
@@ -1540,7 +1531,7 @@
def test_script_while_nonexistant_cond_value(self):
with self.assertRaisesRegex(RuntimeError, "undefined value x"):
- torch.jit._jit_script_compile('''
+ torch.jit.CompilationUnit('''
def test_while(a, b) -> (c):
while a < x:
a = a + 1
@@ -1558,7 +1549,7 @@
'''
inputs = self._make_scalar_vars([42, 1337], np.int32)
outputs = self._make_scalar_vars([1379], np.int32)
- self.checkScript(script, inputs, outputs, False, 'test_while')
+ self.checkScript(script, inputs, outputs[0], False, 'test_while')
def test_script_while_nest_if(self):
script = '''
@@ -1575,7 +1566,7 @@
'''
inputs = self._make_scalar_vars([-1234, 4321], np.int32)
outputs = self._make_scalar_vars([-5564], np.int32)
- self.checkScript(script, inputs, outputs, False, 'test_while_if')
+ self.checkScript(script, inputs, outputs[0], False, 'test_while_if')
def test_script_if_nest_while(self):
script = '''
@@ -1588,10 +1579,10 @@
'''
inputs = self._make_scalar_vars([4321, 1234], np.int32)
outputs = self._make_scalar_vars([-4321], np.int32)
- self.checkScript(script, inputs, outputs, False, 'test_if_while')
+ self.checkScript(script, inputs, outputs[0], False, 'test_if_while')
def test_script_ternary(self):
- cu = torch.jit._jit_script_compile('''
+ cu = torch.jit.CompilationUnit('''
def test_ternary_control(a, b) -> (c):
c = 3
if a > 3:
@@ -1599,14 +1590,14 @@
else:
c = b
''')
- cu2 = torch.jit._jit_script_compile('''
+ cu2 = torch.jit.CompilationUnit('''
def test_ternary(a, b) -> (c):
c = 3
c = a + b if a > 3 else b
''')
self.assertEqual(
- str(cu.get_graph('test_ternary_control')),
- str(cu2.get_graph('test_ternary')),
+ str(cu.cu.get_graph('test_ternary_control')),
+ str(cu2.cu.get_graph('test_ternary')),
)
def test_python_frontend_run(self):
@@ -1621,8 +1612,29 @@
with capture_stdout():
expected_out = func(x, y)
expected_out = (x + y).sigmoid().pow(2)
- self.checkScript(func, [x, y], [expected_out], False)
+ self.checkScript(func, [x, y], expected_out, False)
+ def test_trace_annotation(self):
+ @torch.jit.trace(Variable(torch.rand(1)))
+ def foo(a):
+ return a + a + a
+ s = Variable(torch.rand(2))
+ self.assertEqual(s + s + s, foo(s))
+
+ def test_script_cu(self):
+ cu = torch.jit.CompilationUnit('''
+ def foo(a) -> (b):
+ b = a
+ ''')
+ a = Variable(torch.rand(1))
+ self.assertEqual(a, cu.foo(a))
+
+ def test_script_annotation(self):
+ @torch.jit.script
+ def foo(a):
+ return a + a + a
+ s = Variable(torch.rand(2))
+ self.assertEqual(s + s + s, foo(s))
if __name__ == '__main__':
run_tests()
diff --git a/torch/csrc/jit/init.cpp b/torch/csrc/jit/init.cpp
index c8600e0..36b5483 100644
--- a/torch/csrc/jit/init.cpp
+++ b/torch/csrc/jit/init.cpp
@@ -37,12 +37,6 @@
return F(state->graph);
}
-GraphExecutor createExecutorByGraph(
- std::shared_ptr<Graph> graph,
- bool optimize) {
- return GraphExecutor(std::move(graph), optimize);
-}
-
// This is a temporary constructor so that we can write python tests of
// the executor. It does not have most of the functionality of CompiledFunction
// such as being able to hold parameters...
@@ -65,7 +59,7 @@
tracer::exit(outputs);
auto graph = enter_info.first->graph;
EliminateDeadCode(graph);
- return createExecutorByGraph(std::move(graph), optimize);
+ return GraphExecutor(std::move(graph), optimize);
}
// we cannot use the default py:cast<autograd::Variable> because it currently
@@ -119,17 +113,25 @@
py::arg("optimize") = true)
.def(
py::init([](std::shared_ptr<Graph> graph, bool optimize) {
- return createExecutorByGraph(std::move(graph), optimize);
+ return GraphExecutor(std::move(graph), optimize);
}),
py::arg("graph"),
py::arg("optimize") = true)
- .def("__call__", [](GraphExecutor& ge, py::args args) {
+ .def("__call__", [](GraphExecutor& ge, py::args args) -> py::object {
auto inputs = createVariableTensorList(args);
auto outputs = ge.run(std::move(inputs));
// if we don't tell pybind these are variables it chokes on the
// conversion.
// TODO: fix conversions to be sane and make sure this works.
- return std::vector<autograd::Variable>(outputs.begin(), outputs.end());
+ if(outputs.size() == 1) {
+ return py::cast(static_cast<autograd::Variable&>(outputs[0]));
+ } else {
+ py::tuple tuple(outputs.size());
+ for(size_t i = 0; i < outputs.size(); i++) {
+ tuple[i] = py::cast(static_cast<autograd::Variable&>(outputs[i]));
+ }
+ return tuple;
+ }
});
initPythonIRBindings(module);
initPythonTracerBindings(module);
diff --git a/torch/csrc/jit/script/compiler.cpp b/torch/csrc/jit/script/compiler.cpp
index 536ef67..9ae9098 100644
--- a/torch/csrc/jit/script/compiler.cpp
+++ b/torch/csrc/jit/script/compiler.cpp
@@ -728,10 +728,11 @@
CompilationUnit::~CompilationUnit() {}
-std::unique_ptr<CompilationUnit> jitScriptCompile(const std::string& script) {
- std::unique_ptr<CompilationUnit> cu{new CompilationUnit};
- cu->define(script);
- return cu;
+std::shared_ptr<Graph> jitScriptCompile(Def def) {
+ FunctionTable empty;
+ FunctionDefinition fd(def);
+ to_ir(fd, empty);
+ return fd.graph;
}
} // namespace script
diff --git a/torch/csrc/jit/script/compiler.h b/torch/csrc/jit/script/compiler.h
index c43931f..746facb 100644
--- a/torch/csrc/jit/script/compiler.h
+++ b/torch/csrc/jit/script/compiler.h
@@ -21,7 +21,7 @@
std::unique_ptr<CompilationUnitImpl> pImpl;
};
-std::unique_ptr<CompilationUnit> jitScriptCompile(const std::string& script);
+std::shared_ptr<Graph> jitScriptCompile(Def def);
} // namespace script
} // namespace jit
diff --git a/torch/csrc/jit/script/init.cpp b/torch/csrc/jit/script/init.cpp
index 3add566..ed9ff11 100644
--- a/torch/csrc/jit/script/init.cpp
+++ b/torch/csrc/jit/script/init.cpp
@@ -8,10 +8,12 @@
void initJitScriptBindings(PyObject* module) {
auto m = py::handle(module).cast<py::module>();
py::class_<CompilationUnit>(m, "CompilationUnit")
- .def(py::init<>())
- .def("get_graph", &CompilationUnit::getGraph,
- py::return_value_policy::reference)
- .def("define_function", &CompilationUnit::defineFunction);
+ .def(
+ "get_graph",
+ &CompilationUnit::getGraph,
+ py::return_value_policy::reference)
+ .def(py::init<>())
+ .def("define", &CompilationUnit::define);
m.def("_jit_script_compile", jitScriptCompile);
}
diff --git a/torch/jit/__init__.py b/torch/jit/__init__.py
index 1b52765..5f4b237 100644
--- a/torch/jit/__init__.py
+++ b/torch/jit/__init__.py
@@ -211,7 +211,7 @@
return _compile(arg)
-def trace(f, args=tuple(), kwargs=None, nderivs=0):
+def get_trace_graph(f, args=tuple(), kwargs=None, nderivs=0):
"""
Trace a function or model, returning a tuple consisting of the both the
*trace* of an execution, as well as the original return value.
@@ -429,5 +429,58 @@
raise RuntimeError("JIT and real computation mismatch")
+def trace(*args, **kwargs):
+ """
+ Trace a function and return an executable trace that will be optimized
+ using just-in-time compilation.
+
+ .. warning::
+
+ Just-in-time compilation currently only works for functions/modules
+ which are not data dependent (e.g., have conditionals on data in
+ tensors) and do not have any untracked external dependencies (e.g.,
+ perform input/output or access global variables). If you trace such
+ models, you will silently get incorrect results on subsequent
+ invocations of the model.
+
+ Arg:
+ *args - a list of example tensors that will be passed to the function
+ as inputs while tracing. The resulting trace can be run with
+ inputs of different types and shapes assuming the traced operations
+ support those types and shapes.
+
+ Keyword arguments:
+ optimize (bool, optional): whether or not to apply optimizations. Default: ``True``.
+
+ >>> @jit.trace(torch.autograd.Variable(torch.rand(1)))
+ >>> def f(x):
+ >>> return x * 2
+ """
+ return lambda func: torch._C.GraphExecutor(func, args, kwargs.pop('optimize', True))
+
+
+class CompilationUnit(object):
+ def __init__(self, lang=None, optimize=True):
+ self.cu = torch._C.CompilationUnit()
+ if lang is not None:
+ self.define(lang)
+ self.execution_engines = {}
+ self.optimize = optimize
+
+ def define(self, lang):
+ self.cu.define(lang)
+
+ def __getattr__(self, attr):
+ if attr not in self.execution_engines:
+ graph = self.cu.get_graph(attr)
+ self.execution_engines[attr] = torch._C.GraphExecutor(graph, self.optimize)
+ return self.execution_engines[attr]
+
+
+def script(fn):
+ ast = torch.jit.frontend.get_jit_ast(fn)
+ graph = _jit_script_compile(ast)
+ return torch._C.GraphExecutor(graph, True)
+
if not torch._C._jit_init():
raise RuntimeError("JIT initialization failed")
diff --git a/torch/onnx/__init__.py b/torch/onnx/__init__.py
index b949702..9943465 100644
--- a/torch/onnx/__init__.py
+++ b/torch/onnx/__init__.py
@@ -106,7 +106,7 @@
if isinstance(args, torch.autograd.Variable):
args = (args, )
- trace, torch_out = torch.jit.trace(func, args)
+ trace, torch_out = torch.jit.get_trace_graph(func, args)
_optimize_trace(trace, aten)
if return_outs:
return trace, torch_out
@@ -129,7 +129,7 @@
# can turn training=True (or None, to preserve whatever the original
# training mode was.)
with set_training(model, training):
- trace, torch_out = torch.jit.trace(model, args)
+ trace, torch_out = torch.jit.get_trace_graph(model, args)
if orig_state_dict_keys != _unique_state_dict(model).keys():
raise RuntimeError("state_dict changed after running the tracer; "