Port PackedSequences functions to C++ (#11224)

Summary:
zdevito
Pull Request resolved: https://github.com/pytorch/pytorch/pull/11224

Differential Revision: D9652703

Pulled By: apaszke

fbshipit-source-id: 558e39457e590cad07516e5bb2ecb12789564950
diff --git a/aten/src/ATen/native/PackedSequence.cpp b/aten/src/ATen/native/PackedSequence.cpp
new file mode 100644
index 0000000..9e66095
--- /dev/null
+++ b/aten/src/ATen/native/PackedSequence.cpp
@@ -0,0 +1,166 @@
+#include "ATen/ATen.h"
+#include "ATen/NativeFunctions.h"
+
+namespace at { namespace native {
+
+void checkLongTensor(const Tensor& tensor) {
+  auto & t = tensor.type();
+  AT_CHECK(tensor.dim() == 1 && t.device_type() == at::kCPU && t.scalarType() == at::kLong,
+           "'lengths' argument should be a 1D CPU int64 tensor");
+}
+
+std::tuple<Tensor, Tensor> _pack_padded_sequence(const Tensor& _input, const Tensor& _lengths, bool batch_first) {
+  auto input = batch_first ? _input.transpose(0, 1) : _input;
+  auto lengths_t = _lengths.contiguous();
+  checkLongTensor(lengths_t);
+
+  int64_t batch_size = input.size(1);
+  int64_t * lengths = lengths_t.data<int64_t>();
+  AT_CHECK(lengths_t.size(0) == batch_size,
+           "Expected `len(lengths)` to be equal to batch_size, but got ", lengths_t.size(0),
+           " (batch_size=", batch_size, ")");
+  AT_CHECK(lengths[batch_size - 1] > 0,
+           "Length of all samples has to be greater than 0, but found an element "
+           "in 'lengths' that is <= 0");
+
+  std::vector<at::Tensor> steps;
+  steps.reserve(batch_size);
+  at::Tensor batch_sizes_t = at::empty(lengths[0], _lengths.options());
+  int64_t * batch_sizes = batch_sizes_t.data<int64_t>();
+
+  std::vector<int64_t> step_shape; // == [-1, *input.shape[2:]]
+  {
+    auto input_sizes = input.sizes();
+    step_shape.reserve(input_sizes.size());
+    auto s_input_sizes = input_sizes.slice(2);
+    step_shape.push_back(-1);
+    step_shape.insert(step_shape.end(), s_input_sizes.begin(), s_input_sizes.end());
+  }
+
+  // To understand what's going on in this loop imagine that the input is a padded 2D
+  // array that looks like this (x = valid entry, . = padding)
+  //
+  //  1 1 1 1 1
+  //  2 2 2 . .
+  //  2 2 2 . .
+  //  4 . . . .
+  //  4 . . . .
+  //
+  // Where the vertical dimension corresponds to time, and horizontal dim to batch.
+  // In this example, the lengths array will be equal to [5, 3, 3, 1, 1], and we will
+  // iterate over them in reverse order (from the rightmost column to the left).
+  // We want to avoid eager slicing of the input at every time step, and wait for
+  // the moments where the length increases. In this example, that will happen at the
+  // first, second and fourth steps. Then, we slice out the whole block of the input
+  // that corresponds to this length, and hasn't been sliced yet (the steps at which each
+  // element is sliced are annotated in the array above).  You can think of this as if we
+  // were scanning the sequences from the shortest one, and every time we realize there's
+  // more elements below in our column, we lower the counter (prev_l), and append the new
+  // block to the output.
+  int64_t prev_l = 0;
+  for (int64_t i = 0; i < batch_size; ++i) {
+    int64_t l = lengths[batch_size - 1 - i];
+    if (l > prev_l) {
+      auto current_batch_size = batch_size - i;
+      steps.push_back(input.slice(0, prev_l, l).slice(1, 0, current_batch_size).contiguous().view(step_shape));
+      for (int64_t j = 0; j < (l - prev_l); ++j) {
+        (*batch_sizes++) = current_batch_size;
+      }
+      prev_l = l;
+    } else if (prev_l > l) {
+      AT_ERROR("'lengths' array has to be sorted in decreasing order");
+    }
+  }
+
+  return std::make_tuple(at::cat(steps), batch_sizes_t);
+}
+
+Tensor _pack_padded_sequence_backward(const Tensor& grad, at::IntList input_size, const Tensor& _batch_sizes, bool batch_first) {
+  std::vector<int64_t> input_size_after_t = input_size.vec();
+  if (batch_first) {
+    AT_CHECK(input_size.size() >= 2);
+    std::swap(input_size_after_t[0], input_size_after_t[1]);
+  }
+  auto grad_input = at::zeros(input_size_after_t, grad.options());
+  auto batch_sizes_t = _batch_sizes.contiguous();
+  checkLongTensor(batch_sizes_t);
+
+  int64_t offset = 0;
+  int64_t max_seq_len = batch_sizes_t.size(0);
+  int64_t * batch_sizes = batch_sizes_t.data<int64_t>();
+  for (int64_t i = 0; i < max_seq_len; ++i) {
+    grad_input[i].slice(0, 0, batch_sizes[i]).copy_(grad.slice(0, offset, offset + batch_sizes[i]));
+    offset += batch_sizes[i];
+  }
+
+  if (batch_first) {
+    grad_input = grad_input.transpose(0, 1);
+  }
+
+  return grad_input;
+}
+
+std::tuple<Tensor, Tensor> _pad_packed_sequence(const Tensor& data, const Tensor& _batch_sizes, bool batch_first, Scalar padding_value, int64_t total_length) {
+  auto batch_sizes_t = _batch_sizes.contiguous();
+  checkLongTensor(batch_sizes_t);
+
+  int64_t * batch_sizes = batch_sizes_t.data<int64_t>();
+  int64_t max_batch_size = batch_sizes[0];
+  int64_t max_real_seq_length = batch_sizes_t.size(0);
+  int64_t max_seq_length = max_real_seq_length;
+  if (total_length > 0) {
+    AT_CHECK(total_length >= max_seq_length,
+             "Expected total_length to be at least the length of the longest "
+             "sequence in input, but got total_length=", total_length, " and "
+             "max sequence length being ", max_seq_length);
+    max_seq_length = total_length;
+  }
+
+  std::vector<int64_t> output_size; // == [max_seq_length, max_batch_size, *var_data.size()[1:]]
+  {
+    output_size.reserve(data.dim() + 1);
+    output_size.push_back(max_seq_length);
+    output_size.push_back(max_batch_size);
+    auto s_data_size = data.sizes().slice(1);
+    output_size.insert(output_size.end(), s_data_size.begin(), s_data_size.end());
+  }
+  auto output = at::full(output_size, padding_value, data.options());
+
+  // This will be modified at every iteration, but we reserve memory for it now.
+  std::vector<int64_t> tmp_view_size = std::move(output_size); // == [-1, -1, *var_data.size()[1:]]
+
+  at::Tensor lengths_t = at::empty(max_batch_size, batch_sizes_t.options());
+  int64_t * lengths = lengths_t.data<int64_t>() + max_batch_size - 1;
+  int64_t data_offset = 0;
+  int64_t prev_batch_size = max_batch_size;
+  int64_t prev_i = 0;
+  for (int64_t i = 0; i <= max_real_seq_length; ++i) {
+    int64_t batch_size = i != max_real_seq_length ? batch_sizes[i] : 0;
+    if (batch_size != prev_batch_size) {
+      int64_t l = prev_batch_size * (i - prev_i);
+      // The lines below are equivalent to this:
+      // output[prev_i:i, :prev_batch_size] = tmp.view(i - prev_i, prev_batch_size, *input.shape[2:])
+      auto tmp = data.slice(0, data_offset, data_offset + l);
+      tmp_view_size[0] = i - prev_i;
+      tmp_view_size[1] = prev_batch_size;
+      output.slice(0, prev_i, i).slice(1, 0, prev_batch_size).copy_(tmp.view(tmp_view_size));
+      data_offset += l;
+      prev_i = i;
+    }
+    int64_t dec = prev_batch_size - batch_size;
+    if (dec > 0) {
+      for (int64_t j = 0; j < dec; ++j) {
+        (*lengths--) = i;
+      }
+    }
+    prev_batch_size = batch_size;
+  }
+
+  if (batch_first) {
+    output = output.transpose(0, 1);
+  }
+
+  return std::make_tuple(output, lengths_t);
+}
+
+}} // namespace at::native
diff --git a/aten/src/ATen/native/native_functions.yaml b/aten/src/ATen/native/native_functions.yaml
index 01dfb2e..4b5e758 100644
--- a/aten/src/ATen/native/native_functions.yaml
+++ b/aten/src/ATen/native/native_functions.yaml
@@ -2211,3 +2211,13 @@
 
 - func: rnn_relu_cell(Tensor input, Tensor hx, Tensor w_ih, Tensor w_hh, Tensor? b_ih={}, Tensor? b_hh={}) -> Tensor
   variants: function
+
+# PackedSequence utilities
+- func: _pack_padded_sequence(Tensor input, Tensor lengths, bool batch_first) -> (Tensor, Tensor)
+  variants: function
+
+- func: _pack_padded_sequence_backward(Tensor grad, IntList input_size, Tensor batch_sizes, bool batch_first) -> Tensor
+  variants: function
+
+- func: _pad_packed_sequence(Tensor data, Tensor batch_sizes, bool batch_first, Scalar padding_value, int64_t total_length) -> (Tensor, Tensor)
+  variants: function
diff --git a/test/test_jit.py b/test/test_jit.py
index 7d4cbef..d2416a7 100644
--- a/test/test_jit.py
+++ b/test/test_jit.py
@@ -4095,26 +4095,6 @@
         f = io.BytesIO()
         torch.onnx._export(m, (x, seq_lens), f, verbose=False)
 
-    def test_pack_padded_wrong_types(self):
-        from torch.nn.utils.rnn import pack_padded_sequence, pad_packed_sequence
-
-        class PackPaddedWrapper(torch.nn.Module):
-            def __init__(self):
-                super(PackPaddedWrapper, self).__init__()
-                self.seq_lens = [3, 3, 3, 3]
-
-            __constants__ = ['seq_lens']
-
-            def forward(self, x):
-                return pack_padded_sequence(x, self.seq_lens)
-
-        m = PackPaddedWrapper()
-
-        x = torch.rand(3, 4, 5)
-        f = io.BytesIO()
-        with self.assertRaisesRegex(RuntimeError, 'PackPadded requires `lengths` to be a Tensor'):
-            torch.onnx._export(m, (x,), f)
-
     def test_script_outputs(self):
         with self.assertRaisesRegex(RuntimeError, "cannot be used as a tuple"):
             @torch.jit.script
diff --git a/tools/autograd/derivatives.yaml b/tools/autograd/derivatives.yaml
index 6127eb7..67d9634 100644
--- a/tools/autograd/derivatives.yaml
+++ b/tools/autograd/derivatives.yaml
@@ -1259,3 +1259,7 @@
 
 - name: _thnn_fused_gru_cell(Tensor input_gates, Tensor hidden_gates, Tensor hx, Tensor input_bias, Tensor hidden_bias)
   input_gates, hidden_gates, hx, input_bias, hidden_bias: _thnn_fused_gru_cell_backward(grad, result1, input_bias.defined())
+
+# PackedSequence helpers
+- name: _pack_padded_sequence(Tensor input, Tensor lengths, bool batch_first)
+  input: _pack_padded_sequence_backward(grad, input.sizes(), result1, batch_first)
diff --git a/torch/nn/_functions/packing.py b/torch/nn/_functions/packing.py
deleted file mode 100644
index a45cb63..0000000
--- a/torch/nn/_functions/packing.py
+++ /dev/null
@@ -1,56 +0,0 @@
-import torch
-from torch.autograd import Function
-
-
-class PackPadded(Function):
-    @staticmethod
-    def forward(ctx, input, lengths, batch_first):
-        if batch_first:
-            input = input.transpose(0, 1)
-
-        if lengths[-1] <= 0:
-            raise ValueError("Length of all samples has to be greater than 0, "
-                             "but found an element in 'lengths' that is <= 0")
-
-        steps = []
-        batch_sizes = []
-
-        # lengths is a Tensor, so we must convert to [int] before reversed()
-        lengths_iter = reversed(lengths.tolist())
-
-        batch_size = input.size(1)
-
-        if len(lengths) != batch_size:
-            raise ValueError("Expected `len(lengths)` to be equal to batch_size, but got "
-                             "{} (batch_size={}).".format(len(lengths), batch_size))
-
-        prev_l = 0
-        for i, l in enumerate(lengths_iter):
-            if l > prev_l:
-                c_batch_size = batch_size - i
-                steps.append(input[prev_l:l, :c_batch_size].contiguous().view(-1, *input.size()[2:]))
-                batch_sizes.extend([c_batch_size] * (l - prev_l))
-                prev_l = l
-
-            elif prev_l > l:
-                raise ValueError("'lengths' array has to be sorted in decreasing order")
-
-        ctx.batch_sizes = batch_sizes
-        ctx.batch_first = batch_first
-        ctx.input_size = input.size()
-
-        return torch.cat(steps), torch.LongTensor(batch_sizes)
-
-    @staticmethod
-    def backward(ctx, grad_steps, grad_batch_sizes):
-        grad_input = grad_steps.new(*ctx.input_size).zero_()
-
-        offset = 0
-        for i, bs in enumerate(ctx.batch_sizes):
-            grad_input[i, :bs] = grad_steps[offset:offset + bs]
-            offset += bs
-
-        if ctx.batch_first:
-            grad_input = grad_input.transpose(0, 1)
-
-        return grad_input, None, None
diff --git a/torch/nn/utils/rnn.py b/torch/nn/utils/rnn.py
index bf8a601..b61fdab 100644
--- a/torch/nn/utils/rnn.py
+++ b/torch/nn/utils/rnn.py
@@ -1,11 +1,9 @@
 from collections import namedtuple
+import warnings
 
 import torch
-import torch.onnx
 
 
-from .._functions.packing import PackPadded
-
 PackedSequence_ = namedtuple('PackedSequence', ['data', 'batch_sizes'])
 
 
@@ -140,45 +138,14 @@
     Returns:
         a :class:`PackedSequence` object
     """
-    if isinstance(lengths, list):
-        lengths = torch.LongTensor(lengths)
-
-    data, batch_sizes = PackPadded.apply(input, lengths, batch_first)
-
-    return PackedSequence(data, batch_sizes)
-
-
-def _symbolic_pack_padded_sequence(g, input, lengths, batch_first=False, padding_value=0.0):
-    # There currently is no PackPadded operator in ONNX. We rely on an
-    # optimization pass to remove this later. It is an error if all
-    # PackPadded operators cannot be optimized out.
-
-    if not isinstance(input, torch._C.Value):
-        raise RuntimeError("PackPadded requires `input` to be a Tensor")
-    if not isinstance(lengths, torch._C.Value):
-        raise RuntimeError("PackPadded requires `lengths` to be a Tensor")
-
-    def _onnx_symbolic_pack_padded_sequence(g, input, lengths):
-        if batch_first:
-            input = g.op('Transpose', input, perm_i=[1, 0, 2])
-        if not lengths.type().isSubtypeOf(torch._C.DynamicType.get()):
-            raise RuntimeError("Lengths must be a Tensor for ONNX export")
-        # We know it's a TensorType so this check is now safe.
-        if lengths.type().scalarType() != 'Int':
-            raise RuntimeError("ONNX export requires that the lengths passed "
-                               "to pack_padded_sequence must be of type Int")
-        return g.op("prim::PackPadded", input, lengths, outputs=2)
-
-    def pack_padded_sequence_trace_wrapper(input, lengths):
-        return pack_padded_sequence(input, lengths, batch_first=batch_first)
-
-    outputs = g.wrapPyFuncWithSymbolic(
-        pack_padded_sequence_trace_wrapper, [input, lengths], 2,
-        _onnx_symbolic_pack_padded_sequence)
-    return tuple(o for o in outputs)
-
-
-pack_padded_sequence = torch.onnx.symbolic_override(_symbolic_pack_padded_sequence)(pack_padded_sequence)
+    if torch._C._get_tracing_state() and not isinstance(lengths, torch.Tensor):
+        warnings.warn('pack_padded_sequence has been called with a Python list of '
+                      'sequence lengths. The tracer cannot track the data flow of Python '
+                      'values, and it will treat them as constants, likely rendering '
+                      'the trace incorrect for any other combination of lengths.',
+                      category=torch.jit.TracerWarning, stacklevel=2)
+    lengths = torch.as_tensor(lengths, dtype=torch.int64)
+    return PackedSequence(torch._C._VariableFunctions._pack_padded_sequence(input, lengths, batch_first))
 
 
 def pad_packed_sequence(sequence, batch_first=False, padding_value=0.0, total_length=None):
@@ -214,9 +181,7 @@
         containing the list of lengths of each sequence in the batch.
 
     """
-    var_data, batch_sizes = sequence
-    max_batch_size = int(batch_sizes[0])
-    max_seq_length = batch_sizes.size(0)
+    max_seq_length = sequence.batch_sizes.size(0)
     if total_length is not None:
         if total_length < max_seq_length:
             raise ValueError("Expected total_length to be at least the length "
@@ -224,56 +189,8 @@
                              "total_length={} and max sequence length being {}"
                              .format(total_length, max_seq_length))
         max_seq_length = total_length
-    output = var_data.data.new(max_seq_length, max_batch_size, *var_data.size()[1:]).fill_(padding_value)
-
-    lengths = []
-    data_offset = 0
-    prev_batch_size = int(batch_sizes[0])
-    prev_i = 0
-    for i, batch_size in enumerate(batch_sizes.tolist() + [0]):
-        if batch_size != prev_batch_size:
-            l = prev_batch_size * (i - prev_i)
-            tmp = var_data[data_offset:data_offset + l]
-            output[prev_i:i, :prev_batch_size] = tmp.view(i - prev_i, prev_batch_size, *tmp.size()[1:])
-            data_offset += l
-            prev_i = i
-        dec = prev_batch_size - batch_size
-        if dec > 0:
-            lengths.extend((i,) * dec)
-        prev_batch_size = batch_size
-
-    lengths.reverse()
-
-    if batch_first:
-        output = output.transpose(0, 1)
-    # This Tensor doesn't actually have any history (well,
-    # technically it does; it's just untracked), it is purely here to
-    # make ONNX export easier. That is to say, from an autodiff
-    # standpoint this doesn't make any sense.
-    return output, torch.LongTensor(lengths)
-
-
-def _symbolic_pad_packed_sequence(g, input, batch_first=False, padding_value=0.0, total_length=None):
-    # Ignore total_length as it is not supported in _symbolic_pad_packed_sequence
-    # It is only useful/used when training using data_parallel model, so
-    # It shouldn't be relevant for ONNX anyway
-    def _onnx_symbolic_pad_packed_sequence(g, data, batch_sizes):
-        data, lengths = g.op("prim::PadPacked", data, batch_sizes, outputs=2)
-        if batch_first:
-            data = g.op('Transpose', data, perm_i=[1, 0, 2])
-        return data, lengths
-
-    def pad_packed_sequence_trace_wrapper(data, batch_sizes):
-        return pad_packed_sequence(PackedSequence(data, batch_sizes),
-                                   batch_first=batch_first, padding_value=padding_value)
-
-    data, lengths = g.wrapPyFuncWithSymbolic(
-        pad_packed_sequence_trace_wrapper, [input.data, input.batch_sizes], 2,
-        _onnx_symbolic_pad_packed_sequence)
-    return data, lengths
-
-
-pad_packed_sequence = torch.onnx.symbolic_override(_symbolic_pad_packed_sequence)(pad_packed_sequence)
+    return torch._C._VariableFunctions._pad_packed_sequence(
+        sequence.data, sequence.batch_sizes, batch_first, padding_value, max_seq_length)
 
 
 def pad_sequence(sequences, batch_first=False, padding_value=0):
diff --git a/torch/onnx/symbolic.py b/torch/onnx/symbolic.py
index e214ba2..615530e 100644
--- a/torch/onnx/symbolic.py
+++ b/torch/onnx/symbolic.py
@@ -1163,3 +1163,31 @@
 def detach(g, input):
     # Erase aten::detach nodes because ONNX is inference only
     return input
+
+
+@parse_args('v', 'v', 'i')
+def _pack_padded_sequence(g, input, lengths, batch_first):
+    # There currently is no PackPadded operator in ONNX. We rely on an
+    # optimization pass to remove this later. It is an error if all
+    # PackPadded operators cannot be optimized out.
+    if batch_first:
+        input = g.op('Transpose', input, perm_i=[1, 0, 2])
+    if not lengths.type().isSubtypeOf(torch._C.DynamicType.get()):
+        raise RuntimeError("Lengths must be a Tensor for ONNX export")
+    # We know it's a TensorType so this check is now safe.
+    # It's really only necessary beacuse those operators expand to something that
+    # only works with int32 types in Caffe2...
+    if lengths.type().scalarType() != 'Int':
+        lengths = _cast_Int(g, lengths, False)
+    return g.op("prim::PackPadded", input, lengths, outputs=2)
+
+
+@parse_args('v', 'v', 'i', 't', 'i')
+def _pad_packed_sequence(g, data, batch_sizes, batch_first, padding_value, total_length):
+    # Ignore total_length as it is not supported in _symbolic_pad_packed_sequence
+    # It is only useful/used when training using data_parallel model, so
+    # It shouldn't be relevant for ONNX anyway
+    data, lengths = g.op("prim::PadPacked", data, batch_sizes, outputs=2)
+    if batch_first:
+        data = g.op('Transpose', data, perm_i=[1, 0, 2])
+    return data, lengths