| from typing import List |
| |
| from torch._six import container_abcs |
| from itertools import repeat |
| |
| |
| def _ntuple(n): |
| def parse(x): |
| if isinstance(x, container_abcs.Iterable): |
| return x |
| return tuple(repeat(x, n)) |
| return parse |
| |
| _single = _ntuple(1) |
| _pair = _ntuple(2) |
| _triple = _ntuple(3) |
| _quadruple = _ntuple(4) |
| |
| |
| def _reverse_repeat_tuple(t, n): |
| r"""Reverse the order of `t` and repeat each element for `n` times. |
| |
| This can be used to translate padding arg used by Conv and Pooling modules |
| to the ones used by `F.pad`. |
| """ |
| return tuple(x for x in reversed(t) for _ in range(n)) |
| |
| |
| def _list_with_default(out_size, defaults): |
| # type: (List[int], List[int]) -> List[int] |
| if isinstance(out_size, int): |
| return out_size |
| if len(defaults) <= len(out_size): |
| raise ValueError('Input dimension should be at least {}'.format(len(out_size) + 1)) |
| return [v if v is not None else d for v, d in zip(out_size, defaults[-len(out_size):])] |