add @torch.jit.script, @torch.jit.compile, torch.jit.CompilationUnit(str) (#5367)

* torch.jit.trace annotation now creates a GraphExecutor

The other torch.jit.trace, which was used for testing purposes and for onnx to get the trace graph, is now called torch.jit. torch.jit.get_trace_graph.

* @script annotation, and compilation unit for strings
diff --git a/test/test_jit.py b/test/test_jit.py
index eceeb1e..257109f 100644
--- a/test/test_jit.py
+++ b/test/test_jit.py
@@ -122,7 +122,7 @@
         def f(x, y):
             return torch.sigmoid(torch.tanh(x * (x + y)))
 
-        trace, z = torch.jit.trace(f, (x, y), nderivs=0)
+        trace, z = torch.jit.get_trace_graph(f, (x, y), nderivs=0)
         self.assertExpectedTrace(trace)
 
     # matmul is currently implemented as a native function, which
@@ -134,7 +134,7 @@
         x = Variable(torch.Tensor([[0.4]]), requires_grad=True)
         y = Variable(torch.Tensor([[0.7]]), requires_grad=True)
 
-        trace, z = torch.jit.trace(lambda x, y: x.matmul(y), (x, y), nderivs=0)
+        trace, z = torch.jit.get_trace_graph(lambda x, y: x.matmul(y), (x, y), nderivs=0)
         torch._C._jit_pass_lint(trace)
         torch._C._jit_pass_dce(trace)
         self.assertExpectedTrace(trace)
@@ -203,7 +203,7 @@
                 out = torch.sigmoid(out)
             return out
 
-        trace, z = torch.jit.trace(f, (x, y), nderivs=0)
+        trace, z = torch.jit.get_trace_graph(f, (x, y), nderivs=0)
         self.assertExpectedTrace(trace)
 
     def test_scopes_intermediate_node(self):
@@ -214,7 +214,7 @@
 
         net = Net()
         t = Variable(torch.ones(2), requires_grad=True)
-        trace, _ = torch.jit.trace(net, (t, ))
+        trace, _ = torch.jit.get_trace_graph(net, (t, ))
         torch.onnx._optimize_trace(trace, False)
 
         self.assertExpectedTrace(trace)
@@ -240,7 +240,7 @@
         t = Variable(torch.ones(1, 3, 227, 227), requires_grad=True)
 
         with torch.onnx.set_training(model, False):
-            trace, _ = torch.jit.trace(model, (t, ))
+            trace, _ = torch.jit.get_trace_graph(model, (t, ))
 
         torch.onnx._optimize_trace(trace, False)
 
@@ -254,7 +254,7 @@
         cx = Variable(torch.randn(3, 20).float().cuda())
         module = nn.LSTMCell(10, 20).float().cuda()  # Just to allocate weights with correct sizes
 
-        trace, _ = torch.jit.trace(LSTMCell, (input, (hx, cx)) + tuple(module.parameters()))
+        trace, _ = torch.jit.get_trace_graph(LSTMCell, (input, (hx, cx)) + tuple(module.parameters()))
         torch._C._jit_pass_lint(trace)
         torch._C._jit_pass_dce(trace)
         torch._C._jit_pass_lint(trace)
@@ -316,7 +316,7 @@
         def Foo(hx, cx):
             return torch.cat((hx + cx, hx * cx))
 
-        trace, _ = torch.jit.trace(Foo, (hx, cx))
+        trace, _ = torch.jit.get_trace_graph(Foo, (hx, cx))
         torch._C._jit_pass_lint(trace)
         torch._C._jit_pass_fuse(trace)
         self.assertExpectedTrace(trace)
@@ -329,7 +329,7 @@
             return z1 * z2
         x = Variable(torch.randn(4, 4).float().cuda())
         y = Variable(torch.randn(4, 4).float().cuda())
-        trace, _ = torch.jit.trace(f, (x, y), nderivs=0)
+        trace, _ = torch.jit.get_trace_graph(f, (x, y), nderivs=0)
         torch._C._jit_pass_lint(trace)
         torch._C._jit_pass_dce(trace)
         self.assertExpectedTrace(trace, 'raw')
@@ -487,7 +487,7 @@
                 return a * grad_a
 
         x = Variable(torch.randn(10, 10), requires_grad=True)
-        trace, out = torch.jit.trace(MyFn.apply, x, nderivs=1)
+        trace, out = torch.jit.get_trace_graph(MyFn.apply, x, nderivs=1)
         out.sum().backward()
         torch._C._jit_pass_dce(trace)
         self.assertExpectedTrace(trace)
@@ -901,7 +901,7 @@
         def doit(x, y):
             return torch.sigmoid(torch.tanh(x * (x + y)))
 
-        traced, _ = torch.jit.trace(doit, (x, y))
+        traced, _ = torch.jit.get_trace_graph(doit, (x, y))
         g = torch._C._jit_get_graph(traced)
         g2 = torch._C.Graph()
         g_to_g2 = {}
@@ -931,12 +931,12 @@
 
     def test_batchnorm(self):
         x = Variable(torch.randn(2, 2, 2, 2).fill_(1.0), requires_grad=True)
-        trace, _ = torch.jit.trace(nn.BatchNorm2d(2), x)
+        trace, _ = torch.jit.get_trace_graph(nn.BatchNorm2d(2), x)
         self.assertExpectedTrace(trace)
 
     def test_dropout(self):
         x = Variable(torch.randn(2, 2).fill_(1.0), requires_grad=True)
-        trace, _ = torch.jit.trace(nn.Dropout(0.6), x)
+        trace, _ = torch.jit.get_trace_graph(nn.Dropout(0.6), x)
         self.assertExpectedTrace(trace)
 
     def test_batchnorm_run_twice(self):
@@ -957,7 +957,7 @@
 
     def test_conv(self):
         x = Variable(torch.randn(20, 16, 50, 40).fill_(1.0), requires_grad=True)
-        trace, _ = torch.jit.trace(nn.Conv2d(16, 13, 3, bias=False), x)
+        trace, _ = torch.jit.get_trace_graph(nn.Conv2d(16, 13, 3, bias=False), x)
         self.assertExpectedTrace(trace)
 
     def test_reuse_function(self):
@@ -1145,7 +1145,7 @@
     def test_alexnet(self):
         return
         x = Variable(torch.randn(10, 3, 224, 224).fill_(1.0), requires_grad=True)
-        trace, _ = torch.jit.trace(torchvision.models.AlexNet(), x)
+        trace, _ = torch.jit.get_trace_graph(torchvision.models.AlexNet(), x)
         self.assertExpectedTrace(trace)
         # NB: Purposely NOT testing protobuf export here
 
@@ -1183,14 +1183,14 @@
             out.copy_(x)
             return out
 
-        trace, z = torch.jit.trace(f, (x, ), nderivs=0)
+        trace, z = torch.jit.get_trace_graph(f, (x, ), nderivs=0)
         torch._C._jit_pass_lint(trace)
         torch._C._jit_pass_dce(trace)
         self.assertExpectedTrace(trace)
 
     def test_index_trace(self):
         x = Variable(torch.randn(4, 4), requires_grad=True)
-        trace, z = torch.jit.trace(lambda x: x[0], (x, ), nderivs=1)
+        trace, z = torch.jit.get_trace_graph(lambda x: x[0], (x, ), nderivs=1)
         z.sum().backward()
         torch._C._jit_pass_lint(trace)
         torch._C._jit_pass_dce(trace)
@@ -1217,13 +1217,13 @@
                 return x * self.a + self.b
 
         m = MyModule()
-        trace, _ = torch.jit.trace(m, (Variable(torch.randn(2, 2)),), nderivs=0)
+        trace, _ = torch.jit.get_trace_graph(m, (Variable(torch.randn(2, 2)),), nderivs=0)
         self.assertEqual(len(list(trace.graph().inputs())), 2)
         self.assertExpected(str(trace))
 
     def test_nested_inplace(self):
         x = Variable(torch.randn(2, 2))
-        trace, _ = torch.jit.trace(lambda x: F.threshold(x, 0, 0, inplace=True), (x,), nderivs=0)
+        trace, _ = torch.jit.get_trace_graph(lambda x: F.threshold(x, 0, 0, inplace=True), (x,), nderivs=0)
         self.assertExpectedTrace(trace)
 
     def checkGraphExecutor(self, func, reference_tensors, input_tensors=None, optimize=True, drop=None):
@@ -1238,12 +1238,6 @@
         if input_tensors is None:
             input_tensors = reference_tensors
 
-        def wrapped(*inputs):
-            res = func(*inputs)
-            if isinstance(res, torch.Tensor):
-                return (res,)
-            return res
-
         nograd_inputs = [Variable(t) for t in reference_tensors]
         recording_inputs = [Variable(t, requires_grad=True)
                             for t in reference_tensors]
@@ -1252,13 +1246,13 @@
 
         # test no gradients case
 
-        outputs = wrapped(*nograd_inputs)
+        outputs = func(*nograd_inputs)
         outputs_ge = ge(*nograd_inputs)
         self.assertEqual(outputs, outputs_ge)
 
         # test single grad case
 
-        outputs = wrapped(*recording_inputs)
+        outputs = func(*recording_inputs)
         grads = torch.autograd.grad(allSum(outputs), recording_inputs)
 
         outputs_ge = ge(*recording_inputs)
@@ -1268,7 +1262,7 @@
 
         # test the grad grad case
 
-        outputs = wrapped(*recording_inputs)
+        outputs = func(*recording_inputs)
         l1 = allSum(outputs)
         grads = torch.autograd.grad(l1, recording_inputs, create_graph=True)
         l2 = (allSum(grads) * l1)
@@ -1352,13 +1346,10 @@
 
     def checkScript(self, script, inputs, outputs, optimize, name='func'):
         if isinstance(script, str):
-            cu = torch.jit._jit_script_compile(script)
+            cu = torch.jit.CompilationUnit(script, optimize)
+            ge = getattr(cu, name)
         else:
-            ast = torch.jit.frontend.get_jit_ast(script)
-            cu = torch._C.CompilationUnit()
-            cu.define_function(ast)
-        graph = cu.get_graph(name)
-        ge = torch._C.GraphExecutor(graph, optimize)
+            ge = torch.jit.script(script)
         with capture_stdout() as captured:
             outputs_ge = ge(*inputs)
         self.assertEqual(outputs, outputs_ge)
@@ -1375,7 +1366,7 @@
         a = Variable(torch.rand(1), requires_grad=True)
         b = Variable(torch.rand(1), requires_grad=True)
         outputs = a + b + a
-        self.checkScript(script, [a, b], [outputs], False)
+        self.checkScript(script, [a, b], outputs, False)
 
     def test_script_mul(self):
         script = '''
@@ -1386,7 +1377,7 @@
         a = Variable(torch.rand(1), requires_grad=True)
         b = Variable(torch.rand(1), requires_grad=True)
         outputs = a * b
-        self.checkScript(script, [a, b], [outputs], False)
+        self.checkScript(script, [a, b], outputs, False)
 
     def test_script_triple(self):
         script = '''
@@ -1395,7 +1386,7 @@
         '''
         x = Variable(torch.rand(1).float(), requires_grad=True)
         outputs = 3 * x
-        self.checkScript(script, [x], [outputs], False)
+        self.checkScript(script, [x], outputs, False)
 
     def test_script_slice(self):
         script = '''
@@ -1404,7 +1395,7 @@
         '''
         x = Variable(torch.rand(10).float(), requires_grad=True)
         outputs = x[:5]
-        self.checkScript(script, [x], [outputs], False)
+        self.checkScript(script, [x], outputs, False)
 
     def test_script_gather(self):
         script = '''
@@ -1413,7 +1404,7 @@
         '''
         x = Variable(torch.rand(10).float(), requires_grad=True)
         outputs = x[0]
-        self.checkScript(script, [x], [outputs], False)
+        self.checkScript(script, [x], outputs, False)
 
     def test_script_func_call(self):
         script = '''
@@ -1431,7 +1422,7 @@
         x = Variable(torch.rand(3).float(), requires_grad=True)
         y = Variable(torch.rand(3).float(), requires_grad=True)
         outputs = alpha * x + beta * y
-        self.checkScript(script, [alpha, beta, x, y], [outputs], False)
+        self.checkScript(script, [alpha, beta, x, y], outputs, False)
 
     @unittest.skip("RuntimeError: VariableType::ID() not implemented")
     def test_script_cast(self):
@@ -1473,7 +1464,7 @@
         '''
         inputs = self._make_scalar_vars([1, 1, 10], np.int32)
         outputs = self._make_scalar_vars([20], np.int32)
-        self.checkScript(script, inputs, outputs, False, 'test_while')
+        self.checkScript(script, inputs, outputs[0], False, 'test_while')
 
     def test_script_fibb(self):
         script = '''
@@ -1515,7 +1506,7 @@
         '''
         inputs = self._make_scalar_vars([1, -1], np.int32)
         outputs = self._make_scalar_vars([7], np.int32)
-        self.checkScript(script, inputs, outputs, False, 'test_if')
+        self.checkScript(script, inputs, outputs[0], False, 'test_if')
 
     def test_script_if_noelse(self):
         script = '''
@@ -1526,11 +1517,11 @@
         '''
         inputs = self._make_scalar_vars([-1, 1], np.int32)
         outputs = self._make_scalar_vars([0], np.int32)
-        self.checkScript(script, inputs, outputs, False, 'test_if_noelse')
+        self.checkScript(script, inputs, outputs[0], False, 'test_if_noelse')
 
     def test_script_while_nonexistant_value(self):
         with self.assertRaisesRegex(RuntimeError, "undefined value x"):
-            torch.jit._jit_script_compile('''
+            torch.jit.CompilationUnit('''
             def test_while(a, b) -> (c):
                 while a < 10:
                     a = a + x
@@ -1540,7 +1531,7 @@
 
     def test_script_while_nonexistant_cond_value(self):
         with self.assertRaisesRegex(RuntimeError, "undefined value x"):
-            torch.jit._jit_script_compile('''
+            torch.jit.CompilationUnit('''
             def test_while(a, b) -> (c):
                 while a < x:
                     a = a + 1
@@ -1558,7 +1549,7 @@
         '''
         inputs = self._make_scalar_vars([42, 1337], np.int32)
         outputs = self._make_scalar_vars([1379], np.int32)
-        self.checkScript(script, inputs, outputs, False, 'test_while')
+        self.checkScript(script, inputs, outputs[0], False, 'test_while')
 
     def test_script_while_nest_if(self):
         script = '''
@@ -1575,7 +1566,7 @@
         '''
         inputs = self._make_scalar_vars([-1234, 4321], np.int32)
         outputs = self._make_scalar_vars([-5564], np.int32)
-        self.checkScript(script, inputs, outputs, False, 'test_while_if')
+        self.checkScript(script, inputs, outputs[0], False, 'test_while_if')
 
     def test_script_if_nest_while(self):
         script = '''
@@ -1588,10 +1579,10 @@
         '''
         inputs = self._make_scalar_vars([4321, 1234], np.int32)
         outputs = self._make_scalar_vars([-4321], np.int32)
-        self.checkScript(script, inputs, outputs, False, 'test_if_while')
+        self.checkScript(script, inputs, outputs[0], False, 'test_if_while')
 
     def test_script_ternary(self):
-        cu = torch.jit._jit_script_compile('''
+        cu = torch.jit.CompilationUnit('''
         def test_ternary_control(a, b) -> (c):
             c = 3
             if a > 3:
@@ -1599,14 +1590,14 @@
             else:
                 c = b
         ''')
-        cu2 = torch.jit._jit_script_compile('''
+        cu2 = torch.jit.CompilationUnit('''
         def test_ternary(a, b) -> (c):
             c = 3
             c = a + b if a > 3 else b
         ''')
         self.assertEqual(
-            str(cu.get_graph('test_ternary_control')),
-            str(cu2.get_graph('test_ternary')),
+            str(cu.cu.get_graph('test_ternary_control')),
+            str(cu2.cu.get_graph('test_ternary')),
         )
 
     def test_python_frontend_run(self):
@@ -1621,8 +1612,29 @@
         with capture_stdout():
             expected_out = func(x, y)
         expected_out = (x + y).sigmoid().pow(2)
-        self.checkScript(func, [x, y], [expected_out], False)
+        self.checkScript(func, [x, y], expected_out, False)
 
+    def test_trace_annotation(self):
+        @torch.jit.trace(Variable(torch.rand(1)))
+        def foo(a):
+            return a + a + a
+        s = Variable(torch.rand(2))
+        self.assertEqual(s + s + s, foo(s))
+
+    def test_script_cu(self):
+        cu = torch.jit.CompilationUnit('''
+            def foo(a) -> (b):
+                b = a
+        ''')
+        a = Variable(torch.rand(1))
+        self.assertEqual(a, cu.foo(a))
+
+    def test_script_annotation(self):
+        @torch.jit.script
+        def foo(a):
+            return a + a + a
+        s = Variable(torch.rand(2))
+        self.assertEqual(s + s + s, foo(s))
 
 if __name__ == '__main__':
     run_tests()
diff --git a/torch/csrc/jit/init.cpp b/torch/csrc/jit/init.cpp
index c8600e0..36b5483 100644
--- a/torch/csrc/jit/init.cpp
+++ b/torch/csrc/jit/init.cpp
@@ -37,12 +37,6 @@
   return F(state->graph);
 }
 
-GraphExecutor createExecutorByGraph(
-    std::shared_ptr<Graph> graph,
-    bool optimize) {
-  return GraphExecutor(std::move(graph), optimize);
-}
-
 // This is a temporary constructor so that we can write python tests of
 // the executor. It does not have most of the functionality of CompiledFunction
 // such as being able to hold parameters...
@@ -65,7 +59,7 @@
   tracer::exit(outputs);
   auto graph = enter_info.first->graph;
   EliminateDeadCode(graph);
-  return createExecutorByGraph(std::move(graph), optimize);
+  return GraphExecutor(std::move(graph), optimize);
 }
 
 // we cannot use the default py:cast<autograd::Variable> because it currently
@@ -119,17 +113,25 @@
           py::arg("optimize") = true)
       .def(
           py::init([](std::shared_ptr<Graph> graph, bool optimize) {
-            return createExecutorByGraph(std::move(graph), optimize);
+            return GraphExecutor(std::move(graph), optimize);
           }),
           py::arg("graph"),
           py::arg("optimize") = true)
-      .def("__call__", [](GraphExecutor& ge, py::args args) {
+      .def("__call__", [](GraphExecutor& ge, py::args args) -> py::object {
         auto inputs = createVariableTensorList(args);
         auto outputs = ge.run(std::move(inputs));
         // if we don't tell pybind these are variables it chokes on the
         // conversion.
         // TODO: fix conversions to be sane and make sure this works.
-        return std::vector<autograd::Variable>(outputs.begin(), outputs.end());
+        if(outputs.size() == 1) {
+          return py::cast(static_cast<autograd::Variable&>(outputs[0]));
+        } else {
+          py::tuple tuple(outputs.size());
+          for(size_t i = 0; i < outputs.size(); i++) {
+            tuple[i] = py::cast(static_cast<autograd::Variable&>(outputs[i]));
+          }
+          return tuple;
+        }
       });
   initPythonIRBindings(module);
   initPythonTracerBindings(module);
diff --git a/torch/csrc/jit/script/compiler.cpp b/torch/csrc/jit/script/compiler.cpp
index 536ef67..9ae9098 100644
--- a/torch/csrc/jit/script/compiler.cpp
+++ b/torch/csrc/jit/script/compiler.cpp
@@ -728,10 +728,11 @@
 
 CompilationUnit::~CompilationUnit() {}
 
-std::unique_ptr<CompilationUnit> jitScriptCompile(const std::string& script) {
-  std::unique_ptr<CompilationUnit> cu{new CompilationUnit};
-  cu->define(script);
-  return cu;
+std::shared_ptr<Graph> jitScriptCompile(Def def) {
+  FunctionTable empty;
+  FunctionDefinition fd(def);
+  to_ir(fd, empty);
+  return fd.graph;
 }
 
 } // namespace script
diff --git a/torch/csrc/jit/script/compiler.h b/torch/csrc/jit/script/compiler.h
index c43931f..746facb 100644
--- a/torch/csrc/jit/script/compiler.h
+++ b/torch/csrc/jit/script/compiler.h
@@ -21,7 +21,7 @@
   std::unique_ptr<CompilationUnitImpl> pImpl;
 };
 
-std::unique_ptr<CompilationUnit> jitScriptCompile(const std::string& script);
+std::shared_ptr<Graph> jitScriptCompile(Def def);
 
 } // namespace script
 } // namespace jit
diff --git a/torch/csrc/jit/script/init.cpp b/torch/csrc/jit/script/init.cpp
index 3add566..ed9ff11 100644
--- a/torch/csrc/jit/script/init.cpp
+++ b/torch/csrc/jit/script/init.cpp
@@ -8,10 +8,12 @@
 void initJitScriptBindings(PyObject* module) {
   auto m = py::handle(module).cast<py::module>();
   py::class_<CompilationUnit>(m, "CompilationUnit")
-    .def(py::init<>())
-    .def("get_graph", &CompilationUnit::getGraph,
-         py::return_value_policy::reference)
-    .def("define_function", &CompilationUnit::defineFunction);
+      .def(
+          "get_graph",
+          &CompilationUnit::getGraph,
+          py::return_value_policy::reference)
+      .def(py::init<>())
+      .def("define", &CompilationUnit::define);
   m.def("_jit_script_compile", jitScriptCompile);
 }
 
diff --git a/torch/jit/__init__.py b/torch/jit/__init__.py
index 1b52765..5f4b237 100644
--- a/torch/jit/__init__.py
+++ b/torch/jit/__init__.py
@@ -211,7 +211,7 @@
         return _compile(arg)
 
 
-def trace(f, args=tuple(), kwargs=None, nderivs=0):
+def get_trace_graph(f, args=tuple(), kwargs=None, nderivs=0):
     """
     Trace a function or model, returning a tuple consisting of the both the
     *trace* of an execution, as well as the original return value.
@@ -429,5 +429,58 @@
             raise RuntimeError("JIT and real computation mismatch")
 
 
+def trace(*args, **kwargs):
+    """
+    Trace a function and return an executable trace that will be optimized
+    using just-in-time compilation.
+
+    .. warning::
+
+        Just-in-time compilation currently only works for functions/modules
+        which are not data dependent (e.g., have conditionals on data in
+        tensors) and do not have any untracked external dependencies (e.g.,
+        perform input/output or access global variables). If you trace such
+        models, you will silently get incorrect results on subsequent
+        invocations of the model.
+
+    Arg:
+        *args - a list of example tensors that will be passed to the function
+                as inputs while tracing. The resulting trace can be run with
+                inputs of different types and shapes assuming the traced operations
+                support those types and shapes.
+
+    Keyword arguments:
+        optimize (bool, optional): whether or not to apply optimizations.  Default: ``True``.
+
+        >>> @jit.trace(torch.autograd.Variable(torch.rand(1)))
+        >>> def f(x):
+        >>>     return x * 2
+    """
+    return lambda func: torch._C.GraphExecutor(func, args, kwargs.pop('optimize', True))
+
+
+class CompilationUnit(object):
+    def __init__(self, lang=None, optimize=True):
+        self.cu = torch._C.CompilationUnit()
+        if lang is not None:
+            self.define(lang)
+        self.execution_engines = {}
+        self.optimize = optimize
+
+    def define(self, lang):
+        self.cu.define(lang)
+
+    def __getattr__(self, attr):
+        if attr not in self.execution_engines:
+            graph = self.cu.get_graph(attr)
+            self.execution_engines[attr] = torch._C.GraphExecutor(graph, self.optimize)
+        return self.execution_engines[attr]
+
+
+def script(fn):
+    ast = torch.jit.frontend.get_jit_ast(fn)
+    graph = _jit_script_compile(ast)
+    return torch._C.GraphExecutor(graph, True)
+
 if not torch._C._jit_init():
     raise RuntimeError("JIT initialization failed")
diff --git a/torch/onnx/__init__.py b/torch/onnx/__init__.py
index b949702..9943465 100644
--- a/torch/onnx/__init__.py
+++ b/torch/onnx/__init__.py
@@ -106,7 +106,7 @@
     if isinstance(args, torch.autograd.Variable):
         args = (args, )
 
-    trace, torch_out = torch.jit.trace(func, args)
+    trace, torch_out = torch.jit.get_trace_graph(func, args)
     _optimize_trace(trace, aten)
     if return_outs:
         return trace, torch_out
@@ -129,7 +129,7 @@
     # can turn training=True (or None, to preserve whatever the original
     # training mode was.)
     with set_training(model, training):
-        trace, torch_out = torch.jit.trace(model, args)
+        trace, torch_out = torch.jit.get_trace_graph(model, args)
 
     if orig_state_dict_keys != _unique_state_dict(model).keys():
         raise RuntimeError("state_dict changed after running the tracer; "