use new overload mechanism for rnns (#29614)
Summary:
Uses new overload mechanism for rnns, making it so that python & torchscript go through the same path and using an API that is in line with the one specified
in https://docs.python.org/3/library/typing.html#typing.overload
This brings the TorchScriptable rnns closer to the base implementation; unifying them should be done in a follow up PR but there are still a few limitations that make it difficult to do so.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/29614
Differential Revision: D18486982
Pulled By: eellison
fbshipit-source-id: aaaea66a4a7f12d2e46199ca254f9e8f7475500e
diff --git a/torch/_jit_internal.py b/torch/_jit_internal.py
index d2f00de..278d769 100644
--- a/torch/_jit_internal.py
+++ b/torch/_jit_internal.py
@@ -7,7 +7,7 @@
import inspect
import weakref
import warnings
-import torch._C
+import torch
from torch._six import builtins
from torch._utils_internal import get_source_lines_and_file
@@ -676,7 +676,7 @@
# Retrieves a fully-qualified name (module hierarchy + classname) for a given obj.
def _qualified_name(obj):
# short-circuit in cases where the object already has a known qualified name
- if isinstance(obj, torch.jit.ScriptFunction):
+ if isinstance(obj, torch._C.ScriptFunction):
return obj.qualified_name
name = obj.__name__
diff --git a/torch/csrc/jit/python_ir.cpp b/torch/csrc/jit/python_ir.cpp
index e326704..49d5321 100644
--- a/torch/csrc/jit/python_ir.cpp
+++ b/torch/csrc/jit/python_ir.cpp
@@ -692,6 +692,8 @@
.def_static("get", &BoolType::get);
py::class_<StringType, Type, std::shared_ptr<StringType>>(m, "StringType")
.def_static("get", &StringType::get);
+ py::class_<NoneType, Type, std::shared_ptr<NoneType>>(m, "NoneType")
+ .def_static("get", &NoneType::get);
py::class_<TupleType, Type, std::shared_ptr<TupleType>>(m, "TupleType")
.def(
diff --git a/torch/jit/annotations.py b/torch/jit/annotations.py
index 120dfbb..32fbd30 100644
--- a/torch/jit/annotations.py
+++ b/torch/jit/annotations.py
@@ -7,7 +7,7 @@
BroadcastingList3, Tuple, is_tuple, is_list, Dict, is_dict, Optional, \
is_optional, _qualified_name, Any
from torch._C import TensorType, TupleType, FloatType, IntType, \
- ListType, StringType, DictType, BoolType, OptionalType, ClassType, InterfaceType, AnyType
+ ListType, StringType, DictType, BoolType, OptionalType, ClassType, InterfaceType, AnyType, NoneType
from textwrap import dedent
from torch._six import builtins
@@ -238,6 +238,8 @@
return BoolType.get()
elif ann is Any:
return AnyType.get()
+ elif ann is type(None):
+ return NoneType.get()
elif hasattr(ann, "__torch_script_class__"):
return ClassType(_qualified_name(ann))
elif hasattr(ann, "__torch_script_interface__"):
diff --git a/torch/nn/modules/rnn.py b/torch/nn/modules/rnn.py
index 662fcc3..a9a0d0d 100644
--- a/torch/nn/modules/rnn.py
+++ b/torch/nn/modules/rnn.py
@@ -500,8 +500,6 @@
>>> c0 = torch.randn(2, 3, 20)
>>> output, (hn, cn) = rnn(input, (h0, c0))
"""
- __overloads__ = {'forward': ['forward_packed', 'forward_tensor']}
-
def __init__(self, *args, **kwargs):
super(LSTM, self).__init__('LSTM', *args, **kwargs)
@@ -521,8 +519,29 @@
return hx
return apply_permutation(hx[0], permutation), apply_permutation(hx[1], permutation)
- def forward_impl(self, input, hx, batch_sizes, max_batch_size, sorted_indices):
- # type: (Tensor, Optional[Tuple[Tensor, Tensor]], Optional[Tensor], int, Optional[Tensor]) -> Tuple[Tensor, Tuple[Tensor, Tensor]] # noqa
+ @torch._jit_internal._overload_method # noqa: F811
+ def forward(self, input, hx=None): # noqa: F811
+ # type: (Tensor, Optional[Tuple[Tensor, Tensor]]) -> Tuple[Tensor, Tuple[Tensor, Tensor]]
+ pass
+
+ @torch._jit_internal._overload_method # noqa: F811
+ def forward(self, input, hx=None): # noqa: F811
+ # type: (PackedSequence, Optional[Tuple[Tensor, Tensor]]) -> Tuple[PackedSequence, Tuple[Tensor, Tensor]] # noqa
+ pass
+
+ def forward(self, input, hx=None): # noqa: F811
+ orig_input = input
+ # xxx: isinstance check needs to be in conditional for TorchScript to compile
+ if isinstance(orig_input, PackedSequence):
+ input, batch_sizes, sorted_indices, unsorted_indices = input
+ max_batch_size = batch_sizes[0]
+ max_batch_size = int(max_batch_size)
+ else:
+ batch_sizes = None
+ max_batch_size = input.size(0) if self.batch_first else input.size(1)
+ sorted_indices = None
+ unsorted_indices = None
+
if hx is None:
num_directions = 2 if self.bidirectional else 1
zeros = torch.zeros(self.num_layers * num_directions,
@@ -543,39 +562,12 @@
self.num_layers, self.dropout, self.training, self.bidirectional)
output = result[0]
hidden = result[1:]
-
- return output, hidden
-
- @torch._jit_internal.export
- def forward_tensor(self, input, hx=None):
- # type: (Tensor, Optional[Tuple[Tensor, Tensor]]) -> Tuple[Tensor, Tuple[Tensor, Tensor]]
- batch_sizes = None
- max_batch_size = input.size(0) if self.batch_first else input.size(1)
- sorted_indices = None
- unsorted_indices = None
-
- output, hidden = self.forward_impl(input, hx, batch_sizes, max_batch_size, sorted_indices)
-
- return output, self.permute_hidden(hidden, unsorted_indices)
-
- @torch._jit_internal.export
- def forward_packed(self, input, hx=None):
- # type: (PackedSequence, Optional[Tuple[Tensor, Tensor]]) -> Tuple[PackedSequence, Tuple[Tensor, Tensor]] # noqa
- input, batch_sizes, sorted_indices, unsorted_indices = input
- max_batch_size = batch_sizes[0]
- max_batch_size = int(max_batch_size)
-
- output, hidden = self.forward_impl(input, hx, batch_sizes, max_batch_size, sorted_indices)
-
- output = PackedSequence(output, batch_sizes, sorted_indices, unsorted_indices)
- return output, self.permute_hidden(hidden, unsorted_indices)
-
- @torch._jit_internal.ignore
- def forward(self, input, hx=None):
- if isinstance(input, PackedSequence):
- return self.forward_packed(input, hx)
+ # xxx: isinstance check needs to be in conditional for TorchScript to compile
+ if isinstance(orig_input, PackedSequence):
+ output_packed = PackedSequence(output, batch_sizes, sorted_indices, unsorted_indices)
+ return output_packed, self.permute_hidden(hidden, unsorted_indices)
else:
- return self.forward_tensor(input, hx)
+ return output, self.permute_hidden(hidden, unsorted_indices)
class GRU(RNNBase):
@@ -682,23 +674,32 @@
>>> h0 = torch.randn(2, 3, 20)
>>> output, hn = rnn(input, h0)
"""
- __overloads__ = {'forward': ['forward_packed', 'forward_tensor']}
-
def __init__(self, *args, **kwargs):
super(GRU, self).__init__('GRU', *args, **kwargs)
- def run_impl(self, input, hx, batch_sizes):
- # type: (Tensor, Tensor, Optional[Tensor]) -> Tuple[Tensor, Tensor]
- if batch_sizes is None:
- result = _VF.gru(input, hx, self._flat_weights, self.bias, self.num_layers,
- self.dropout, self.training, self.bidirectional, self.batch_first)
- else:
- result = _VF.gru(input, batch_sizes, hx, self._flat_weights, self.bias,
- self.num_layers, self.dropout, self.training, self.bidirectional)
- return result
+ @torch._jit_internal._overload_method # noqa: F811
+ def forward(self, input, hx=None): # noqa: F811
+ # type: (Tensor, Optional[Tensor]) -> Tuple[Tensor, Tensor]
+ pass
- def forward_impl(self, input, hx, batch_sizes, max_batch_size, sorted_indices):
- # type: (Tensor, Optional[Tensor], Optional[Tensor], int, Optional[Tensor]) -> Tuple[Tensor, Tensor] # noqa
+ @torch._jit_internal._overload_method # noqa: F811
+ def forward(self, input, hx=None): # noqa: F811
+ # type: (PackedSequence, Optional[Tensor]) -> Tuple[PackedSequence, Tensor]
+ pass
+
+ def forward(self, input, hx=None): # noqa: F811
+ orig_input = input
+ # xxx: isinstance check needs to be in conditional for TorchScript to compile
+ if isinstance(orig_input, PackedSequence):
+ input, batch_sizes, sorted_indices, unsorted_indices = input
+ max_batch_size = batch_sizes[0]
+ max_batch_size = int(max_batch_size)
+ else:
+ batch_sizes = None
+ max_batch_size = input.size(0) if self.batch_first else input.size(1)
+ sorted_indices = None
+ unsorted_indices = None
+
if hx is None:
num_directions = 2 if self.bidirectional else 1
hx = torch.zeros(self.num_layers * num_directions,
@@ -710,37 +711,21 @@
hx = self.permute_hidden(hx, sorted_indices)
self.check_forward_args(input, hx, batch_sizes)
- result = self.run_impl(input, hx, batch_sizes)
+ if batch_sizes is None:
+ result = _VF.gru(input, hx, self._flat_weights, self.bias, self.num_layers,
+ self.dropout, self.training, self.bidirectional, self.batch_first)
+ else:
+ result = _VF.gru(input, batch_sizes, hx, self._flat_weights, self.bias,
+ self.num_layers, self.dropout, self.training, self.bidirectional)
output = result[0]
hidden = result[1]
- return output, hidden
- @torch._jit_internal.export
- def forward_packed(self, input, hx=None):
- # type: (PackedSequence, Optional[Tensor]) -> Tuple[PackedSequence, Tensor]
- input, batch_sizes, sorted_indices, unsorted_indices = input
- max_batch_size = batch_sizes[0]
- max_batch_size = int(max_batch_size)
- output, hidden = self.forward_impl(input, hx, batch_sizes, max_batch_size, sorted_indices)
- output = PackedSequence(output, batch_sizes, sorted_indices, unsorted_indices)
- return output, self.permute_hidden(hidden, unsorted_indices)
-
- @torch._jit_internal.export
- def forward_tensor(self, input, hx=None):
- # type: (Tensor, Optional[Tensor]) -> Tuple[Tensor, Tensor]
- batch_sizes = None
- max_batch_size = input.size(0) if self.batch_first else input.size(1)
- sorted_indices = None
- unsorted_indices = None
- output, hidden = self.forward_impl(input, hx, batch_sizes, max_batch_size, sorted_indices)
- return output, self.permute_hidden(hidden, unsorted_indices)
-
- @torch._jit_internal.ignore
- def forward(self, input, hx=None):
- if isinstance(input, PackedSequence):
- return self.forward_packed(input, hx)
+ # xxx: isinstance check needs to be in conditional for TorchScript to compile
+ if isinstance(orig_input, PackedSequence):
+ output_packed = PackedSequence(output, batch_sizes, sorted_indices, unsorted_indices)
+ return output_packed, self.permute_hidden(hidden, unsorted_indices)
else:
- return self.forward_tensor(input, hx)
+ return output, self.permute_hidden(hidden, unsorted_indices)
class RNNCellBase(Module):