Re-enable and fix most JIT tests
diff --git a/test/expect/TestJit.test_alexnet.expect b/test/expect/TestJit.test_alexnet.expect
index 93fa843..0319335 100644
--- a/test/expect/TestJit.test_alexnet.expect
+++ b/test/expect/TestJit.test_alexnet.expect
@@ -16,28 +16,31 @@
%16 : Double(1000, 4096)
%17 : Double(1000)) {
%19 : Double(10, 64, 55, 55), %20 : Handle = CppOp[ConvForward](%1, %2, %3), uses = [[%21.i0], []];
- %22 : Double(10, 64, 55, 55), %23 : Handle = ^Threshold(0, 0, True)(%19), uses = [[%24.i0], []];
- %25 : Double(10, 64, 27, 27), %26 : Long(10, 64, 27, 27), %27 : Handle = ^MaxPool2d(3, 2, 0, 1, False)(%22), uses = [[%28.i0], [], []];
- %29 : Double(10, 192, 27, 27), %30 : Handle = CppOp[ConvForward](%25, %4, %5), uses = [[%31.i0], []];
- %32 : Double(10, 192, 27, 27), %33 : Handle = ^Threshold(0, 0, True)(%29), uses = [[%34.i0], []];
- %35 : Double(10, 192, 13, 13), %36 : Long(10, 192, 13, 13), %37 : Handle = ^MaxPool2d(3, 2, 0, 1, False)(%32), uses = [[%38.i0], [], []];
- %39 : Double(10, 384, 13, 13), %40 : Handle = CppOp[ConvForward](%35, %6, %7), uses = [[%41.i0], []];
- %42 : Double(10, 384, 13, 13), %43 : Handle = ^Threshold(0, 0, True)(%39), uses = [[%44.i0], []];
- %45 : Double(10, 256, 13, 13), %46 : Handle = CppOp[ConvForward](%42, %8, %9), uses = [[%47.i0], []];
- %48 : Double(10, 256, 13, 13), %49 : Handle = ^Threshold(0, 0, True)(%45), uses = [[%50.i0], []];
- %51 : Double(10, 256, 13, 13), %52 : Handle = CppOp[ConvForward](%48, %10, %11), uses = [[%53.i0], []];
- %54 : Double(10, 256, 13, 13), %55 : Handle = ^Threshold(0, 0, True)(%51), uses = [[%56.i0], []];
- %57 : Double(10, 256, 6, 6), %58 : Long(10, 256, 6, 6), %59 : Handle = ^MaxPool2d(3, 2, 0, 1, False)(%54), uses = [[%60.i0], [], []];
- %61 : Double(10, 9216), %62 : Handle = ^View((10, 9216))(%57), uses = [[%63.i0], []];
- %64 : Double(10, 9216), %65 : Handle = ^Dropout(0.5, True, False)(%61), uses = [[%68.i1], []];
- %67 : Double(9216!, 4096!) = ^Transpose(0, 1)(%12), uses = [[%68.i2]];
- %69 : Double(10, 4096), %70 : Handle = ^Addmm(1, 1, False)(%13, %64, %67), uses = [[%71.i0], []];
- %72 : Double(10, 4096), %73 : Handle = ^Threshold(0, 0, True)(%69), uses = [[%74.i0], []];
- %75 : Double(10, 4096), %76 : Handle = ^Dropout(0.5, True, False)(%72), uses = [[%79.i1], []];
- %78 : Double(4096!, 4096!) = ^Transpose(0, 1)(%14), uses = [[%79.i2]];
- %80 : Double(10, 4096), %81 : Handle = ^Addmm(1, 1, False)(%15, %75, %78), uses = [[%82.i0], []];
- %83 : Double(10, 4096), %84 : Handle = ^Threshold(0, 0, True)(%80), uses = [[%87.i1], []];
- %86 : Double(4096!, 1000!) = ^Transpose(0, 1)(%16), uses = [[%87.i2]];
- %88 : Double(10, 1000), %89 : Handle = ^Addmm(1, 1, False)(%17, %83, %86), uses = [[%0.i0], []];
- return (%88);
+ %22 : Double(10, 64, 55, 55) = threshold[threshold={0}, value={0}, inplace=1](%19), uses = [[%23.i0]];
+ %24 : Double(10, 64, 27, 27), %25 : Long(10, 64, 27, 27) = max_pool2d[kernel_size=[3, 3], stride=[2, 2], padding=[0, 0], dilation=[1, 1], ceil_mode=0](%22), uses = [[%26.i0], []];
+ %27 : Double(10, 192, 27, 27), %28 : Handle = CppOp[ConvForward](%24, %4, %5), uses = [[%29.i0], []];
+ %30 : Double(10, 192, 27, 27) = threshold[threshold={0}, value={0}, inplace=1](%27), uses = [[%31.i0]];
+ %32 : Double(10, 192, 13, 13), %33 : Long(10, 192, 13, 13) = max_pool2d[kernel_size=[3, 3], stride=[2, 2], padding=[0, 0], dilation=[1, 1], ceil_mode=0](%30), uses = [[%34.i0], []];
+ %35 : Double(10, 384, 13, 13), %36 : Handle = CppOp[ConvForward](%32, %6, %7), uses = [[%37.i0], []];
+ %38 : Double(10, 384, 13, 13) = threshold[threshold={0}, value={0}, inplace=1](%35), uses = [[%39.i0]];
+ %40 : Double(10, 256, 13, 13), %41 : Handle = CppOp[ConvForward](%38, %8, %9), uses = [[%42.i0], []];
+ %43 : Double(10, 256, 13, 13) = threshold[threshold={0}, value={0}, inplace=1](%40), uses = [[%44.i0]];
+ %45 : Double(10, 256, 13, 13), %46 : Handle = CppOp[ConvForward](%43, %10, %11), uses = [[%47.i0], []];
+ %48 : Double(10, 256, 13, 13) = threshold[threshold={0}, value={0}, inplace=1](%45), uses = [[%49.i0]];
+ %50 : Double(10, 256, 6, 6), %51 : Long(10, 256, 6, 6) = max_pool2d[kernel_size=[3, 3], stride=[2, 2], padding=[0, 0], dilation=[1, 1], ceil_mode=0](%48), uses = [[%52.i0], []];
+ %53 : Double(10, 9216) = view[size=[10, 9216]](%50), uses = [[%54.i0]];
+ %55 : Double(10, 9216), %56 : Handle = ^Dropout(0.5, True, False)(%53), uses = [[%61.i1], []];
+ %58 : Double(9216!, 4096!) = t(%12), uses = [[%61.i2]];
+ %60 : Double(10!, 4096) = expand[size=[10, 4096]](%13), uses = [[%61.i0]];
+ %62 : Double(10, 4096) = addmm[beta={1}, alpha={1}](%60, %55, %58), uses = [[%63.i0]];
+ %64 : Double(10, 4096) = threshold[threshold={0}, value={0}, inplace=1](%62), uses = [[%65.i0]];
+ %66 : Double(10, 4096), %67 : Handle = ^Dropout(0.5, True, False)(%64), uses = [[%72.i1], []];
+ %69 : Double(4096!, 4096!) = t(%14), uses = [[%72.i2]];
+ %71 : Double(10!, 4096) = expand[size=[10, 4096]](%15), uses = [[%72.i0]];
+ %73 : Double(10, 4096) = addmm[beta={1}, alpha={1}](%71, %66, %69), uses = [[%74.i0]];
+ %75 : Double(10, 4096) = threshold[threshold={0}, value={0}, inplace=1](%73), uses = [[%80.i1]];
+ %77 : Double(4096!, 1000!) = t(%16), uses = [[%80.i2]];
+ %79 : Double(10!, 1000) = expand[size=[10, 1000]](%17), uses = [[%80.i0]];
+ %81 : Double(10, 1000) = addmm[beta={1}, alpha={1}](%79, %75, %77), uses = [[%0.i0]];
+ return (%81);
}
diff --git a/test/expect/TestJit.test_assign_traces.expect b/test/expect/TestJit.test_assign_traces.expect
index ebb830b..606f148 100644
--- a/test/expect/TestJit.test_assign_traces.expect
+++ b/test/expect/TestJit.test_assign_traces.expect
@@ -3,6 +3,6 @@
%4 : Double(10, 10!)) {
%3 : Double(10, 10) = ^MyFn()(%1), uses = [[%0.i0, %5.i0]];
---------------- stage 1 ----------------
- %6 : Double(10, 10) = ^Mul()(%3, %4), uses = [[%0.i1]];
+ %6 : Double(10, 10) = mul(%3, %4), uses = [[%0.i1]];
return (%3, %6);
}
diff --git a/test/expect/TestJit.test_backward.expect b/test/expect/TestJit.test_backward.expect
index ac243bc..5c8157b 100644
--- a/test/expect/TestJit.test_backward.expect
+++ b/test/expect/TestJit.test_backward.expect
@@ -5,19 +5,19 @@
-------- stage 2 --------
%14 : Double(2, 2!)
%15 : Double(2, 2)) {
- %4 : Double(2, 2) = ^MulConstant(2)(%2), uses = [[%5.i0, %10.i1, %16.i1]];
- %6 : Double(2, 2) = ^Mul()(%4, %1), uses = [[%0.i0]];
+ %4 : Double(2, 2) = mul[other={2}](%2), uses = [[%5.i0, %10.i1, %16.i1]];
+ %6 : Double(2, 2) = mul(%4, %1), uses = [[%0.i0]];
---------------- stage 1 ----------------
- %9 : Double(2, 2) = ^Mul()(%7, %1), uses = [[%12.i0]];
- %11 : Double(2, 2) = ^Mul()(%7, %4), uses = [[%0.i1]];
- %13 : Double(2, 2) = ^MulConstant(2)(%9), uses = [[%0.i2]];
+ %9 : Double(2, 2) = mul(%7, %1), uses = [[%12.i0]];
+ %11 : Double(2, 2) = mul(%7, %4), uses = [[%0.i1]];
+ %13 : Double(2, 2) = mul[other={2}](%9), uses = [[%0.i2]];
---------------- stage 2 ----------------
- %17 : Double(2, 2) = ^Mul()(%14, %4), uses = [[%28.i0]];
- %19 : Double(2, 2) = ^Mul()(%14, %7), uses = [[%22.i0]];
- %21 : Double(2, 2) = ^MulConstant(2)(%15), uses = [[%24.i0, %26.i0]];
- %23 : Double(2, 2) = ^MulConstant(2)(%19), uses = [[%0.i5]];
- %25 : Double(2, 2) = ^Mul()(%21, %1), uses = [[%28.i1]];
- %27 : Double(2, 2) = ^Mul()(%21, %7), uses = [[%0.i4]];
+ %17 : Double(2, 2) = mul(%14, %4), uses = [[%28.i0]];
+ %19 : Double(2, 2) = mul(%14, %7), uses = [[%22.i0]];
+ %21 : Double(2, 2) = mul[other={2}](%15), uses = [[%24.i0, %26.i0]];
+ %23 : Double(2, 2) = mul[other={2}](%19), uses = [[%0.i5]];
+ %25 : Double(2, 2) = mul(%21, %1), uses = [[%28.i1]];
+ %27 : Double(2, 2) = mul(%21, %7), uses = [[%0.i4]];
%29 : Double(2, 2) = CppOp[N5torch8autograd3AddE](%17, %25), uses = [[%0.i3]];
return (%6, %11, %13, %29, %27, %23);
}
diff --git a/test/expect/TestJit.test_backward_opaque.expect b/test/expect/TestJit.test_backward_opaque.expect
index 63f4a55..8b7f59a 100644
--- a/test/expect/TestJit.test_backward_opaque.expect
+++ b/test/expect/TestJit.test_backward_opaque.expect
@@ -1,9 +1,10 @@
graph(%1 : Double(3, 3)
%2 : Double(3, 3)
-------- stage 1 --------
- %6 : Double(3, 3)) {
- %4 : Double(3, 3), %5 : Handle = ^Cross()(%1, %2), uses = [[%0.i0], [%7.i1]];
+ %5 : Double(3, 3)) {
+ %4 : Double(3, 3) = cross[dim=-1](%1, %2), uses = [[%0.i0]];
---------------- stage 1 ----------------
- %17 : Double(3, 3), %18 : Double(3, 3), %19 : Handle = CppOp[N5torch8autograd4EvalE](%6, %5), uses = [[%0.i1], [%0.i2], []];
- return (%4, %17, %18);
+ %7 : Double(3, 3) = cross[dim=-1](%2, %5), uses = [[%0.i1]];
+ %9 : Double(3, 3) = cross[dim=-1](%5, %1), uses = [[%0.i2]];
+ return (%4, %7, %9);
}
diff --git a/test/expect/TestJit.test_cse.expect b/test/expect/TestJit.test_cse.expect
index 80eb6b0..886f310 100644
--- a/test/expect/TestJit.test_cse.expect
+++ b/test/expect/TestJit.test_cse.expect
@@ -1,10 +1,10 @@
graph(%1 : Double(2)
%2 : Double(2)) {
- %3 : Double(2) = Add(%1, %2), uses = [%5.i0, %5.i1, %7.i1];
- %5 : Double(2) = Mul(%3, %3), uses = [%7.i0];
- %7 : Double(2) = Mul(%5, %3), uses = [%8.i0, %16.i0];
- %8 : Double(2) = Tanh(%7), uses = [%10.i0, %10.i1];
- %10 : Double(2) = Add(%8, %8), uses = [%16.i1];
- %16 : Double(2) = Add(%7, %10), uses = [%0.i0];
- return (%16);
+ %4 : Double(2) = add[alpha={1}](%1, %2), uses = [[%7.i0, %7.i1, %11.i1]];
+ %8 : Double(2) = mul(%4, %4), uses = [[%11.i0]];
+ %12 : Double(2) = mul(%8, %4), uses = [[%13.i0, %29.i0]];
+ %14 : Double(2) = tanh(%12), uses = [[%17.i0, %17.i1]];
+ %18 : Double(2) = add[alpha={1}](%14, %14), uses = [[%29.i1]];
+ %30 : Double(2) = add[alpha={1}](%12, %18), uses = [[%0.i0]];
+ return (%30);
}
diff --git a/test/expect/TestJit.test_fusion_distribute-onnx.expect b/test/expect/TestJit.test_fusion_distribute-onnx.expect
deleted file mode 100644
index 5b8a22b..0000000
--- a/test/expect/TestJit.test_fusion_distribute-onnx.expect
+++ /dev/null
@@ -1,7 +0,0 @@
-graph(%1 : Double(4, 4)
- %2 : Double(4, 4)) {
- %3 : Double(4, 4) = Add(%1, %2), uses = [%4.i0];
- %5 : Double(4!, 2), %6 : Double(4!, 2) = Split[split=[2, 2], axis=1](%3), uses = [[%7.i0], [%7.i1]];
- %7 : Double(4, 2) = Mul(%5, %6), uses = [%0.i0];
- return (%7);
-}
diff --git a/test/expect/TestJit.test_inplace_transplant.expect b/test/expect/TestJit.test_inplace_transplant.expect
index fff2e3e..afc6b4a 100644
--- a/test/expect/TestJit.test_inplace_transplant.expect
+++ b/test/expect/TestJit.test_inplace_transplant.expect
@@ -1,6 +1,6 @@
graph(%1 : Double(1)) {
- %3 : Double(1), %4 : Handle = ^Clone()(%1), uses = [[%5.i0], []];
- %6 : Double(1) = ^AddConstant(2, True)(%3), uses = [[%7.i0]];
- %8 : Double(1) = ^AddConstant(3, True)(%6), uses = [[%0.i0]];
- return (%8);
+ %3 : Double(1) = clone(%1), uses = [[%4.i0]];
+ %5 : Double(1) = add[other={2}, alpha={1}](%3), uses = [[%6.i0]];
+ %7 : Double(1) = add[other={3}, alpha={1}](%5), uses = [[%0.i0]];
+ return (%7);
}
diff --git a/test/expect/TestJit.test_python_ir.expect b/test/expect/TestJit.test_python_ir.expect
index 6848555..d91104f 100644
--- a/test/expect/TestJit.test_python_ir.expect
+++ b/test/expect/TestJit.test_python_ir.expect
@@ -1,9 +1,9 @@
graph(%1 : UNKNOWN_TYPE
%2 : UNKNOWN_TYPE) {
- %4 : Double(1) = Add[note=from_pyop, some_value=1](%1, %2), uses = [[%5.i1]];
- %6 : Double(1) = Mul[note=from_pyop, some_value=0](%1, %4), uses = [[%7.i0]];
- %8 : Double(1) = Tanh[note=from_pyop, some_value=0](%6), uses = [[%9.i0]];
- %10 : Double(1) = Sigmoid[note=from_pyop, some_value=0](%8), uses = [[%0.i0]];
+ %4 : Double(1) = add[alpha={1}](%1, %2), uses = [[%5.i1]];
+ %6 : Double(1) = mul(%1, %4), uses = [[%7.i0]];
+ %8 : Double(1) = tanh(%6), uses = [[%9.i0]];
+ %10 : Double(1) = sigmoid(%8), uses = [[%0.i0]];
%11 : UNKNOWN_TYPE = TensorTest[a= 1 1 1 1 [ CPUDoubleTensor{2,2} ]](), uses = [];
return (%10);
}
diff --git a/test/expect/TestJit.test_simple.expect b/test/expect/TestJit.test_simple.expect
index d621a06..97f1839 100644
--- a/test/expect/TestJit.test_simple.expect
+++ b/test/expect/TestJit.test_simple.expect
@@ -1,8 +1,8 @@
graph(%1 : Double(1)
%2 : Double(1)) {
- %3 : Double(1) = Add(%1, %2), uses = [%4.i1];
- %4 : Double(1) = Mul(%1, %3), uses = [%5.i0];
- %5 : Double(1) = Tanh(%4), uses = [%6.i0];
- %6 : Double(1) = Sigmoid(%5), uses = [%0.i0];
- return (%6);
+ %4 : Double(1) = add[alpha={1}](%1, %2), uses = [[%5.i1]];
+ %6 : Double(1) = mul(%1, %4), uses = [[%7.i0]];
+ %8 : Double(1) = tanh(%6), uses = [[%9.i0]];
+ %10 : Double(1) = sigmoid(%8), uses = [[%0.i0]];
+ return (%10);
}
diff --git a/test/test_jit.py b/test/test_jit.py
index 982bc1d..4156b98 100644
--- a/test/test_jit.py
+++ b/test/test_jit.py
@@ -33,7 +33,6 @@
return hy, cy
-@unittest.skip("JIT tests temporarily broken")
class TestJit(TestCase):
maxDiff = None
@@ -45,13 +44,10 @@
return torch.sigmoid(torch.tanh(x * (x + y)))
trace, z = torch.jit.trace(f, (x, y), nderivs=0)
-
torch._C._jit_pass_lint(trace)
- torch._C._jit_pass_onnx(trace)
- torch._C._jit_pass_lint(trace)
-
self.assertExpected(str(trace))
+ @unittest.skip("Fuser is broken")
@unittest.skipIf(not torch.cuda.is_available(), "fuser requires CUDA")
def test_lstm_fusion(self):
input = Variable(torch.randn(3, 10).cuda())
@@ -61,12 +57,11 @@
trace, _ = torch.jit.trace(LSTMCell, (input, (hx, cx)) + tuple(module.parameters()))
torch._C._jit_pass_lint(trace)
- torch._C._jit_pass_onnx(trace)
- torch._C._jit_pass_lint(trace)
torch._C._jit_pass_fuse(trace)
torch._C._jit_pass_lint(trace)
self.assertExpected(str(trace))
+ @unittest.skip("Fuser is broken")
@unittest.skipIf(not torch.cuda.is_available(), "fuser requires CUDA")
def test_run_lstm_fusion(self):
input = Variable(torch.randn(3, 10).cuda())
@@ -80,6 +75,7 @@
z2 = CompiledLSTMCell(input, (hx, cx), *module.parameters(), _assert_compiled=True)
self.assertEqual(z, z2)
+ @unittest.skip("Fuser is broken")
@unittest.skipIf(not torch.cuda.is_available(), "fuser requires CUDA")
def test_fusion_distribute(self):
def f(x, y):
@@ -90,9 +86,6 @@
trace, _ = torch.jit.trace(f, (x, y), nderivs=0)
torch._C._jit_pass_lint(trace)
self.assertExpected(str(trace), 'raw')
- torch._C._jit_pass_onnx(trace)
- torch._C._jit_pass_lint(trace)
- self.assertExpected(str(trace), 'onnx')
torch._C._jit_pass_fuse(trace)
torch._C._jit_pass_lint(trace)
self.assertExpected(str(trace))
@@ -107,8 +100,6 @@
z = (x + y) * (x + y) * (x + y) + t
torch._C._tracer_exit((z,))
torch._C._jit_pass_lint(trace)
- torch._C._jit_pass_onnx(trace)
- torch._C._jit_pass_lint(trace)
torch._C._jit_pass_cse(trace)
self.assertExpected(str(trace))
@@ -280,19 +271,38 @@
self.assertExpected(str(trace))
def test_inplace_flags(self):
+ class InplaceFn(Function):
+ @staticmethod
+ def forward(ctx, x):
+ ctx.mark_dirty(x)
+ return x.add_(1)
+
+ @staticmethod
+ def backward(ctx, go):
+ return go
+
+ class RegularFn(Function):
+ @staticmethod
+ def forward(ctx, x):
+ return x.add(1)
+
+ @staticmethod
+ def backward(ctx, go):
+ return go
+
x = Variable(torch.Tensor([0]), requires_grad=True)
trace = torch._C._tracer_enter((x,), 0)
- y = x + 2
- y.add_(2)
- y.mul_(4)
- y = y * 2
+ y = RegularFn.apply(x)
+ y = InplaceFn.apply(y)
+ y = InplaceFn.apply(y)
+ y = RegularFn.apply(y)
torch._C._tracer_exit((y,))
ops = [n for n in trace.graph().nodes() if n.kind() != 'Select']
for op in ops:
- self.assertTrue(op.hasAttribute('__inplace'))
+ self.assertTrue(op.hasAttribute('inplace'))
inplace_flags = [False, True, True, False]
for op, is_inplace in zip(ops, inplace_flags):
- self.assertEqual(op.i('__inplace'), is_inplace)
+ self.assertEqual(op.i('inplace'), is_inplace)
def test_inplace_check(self):
class MyInplaceFn(Function):
@@ -548,7 +558,6 @@
assert(n_.i("some_value") == len(node.scalar_args()))
else:
n_ = g2.createClone(node, lambda x: g_to_g2[x])
- assert(n_.kindOf("Offset") == "i")
g_to_g2[node] = g2.appendNode(n_)
diff --git a/torch/autograd/variable.py b/torch/autograd/variable.py
index b3b5a92..8b26553 100644
--- a/torch/autograd/variable.py
+++ b/torch/autograd/variable.py
@@ -426,15 +426,12 @@
def bernoulli(self):
return Bernoulli.apply(self)
- def __add__(self, other):
- return self.add(other)
- __radd__ = __add__
+ __radd__ = __add__ = _C._VariableBase.add
def __iadd__(self, other):
return self.add_(other)
- def __sub__(self, other):
- return self.sub(other)
+ __sub__ = _C._VariableBase.sub
def __isub__(self, other):
return self.sub_(other)
@@ -442,9 +439,7 @@
def __rsub__(self, other):
return -self + other
- def __mul__(self, other):
- return self.mul(other)
- __rmul__ = __mul__
+ __rmul__ = __mul__ = _C._VariableBase.mul
def __imul__(self, other):
return self.mul_(other)
@@ -454,9 +449,7 @@
return NotImplemented
return self.matmul(other)
- def __div__(self, other):
- return self.div(other)
- __truediv__ = __div__
+ __truediv__ = __div__ = _C._VariableBase.div
def __rdiv__(self, other):
return self.reciprocal() * other
@@ -465,8 +458,7 @@
def __idiv__(self, other):
return self.div_(other)
- def __pow__(self, other):
- return self.pow(other)
+ __pow__ = _C._VariableBase.pow
def __ipow__(self, other):
raise NotImplementedError("in-place pow not implemented")
diff --git a/torch/csrc/autograd/functions/jit_closure.cpp b/torch/csrc/autograd/functions/jit_closure.cpp
index d1d6e51..9f8fd92 100644
--- a/torch/csrc/autograd/functions/jit_closure.cpp
+++ b/torch/csrc/autograd/functions/jit_closure.cpp
@@ -11,6 +11,7 @@
#include "torch/csrc/autograd/python_engine.h"
#include "torch/csrc/autograd/python_variable.h"
#include "torch/csrc/autograd/python_function.h"
+#include "torch/csrc/jit/generated/aten_dispatch.h"
#ifdef WITH_CUDA
#include "torch/csrc/jit/fusion_compiler.h"
#endif
@@ -115,20 +116,28 @@
};
};
-// A hack that will let us implement some of the ops we care
-// about before the major Python -> C++ Function migration
struct LambdaFunction : public Function {
+ LambdaFunction(const jit::TensorOp& op)
+ : LambdaFunction(op.num_inputs, op.op) {
+ this->name_ = op.name;
+ }
+
LambdaFunction(int num_inputs, std::function<variable_list(const variable_list&)> fn)
- : fn(fn) {
+ : fn_(fn) {
this->is_executable = true;
this->num_inputs = num_inputs;
}
- virtual variable_list apply(const variable_list& inputs) {
- return fn(inputs);
+ virtual std::string name() override {
+ return name_.size() == 0 ? "LambdaFunction" : name_;
}
- std::function<variable_list(const variable_list&)> fn;
+ virtual variable_list apply(const variable_list& inputs) override {
+ return fn_(inputs);
+ }
+
+ std::string name_;
+ std::function<variable_list(const variable_list&)> fn_;
};
// Wraps a PythonOp and dispatches calls to Functions implemented in Python
@@ -583,7 +592,7 @@
IR_ELSEIF(Concat)
return std::make_shared<torch::autograd::Cat>(value->i(kaxis));
IR_ELSE()
- throw std::runtime_error(std::string("unrecognized NodeKind: ") + symbolToString(node->kind()));
+ return std::make_shared<LambdaFunction>(getTensorOp(node));
IR_END()
}
@@ -671,7 +680,7 @@
// Roots for a call to the engine. The list contains function in this order:
// [ apply input roots | prev stage input roots | constant factory ]
function_list roots;
- std::vector<VariableFlags> var_flags;
+ std::pair<std::vector<VariableFlags>, std::vector<VariableFlags>> var_flags;
// Output node
std::shared_ptr<Function> output;
@@ -703,15 +712,14 @@
};
AutogradClosure::AutogradClosure(const std::shared_ptr<MultiStageClosure>& desc)
- : AutogradClosure(desc, 0, {}) {}
+ : AutogradClosure(desc, 0) {}
// TODO: there's a lot processing involved in creating a new AutogradClosure instance,
// so it might be worth to keep a pool of unused instances (or at least their attrs)
// for all stages. We can't save saved_vars and saved_handles, but all callbacks
// can be made reusable.
-AutogradClosure::AutogradClosure(const std::shared_ptr<MultiStageClosure>& desc, std::size_t stage, FunctionFlags &&f)
- : Function(std::move(f))
- , desc(desc)
+AutogradClosure::AutogradClosure(const std::shared_ptr<MultiStageClosure>& desc, std::size_t stage)
+ : desc(desc)
, stage(stage) {
auto & stage_desc = desc->stages[stage];
@@ -777,10 +785,10 @@
// Validate inputs
auto num_inputs = inputs.size();
- if (num_inputs != stage_closure.var_flags.size())
+ if (num_inputs != stage_closure.var_flags.first.size())
throw std::runtime_error("AutogradClosure received an incorrect number of inputs");
for (std::size_t i = 0; i < num_inputs; ++i) {
- auto & flags = stage_closure.var_flags[i];
+ auto & flags = stage_closure.var_flags.first[i];
if (!flags.verify(inputs[i]))
throw std::runtime_error("AutogradClosure received inputs with different flags");
}
@@ -797,16 +805,15 @@
auto& engine = python::PythonEngine::getDefaultEngine();
engine.execute(stage_closure.roots, input_leaves, true, pre_callbacks, post_callbacks);
- // See Note [Null-edge pruning]
- auto relevant_inputs = filter(inputs, [](const Variable& var) { return var.defined() && var.requires_grad(); });
- auto result = wrap_outputs(relevant_inputs, std::move(outputs), [this](FunctionFlags f) -> std::shared_ptr<Function> {
+ // Create the backward function lazily
+ auto make_grad_fn = [this]() -> std::shared_ptr<Function> {
if (this->stage == this->desc->stages.size() - 1) {
std::string msg = "JIT closure compiled only for ";
msg += std::to_string(this->stage);
msg += " backwards";
- return std::make_shared<Error>(std::move(msg), std::move(f));
+ return std::make_shared<Error>(std::move(msg));
}
- auto bw_fn = std::shared_ptr<AutogradClosure>(new AutogradClosure(this->desc, this->stage + 1, std::move(f)));
+ auto bw_fn = std::shared_ptr<AutogradClosure>(new AutogradClosure(this->desc, this->stage + 1));
// TODO: don't make a full copy of saved_* - copy only the things that bw needs
bw_fn->saved_vars = this->saved_vars;
bw_fn->saved_vars.insert(std::make_move_iterator(this->captured_vars.begin()),
@@ -824,7 +831,33 @@
// was run, so it must have been executable).
bw_fn->is_executable = true;
return bw_fn;
- });
+ };
+
+ // See Note [Null-edge pruning]
+ variable_list result;
+ auto num_outputs = outputs.size();
+ std::shared_ptr<Function> grad_fn;
+ JIT_ASSERT(outputs.size() == stage_closure.var_flags.second.size());
+ for (std::size_t i = 0; i < num_outputs; ++i) {
+ auto & flags = stage_closure.var_flags.second[i];
+ if (flags.requires_grad) {
+ if (!grad_fn) grad_fn = make_grad_fn();
+ result.push_back(make_variable(outputs[i], grad_fn));
+ } else {
+ result.push_back(make_variable(outputs[i], flags.requires_grad, flags.is_volatile));
+ }
+ }
+
+ // If we created grad_fn for any of the outputs, we also need to fill in next_functions
+ if (grad_fn) {
+ for (auto & input : inputs) {
+ if (!input.requires_grad()) continue;
+ grad_fn->next_functions.emplace_back(
+ input.grad_fn() ? input.grad_fn() : input.grad_accumulator(),
+ input.output_nr());
+ }
+ }
+
captured_vars.clear();
captured_handles.clear();
outputs.clear();
diff --git a/torch/csrc/autograd/functions/jit_closure.h b/torch/csrc/autograd/functions/jit_closure.h
index 25f2ca2..6d905e6 100644
--- a/torch/csrc/autograd/functions/jit_closure.h
+++ b/torch/csrc/autograd/functions/jit_closure.h
@@ -28,7 +28,7 @@
virtual variable_list apply(const variable_list& inputs) override;
private:
- AutogradClosure(const std::shared_ptr<MultiStageClosure>& desc, std::size_t stage, FunctionFlags&& f);
+ AutogradClosure(const std::shared_ptr<MultiStageClosure>& desc, std::size_t stage);
variable_list rewrapInputs(const variable_list& inputs);
diff --git a/torch/csrc/autograd/python_function.cpp b/torch/csrc/autograd/python_function.cpp
index b6335ed..29a7a41 100644
--- a/torch/csrc/autograd/python_function.cpp
+++ b/torch/csrc/autograd/python_function.cpp
@@ -711,7 +711,7 @@
sel->inferTypeFrom(output.data());
tracer::setValueTrace(tracing_state, output, sel);
}
- this_expr->i_(k__inplace, is_inplace);
+ this_expr->i_(kinplace, is_inplace);
// See definition in function.cpp.
THPObjectPtr passes_py_bool {PyObject_GetAttrString(op_obj, "is_traceable")};
diff --git a/torch/csrc/jit/interned_strings.h b/torch/csrc/jit/interned_strings.h
index 0836295..f6542c7 100644
--- a/torch/csrc/jit/interned_strings.h
+++ b/torch/csrc/jit/interned_strings.h
@@ -64,7 +64,7 @@
_(shape) \
_(axes) \
_(group) \
-_(__inplace)
+_(inplace)
enum BuiltinSymbol {
#define DEFINE_SYMBOL(s) \
diff --git a/torch/csrc/jit/ir.cpp b/torch/csrc/jit/ir.cpp
index 0420ad1..45b9408 100644
--- a/torch/csrc/jit/ir.cpp
+++ b/torch/csrc/jit/ir.cpp
@@ -177,7 +177,17 @@
case AttributeKind::t:
{
at::Tensor t = n->t(name);
- if (t.numel() <= max_tensor_display_size) {
+ // 1-elem tensors are usually boxed scalars, so print them like it
+ if (t.numel() == 1) {
+ auto scalar = at::Scalar(t.view({})).local();
+ out << "{";
+ if (scalar.isFloatingPoint()) {
+ out << scalar.toDouble();
+ } else {
+ out << scalar.toLong();
+ }
+ out << "}";
+ } else if (t.numel() <= max_tensor_display_size) {
// TODO: This is awful code. Also it doesn't work on Windows.
std::ostringstream tensor_ss;
tensor_ss << t;
diff --git a/torch/csrc/jit/passes/common_subexpression_elimination.cpp b/torch/csrc/jit/passes/common_subexpression_elimination.cpp
index 93d9b08..669edba 100644
--- a/torch/csrc/jit/passes/common_subexpression_elimination.cpp
+++ b/torch/csrc/jit/passes/common_subexpression_elimination.cpp
@@ -8,6 +8,16 @@
namespace torch { namespace jit {
+namespace {
+
+bool tensorEqual(const at::Tensor& lhs, const at::Tensor& rhs) {
+ return &lhs.type() == &rhs.type() && lhs.equal(rhs);
+}
+
+bool tensorListEqual(const std::vector<at::Tensor>& lhs, const std::vector<at::Tensor>& rhs) {
+ if (lhs.size() != rhs.size()) return false;
+ return std::equal(lhs.begin(), lhs.end(), rhs.begin(), tensorEqual);
+};
// Check whether two nodes have the same attributes in CSE.
@@ -24,6 +34,8 @@
auto lnames = lhs->attributeNames();
auto rnames = rhs->attributeNames();
+ std::sort(lnames.begin(), lnames.end());
+ std::sort(rnames.begin(), rnames.end());
if (lnames != rnames) return false;
for (auto name : lnames) {
@@ -40,8 +52,13 @@
COMPARE_ATTRIBUTEVALUE(is)
COMPARE_ATTRIBUTEVALUE(s)
COMPARE_ATTRIBUTEVALUE(ss)
+ case AttributeKind::t:
+ if (!tensorEqual(lhs->t(name), rhs->t(name))) return false;
+ break;
+ case AttributeKind::ts:
+ if (!tensorListEqual(lhs->ts(name), rhs->ts(name))) return false;
default:
- // NB: Comparison of nodes with tensor(s) or graph(s) will return false.
+ // NB: Comparison of nodes with graph(s) will return false.
return false;
}
@@ -92,6 +109,8 @@
}
};
+} // anonymous namespace
+
// The function implements common subexpression elimination.
// Since the nodes are visited in topological order, one pass is enough.
void EliminateCommonSubexpression(std::shared_ptr<Graph>& graph) {
diff --git a/torch/csrc/jit/tracer.cpp b/torch/csrc/jit/tracer.cpp
index 89f298a..f566c9a 100644
--- a/torch/csrc/jit/tracer.cpp
+++ b/torch/csrc/jit/tracer.cpp
@@ -56,7 +56,7 @@
setValueTrace(tracing_state, input, input_node);
input_node->inferTypeFrom(input.data());
}
- tracing_state->var_flags.at(graph->stage()) = detail::getVarFlags(inputs);
+ tracing_state->var_flags.at(graph->stage()).first = detail::getVarFlags(inputs);
}
void exitTrace(const variable_list& inputs, const variable_list& outputs) {
diff --git a/torch/csrc/jit/tracer.h b/torch/csrc/jit/tracer.h
index 62e283d..ce832e3 100644
--- a/torch/csrc/jit/tracer.h
+++ b/torch/csrc/jit/tracer.h
@@ -200,7 +200,7 @@
}
}
// TODO: this might not work with the way we handle buffers
- state->var_flags[0] = detail::getVarFlags(inputs);
+ state->var_flags[0].first = detail::getVarFlags(inputs);
state->active = true;
state->inputs = inputs;
return state;
@@ -214,6 +214,7 @@
state->graph->registerOutput(getValueTrace(state, output, true));
}
state->active = false;
+ state->var_flags[state->graph->stage()].second = detail::getVarFlags(outputs);
}
// Marks a backwards subgraph that should be traced as the next stage.
diff --git a/torch/csrc/jit/tracer_state.h b/torch/csrc/jit/tracer_state.h
index 8b383ca..3c2c32d 100644
--- a/torch/csrc/jit/tracer_state.h
+++ b/torch/csrc/jit/tracer_state.h
@@ -64,7 +64,8 @@
// TODO: Perhaps, turn this into an owning reference. The buffers
// are persistent, so this won't lead to a leak.
std::unordered_map<void*, Node*> buffer_map;
- std::vector<std::vector<VariableFlags>> var_flags;
+ // A pair of (input_flags, output_flags) for each stage
+ std::vector<std::pair<std::vector<VariableFlags>, std::vector<VariableFlags>>> var_flags;
std::vector<function_list> output_edges;
std::mutex mutex;
diff --git a/torch/jit/__init__.py b/torch/jit/__init__.py
index 26b8591..1d7f442 100644
--- a/torch/jit/__init__.py
+++ b/torch/jit/__init__.py
@@ -469,7 +469,6 @@
# It's important to always run DCE, because backward can create a lot of unnecessary nodes
_run_pass(torch._C._jit_pass_dce, complete_trace)
- _run_pass(torch._C._jit_pass_onnx, complete_trace)
_run_pass(_passes._check_inplace, complete_trace)
if self.optimize:
_run_pass(torch._C._jit_pass_fuse, complete_trace)
diff --git a/torch/jit/passes/inplace.py b/torch/jit/passes/inplace.py
index 83246cd..0ea9109 100644
--- a/torch/jit/passes/inplace.py
+++ b/torch/jit/passes/inplace.py
@@ -7,5 +7,5 @@
graph = trace.graph()
for node in graph.nodes():
if node.kind() == 'PythonOp':
- if node.i('__inplace'):
+ if node.i('inplace'):
raise RuntimeError("inplace {} not supported in the JIT".format(node.pyname()))