Added a flatten module (#22245)
Summary:
https://github.com/pytorch/pytorch/issues/2118
I'm not sure I'm doing it correctly, so I'll add tests if we decide that it's roughly correct.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/22245
Differential Revision: D16508957
Pulled By: Chillee
fbshipit-source-id: a8dc7af999ba698c921006889f71cb1bc5a59d50
diff --git a/docs/source/nn.rst b/docs/source/nn.rst
index 9ae1cae..3a6a4b1 100644
--- a/docs/source/nn.rst
+++ b/docs/source/nn.rst
@@ -869,3 +869,9 @@
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
.. autofunction:: torch.nn.utils.rnn.pack_sequence
+
+:hidden:`Flatten`
+~~~~~~~~~~~~~~~~~
+
+.. autoclass:: Flatten
+ :members:
diff --git a/test/common_nn.py b/test/common_nn.py
index c8fe5e1..4ad680e 100644
--- a/test/common_nn.py
+++ b/test/common_nn.py
@@ -105,6 +105,11 @@
input_size=(2, 3, 4, 5)
),
dict(
+ module_name='Flatten',
+ input_size=(2, 3, 4, 5),
+ reference_fn=lambda i, *_: torch.flatten(i, 1)
+ ),
+ dict(
module_name='Softmax',
constructor_args=(1,),
input_size=(10, 20),
diff --git a/torch/nn/modules/__init__.py b/torch/nn/modules/__init__.py
index 174cd91..3929e5b 100644
--- a/torch/nn/modules/__init__.py
+++ b/torch/nn/modules/__init__.py
@@ -28,7 +28,8 @@
from .fold import Fold, Unfold
from .adaptive import AdaptiveLogSoftmaxWithLoss
from .transformer import TransformerEncoder, TransformerDecoder, \
- TransformerEncoderLayer, TransformerDecoderLayer, Transformer
+ TransformerEncoderLayer, TransformerDecoderLayer, Transformer
+from .flatten import Flatten
__all__ = [
'Module', 'Identity', 'Linear', 'Conv1d', 'Conv2d', 'Conv3d', 'ConvTranspose1d',
@@ -50,6 +51,7 @@
'PairwiseDistance', 'AdaptiveMaxPool1d', 'AdaptiveMaxPool2d', 'AdaptiveMaxPool3d', 'AdaptiveAvgPool1d',
'AdaptiveAvgPool2d', 'AdaptiveAvgPool3d', 'TripletMarginLoss', 'ZeroPad2d', 'ConstantPad1d', 'ConstantPad2d',
'ConstantPad3d', 'Bilinear', 'CosineSimilarity', 'Unfold', 'Fold',
- 'AdaptiveLogSoftmaxWithLoss', 'TransformerEncoder', 'TransformerDecoder',
+ 'AdaptiveLogSoftmaxWithLoss', 'TransformerEncoder', 'TransformerDecoder',
'TransformerEncoderLayer', 'TransformerDecoderLayer', 'Transformer',
+ 'Flatten'
]
diff --git a/torch/nn/modules/flatten.py b/torch/nn/modules/flatten.py
new file mode 100644
index 0000000..d15cd75
--- /dev/null
+++ b/torch/nn/modules/flatten.py
@@ -0,0 +1,29 @@
+from .module import Module
+
+class Flatten(Module):
+ r"""
+ Flattens a contiguous range of dims into a tensor. For use with :class:`~nn.Sequential`.
+ Args:
+ start_dim: first dim to flatten (default = 1).
+ end_dim: last dim to flatten (default = -1).
+
+ Shape:
+ - Input: :math:`(N, *dims)`
+ - Output: :math:`(N, \prod *dims)` (for the default case).
+
+
+ Examples::
+ >>> m = nn.Sequential(
+ >>> nn.Conv2d(1, 32, 5, 1, 1),
+ >>> nn.Flatten()
+ >>> )
+ """
+ __constants__ = ['start_dim', 'end_dim']
+
+ def __init__(self, start_dim=1, end_dim=-1):
+ super(Flatten, self).__init__()
+ self.start_dim = start_dim
+ self.end_dim = end_dim
+
+ def forward(self, input):
+ return input.flatten(self.start_dim, self.end_dim)