Auto-batching IR transformation for control flow (#9392)
Summary:
Implement IR transformation for control flow
- `prim::Constant`: clone to new graph directly
- `prim::NumToTensor`: create a `BatchTensor` from output tensor with `batch_size = 1`
- `prim::TensorToNum`: clone to new graph
- `prim::ListConstruct`: clone to new graph
- `prim::If`: execute both `if_block` and `else_block` and combine results from them using `cond`
- `prim::Loop`:
- for loop
- while loop: change while `cond` to `cond_any`, use `cond` to update outputs
test case: hand-written LSTM, greedy search, beam search
Pull Request resolved: https://github.com/pytorch/pytorch/pull/9392
Differential Revision: D8822369
Pulled By: ChunliF
fbshipit-source-id: 8f03c95757d32e8c4580eeab3974fd1bc429a1e5
diff --git a/test/expect/TestBatched.test_for.expect b/test/expect/TestBatched.test_for.expect
new file mode 100644
index 0000000..bcbcffa
--- /dev/null
+++ b/test/expect/TestBatched.test_for.expect
@@ -0,0 +1,22 @@
+graph(%x.1_data : Dynamic
+ %x.1_mask : Dynamic
+ %x.1_dims : Dynamic
+ %y_data : Dynamic
+ %y_mask : Dynamic
+ %y_dims : Dynamic) {
+ %6 : int = prim::Constant[value=10]()
+ %7 : int = prim::Constant[value=1]()
+ %x : Dynamic, %21 : Dynamic, %22 : Dynamic = prim::Loop(%6, %7, %x.1_data, %x.1_mask, %x.1_dims)
+ block0(%loop_num : int, %5_data : Dynamic, %5_mask : Dynamic, %5_dims : Dynamic) {
+ %13 : int = prim::Constant[value=1]()
+ %14 : Long() = prim::NumToTensor(%13)
+ %alpha : float = prim::TensorToNum(%14)
+ %data.1 : Dynamic = aten::add(%5_data, %y_data, %alpha)
+ %mask : Dynamic = aten::mul(%5_mask, %y_mask)
+ %dims : Dynamic = aten::__or__(%5_dims, %y_dims)
+ %19 : int = prim::Constant[value=1]()
+ %data : Dynamic = aten::where(%mask, %data.1, %5_data)
+ -> (%19, %data, %mask, %dims)
+ }
+ return (%x, %21, %22);
+}
diff --git a/test/expect/TestBatched.test_if_else.expect b/test/expect/TestBatched.test_if_else.expect
new file mode 100644
index 0000000..0698584
--- /dev/null
+++ b/test/expect/TestBatched.test_if_else.expect
@@ -0,0 +1,52 @@
+graph(%a.1_data : Dynamic
+ %a.1_mask : Dynamic
+ %a.1_dims : Dynamic
+ %b_data : Dynamic
+ %b_mask : Dynamic
+ %b_dims : Dynamic) {
+ %6 : Dynamic = aten::gt(%a.1_data, %b_data)
+ %7 : Dynamic = aten::mul(%a.1_mask, %b_mask)
+ %8 : Dynamic = aten::__or__(%a.1_dims, %b_dims)
+ %9 : int = prim::TensorToNum(%6)
+ %10 : int = prim::Constant[value=1]()
+ %11 : Long() = prim::NumToTensor(%10)
+ %alpha.1 : float = prim::TensorToNum(%11)
+ %data.1 : Dynamic = aten::add(%a.1_data, %b_data, %alpha.1)
+ %mask.1 : Dynamic = aten::mul(%a.1_mask, %b_mask)
+ %dims.1 : Dynamic = aten::__or__(%a.1_dims, %b_dims)
+ %16 : int = prim::Constant[value=1]()
+ %17 : Long() = prim::NumToTensor(%16)
+ %alpha : float = prim::TensorToNum(%17)
+ %data.4 : Dynamic = aten::sub(%a.1_data, %b_data, %alpha)
+ %mask : Dynamic = aten::mul(%a.1_mask, %b_mask)
+ %dims : Dynamic = aten::__or__(%a.1_dims, %b_dims)
+ %22 : Dynamic = aten::type_as(%7, %6)
+ %cond_mask.1 : Dynamic = aten::mul(%6, %22)
+ %24 : int = aten::dim(%cond_mask.1)
+ %25 : int = prim::Constant[value=1]()
+ %26 : int = aten::eq(%24, %25)
+ %cond_data : Dynamic, %cond_mask : Dynamic, %data : Dynamic = prim::If(%26)
+ block0() {
+ %30 : int = aten::dim(%data.1)
+ %31 : int = prim::Constant[value=1]()
+ %32 : int = aten::sub(%30, %31)
+ %33 : int = prim::Constant[value=1]()
+ %data.3 : Dynamic = prim::Loop(%32, %33, %cond_mask.1)
+ block0(%_ : int, %36 : Dynamic) {
+ %37 : int = aten::dim(%36)
+ %data.2 : Dynamic = aten::unsqueeze(%36, %37)
+ %39 : int = prim::Constant[value=1]()
+ -> (%39, %data.2)
+ }
+ %cond_data.1 : Dynamic = aten::expand_as(%data.3, %data.1)
+ %cond_mask.2 : Dynamic = aten::expand_as(%data.3, %mask.1)
+ -> (%cond_data.1, %cond_mask.2, %data.3)
+ }
+ block1() {
+ -> (%cond_mask.1, %cond_mask.1, %cond_mask.1)
+ }
+ %res_data : Dynamic = aten::where(%cond_data, %data.1, %data.4)
+ %res_mask : Dynamic = aten::where(%cond_mask, %mask.1, %mask)
+ %res_dims : Dynamic = aten::__or__(%dims.1, %dims)
+ return (%res_data, %res_mask, %res_dims);
+}
diff --git a/test/expect/TestBatched.test_if_else_with_scalar.expect b/test/expect/TestBatched.test_if_else_with_scalar.expect
new file mode 100644
index 0000000..c7755a5
--- /dev/null
+++ b/test/expect/TestBatched.test_if_else_with_scalar.expect
@@ -0,0 +1,53 @@
+graph(%a.1_data : Dynamic
+ %a.1_mask : Dynamic
+ %a.1_dims : Dynamic
+ %b_data : Dynamic
+ %b_mask : Dynamic
+ %b_dims : Dynamic) {
+ %6 : float = prim::Constant[value=0.1]()
+ %7 : Float() = prim::NumToTensor(%6)
+ %other : float = prim::TensorToNum(%7)
+ %9 : Dynamic = aten::gt(%a.1_data, %other)
+ %10 : int = prim::TensorToNum(%9)
+ %11 : int = prim::Constant[value=1]()
+ %12 : Long() = prim::NumToTensor(%11)
+ %alpha.1 : float = prim::TensorToNum(%12)
+ %data.1 : Dynamic = aten::add(%a.1_data, %b_data, %alpha.1)
+ %mask.1 : Dynamic = aten::mul(%a.1_mask, %b_mask)
+ %dims.1 : Dynamic = aten::__or__(%a.1_dims, %b_dims)
+ %17 : int = prim::Constant[value=1]()
+ %18 : Long() = prim::NumToTensor(%17)
+ %alpha : float = prim::TensorToNum(%18)
+ %data.4 : Dynamic = aten::sub(%a.1_data, %b_data, %alpha)
+ %mask : Dynamic = aten::mul(%a.1_mask, %b_mask)
+ %dims : Dynamic = aten::__or__(%a.1_dims, %b_dims)
+ %23 : Dynamic = aten::type_as(%a.1_mask, %9)
+ %cond_mask.1 : Dynamic = aten::mul(%9, %23)
+ %25 : int = aten::dim(%cond_mask.1)
+ %26 : int = prim::Constant[value=1]()
+ %27 : int = aten::eq(%25, %26)
+ %cond_data : Dynamic, %cond_mask : Dynamic, %data : Dynamic = prim::If(%27)
+ block0() {
+ %31 : int = aten::dim(%data.1)
+ %32 : int = prim::Constant[value=1]()
+ %33 : int = aten::sub(%31, %32)
+ %34 : int = prim::Constant[value=1]()
+ %data.3 : Dynamic = prim::Loop(%33, %34, %cond_mask.1)
+ block0(%_ : int, %37 : Dynamic) {
+ %38 : int = aten::dim(%37)
+ %data.2 : Dynamic = aten::unsqueeze(%37, %38)
+ %40 : int = prim::Constant[value=1]()
+ -> (%40, %data.2)
+ }
+ %cond_data.1 : Dynamic = aten::expand_as(%data.3, %data.1)
+ %cond_mask.2 : Dynamic = aten::expand_as(%data.3, %mask.1)
+ -> (%cond_data.1, %cond_mask.2, %data.3)
+ }
+ block1() {
+ -> (%cond_mask.1, %cond_mask.1, %cond_mask.1)
+ }
+ %res_data : Dynamic = aten::where(%cond_data, %data.1, %data.4)
+ %res_mask : Dynamic = aten::where(%cond_mask, %mask.1, %mask)
+ %res_dims : Dynamic = aten::__or__(%dims.1, %dims)
+ return (%res_data, %res_mask, %res_dims);
+}
diff --git a/test/expect/TestBatched.test_if_noelse.expect b/test/expect/TestBatched.test_if_noelse.expect
new file mode 100644
index 0000000..1d98fe9
--- /dev/null
+++ b/test/expect/TestBatched.test_if_noelse.expect
@@ -0,0 +1,46 @@
+graph(%a.1_data : Dynamic
+ %a.1_mask : Dynamic
+ %a.1_dims : Dynamic
+ %b_data : Dynamic
+ %b_mask : Dynamic
+ %b_dims : Dynamic) {
+ %6 : Dynamic = aten::gt(%a.1_data, %b_data)
+ %7 : Dynamic = aten::mul(%a.1_mask, %b_mask)
+ %8 : Dynamic = aten::__or__(%a.1_dims, %b_dims)
+ %9 : int = prim::TensorToNum(%6)
+ %10 : int = prim::Constant[value=1]()
+ %11 : Long() = prim::NumToTensor(%10)
+ %alpha : float = prim::TensorToNum(%11)
+ %data.1 : Dynamic = aten::add(%a.1_data, %b_data, %alpha)
+ %mask : Dynamic = aten::mul(%a.1_mask, %b_mask)
+ %dims : Dynamic = aten::__or__(%a.1_dims, %b_dims)
+ %16 : Dynamic = aten::type_as(%7, %6)
+ %cond_mask.1 : Dynamic = aten::mul(%6, %16)
+ %18 : int = aten::dim(%cond_mask.1)
+ %19 : int = prim::Constant[value=1]()
+ %20 : int = aten::eq(%18, %19)
+ %cond_data : Dynamic, %cond_mask : Dynamic, %data : Dynamic = prim::If(%20)
+ block0() {
+ %24 : int = aten::dim(%data.1)
+ %25 : int = prim::Constant[value=1]()
+ %26 : int = aten::sub(%24, %25)
+ %27 : int = prim::Constant[value=1]()
+ %data.3 : Dynamic = prim::Loop(%26, %27, %cond_mask.1)
+ block0(%_ : int, %30 : Dynamic) {
+ %31 : int = aten::dim(%30)
+ %data.2 : Dynamic = aten::unsqueeze(%30, %31)
+ %33 : int = prim::Constant[value=1]()
+ -> (%33, %data.2)
+ }
+ %cond_data.1 : Dynamic = aten::expand_as(%data.3, %data.1)
+ %cond_mask.2 : Dynamic = aten::expand_as(%data.3, %mask)
+ -> (%cond_data.1, %cond_mask.2, %data.3)
+ }
+ block1() {
+ -> (%cond_mask.1, %cond_mask.1, %cond_mask.1)
+ }
+ %res_data : Dynamic = aten::where(%cond_data, %data.1, %a.1_data)
+ %res_mask : Dynamic = aten::where(%cond_mask, %mask, %a.1_mask)
+ %res_dims : Dynamic = aten::__or__(%dims, %a.1_dims)
+ return (%res_data, %res_mask, %res_dims);
+}
diff --git a/test/expect/TestBatched.test_if_noelse_with_scalar.expect b/test/expect/TestBatched.test_if_noelse_with_scalar.expect
new file mode 100644
index 0000000..935bedb
--- /dev/null
+++ b/test/expect/TestBatched.test_if_noelse_with_scalar.expect
@@ -0,0 +1,47 @@
+graph(%a.1_data : Dynamic
+ %a.1_mask : Dynamic
+ %a.1_dims : Dynamic
+ %b_data : Dynamic
+ %b_mask : Dynamic
+ %b_dims : Dynamic) {
+ %6 : float = prim::Constant[value=0.1]()
+ %7 : Float() = prim::NumToTensor(%6)
+ %other : float = prim::TensorToNum(%7)
+ %9 : Dynamic = aten::gt(%a.1_data, %other)
+ %10 : int = prim::TensorToNum(%9)
+ %11 : int = prim::Constant[value=1]()
+ %12 : Long() = prim::NumToTensor(%11)
+ %alpha : float = prim::TensorToNum(%12)
+ %data.1 : Dynamic = aten::add(%a.1_data, %b_data, %alpha)
+ %mask : Dynamic = aten::mul(%a.1_mask, %b_mask)
+ %dims : Dynamic = aten::__or__(%a.1_dims, %b_dims)
+ %17 : Dynamic = aten::type_as(%a.1_mask, %9)
+ %cond_mask.1 : Dynamic = aten::mul(%9, %17)
+ %19 : int = aten::dim(%cond_mask.1)
+ %20 : int = prim::Constant[value=1]()
+ %21 : int = aten::eq(%19, %20)
+ %cond_data : Dynamic, %cond_mask : Dynamic, %data : Dynamic = prim::If(%21)
+ block0() {
+ %25 : int = aten::dim(%data.1)
+ %26 : int = prim::Constant[value=1]()
+ %27 : int = aten::sub(%25, %26)
+ %28 : int = prim::Constant[value=1]()
+ %data.3 : Dynamic = prim::Loop(%27, %28, %cond_mask.1)
+ block0(%_ : int, %31 : Dynamic) {
+ %32 : int = aten::dim(%31)
+ %data.2 : Dynamic = aten::unsqueeze(%31, %32)
+ %34 : int = prim::Constant[value=1]()
+ -> (%34, %data.2)
+ }
+ %cond_data.1 : Dynamic = aten::expand_as(%data.3, %data.1)
+ %cond_mask.2 : Dynamic = aten::expand_as(%data.3, %mask)
+ -> (%cond_data.1, %cond_mask.2, %data.3)
+ }
+ block1() {
+ -> (%cond_mask.1, %cond_mask.1, %cond_mask.1)
+ }
+ %res_data : Dynamic = aten::where(%cond_data, %data.1, %a.1_data)
+ %res_mask : Dynamic = aten::where(%cond_mask, %mask, %a.1_mask)
+ %res_dims : Dynamic = aten::__or__(%dims, %a.1_dims)
+ return (%res_data, %res_mask, %res_dims);
+}
diff --git a/test/expect/TestBatched.test_while.expect b/test/expect/TestBatched.test_while.expect
new file mode 100644
index 0000000..a32cd39
--- /dev/null
+++ b/test/expect/TestBatched.test_while.expect
@@ -0,0 +1,65 @@
+graph(%a.1_data : Dynamic
+ %a.1_mask : Dynamic
+ %a.1_dims : Dynamic
+ %b_data : Dynamic
+ %b_mask : Dynamic
+ %b_dims : Dynamic) {
+ %6 : int = prim::Constant[value=2147483647]()
+ %7 : Dynamic = aten::gt(%a.1_data, %b_data)
+ %8 : Dynamic = aten::mul(%a.1_mask, %b_mask)
+ %9 : Dynamic = aten::__or__(%a.1_dims, %b_dims)
+ %10 : int = prim::TensorToNum(%7)
+ %11 : Dynamic = aten::mul(%7, %8)
+ %12 : Dynamic = aten::sum(%11)
+ %13 : int = prim::Constant[value=0]()
+ %14 : Dynamic = aten::gt(%12, %13)
+ %15 : int = prim::TensorToNum(%14)
+ %64 : Dynamic, %65 : Dynamic, %66 : Dynamic, %a : Dynamic, %62 : Dynamic, %63 : Dynamic = prim::Loop(%6, %15, %7, %8, %9, %a.1_data, %a.1_mask, %a.1_dims)
+ block0(%loop_num : int, %cond_data.2 : Dynamic, %cond_mask.3 : Dynamic, %cond_dims : Dynamic, %6_data : Dynamic, %6_mask : Dynamic, %6_dims : Dynamic) {
+ %24 : int = prim::Constant[value=1]()
+ %25 : Long() = prim::NumToTensor(%24)
+ %alpha : float = prim::TensorToNum(%25)
+ %data.1 : Dynamic = aten::sub(%6_data, %b_data, %alpha)
+ %mask : Dynamic = aten::mul(%6_mask, %b_mask)
+ %dims : Dynamic = aten::__or__(%6_dims, %b_dims)
+ %30 : Dynamic = aten::gt(%data.1, %b_data)
+ %31 : Dynamic = aten::mul(%mask, %b_mask)
+ %32 : Dynamic = aten::__or__(%dims, %b_dims)
+ %33 : int = prim::TensorToNum(%30)
+ %34 : Dynamic = aten::type_as(%cond_mask.3, %cond_data.2)
+ %cond_mask.1 : Dynamic = aten::mul(%cond_data.2, %34)
+ %36 : int = aten::dim(%cond_mask.1)
+ %37 : int = prim::Constant[value=1]()
+ %38 : int = aten::eq(%36, %37)
+ %cond_data : Dynamic, %cond_mask : Dynamic, %data : Dynamic = prim::If(%38)
+ block0() {
+ %42 : int = aten::dim(%data.1)
+ %43 : int = prim::Constant[value=1]()
+ %44 : int = aten::sub(%42, %43)
+ %45 : int = prim::Constant[value=1]()
+ %data.3 : Dynamic = prim::Loop(%44, %45, %cond_mask.1)
+ block0(%_ : int, %48 : Dynamic) {
+ %49 : int = aten::dim(%48)
+ %data.2 : Dynamic = aten::unsqueeze(%48, %49)
+ %51 : int = prim::Constant[value=1]()
+ -> (%51, %data.2)
+ }
+ %cond_data.1 : Dynamic = aten::expand_as(%data.3, %data.1)
+ %cond_mask.2 : Dynamic = aten::expand_as(%data.3, %mask)
+ -> (%cond_data.1, %cond_mask.2, %data.3)
+ }
+ block1() {
+ -> (%cond_mask.1, %cond_mask.1, %cond_mask.1)
+ }
+ %res_data : Dynamic = aten::where(%cond_data, %data.1, %6_data)
+ %res_mask : Dynamic = aten::where(%cond_mask, %mask, %6_mask)
+ %res_dims : Dynamic = aten::__or__(%dims, %6_dims)
+ %57 : Dynamic = aten::mul(%30, %31)
+ %58 : Dynamic = aten::sum(%57)
+ %59 : int = prim::Constant[value=0]()
+ %60 : Dynamic = aten::gt(%58, %59)
+ %61 : int = prim::TensorToNum(%60)
+ -> (%61, %30, %31, %32, %res_data, %res_mask, %res_dims)
+ }
+ return (%a, %62, %63);
+}
diff --git a/test/test_jit.py b/test/test_jit.py
index 6e73242..3a78dcb 100644
--- a/test/test_jit.py
+++ b/test/test_jit.py
@@ -1197,8 +1197,9 @@
# generate random examples and create an batchtensor with them
def rand_batch(self, *dims):
dims = [dim for dim in dims if dim != ()]
- xs = [torch.rand(1, *(random.randint(1, size) if b else size for b, size in dims[1:])) for i in range(dims[0])]
- xb = BatchTensor(xs, torch.tensor([b for b, d in dims[1:]]))
+ xs = [torch.rand(1, *(random.randint(1, size) if b else size for b, size in dims[1:]),
+ requires_grad=True) for i in range(dims[0])]
+ xb = BatchTensor(xs, torch.tensor([b for b, d in dims[1:]]).byte())
return xs, xb
def test_create_batchtensor(self):
@@ -1226,20 +1227,20 @@
def test_batch_elementwise_binary(self):
@torch.jit.batch(batch_size=4)
- def mul(a, b):
- return a * b
+ def add(a, b):
+ return a + b
xs, batch = self.rand_batch(4, (True, 3), (False, 2))
xs2, batch2 = xs, batch
- res_batch = mul(batch, batch2)
- res = [torch.mul(xs[j], xs2[j]) for j in range(4)]
+ res_batch = add(batch, batch2)
+ res = [torch.add(xs[j], xs2[j]) for j in range(4)]
self.assertEqual(res, res_batch.examples())
# test broadcast
xs, batch = self.rand_batch(4, (False, 3), (False, 2))
b = torch.rand(3, 2)
- res_batch = mul(batch, b)
- res = [torch.mul(xs[j], b) for j in range(4)]
+ res_batch = add(batch, b)
+ res = [torch.add(xs[j], b) for j in range(4)]
self.assertEqual(res, res_batch.examples())
def test_batch_mm(self):
@@ -1286,6 +1287,33 @@
xs2, batch2 = self.rand_batch(4, (False, 2), (True, 3))
matmul_test(xs, batch, xs2, batch2)
+ def test_batch_select(self):
+ @torch.jit.batch(batch_size=4)
+ def select(x):
+ return torch.select(x, 1, 0)
+
+ xs, batch = self.rand_batch(4, (True, 3), (True, 2))
+ res_batch = select(batch)
+ res = [torch.select(xs[j], 1, 0) for j in range(4)]
+ self.assertEqual(res, res_batch.examples())
+
+ xs, batch = self.rand_batch(4, (False, 3), (True, 2))
+ res_batch = select(batch)
+ res = [torch.select(xs[j], 1, 0) for j in range(4)]
+ self.assertEqual(res, res_batch.examples())
+
+ def test_batch_index_select(self):
+ @torch.jit.batch(batch_size=4)
+ def index_select(x, ind):
+ return x.index_select(1, ind)
+
+ xs, batch = self.rand_batch(4, (False, 5), (True, 2))
+ ind = [torch.randint(0, 4, (1,), dtype=torch.long) for i in range(4)]
+ ind_batch = BatchTensor(ind, torch.tensor([]).byte())
+ res_batch = index_select(batch, ind_batch)
+ res = [torch.index_select(xs[j], 1, ind[j]) for j in range(4)]
+ self.assertEqual(res, res_batch.examples())
+
def test_batch_where(self):
@torch.jit.batch(batch_size=4)
def where(c, a, b):
@@ -1302,41 +1330,238 @@
res = [torch.where(xs_cond[j], xs[j], xs2[j]) for j in range(4)]
self.assertEqual(res, res_batch.examples())
- @unittest.skip("Need support for scalar arguments")
- def test_lstm_cell(self):
- def LSTMCell(x, h, c, w_xi, w_xf, w_xo, w_xc, w_hi, w_hf, w_ho, w_hc, b_i, b_f, b_o, b_c):
- i_t = torch.matmul(x, w_xi) + torch.matmul(h, w_hi) + b_i
- f_t = torch.matmul(x, w_xf) + torch.matmul(h, w_hf) + b_f
- o_t = torch.matmul(x, w_xo) + torch.matmul(h, w_ho) + b_o
- # activations
- i_t = torch.sigmoid(i_t)
- f_t = torch.sigmoid(f_t)
- o_t = torch.sigmoid(o_t)
- # cell computations
- c_t = torch.matmul(x, w_xc) + torch.matmul(h, w_hc) + b_c
- c_t = torch.tanh(c_t)
- c_t = torch.mul(c, f_t) + torch.mul(i_t, c_t)
- h_t = torch.mul(o_t, torch.tanh(c_t))
- return h_t
+ def test_batch_argmax(self):
+ @torch.jit.batch(batch_size=4)
+ def argmax(a):
+ return torch.argmax(a, 1)
+
+ xs, batch = self.rand_batch(4, (True, 5), (True, 6))
+ res_batch = argmax(batch)
+ res = [torch.argmax(xs[j], 1) for j in range(4)]
+ self.assertEqual(res, res_batch.examples())
@torch.jit.batch(batch_size=4)
- def LSTMCell_batch(x, h, c, w_xi, w_xf, w_xo, w_xc, w_hi, w_hf, w_ho, w_hc, b_i, b_f, b_o, b_c):
- i_t = torch.matmul(x, w_xi) + torch.matmul(h, w_hi) + b_i
- f_t = torch.matmul(x, w_xf) + torch.matmul(h, w_hf) + b_f
- o_t = torch.matmul(x, w_xo) + torch.matmul(h, w_ho) + b_o
- # activations
- i_t = torch.sigmoid(i_t)
- f_t = torch.sigmoid(f_t)
- o_t = torch.sigmoid(o_t)
- # cell computations
- c_t = torch.matmul(x, w_xc) + torch.matmul(h, w_hc) + b_c
- c_t = torch.tanh(c_t)
- c_t = torch.mul(c, f_t) + torch.mul(i_t, c_t)
- h_t = torch.mul(o_t, torch.tanh(c_t))
- return h_t
+ def argmax(a):
+ return torch.argmax(a, 1, False)
+
+ res_batch = argmax(batch)
+ res = [torch.argmax(xs[j], 1, False) for j in range(4)]
+ self.assertEqual(res, res_batch.examples())
+
+ def test_batch_topk(self):
+ @torch.jit.batch(batch_size=4)
+ def topk(a):
+ return torch.topk(a, 3, 1)
+
+ xs, batch = self.rand_batch(4, (False, 5), (True, 6))
+
+ # along static dim
+ res_batch = topk(batch)
+ res = [torch.topk(xs[j], 3, 1)[0] for j in range(4)]
+ res_idx = [torch.topk(xs[j], 3, 1)[1] for j in range(4)]
+ self.assertEqual(res, res_batch[0].examples())
+ self.assertEqual(res_idx, res_batch[1].examples())
+
+ @torch.jit.batch(batch_size=4)
+ def topk(a):
+ return torch.topk(a, 1, 2)
+
+ # along dynamic dim
+ res_batch = topk(batch)
+ res = [torch.topk(xs[j], 1, 2)[0] for j in range(4)]
+ res_idx = [torch.topk(xs[j], 1, 2)[1] for j in range(4)]
+ self.assertEqual(res, res_batch[0].examples())
+ self.assertEqual(res_idx, res_batch[1].examples())
+
+ def test_batch_softmax(self):
+ @torch.jit.batch(batch_size=4)
+ def softmax(a):
+ return torch.softmax(a, 1)
+
+ xs, batch = self.rand_batch(4, (False, 5), (True, 6))
+
+ # along static dim
+ res_batch = softmax(batch)
+ res = [torch.softmax(xs[j], 1) for j in range(4)]
+ self.assertEqual(res, res_batch.examples())
+
+ @torch.jit.batch(batch_size=4)
+ def softmax(a):
+ return torch.softmax(a, 2)
+
+ # along dynamic dim
+ res_batch = softmax(batch)
+ res = [torch.softmax(xs[j], 2) for j in range(4)]
+ self.assertEqual(res, res_batch.examples())
+
+ def test_batch_view(self):
+ @torch.jit.batch(batch_size=4)
+ def view(a):
+ return a.view([4, -1, 3])
+
+ xs, batch = self.rand_batch(4, (True, 5), (False, 3))
+ res_batch = view(batch)
+ res = [xs[j].view([1, -1, 3]) for j in range(4)]
+ self.assertEqual(res, res_batch.examples())
+
+ def test_batch_cat(self):
+ @torch.jit.batch(batch_size=4)
+ def cat2(a, b):
+ return torch.cat([a, b], 2)
+
+ xs, batch = self.rand_batch(4, (True, 5), (False, 3))
+ xs2, batch2 = xs, batch
+ res_batch = cat2(batch, batch2)
+ res = [torch.cat([xs[j], xs2[j]], 2) for j in range(4)]
+ self.assertEqual(res, res_batch.examples())
+
+ def test_batch_sum(self):
+ @torch.jit.batch(batch_size=4)
+ def batch_sum(a):
+ return a.sum()
+
+ xs, batch = self.rand_batch(4, (True, 5), (False, 3))
+ res_batch = batch_sum(batch)
+ res = [xs[j].sum().unsqueeze(0) for j in range(4)]
+ self.assertEqual(res, res_batch.examples())
+
+ def test_if_else(self):
+ def single_if(a, b):
+ if a > b:
+ a = a + b
+ else:
+ a = a - b
+ return a
+
+ batch_if = torch.jit.batch(batch_size=4)(single_if)
+
+ a, batch_a = self.rand_batch(4, ())
+ b, batch_b = self.rand_batch(4, ())
+ res_batch = batch_if(batch_a, batch_b)
+ res = [single_if(a[j], b[j]) for j in range(4)]
+ self.assertEqual(res, res_batch.examples())
+
+ script_if = torch.jit.script(single_if)
+ graph = torch.to_batch_graph(script_if.graph)
+ self.assertExpected(str(graph))
+
+ def test_if_else_with_scalar(self):
+ def single_if(a, b):
+ if a > 0.1:
+ a = a + b
+ else:
+ a = a - b
+ return a
+
+ batch_if = torch.jit.batch(batch_size=4)(single_if)
+
+ a, batch_a = self.rand_batch(4, ())
+ b, batch_b = self.rand_batch(4, ())
+ res_batch = batch_if(batch_a, batch_b)
+ res = [single_if(a[j], b[j]) for j in range(4)]
+ self.assertEqual(res, res_batch.examples())
+
+ script_if = torch.jit.script(single_if)
+ graph = torch.to_batch_graph(script_if.graph)
+ self.assertExpected(str(graph))
+
+ def test_if_noelse(self):
+ def single_if(a, b):
+ if a > b:
+ a = a + b
+ return a
+
+ batch_if = torch.jit.batch(batch_size=4)(single_if)
+
+ a, batch_a = self.rand_batch(4, ())
+ b, batch_b = self.rand_batch(4, ())
+ res_batch = batch_if(batch_a, batch_b)
+ res = [single_if(a[j], b[j]) for j in range(4)]
+ self.assertEqual(res, res_batch.examples())
+
+ script_if = torch.jit.script(single_if)
+ graph = torch.to_batch_graph(script_if.graph)
+ self.assertExpected(str(graph))
+
+ def test_if_noelse_with_scalar(self):
+ def single_if(a, b):
+ if a > 0.1:
+ a = a + b
+ return a
+
+ batch_if = torch.jit.batch(batch_size=4)(single_if)
+
+ a, batch_a = self.rand_batch(4, ())
+ b, batch_b = self.rand_batch(4, ())
+ res_batch = batch_if(batch_a, batch_b)
+ res = [single_if(a[j], b[j]) for j in range(4)]
+ self.assertEqual(res, res_batch.examples())
+
+ script_if = torch.jit.script(single_if)
+ graph = torch.to_batch_graph(script_if.graph)
+ self.assertExpected(str(graph))
+
+ def test_while(self):
+ def single_while(a, b):
+ while a > b:
+ a = a - b
+ return a
+
+ batch_while = torch.jit.batch(batch_size=4)(single_while)
+
+ a, batch_a = self.rand_batch(4, ())
+ b = [torch.abs(torch.rand(1)) for i in range(4)]
+ batch_b = BatchTensor(b, torch.tensor([]).byte())
+ res_batch = batch_while(batch_a, batch_b)
+ res = [single_while(a[j], b[j]) for j in range(4)]
+ self.assertEqual(res, res_batch.examples())
+
+ script_while = torch.jit.script(single_while)
+ graph = torch.to_batch_graph(script_while.graph)
+ self.assertExpected(str(graph))
+
+ def test_for(self):
+ def single_for(x, y):
+ for _ in range(10):
+ x = x + y
+ return x
+
+ batch_for = torch.jit.batch(batch_size=4)(single_for)
+
+ a, batch_a = self.rand_batch(4, ())
+ b, batch_b = self.rand_batch(4, ())
+ res_batch = batch_for(batch_a, batch_b)
+ res = [single_for(a[j], b[j]) for j in range(4)]
+ self.assertEqual(res, res_batch.examples())
+
+ script_for = torch.jit.script(single_for)
+ graph = torch.to_batch_graph(script_for.graph)
+ self.assertExpected(str(graph))
+
+ def test_lstm(self):
+ def LSTM(x_all, h, c, w_xi, w_xf, w_xo, w_xc, w_hi, w_hf, w_ho, w_hc, b_i, b_f, b_o, b_c):
+ for i in range(x_all.size(1)):
+ x = x_all.select(1, i)
+ i_t = torch.matmul(x, w_xi) + torch.matmul(h, w_hi) + b_i
+ f_t = torch.matmul(x, w_xf) + torch.matmul(h, w_hf) + b_f
+ o_t = torch.matmul(x, w_xo) + torch.matmul(h, w_ho) + b_o
+ # activations
+ i_t = torch.sigmoid(i_t)
+ f_t = torch.sigmoid(f_t)
+ o_t = torch.sigmoid(o_t)
+ # cell computations
+ c_t = torch.matmul(x, w_xc) + torch.matmul(h, w_hc) + b_c
+ c_t = torch.tanh(c_t)
+ c_t = torch.mul(c_t, f_t) + torch.mul(i_t, c_t)
+ h_t = torch.mul(o_t, torch.tanh(c_t))
+ h = h_t
+ c = c_t
+ return h
+
+ LSTM_batch = torch.jit.batch(batch_size=4)(LSTM)
batch_size, input_size, hidden_size = 4, 3, 2
- xs, batch = self.rand_batch(batch_size, (False, input_size))
+ xs, batch = self.rand_batch(batch_size, (True, 4), (False, input_size))
hx, h_batch = self.rand_batch(batch_size, (False, hidden_size))
cx, c_batch = self.rand_batch(batch_size, (False, hidden_size))
@@ -1356,10 +1581,161 @@
b_o = torch.rand(hidden_size)
b_c = torch.rand(hidden_size)
- ys = [LSTMCell(xs[j].squeeze(0), hx[j], cx[j], w_xi, w_xf, w_xo, w_xc,
- w_hi, w_hf, w_ho, w_hc, b_i, b_f, b_o, b_c) for j in range(batch_size)]
- ybs = LSTMCell_batch(batch, h_batch, c_batch, w_xi, w_xf, w_xo, w_xc,
- w_hi, w_hf, w_ho, w_hc, b_i, b_f, b_o, b_c)
+ ys = [LSTM(xs[j], hx[j], cx[j], w_xi, w_xf, w_xo, w_xc,
+ w_hi, w_hf, w_ho, w_hc, b_i, b_f, b_o, b_c) for j in range(batch_size)]
+ ybs = LSTM_batch(batch, h_batch, c_batch, w_xi, w_xf, w_xo, w_xc,
+ w_hi, w_hf, w_ho, w_hc, b_i, b_f, b_o, b_c)
+ self.assertEqual(ys, ybs.examples())
+
+ def test_greedy_search(self):
+ def greedy(x, h, c, embed, w_xi, w_xf, w_xo, w_xc, w_hi, w_hf, w_ho, w_hc,
+ b_i, b_f, b_o, b_c, w_hs, b_s, iter_num):
+ iter_count = torch.zeros_like(iter_num)
+ while(iter_count < iter_num):
+ iter_count += 1
+ # LSTM Cell
+ i_t = torch.matmul(x, w_xi) + torch.matmul(h, w_hi) + b_i
+ f_t = torch.matmul(x, w_xf) + torch.matmul(h, w_hf) + b_f
+ o_t = torch.matmul(x, w_xo) + torch.matmul(h, w_ho) + b_o
+ # activations
+ i_t = torch.sigmoid(i_t)
+ f_t = torch.sigmoid(f_t)
+ o_t = torch.sigmoid(o_t)
+ # cell computations
+ c_t = torch.matmul(x, w_xc) + torch.matmul(h, w_hc) + b_c
+ c_t = torch.tanh(c_t)
+ c_t = torch.mul(c_t, f_t) + torch.mul(i_t, c_t)
+ h_t = torch.mul(o_t, torch.tanh(c_t))
+ h = h_t
+ c = c_t
+ # calculate feature with max probability
+ s_t = torch.matmul(h_t, w_hs) + b_s
+ p_t = torch.softmax(s_t, 1)
+ i_t = torch.argmax(p_t, 1)
+ x = embed.index_select(1, i_t).squeeze(1)
+ return h
+
+ greedy_batch = torch.jit.batch(batch_size=4)(greedy)
+
+ batch_size, input_size, hidden_size, vocab_size = 4, 6, 8, 7
+ xs, batch = self.rand_batch(batch_size, (False, input_size))
+ hx, h_batch = self.rand_batch(batch_size, (False, hidden_size))
+ cx, c_batch = self.rand_batch(batch_size, (False, hidden_size))
+ embed, embed_batch = self.rand_batch(batch_size, (False, vocab_size), (False, input_size))
+ iter_num = [torch.randint(2, 5, (1,)) for i in range(batch_size)]
+ iter_num_batch = BatchTensor(iter_num, torch.tensor([]).byte())
+
+ # input to hidden weights
+ w_xi = torch.rand(input_size, hidden_size)
+ w_xf = torch.rand(input_size, hidden_size)
+ w_xo = torch.rand(input_size, hidden_size)
+ w_xc = torch.rand(input_size, hidden_size)
+ # hidden to hidden weights
+ w_hi = torch.rand(hidden_size, hidden_size)
+ w_hf = torch.rand(hidden_size, hidden_size)
+ w_ho = torch.rand(hidden_size, hidden_size)
+ w_hc = torch.rand(hidden_size, hidden_size)
+ # bias terms
+ b_i = torch.rand(hidden_size)
+ b_f = torch.rand(hidden_size)
+ b_o = torch.rand(hidden_size)
+ b_c = torch.rand(hidden_size)
+ # hidden to vocab weights, bias
+ w_hs = torch.rand(hidden_size, vocab_size)
+ b_s = torch.rand(vocab_size)
+
+ ys = [greedy(xs[j], hx[j], cx[j], embed[j], w_xi, w_xf, w_xo, w_xc,
+ w_hi, w_hf, w_ho, w_hc, b_i, b_f, b_o, b_c, w_hs, b_s, iter_num[j]) for j in range(batch_size)]
+ ybs = greedy_batch(batch, h_batch, c_batch, embed_batch, w_xi, w_xf, w_xo, w_xc,
+ w_hi, w_hf, w_ho, w_hc, b_i, b_f, b_o, b_c, w_hs, b_s, iter_num_batch)
+ self.assertEqual(ys, ybs.examples())
+
+ def test_beam_search(self):
+ def beam(x, h, c, embed, w_xi, w_xf, w_xo, w_xc, w_hi, w_hf, w_ho, w_hc,
+ b_i, b_f, b_o, b_c, w_hs, b_s, iter_num, idx):
+ k = 5
+ vocab_size = embed.size(1)
+ iter_count = torch.zeros_like(iter_num)
+ max_len = idx.size(2)
+ while(iter_count < iter_num):
+ iter_count += 1
+ # LSTM Cell
+ i_t = torch.matmul(x, w_xi) + torch.matmul(h, w_hi) + b_i
+ f_t = torch.matmul(x, w_xf) + torch.matmul(h, w_hf) + b_f
+ o_t = torch.matmul(x, w_xo) + torch.matmul(h, w_ho) + b_o
+ # activations
+ i_t = torch.sigmoid(i_t)
+ f_t = torch.sigmoid(f_t)
+ o_t = torch.sigmoid(o_t)
+ # cell computations
+ c_t = torch.matmul(x, w_xc) + torch.matmul(h, w_hc) + b_c
+ c_t = torch.tanh(c_t)
+ c_t = torch.mul(c_t, f_t) + torch.mul(i_t, c_t)
+ h_t = torch.mul(o_t, torch.tanh(c_t))
+ h = h_t
+ c = c_t
+ # calculate features with max probability
+ s_t = torch.matmul(h_t, w_hs) + b_s
+ s_t = s_t.view([1, s_t.size(1) * s_t.size(2)])
+ p_t = torch.softmax(s_t, 1)
+ prob_t, idx_t = torch.topk(p_t, k, 1)
+ if(int(idx_t.dim()) > 1):
+ idx_t_tmp = idx_t.squeeze(0)
+ else:
+ idx_t_tmp = idx_t
+ new_y = torch.fmod(idx_t_tmp, vocab_size)
+ pre_y = idx_t_tmp / vocab_size
+ x = embed.index_select(1, new_y)
+ h = h_t.index_select(1, pre_y)
+ c = c_t.index_select(1, pre_y)
+ iter = int(iter_count[0])
+ idx = torch.cat([idx.narrow(2, 0, iter).index_select(1, pre_y),
+ torch.fmod(idx_t, vocab_size).unsqueeze(-1),
+ idx.narrow(2, iter, max_len - iter)], 2)
+ idx = idx.narrow(2, 0, max_len)
+ return idx
+
+ beam_batch = torch.jit.batch(batch_size=4)(beam)
+
+ k = 5
+ batch_size, input_size, hidden_size, vocab_size = 4, 6, 8, 7
+ max_len = 5
+ xs, batch = self.rand_batch(batch_size, (False, 1), (False, input_size))
+ hx, h_batch = self.rand_batch(batch_size, (False, 1), (False, hidden_size))
+ cx, c_batch = self.rand_batch(batch_size, (False, 1), (False, hidden_size))
+ embed, embed_batch = self.rand_batch(batch_size, (False, vocab_size), (False, input_size))
+ iter_num = [torch.randint(2, max_len + 1, (1,)) for i in range(batch_size)]
+ iter_num_batch = BatchTensor(iter_num, torch.tensor([]).byte())
+
+ # input to hidden weights
+ w_xi = torch.rand(input_size, hidden_size)
+ w_xf = torch.rand(input_size, hidden_size)
+ w_xo = torch.rand(input_size, hidden_size)
+ w_xc = torch.rand(input_size, hidden_size)
+ # hidden to hidden weights
+ w_hi = torch.rand(hidden_size, hidden_size)
+ w_hf = torch.rand(hidden_size, hidden_size)
+ w_ho = torch.rand(hidden_size, hidden_size)
+ w_hc = torch.rand(hidden_size, hidden_size)
+ # bias terms
+ b_i = torch.rand(1, hidden_size)
+ b_f = torch.rand(1, hidden_size)
+ b_o = torch.rand(1, hidden_size)
+ b_c = torch.rand(1, hidden_size)
+ # hidden to vocab weights, bias
+ w_hs = torch.rand(hidden_size, vocab_size)
+ b_s = torch.rand(1, vocab_size)
+
+ idx_batch = torch.jit.BatchTensor(torch.zeros([batch_size, k, max_len], dtype=torch.long),
+ torch.zeros([batch_size, 1, max_len]).byte(),
+ torch.tensor([0, 1]).byte())
+ idx = [torch.zeros([1, k, max_len], dtype=torch.long) for _ in range(batch_size)]
+
+ ys = [beam(xs[j], hx[j], cx[j], embed[j], w_xi, w_xf, w_xo, w_xc, w_hi, w_hf, w_ho, w_hc,
+ b_i, b_f, b_o, b_c, w_hs, b_s, iter_num[j], idx[j]).narrow(2, 0, int(iter_num[j]))
+ for j in range(batch_size)]
+ ybs = beam_batch(batch, h_batch, c_batch, embed_batch, w_xi, w_xf, w_xo, w_xc,
+ w_hi, w_hf, w_ho, w_hc, b_i, b_f, b_o, b_c, w_hs, b_s, iter_num_batch, idx_batch)
self.assertEqual(ys, ybs.examples())
diff --git a/torch/csrc/jit/passes/to_batch.cpp b/torch/csrc/jit/passes/to_batch.cpp
index 3d3527b..f78da9b 100644
--- a/torch/csrc/jit/passes/to_batch.cpp
+++ b/torch/csrc/jit/passes/to_batch.cpp
@@ -3,59 +3,530 @@
namespace torch { namespace jit {
-std::unordered_map<std::string, std::shared_ptr<Graph>> ToBatch::batch_operator_table;
+std::unordered_map<std::string, std::vector<std::shared_ptr<Graph>>> ToBatch::batch_operator_table;
+
+std::shared_ptr<Graph> ToBatch::getBatchOperator(std::string name, int64_t num_inputs){
+ if(batch_operator_table.find(name) == batch_operator_table.end()){
+ throw std::runtime_error("function " + name + " is not supported in batched tensor yet");
+ }
+ auto ops = batch_operator_table.at(name);
+ if(num_inputs == -1) // default function
+ return ops[0];
+ for(auto op : ops){
+ if(size_t(num_inputs) == op->inputs().size())
+ return op;
+ }
+ throw std::runtime_error("function " + name + " with " + std::to_string(num_inputs) + " inputs is not supported in batched tensor yet");
+}
+
+// replace aten operator node with BatchTensor operator graph
+void ToBatch::visitAten(Node* n, Block* block, Block* res_block){
+ auto res_graph = res_block->owningGraph();
+ auto func_name = std::string(n->kind().toUnqualString());
+ std::vector<Value*> new_inputs;
+ for(Value *input : n->inputs()){
+ if(rn_env.find(input) == rn_env.end()){ // non-tensor input
+ auto new_input = batch_map.at(input);
+ new_inputs.insert(new_inputs.end(), new_input.begin(), new_input.end());
+ }
+ else{ // batched tensor input
+ new_inputs.push_back(rn_env.at(input));
+ }
+ }
+
+ // transform scalar to tensor before pass to batch operator script
+ for(size_t i = 0; i < new_inputs.size(); i++){
+ auto input = new_inputs[i];
+ if(input->type() == IntType::get() || input->type() == FloatType::get()){
+ auto to_tensor_node = res_graph->createNumToTensor(input);
+ res_graph->insertNode(to_tensor_node);
+ new_inputs[i] = to_tensor_node->output();
+ }
+ }
+
+ auto batch_graph = getBatchOperator(func_name, new_inputs.size());
+ auto outputs = script::inlineCallTo(*res_block->owningGraph(), *batch_graph, new_inputs);
+
+ // Assume all outputs from inlined operator implementation are in the triple form batched tensor or just a single non-tensor.
+ if(outputs.size() == 1){
+ // if previous output is scalar, transform new output back to scalar from dynamic
+ if(n->outputs()[0]->type() != outputs[0]->type()){
+ Node* to_scalar_node;
+ if(n->outputs()[0]->type() == IntType::get()){
+ to_scalar_node = res_graph->createTensorToNum(IntType::get(), outputs[0]);
+ }
+ else if(n->outputs()[0]->type() == FloatType::get()){
+ to_scalar_node = res_graph->createTensorToNum(FloatType::get(), outputs[0]);
+ }
+ else{
+ throw std::runtime_error("NYI: scalar type other than int, float is not supported yet");
+ }
+ res_graph->insertNode(to_scalar_node);
+ rn_env[n->outputs()[0]] = to_scalar_node->output();
+ }
+ else
+ rn_env[n->outputs()[0]] = outputs[0];
+ }
+ else{
+ for(size_t i = 0; i < n->outputs().size(); i++){
+ auto output = n->outputs()[i];
+ batch_map[output] = std::vector<Value*>(outputs.begin() + i * EXP_BTENSOR_SIZE, outputs.begin() + i * EXP_BTENSOR_SIZE + EXP_BTENSOR_SIZE);
+ }
+ }
+}
+
+// clone prim::Constant to new graph
+// batching transformation is applied to the output of prim::NumToTensor.
+// If there is a prim::NumToTensor following prim::Constant, it will be finally transformed to BatchTensor.
+void ToBatch::visitConstant(Node* n, Block* block, Block* res_block){
+ auto res_graph = res_block->owningGraph();
+ auto* r_node = res_graph->createClone(n, rn_fn);
+ r_node->setStage(n->stage());
+ res_block->appendNode(r_node);
+ rn_env[n->output()] = r_node->output();
+}
+
+// change return tensor to expanded batched tensor, eg: {data, mask, dims}
+void ToBatch::visitNumToTensor(Node* n, Block* block, Block* res_block){
+ auto res_graph = res_block->owningGraph();
+ auto* r_node = res_graph->createClone(n, rn_fn);
+ r_node->setStage(n->stage());
+ res_block->appendNode(r_node);
+ auto outputs = script::inlineCallTo(*res_block->owningGraph(), *getBatchOperator("batch_from_scalar_tensor"), r_node->outputs());
+ batch_map[n->output()] = outputs;
+}
+
+// clone prim::TensorToNum to new graph
+void ToBatch::visitTensorToNum(Node* n, Block* block, Block* res_block){
+ auto res_graph = res_block->owningGraph();
+ if(rn_env.find(n->input()) == rn_env.end()){
+ rn_env[n->input()] = batch_map.at(n->input())[0];
+ }
+ auto* r_node = res_graph->createClone(n, rn_fn);
+ r_node->setStage(n->stage());
+ res_block->appendNode(r_node);
+ rn_env[n->output()] = r_node->output();
+ batch_map[n->output()] = batch_map.at(n->input());
+}
+
+// clone prim::ListConstruct to new graph
+void ToBatch::visitListConstruct(Node* n, Block* block, Block* res_block){
+ auto res_graph = res_block->owningGraph();
+ if(n->inputs()[0]->type() == DynamicType::get()){ // TensorList: expand directly
+ std::vector<Value*> inputs;
+ for(Value* input: n->inputs()) {
+ auto res = batch_map.at(input);
+ inputs.insert(inputs.end(), res.begin(), res.end());
+ }
+ batch_map[n->output()] = inputs;
+ }
+ else { // ScalarList: transform to tensor, then transform back
+ for(Value* input : n->inputs()) {
+ if(rn_env.find(input) == rn_env.end()){
+ rn_env[input] = batch_map.at(input)[0];
+ }
+ }
+ auto* r_node = res_graph->createClone(n, rn_fn);
+ r_node->setStage(n->stage());
+ res_block->appendNode(r_node);
+ // transform int[] to tensor
+ auto to_tensor_node = res_graph->create(Symbol::fromQualString("aten::_list_to_tensor"));
+ to_tensor_node->setStage(n->stage());
+ to_tensor_node->addInput(r_node->output());
+ res_block->appendNode(to_tensor_node);
+ rn_env[n->output()] = to_tensor_node->output();
+ }
+}
+
+// prim::If transformation:
+// elif is not supported
+//
+// transformation example:
+// @torch.jit.batch(batch_size=4)
+// def batch_if(a, b):
+// if a > b:
+// a += b
+// else:
+// a -= b
+// return a
+//
+// original graph:
+// graph(%a.1 : Dynamic
+// %b : Dynamic) {
+// %2 : Dynamic = aten::gt(%a.1, %b)
+// %a : Dynamic = prim::If(%2)
+// block0() {
+// %a.2 : Dynamic = aten::add[alpha={1}](%a.1, %b)
+// -> (%a.2)
+// }
+// block1() {
+// %a.3 : Dynamic = aten::sub[alpha={1}](%a.1, %b)
+// -> (%a.3)
+// }
+// return (%a);
+// }
+//
+// transformed graph:
+// graph(%a.1_data : Dynamic
+// %a.1_mask : Dynamic
+// %a.1_dims : Dynamic
+// %b_data : Dynamic
+// %b_mask : Dynamic
+// %b_dims : Dynamic) {
+// %6 : Dynamic = aten::gt(%a.1_data, %b_data) // calculate condition
+// %7 : Dynamic = aten::mul(%a.1_mask, %b_mask)
+// %8 : Dynamic = aten::__or__(%a.1_dims, %b_dims)
+// %9 : int = prim::TensorToNum(%6)
+// %10 : Long() = prim::Constant[value={1}]() // if_block
+// %alpha.1 : float = prim::TensorToNum(%10)
+// %data.1 : Dynamic = aten::add(%a.1_data, %b_data, %alpha.1)
+// %mask.1 : Dynamic = aten::mul(%a.1_mask, %b_mask)
+// %dims.1 : Dynamic = aten::__or__(%a.1_dims, %b_dims)
+// %15 : Long() = prim::Constant[value={1}]() // else_block
+// %alpha : float = prim::TensorToNum(%15)
+// %data.4 : Dynamic = aten::sub(%a.1_data, %b_data, %alpha)
+// %mask : Dynamic = aten::mul(%a.1_mask, %b_mask)
+// %dims : Dynamic = aten::__or__(%a.1_dims, %b_dims)
+// %20 : Dynamic = aten::type_as(%7, %6) // combine two outputs (batch_where)
+// %cond_mask.1 : Dynamic = aten::mul(%6, %20)
+// %22 : int = aten::dim(%cond_mask.1)
+// %23 : int = prim::Constant[value=1]()
+// %24 : int = aten::eq(%22, %23)
+// %cond_data : Dynamic, %cond_mask : Dynamic, %data : Dynamic = prim::If(%24)
+// block0() {
+// %28 : int = aten::dim(%data.1)
+// %29 : int = prim::Constant[value=1]()
+// %30 : int = aten::sub(%28, %29)
+// %31 : int = prim::Constant[value=1]()
+// %data.3 : Dynamic = prim::Loop(%30, %31, %cond_mask.1)
+// block0(%_ : int, %34 : Dynamic) {
+// %35 : int = prim::Constant[value=1]()
+// %36 : int = aten::neg(%35)
+// %data.2 : Dynamic = aten::unsqueeze(%34, %36)
+// %38 : int = prim::Constant[value=1]()
+// -> (%38, %data.2)
+// }
+// %cond_data.1 : Dynamic = aten::expand_as(%data.3, %data.1)
+// %cond_mask.2 : Dynamic = aten::expand_as(%data.3, %mask.1)
+// -> (%cond_data.1, %cond_mask.2, %data.3)
+// }
+// block1() {
+// -> (%cond_mask.1, %cond_mask.1, %cond_mask.1)
+// }
+// %res_data : Dynamic = aten::where(%cond_data, %data.1, %data.4)
+// %res_mask : Dynamic = aten::where(%cond_mask, %mask.1, %mask)
+// %res_dims : Dynamic = aten::__or__(%dims.1, %dims)
+// return (%res_data, %res_mask, %res_dims);
+// }
+void ToBatch::visitIf(Node* n, Block* block, Block* res_block){
+ toBatch(n->blocks()[0], res_block);
+ toBatch(n->blocks()[1], res_block);
+
+ // combine results from two if paths
+ for(size_t i = 0; i < n->outputs().size(); i++){
+ std::vector<Value*> inputs;
+ if(batch_map.find(n->input()) == batch_map.end()){ // cond is scalar
+ inputs.push_back(rn_env.at(n->input()));
+ }
+ else{ // cond is tensor
+ auto cond = batch_map.at(n->input());
+ inputs.insert(inputs.end(), cond.begin(), cond.end());
+ }
+ auto if_output = batch_map.at(n->blocks()[0]->outputs()[i]);
+ inputs.insert(inputs.end(), if_output.begin(), if_output.end());
+ auto else_output = batch_map.at(n->blocks()[1]->outputs()[i]);
+ inputs.insert(inputs.end(), else_output.begin(), else_output.end());
+ auto outputs = script::inlineCallTo(*res_block->owningGraph(), *getBatchOperator("where", inputs.size()), inputs);
+ batch_map[n->outputs()[i]] = outputs;
+ }
+}
+
+// prim::Loop transformation:
+//
+// transformation example:
+// @torch.jit.batch(batch_size=4)
+// def batch_while(a, b):
+// while a > b:
+// a -= b
+// return a
+//
+// original graph:
+// graph(%a.1 : Dynamic
+// %b : Dynamic) {
+// %2 : int = prim::Constant[value={2147483647}]()
+// %3 : Dynamic = aten::gt(%a.1, %b)
+// %a : Dynamic = prim::Loop(%2, %3, %a.1)
+// block0(%4 : Dynamic, %5 : Dynamic) {
+// %a.2 : Dynamic = aten::sub[alpha={1}](%5, %b)
+// %9 : Dynamic = aten::gt(%a.2, %b)
+// -> (%9, %a.2)
+// }
+// return (%a);
+// }
+//
+// transformed graph:
+// graph(%a.1_data : Dynamic
+// %a.1_mask : Dynamic
+// %a.1_dims : Dynamic
+// %b_data : Dynamic
+// %b_mask : Dynamic
+// %b_dims : Dynamic) {
+// %6 : int = prim::Constant[value=2147483647]()
+// %7 : Dynamic = aten::gt(%a.1_data, %b_data)
+// %8 : Dynamic = aten::mul(%a.1_mask, %b_mask)
+// %9 : Dynamic = aten::__or__(%a.1_dims, %b_dims)
+// %10 : int = prim::TensorToNum(%7)
+// %11 : Dynamic = aten::mul(%7, %8)
+// %12 : Dynamic = aten::sum(%11)
+// %13 : Dynamic = aten::gt[other={0}](%12) // cond_any
+// %14 : int = prim::TensorToNum(%13)
+// %62 : Dynamic, %63 : Dynamic, %64 : Dynamic, %a : Dynamic, %60 : Dynamic, %61 : Dynamic = prim::Loop(%6, %14, %7, %8, %9, %a.1_data, %a.1_mask, %a.1_dims)
+// block0(%loop_num : int, %cond_data.2 : Dynamic, %cond_mask.3 : Dynamic, %cond_dims : Dynamic, %6_data : Dynamic, %6_mask : Dynamic, %6_dims : Dynamic) {
+// %23 : Long() = prim::Constant[value={1}]()
+// %alpha : float = prim::TensorToNum(%23)
+// %data.1 : Dynamic = aten::sub(%6_data, %b_data, %alpha)
+// %mask : Dynamic = aten::mul(%6_mask, %b_mask)
+// %dims : Dynamic = aten::__or__(%6_dims, %b_dims)
+// %28 : Dynamic = aten::gt(%data.1, %b_data)
+// %29 : Dynamic = aten::mul(%mask, %b_mask)
+// %30 : Dynamic = aten::__or__(%dims, %b_dims)
+// %31 : int = prim::TensorToNum(%28)
+// %32 : Dynamic = aten::type_as(%cond_mask.3, %cond_data.2) // update outputs (batch_where)
+// %cond_mask.1 : Dynamic = aten::mul(%cond_data.2, %32)
+// %34 : int = aten::dim(%cond_mask.1)
+// %35 : int = prim::Constant[value=1]()
+// %36 : int = aten::eq(%34, %35)
+// %cond_data : Dynamic, %cond_mask : Dynamic, %data : Dynamic = prim::If(%36)
+// block0() {
+// %40 : int = aten::dim(%data.1)
+// %41 : int = prim::Constant[value=1]()
+// %42 : int = aten::sub(%40, %41)
+// %43 : int = prim::Constant[value=1]()
+// %data.3 : Dynamic = prim::Loop(%42, %43, %cond_mask.1)
+// block0(%_ : int, %46 : Dynamic) {
+// %47 : int = prim::Constant[value=1]()
+// %48 : int = aten::neg(%47)
+// %data.2 : Dynamic = aten::unsqueeze(%46, %48)
+// %50 : int = prim::Constant[value=1]()
+// -> (%50, %data.2)
+// }
+// %cond_data.1 : Dynamic = aten::expand_as(%data.3, %data.1)
+// %cond_mask.2 : Dynamic = aten::expand_as(%data.3, %mask)
+// -> (%cond_data.1, %cond_mask.2, %data.3)
+// }
+// block1() {
+// -> (%cond_mask.1, %cond_mask.1, %cond_mask.1)
+// }
+// %res_data : Dynamic = aten::where(%cond_data, %data.1, %6_data)
+// %res_mask : Dynamic = aten::where(%cond_mask, %mask, %6_mask)
+// %res_dims : Dynamic = aten::__or__(%dims, %6_dims)
+// %56 : Dynamic = aten::mul(%28, %29)
+// %57 : Dynamic = aten::sum(%56)
+// %58 : Dynamic = aten::gt[other={0}](%57)
+// %59 : int = prim::TensorToNum(%58)
+// -> (%59, %28, %29, %30, %res_data, %res_mask, %res_dims)
+// }
+// return (%a, %60, %61);
+// }
+void ToBatch::visitLoop(Node* n, Block* block, Block* res_block){
+ auto res_graph = res_block->owningGraph();
+ // bool cond_is_tensor indicates whether cond is tensor
+ // cond_is_tensor = false, eg: for loop, n->inputs()[1] = byte()
+ // cond_is_tensor = true, eg: in some while loop, cond is a batched tensor,
+ // we need to add expanded cond to the inputs of loop node and block,
+ // and compute cond_any as cond for while loop
+ bool cond_is_tensor = (batch_map.find(n->inputs()[1]) != batch_map.end());
+
+ // create prim::Loop node for res_block
+
+ // type of cond in loop should be int type
+ if(rn_env.at(n->inputs()[0])->type() != IntType::get()){
+ auto to_int_node = res_graph->createTensorToNum(IntType::get(), rn_env.at(n->inputs()[0]));
+ res_graph->insertNode(to_int_node);
+ rn_env[n->inputs()[0]] = to_int_node->output();
+ }
+ if(cond_is_tensor){
+ auto cond = batch_map.at(n->inputs()[1]);
+ auto cond_any = script::inlineCallTo(*res_block->owningGraph(), *getBatchOperator("any"), cond);
+ auto to_int_node = res_graph->createTensorToNum(IntType::get(), cond_any[0]);
+ res_graph->insertNode(to_int_node);
+ rn_env[n->inputs()[1]] = to_int_node->output();
+ }
+ for(size_t i = 2; i < n->inputs().size(); i++){
+ auto input = n->inputs()[i];
+ rn_env[input] = batch_map.at(input)[0];
+ }
+ auto* r_node = res_graph->createClone(n, rn_fn, /*copy_blocks=*/false);
+
+ // change inputs of prim::Loop
+ if(cond_is_tensor){
+ for(size_t i = 0; i < EXP_BTENSOR_SIZE; i++){
+ auto cond = batch_map.at(n->inputs()[1]);
+ r_node->insertInput(i + 2, cond[i]);
+ }
+ }
+ for(size_t i = 2; i < n->inputs().size(); i++){
+ for(size_t j = 1; j < EXP_BTENSOR_SIZE; j++){
+ r_node->insertInput((i - 2) * EXP_BTENSOR_SIZE + EXP_BTENSOR_SIZE * cond_is_tensor + 2 + j, batch_map.at(n->inputs()[i])[j]);
+ }
+ }
+ r_node->setStage(n->stage());
+ res_block->appendNode(r_node);
+
+ // create block for Loop node in res_block
+ // if cond is tensor: first 4 inputs of block: cond_any, cond_data, cond_mask, cond_dims
+ // if cond is not tensor: first 1 input of block: cond
+ auto loop_block = r_node->addBlock();
+
+ // add inputs
+ loop_block->addInput("loop_num");
+ loop_block->inputs()[0]->setType(IntType::get());
+ rn_env[n->blocks()[0]->inputs()[0]] = loop_block->inputs()[0];
+ if(cond_is_tensor){
+ for(size_t i = 0; i < EXP_BTENSOR_SIZE; i++){
+ loop_block->addInput("cond_" + EXP_BTENSOR_NAME[i]);
+ }
+ }
+ for(size_t i = 1; i < n->blocks()[0]->inputs().size(); i++){
+ auto input = n->blocks()[0]->inputs()[i];
+ auto name = input->uniqueName();
+ for(size_t j = 0; j < EXP_BTENSOR_SIZE; j++){
+ loop_block->addInput(name + "_" + EXP_BTENSOR_NAME[j]);
+ }
+ batch_map[input] = std::vector<Value*>(loop_block->inputs().slice((i - 1) * EXP_BTENSOR_SIZE + 1 + EXP_BTENSOR_SIZE * cond_is_tensor, EXP_BTENSOR_SIZE).vec());
+ }
+
+ toBatch(n->blocks()[0], loop_block);
+
+ WithInsertPoint guard(loop_block);
+
+ // use where operator to update variables and add to outputs
+ for(size_t i = 0; i < n->outputs().size(); i++){
+ std::vector<Value*> inputs, outputs;
+ if(cond_is_tensor){
+ for(size_t j = 0; j < EXP_BTENSOR_SIZE; j++){
+ inputs.push_back(loop_block->inputs()[j + 1]);
+ }
+ auto data = batch_map.at(n->blocks()[0]->outputs()[i + 1]);
+ inputs.insert(inputs.end(), data.begin(), data.end());
+ for(size_t j = 0; j < EXP_BTENSOR_SIZE; j++){
+ inputs.push_back(loop_block->inputs()[i * EXP_BTENSOR_SIZE + j + EXP_BTENSOR_SIZE + 1]);
+ }
+ outputs = script::inlineCallTo(*res_block->owningGraph(), *getBatchOperator("where"), inputs);
+ }
+ else{
+ for(size_t j = 0; j < EXP_BTENSOR_SIZE; j++){
+ inputs.push_back(loop_block->inputs()[i * EXP_BTENSOR_SIZE + j + 1]);
+ }
+ auto data = batch_map.at(n->blocks()[0]->outputs()[i + 1]);
+ inputs.insert(inputs.end(), data.begin(), data.end());
+ outputs = script::inlineCallTo(*res_block->owningGraph(), *getBatchOperator("update"), inputs);
+ }
+ batch_map[n->outputs()[i]] = outputs;
+ for(size_t j = 0; j < EXP_BTENSOR_SIZE; j++){
+ loop_block->registerOutput(outputs[j]);
+ }
+ }
+
+ // update loop conditions
+ if(cond_is_tensor){
+ auto cond = batch_map.at(n->blocks()[0]->outputs()[0]);
+ auto cond_any = script::inlineCallTo(*res_block->owningGraph(), *getBatchOperator("any"), cond);
+ auto to_int_node = res_graph->createTensorToNum(IntType::get(), cond_any[0]);
+ res_graph->insertNode(to_int_node);
+ loop_block->insertOutput(0, to_int_node->output());
+ for(size_t i = 0; i < EXP_BTENSOR_SIZE; i++){
+ loop_block->insertOutput(i + 1, cond[i]);
+ }
+ }
+ else{
+ auto cond = rn_env.at(n->blocks()[0]->outputs()[0]);
+ loop_block->insertOutput(0, cond);
+ }
+
+ // change outputs of prim::Loop
+ auto size = r_node->outputs().size();
+ for(size_t i = 0; i < size; i++){
+ for(size_t j = 1; j < EXP_BTENSOR_SIZE; j++){
+ r_node->insertOutput(i * EXP_BTENSOR_SIZE + j);
+ }
+ batch_map[n->outputs()[i]] = r_node->outputs().slice(i * EXP_BTENSOR_SIZE, EXP_BTENSOR_SIZE).vec();
+ }
+ // add cond to outputs of loop node
+ if(cond_is_tensor){
+ for(size_t i = 0; i < EXP_BTENSOR_SIZE; i++){
+ r_node->insertOutput(i);
+ }
+ }
+}
void ToBatch::toBatch(Block* block, Block* res_block) {
- // change inputs of a graph - expand tensor to {data, mask, dims}
- auto size = block->inputs().size();
- for(size_t i = 0; i < size; i++){
- auto input = block->inputs()[i];
- auto name = input->uniqueName();
- res_block->addInput(name + "_data");
- res_block->addInput(name + "_mask");
- res_block->addInput(name + "_dims");
- batch_map[input] = res_block->inputs().slice(i * 3, 3).vec();
+ WithInsertPoint guard(res_block);
+
+ // change inputs of block - expand tensor to batchtensor eg: (data, mask, dims)
+ // eg: a -> a_data, a_mask, a_dims
+ // for block in prim::Loop, register inputs separately to deal with cond
+ if(!block->owningNode() || block->owningNode()->kind() != prim::Loop){
+ auto size = block->inputs().size();
+ for(size_t i = 0; i < size; i++){
+ auto input = block->inputs()[i];
+ auto name = input->uniqueName();
+ for(size_t j = 0; j < EXP_BTENSOR_SIZE; j++){
+ res_block->addInput(name + "_" + EXP_BTENSOR_NAME[j]);
+ }
+ batch_map[input] = std::vector<Value*>(res_block->inputs().slice(i * EXP_BTENSOR_SIZE, EXP_BTENSOR_SIZE).vec());
+ }
}
for (auto it = block->nodes().begin(); it != block->nodes().end(); it++) {
auto n = *it;
- // replace tensor operator to BatchTensor operator
if(n->kind().is_aten()){
- auto batch_graph = batch_operator_table.at(n->kind().toUnqualString());
- WithInsertPoint guard(res_block);
- std::vector<Value*> new_inputs;
- for(Value *input : n->inputs()){
- if(batch_map.find(input) != batch_map.end()){
- auto new_input = batch_map.at(input);
- new_inputs.insert(new_inputs.end(), new_input.begin(), new_input.end());
- }
- else{
- throw std::runtime_error("NYI: non-tensor input for aten operator is not supported yet");
- }
- }
- auto outputs = script::inlineCallTo(*res_block->owningGraph(), *batch_graph, new_inputs);
- // Assume all outputs from inlined operator implementation are in the triple form.
- for(size_t i = 0; i < n->outputs().size(); i++){
- auto output = n->outputs()[i];
- batch_map[output] = std::vector<Value*>(outputs.begin() + i * 3, outputs.begin() + i * 3 + 3);
- }
+ visitAten(n, block, res_block);
}
else if(n->kind().is_prim()){
- throw std::runtime_error("NYI: node of prim kind is not supported to transform to batch graph yet");
+ switch(n->kind()){
+ case prim::Constant:
+ visitConstant(n, block, res_block);
+ break;
+ case prim::NumToTensor:
+ visitNumToTensor(n, block, res_block);
+ break;
+ case prim::TensorToNum:
+ visitTensorToNum(n, block, res_block);
+ break;
+ case prim::ListConstruct:
+ visitListConstruct(n, block, res_block);
+ break;
+ case prim::If:
+ visitIf(n, block, res_block);
+ break;
+ case prim::Loop:
+ visitLoop(n, block, res_block);
+ break;
+ default:
+ throw std::runtime_error("NYI: node of prim kind other than [Constant, NumToTensor, TensorToNum, If, Loop] is not supported yet");
+ }
+ }
+ else{
+ throw std::runtime_error("NYI: node that is not aten or prim kind is not supported yet");
}
}
- // change outputs of a graph - expand tensor to {data, mask, dims}
- for(Value* output : block->outputs()){
- auto r_output = batch_map.at(output);
- res_block->registerOutput(r_output[0]);
- res_block->registerOutput(r_output[1]);
- res_block->registerOutput(r_output[2]);
+ // change outputs of block - expand tensor to batchtensor(data, mask, dims)
+ // for block in prim::Loop, register outputs separately to deal with cond and cond_any
+ // for block in prim::If, register outputs separately by combining outputs from two paths and return
+ if(!block->owningNode() || (block->owningNode()->kind() != prim::Loop && block->owningNode()->kind() != prim::If)) {
+ for(Value* output : block->outputs()){
+ auto r_output = batch_map.at(output);
+ for(size_t i = 0; i < EXP_BTENSOR_SIZE; i++){
+ res_block->registerOutput(r_output[i]);
+ }
+ }
}
}
std::shared_ptr<Graph> to_batch_graph(std::shared_ptr<Graph>& graph){
// std::cout<<graph->toString()<<std::endl;
- auto res_graph = std::make_shared<Graph>(graph->scope_root());
+ std::shared_ptr<Graph> res_graph = std::make_shared<Graph>(graph->scope_root());
ToBatch to_batch;
to_batch.toBatch(graph->block(), res_graph->block());
// std::cout<<res_graph->toString()<<std::endl;
@@ -66,7 +537,7 @@
auto m = py::handle(module).cast<py::module>();
m.def("to_batch_graph", &to_batch_graph);
m.def("register_batch_operator", [](std::string name, std::shared_ptr<Graph> graph){
- ToBatch::batch_operator_table[name] = graph;
+ ToBatch::batch_operator_table[name].push_back(graph);
});
}
diff --git a/torch/csrc/jit/passes/to_batch.h b/torch/csrc/jit/passes/to_batch.h
index 23c23a0..6545e2a 100644
--- a/torch/csrc/jit/passes/to_batch.h
+++ b/torch/csrc/jit/passes/to_batch.h
@@ -3,14 +3,33 @@
#include "torch/csrc/jit/pybind.h"
#include "torch/csrc/jit/ir.h"
+#include <ATen/ATen.h>
+
namespace torch { namespace jit {
class ToBatch {
private:
+ // number of tensors to represent a expanded BatchTensor. {data, mask, dims} for now.
+ const size_t EXP_BTENSOR_SIZE = 3;
+ const std::vector<std::string> EXP_BTENSOR_NAME = {"data", "mask", "dims"};
// mapping from tensor in original graph to {data, mask, dims} in new graph
std::unordered_map<Value*, std::vector<Value*>> batch_map;
+ // mapping from input in original graph to new input in new graph - used in createClone
+ std::unordered_map<Value*, Value*> rn_env;
+ std::function<Value*(Value*)> rn_fn = [this](Value* v) { return rn_env.at(v); };
+
+private:
+ std::shared_ptr<Graph> getBatchOperator(std::string name, int64_t input_num = -1);
+ void visitAten(Node* n, Block* block, Block* res_block);
+ void visitConstant(Node* n, Block* block, Block* res_block);
+ void visitNumToTensor(Node* n, Block* block, Block* res_block);
+ void visitTensorToNum(Node* n, Block* block, Block* res_block);
+ void visitListConstruct(Node* n, Block* block, Block* res_block);
+ void visitIf(Node* n, Block* block, Block* res_block);
+ void visitLoop(Node* n, Block* block, Block* res_block);
+
public:
- static std::unordered_map<std::string, std::shared_ptr<Graph>> batch_operator_table;
+ static std::unordered_map<std::string, std::vector<std::shared_ptr<Graph>>> batch_operator_table;
TORCH_API void toBatch(Block* block, Block* res_block);
};
diff --git a/torch/csrc/jit/register_prim_ops.cpp b/torch/csrc/jit/register_prim_ops.cpp
index 90d49e2..f2b8ea1 100644
--- a/torch/csrc/jit/register_prim_ops.cpp
+++ b/torch/csrc/jit/register_prim_ops.cpp
@@ -347,7 +347,35 @@
return 0;
};
}),
-
+ Operator(
+ "aten::_tensor_to_list(Tensor a) -> int[]",
+ [](Node* node) {
+ return [=](Stack& stack) {
+ at::Tensor t;
+ pop(stack, t);
+ std::vector<int64_t> elems;
+ for(int i = 0; i < t.size(0); i++){
+ elems.push_back(*t[i].toIntData());
+ }
+ push(stack, jit::IntList::create(elems));
+ return 0;
+ };
+ }),
+ Operator(
+ "aten::_list_to_tensor(int[] a) -> Tensor",
+ [](Node* node) {
+ return [=](Stack& stack) {
+ std::vector<int64_t> l;
+ pop(stack, l);
+ auto t = torch::empty(
+ {static_cast<int64_t>(l.size())}, at::dtype(at::kInt));
+ for(size_t i = 0; i < l.size(); i++){
+ t[i] = l[i];
+ }
+ push(stack, t);
+ return 0;
+ };
+ }),
// commutative
DEFINE_ST_OP(mul, at::mul(b, a))
DEFINE_ST_OP(add, at::add(b, a))
diff --git a/torch/jit/__init__.py b/torch/jit/__init__.py
index c0cf4f9..d09e970 100644
--- a/torch/jit/__init__.py
+++ b/torch/jit/__init__.py
@@ -403,9 +403,12 @@
else:
new_args.append(arg)
res = res_mod(*new_args)
- # assert len(res) / 3 == 0
- # result = [BatchTensor(*res[i * 3: i * 3 + 3]) for i in range(len(res) // 3)]
- result = BatchTensor(*res)
+ assert len(res) % 3 == 0
+ if len(res) % 3 != 0:
+ raise "non-batched-tensor output is not supported yet"
+ result = [BatchTensor(*res[i * 3: i * 3 + 3]) for i in range(len(res) // 3)]
+ if len(result) == 1:
+ return result[0]
return result
wrapper.__doc__ = fn.__doc__
return wrapper
diff --git a/torch/jit/batchop.py b/torch/jit/batchop.py
index bda6a3a..053130d 100644
--- a/torch/jit/batchop.py
+++ b/torch/jit/batchop.py
@@ -1,6 +1,9 @@
import torch
+from torch.jit import BatchTensor
+# TODO: there are some commented raise statements
+# when we support rasie exception in script, we want to check them
@torch.jit.script
def batch_tanh(data, mask, dims):
data = torch.tanh(data)
@@ -14,14 +17,53 @@
@torch.jit.script
-def batch_add(data1, mask1, dims1, data2, mask2, dims2):
- data = torch.add(data1, data2)
+def batch_relu(data, mask, dims):
+ data = torch.relu(data)
+ return data, mask, dims
+
+
+@torch.jit.script
+def batch_neg(data, mask, dims):
+ data = torch.neg(data)
+ return data, mask, dims
+
+
+@torch.jit.script
+def batch_neg_scalar(data):
+ return torch.neg(data)
+
+
+@torch.jit.script
+def batch_add(data1, mask1, dims1, data2, mask2, dims2, alpha_):
+ alpha = float(alpha_)
+ data = torch.add(data1, data2, alpha)
mask = mask1 * mask2
dims = dims1 or dims2
return data, mask, dims
@torch.jit.script
+def batch_add_scalar(data, mask, dims, other, alpha_):
+ alpha = float(alpha_)
+ data = torch.add(data, other.type_as(data), alpha)
+ return data, mask, dims
+
+
+@torch.jit.script
+def batch_sub(data1, mask1, dims1, data2, mask2, dims2, alpha_):
+ alpha = float(alpha_)
+ data = torch.sub(data1, data2, alpha)
+ mask = mask1 * mask2
+ dims = dims1 or dims2
+ return data, mask, dims
+
+
+@torch.jit.script
+def batch_sub_scalar(data1, data2):
+ return data1 - data2
+
+
+@torch.jit.script
def batch_mul(data1, mask1, dims1, data2, mask2, dims2):
data = torch.mul(data1, data2)
mask = mask1 * mask2
@@ -30,6 +72,17 @@
@torch.jit.script
+def batch_mul_scalar(data1, data2):
+ return data1 * data2
+
+
+@torch.jit.script
+def batch_div(data, mask, dims, other): # div(batchtensor, scalar)
+ data = torch.div(data, other)
+ return data, mask, dims
+
+
+@torch.jit.script
def batch_mm(data1, mask1, dims1, data2, mask2, dims2):
data1 = data1 * mask1.type_as(data1)
data2 = data2 * mask2.type_as(data2)
@@ -88,26 +141,388 @@
# raise ValueError("Cannot select 0 dim in BatchTensor")
data = data.select(dim, index)
if dims[dim - 1]:
- mask = mask.select(dim, 0)
- else:
mask = mask.select(dim, index)
+ else:
+ mask = mask.select(dim, 0)
dims = torch.cat((dims[:dim - 1], dims[dim:dims.size(0)]))
return data, mask, dims
+@torch.jit.script
+def batch_fmod(data, mask, dims, other_):
+ other = int(other_)
+ data = torch.fmod(data, other)
+ return data, mask, dims
+
+
+@torch.jit.script
+def batch_zeros_like(data, mask, dims):
+ res_data = torch.zeros_like(data)
+ return res_data, mask, dims
+
+
+@torch.jit.script
+def batch_index_select(data, mask, dims, dim_, index_data, index_mask, index_dims):
+ dim = int(dim_)
+ # if dim == 0:
+ # raise ValueError("Cannot index_select along 0 dim in BatchTensor")
+ batch_size = data.size(0) # TODO maybe index_mask will be used at some point
+ res_data = torch.zeros([0])
+ res_mask = torch.zeros([0])
+ for i in range(batch_size):
+ d = data[i].index_select(dim - 1, index_data[i]).unsqueeze(0)
+ if dims[dim - 1]:
+ m = mask[i].index_select(dim - 1, index_data[i]).unsqueeze(0)
+ else:
+ m = mask[i].unsqueeze(0)
+ if i == 0:
+ res_data = d
+ res_mask = m
+ else:
+ res_data = torch.cat((res_data, d), 0)
+ res_mask = torch.cat((res_mask, m), 0)
+ return res_data, res_mask, dims
+
+
+@torch.jit.script
+def batch_view_as(data, mask, dims, data1, mask1, dims1):
+ # if data.size(0) != data1.size(0):
+ # raise ValueError("In view_as, tensor and target tensor should have the same batch_size")
+ # if not torch.equal(dims, dims1):
+ # raise ValueError("In batched view_as, dims and target dims should be the same")
+ data = data.view_as(data1)
+ mask = mask.view_as(mask1)
+ dims = dims1
+ return data, mask, dims
+
+
# assume data, data1, data2 have same size
@torch.jit.script
def batch_where(data, mask, dims, data1, mask1, dims1, data2, mask2, dims2):
- res_data = torch.where(data, data1, data2)
- res_mask = torch.where(data, mask1, mask2)
+ data = data * mask.type_as(data)
+ cond_data = data
+ cond_mask = data
+ if data.dim() == 1:
+ for _ in range(data1.dim() - 1):
+ data = data.unsqueeze(data.dim())
+ cond_data = data.expand_as(data1)
+ cond_mask = data.expand_as(mask1)
+ res_data = torch.where(cond_data, data1, data2)
+ res_mask = torch.where(cond_mask, mask1, mask2)
res_dims = dims1 or dims2
return res_data, res_mask, res_dims
+
+@torch.jit.script
+def batch_where_scalar(cond_, data1, mask1, dims1, data2, mask2, dims2):
+ cond = torch.zeros([1], dtype=torch.uint8) * cond_
+ res_data = torch.where(cond, data1, data2)
+ res_mask = torch.where(cond, mask1, mask2)
+ res_dims = torch.where(cond, dims1, dims2)
+ return res_data, res_mask, res_dims
+
+
+@torch.jit.script
+def batch_update(batch_data, batch_mask, batch_dims, new_data, new_mask, new_dims):
+ data = torch.where(new_mask, new_data, batch_data)
+ return data, new_mask, new_dims # TODO: consider whether return new_mask and new_dims
+
+
+@torch.jit.script
+def batch_any(data, mask, dims):
+ return torch.gt(torch.sum(data * mask), 0)
+
+
+@torch.jit.script
+def batch_type_as(data, mask, dims, data1, mask1, dims1):
+ return data.type_as(data1), mask, dims
+
+
+@torch.jit.script
+def batch_gt(data, mask, dims, data1, mask1, dims1):
+ return torch.gt(data, data1), mask * mask1, dims or dims1
+
+
+@torch.jit.script
+def batch_gt_scalar(data1, data2):
+ return torch.gt(data1, data2)
+
+
+@torch.jit.script
+def batch_gt_one_scalar(data, mask, dims, other_):
+ other = float(other_)
+ return torch.gt(data, other), mask, dims
+
+
+@torch.jit.script
+def batch_lt(data, mask, dims, data1, mask1, dims1):
+ return torch.lt(data, data1), mask * mask1, dims or dims1
+
+
+@torch.jit.script
+def batch_eq(data, mask, dims, data1, mask1, dims1):
+ return torch.eq(data, data1), mask * mask1, dims or dims1
+
+
+@torch.jit.script
+def batch_size(data, mask, dims, dim_):
+ dim = int(dim_)
+ return data.size(dim)
+
+
+@torch.jit.script
+def batch_dim(data, mask, dims):
+ return data.dim()
+
+
+@torch.jit.script
+def batch_squeeze(data, mask, dims, dim_):
+ if int(dim_) < 0:
+ dim_ += data.dim()
+ dim = int(dim_)
+ # if dim == 0:
+ # raise ValueError("cannot do squeeze along batch_dim")
+ data = data.squeeze(dim)
+ mask = mask.squeeze(dim)
+ dims = torch.cat((dims[:dim - 1], dims[dim:dims.size(0)]))
+ return data, mask, dims
+
+
+@torch.jit.script
+def batch_unsqueeze(data, mask, dims, dim_):
+ if int(dim_) < 0:
+ dim_ += data.dim() + 1
+ dim = int(dim_)
+ # if dim == 0:
+ # raise ValueError("cannot do unsqueeze along batch_dim")
+ data = data.unsqueeze(dim)
+ mask = mask.unsqueeze(dim)
+ dims = torch.cat((dims[:dim], torch.zeros([1], dtype=torch.uint8), dims[dim:dims.size(0)]))
+ return data, mask, dims
+
+
+@torch.jit.script
+def batch_argmax(data, mask, dims, dim_, keepdim_):
+ dim = int(dim_)
+ keepdim = int(keepdim_)
+ # if dim == 0:
+ # raise ValueError("cannot do argmax along batch_dim")
+ batch_size = data.size(0)
+ res_data = torch.zeros([0])
+ for i in range(batch_size):
+ if dims[dim - 1]:
+ if dim - 1 != 0:
+ m = mask[i].transpose(0, dim - 1)
+ else:
+ m = mask[i]
+ valid_num = m.sum(0, keepdim=True)
+ while(valid_num.dim() >= 1):
+ valid_num = valid_num[0]
+ d = data[i].unsqueeze(0).narrow(dim, 0, int(valid_num))
+ else:
+ d = data[i].unsqueeze(0)
+ d = d.argmax(dim, keepdim)
+ if i == 0:
+ res_data = d
+ else:
+ res_data = torch.cat([res_data, d], 0)
+ if keepdim:
+ mask = mask
+ else:
+ mask = mask.select(dim, 0)
+ dims = torch.cat((dims[:dim - 1], dims[dim:dims.size(0)]))
+ return res_data, mask, dims
+
+
+@torch.jit.script
+def batch_topk(data, mask, dims, k_, dim_, largest_, sorted_):
+ k = int(k_)
+ dim = int(dim_)
+ largest = int(largest_)
+ sorted = int(sorted_)
+ # if dim == 0:
+ # raise ValueError("cannot do topk along batch_dim")
+ batch_size = data.size(0)
+ res_data = torch.zeros([0])
+ res_index = torch.zeros([0])
+ for i in range(batch_size):
+ if dims[dim - 1]:
+ if dim - 1 != 0:
+ m = mask[i].transpose(0, dim - 1)
+ else:
+ m = mask[i]
+ valid_num = m.sum(0, keepdim=True)
+ while(valid_num.dim() >= 1):
+ valid_num = valid_num[0]
+ d = data[i].unsqueeze(0).narrow(dim, 0, int(valid_num))
+ else:
+ d = data[i].unsqueeze(0)
+ d, idx = d.topk(k, dim, largest, sorted)
+ if i == 0:
+ res_data = d
+ res_index = idx
+ else:
+ res_data = torch.cat([res_data, d], 0)
+ res_index = torch.cat([res_index, idx], 0)
+ if dims[dim - 1]:
+ mask = mask.narrow(dim, 0, k)
+ return res_data, mask, dims, res_index, mask, dims
+
+
+@torch.jit.script
+def batch_softmax(data, mask, dims, dim_):
+ dim = int(dim_)
+ # if dim == 0:
+ # raise ValueError("cannot do softmax along batch_dim")
+ batch_size = data.size(0)
+ max_len = data.size(dim)
+ res_data = torch.zeros([0])
+ for i in range(batch_size):
+ if dims[dim - 1]:
+ if dim - 1 != 0:
+ m = mask[i].transpose(0, dim - 1)
+ else:
+ m = mask[i]
+ valid_num = m.sum(0, keepdim=True)
+ while(valid_num.dim() >= 1):
+ valid_num = valid_num[0]
+ valid_num = int(valid_num)
+ d = data[i].unsqueeze(0).narrow(dim, 0, valid_num).softmax(dim)
+ if valid_num < max_len:
+ d = torch.cat([d, data[i].unsqueeze(0).narrow(dim, valid_num, max_len - valid_num)], dim)
+ else:
+ d = data[i].unsqueeze(0).softmax(dim)
+ if i == 0:
+ res_data = d
+ else:
+ res_data = torch.cat([res_data, d], 0)
+ return res_data, mask, dims
+
+
+# size argument in dynamic dimension has to be -1
+# in static dimension, size has to be specified, -1 is not supported
+@torch.jit.script
+def batch_view(data, mask, dims, sizes):
+ batch_size = data.size(0)
+ # if(sizes[0] != batch_size and sizes[0] != -1 and sizes[0] != 1):
+ # raise "first dim in view must be 1, -1, or batch size"
+ # for i in range(dims.size(0)):
+ # if dims[0] == 1 and sizes[i + 1] != -1:
+ # raise "size argument in dynamic dimension has to be -1"
+ sizes = sizes.type_as(torch.ones([1], dtype=torch.int))
+ data_sizes_ = torch.cat([torch.ones([1], dtype=torch.int) * batch_size, sizes.narrow(0, 1, sizes.size(0) - 1)], 0)
+ data_sizes = data_sizes_._tensor_to_list()
+ res_data = data.view(data_sizes)
+ mask_sizes_ = data_sizes_.narrow(0, 0, 1)
+ res_dims = data_sizes_.narrow(0, 0, 1)
+ for i_ in range(sizes.size(0) - 1):
+ i = i_ + 1
+ if(sizes[i] == -1):
+ cur_size_ = mask.size(i)
+ cur_dim = 1
+ else:
+ cur_size_ = 1
+ cur_dim = 0
+ mask_sizes_ = torch.cat([mask_sizes_, torch.ones([1], dtype=torch.int) * cur_size_])
+ res_dims = torch.cat([res_dims, torch.ones([1], dtype=torch.int) * cur_dim])
+ mask_sizes = mask_sizes_._tensor_to_list()
+ res_mask = mask.view(mask_sizes)
+ return res_data, res_mask, res_dims.narrow(0, 1, res_dims.size(0) - 1).type_as(dims)
+
+
+@torch.jit.script
+def batch_cat2(data1, mask1, dims1, data2, mask2, dims2, dim_):
+ dim = int(dim_)
+ data = torch.cat([data1, data2], dim)
+ if(dims1[dim - 1]):
+ mask = torch.cat([mask1, mask2], dim)
+ else:
+ mask = mask1
+ return data, mask, dims1
+
+
+@torch.jit.script
+def batch_cat3(data1, mask1, dims1, data2, mask2, dims2, data3, mask3, dims3, dim_):
+ dim = int(dim_)
+ data = torch.cat([data1, data2, data3], dim)
+ if(dims1[dim - 1]):
+ mask = torch.cat([mask1, mask2, mask3], dim)
+ else:
+ mask = mask1
+ return data, mask, dims1
+
+
+@torch.jit.script
+def batch_narrow(data, mask, dims, dimension_, start_, length_):
+ dimension = int(dimension_)
+ start = int(start_)
+ length = int(length_)
+ # if dimension == 0:
+ # raise ValueError("cannot do narrow along batch_dim")
+ data = data.narrow(dimension, start, length)
+ if dims[dimension - 1]:
+ mask = mask.narrow(dimension, start, length)
+ else:
+ mask = mask.narrow(dimension, 0, 1)
+ return data, mask, dims
+
+
+@torch.jit.script
+def batch_sum(data, mask, dims):
+ data = data * mask.type_as(data)
+ for _ in range(dims.size(0)):
+ data = data.sum(1)
+ mask = torch.ones([data.size(0)], dtype=torch.uint8)
+ dims = dims[:0] # empty tensor
+ return data, mask, dims
+
+
+@torch.jit.script
+def batch_from_scalar_tensor(data):
+ data = data.unsqueeze(0)
+ mask = torch.ones([1], dtype=torch.uint8)
+ dims = torch.zeros([0], dtype=torch.uint8)
+ return data, mask, dims
+
torch.register_batch_operator("tanh", batch_tanh.graph)
torch.register_batch_operator("sigmoid", batch_sigmoid.graph)
+torch.register_batch_operator("relu", batch_relu.graph)
+torch.register_batch_operator("neg", batch_neg.graph)
+torch.register_batch_operator("neg", batch_neg_scalar.graph)
torch.register_batch_operator("add", batch_add.graph)
+torch.register_batch_operator("add", batch_add_scalar.graph)
+torch.register_batch_operator("sub", batch_sub.graph)
+torch.register_batch_operator("sub", batch_sub_scalar.graph)
torch.register_batch_operator("mul", batch_mul.graph)
+torch.register_batch_operator("mul", batch_mul_scalar.graph)
+torch.register_batch_operator("div", batch_div.graph)
torch.register_batch_operator("matmul", batch_matmul.graph)
torch.register_batch_operator("mm", batch_mm.graph)
+torch.register_batch_operator("fmod", batch_fmod.graph)
+torch.register_batch_operator("zeros_like", batch_zeros_like.graph)
torch.register_batch_operator("select", batch_select.graph)
+torch.register_batch_operator("index_select", batch_index_select.graph)
+torch.register_batch_operator("view_as", batch_view_as.graph)
torch.register_batch_operator("where", batch_where.graph)
+torch.register_batch_operator("where", batch_where_scalar.graph)
+torch.register_batch_operator("update", batch_update.graph)
+torch.register_batch_operator("any", batch_any.graph)
+torch.register_batch_operator("type_as", batch_type_as.graph)
+torch.register_batch_operator("gt", batch_gt.graph)
+torch.register_batch_operator("gt", batch_gt_scalar.graph)
+torch.register_batch_operator("gt", batch_gt_one_scalar.graph)
+torch.register_batch_operator("lt", batch_lt.graph)
+torch.register_batch_operator("eq", batch_eq.graph)
+torch.register_batch_operator("size", batch_size.graph)
+torch.register_batch_operator("dim", batch_dim.graph)
+torch.register_batch_operator("squeeze", batch_squeeze.graph)
+torch.register_batch_operator("unsqueeze", batch_unsqueeze.graph)
+torch.register_batch_operator("argmax", batch_argmax.graph)
+torch.register_batch_operator("topk", batch_topk.graph)
+torch.register_batch_operator("softmax", batch_softmax.graph)
+torch.register_batch_operator("view", batch_view.graph)
+torch.register_batch_operator("cat", batch_cat2.graph)
+torch.register_batch_operator("cat", batch_cat3.graph)
+torch.register_batch_operator("narrow", batch_narrow.graph)
+torch.register_batch_operator("sum", batch_sum.graph)
+torch.register_batch_operator("batch_from_scalar_tensor", batch_from_scalar_tensor.graph)