blob: 249280ab5584940d4023362994e53cabc2b8bae1 [file] [log] [blame]
from .module import Module
from typing import Union
from torch import Tensor
from torch import Size
class Flatten(Module):
r"""
Flattens a contiguous range of dims into a tensor. For use with :class:`~nn.Sequential`.
Shape:
- Input: :math:`(N, *dims)`
- Output: :math:`(N, \prod *dims)` (for the default case).
Args:
start_dim: first dim to flatten (default = 1).
end_dim: last dim to flatten (default = -1).
Examples::
>>> input = torch.randn(32, 1, 5, 5)
>>> m = nn.Sequential(
>>> nn.Conv2d(1, 32, 5, 1, 1),
>>> nn.Flatten()
>>> )
>>> output = m(input)
>>> output.size()
torch.Size([32, 288])
"""
__constants__ = ['start_dim', 'end_dim']
start_dim: int
end_dim: int
def __init__(self, start_dim: int = 1, end_dim: int = -1) -> None:
super(Flatten, self).__init__()
self.start_dim = start_dim
self.end_dim = end_dim
def forward(self, input: Tensor) -> Tensor:
return input.flatten(self.start_dim, self.end_dim)
def extra_repr(self) -> str:
return 'start_dim={}, end_dim={}'.format(
self.start_dim, self.end_dim
)
class Unflatten(Module):
r"""
Unflattens a tensor into another tensor of a desired shape. For use with :class:`~nn.Sequential`.
* :attr:`dim` specifies the dimension of the input tensor to be flattened, and it can
be either `str` or `int` when `NamedTensor` or `Tensor` is used, respectively.
* :attr:`unflattened_size` is the size of the unflattened dimension of the tensor and it can be a
`namedshape` (`tuple` of tuples) if :attr:`dim` is `str` or a `tuple` of ints as well as `torch.Size` if
:attr:`dim` is an `int`.
Shape:
- Input: :math:`(N, *dims)`
- Output: :math:`(N, C_{\text{out}}, H_{\text{out}}, W_{\text{out}})`
Args:
dim (Union[int, str]): Dimension to be flattened
unflattened_size (Union[tuple, torch.Size]): Size of the output tensor
Examples:
>>> input = torch.randn(2, 50)
>>> # With tuple of ints
>>> m = nn.Sequential(
>>> nn.Linear(50, 50),
>>> nn.Unflatten(1, (2, 5, 5))
>>> )
>>> output = m(output)
>>> output.size()
torch.Size([2, 2, 5, 5])
>>> # With torch.Size
>>> m = nn.Sequential(
>>> nn.Linear(50, 50),
>>> nn.Unflatten(1, torch.Size([2, 5, 5]))
>>> )
>>> output = m(output)
>>> output.size()
torch.Size([2, 2, 5, 5])
>>> # With namedshape (tuple of tuples)
>>> m = nn.Sequential(
>>> nn.Linear(50, 50),
>>> nn.Unflatten('features', (('C', 2), ('H', 50), ('W',50)))
>>> )
>>> output = m(output)
>>> output.size()
torch.Size([2, 2, 5, 5])
"""
__constants__ = ['dim', 'unflattened_size']
dim: Union[int, str]
unflattened_size: Union[tuple, Size]
def __init__(self, dim: Union[int, str], unflattened_size: Union[tuple, Size]) -> None:
super(Unflatten, self).__init__()
if isinstance(dim, int):
self._require_tuple_int(unflattened_size)
self.named = False
else:
self._require_tuple_tuple(unflattened_size)
self.named = True
self.dim = dim
self.unflattened_size = unflattened_size
def _require_tuple_tuple(self, input):
if (isinstance(input, tuple)):
for idx, elem in enumerate(input):
if not isinstance(elem, tuple):
raise TypeError("unflattened_size must be tuple of tuples, " +
"but found element of type {} at pos {}".format(type(elem).__name__, idx))
return
raise TypeError("unflattened_size must be a tuple of tuples, " +
"but found type {}".format(type(input).__name__))
def _require_tuple_int(self, input):
if (isinstance(input, tuple)):
for idx, elem in enumerate(input):
if not isinstance(elem, int):
raise TypeError("unflattened_size must be tuple of ints, " +
"but found element of type {} at pos {}".format(type(elem).__name__, idx))
return
raise TypeError("unflattened_size must be a tuple of ints, but found type {}".format(type(input).__name__))
def forward(self, input: Tensor) -> Tensor:
if self.named:
return input.unflatten(self.dim, self.unflattened_size)
else:
dim = int(self.dim)
if dim < 0:
dim += input.dim()
inp_size = list(input.size())
new_size = inp_size[:dim] + list(self.unflattened_size) + inp_size[dim + 1:]
return input.view(new_size)
def extra_repr(self) -> str:
return 'dim={}, unflattened_size={}'.format(
self.dim, self.unflattened_size
)