| from torch._six import container_abcs |
| from itertools import repeat |
| |
| |
| def _ntuple(n): |
| def parse(x): |
| if isinstance(x, container_abcs.Iterable): |
| return x |
| return tuple(repeat(x, n)) |
| return parse |
| |
| _single = _ntuple(1) |
| _pair = _ntuple(2) |
| _triple = _ntuple(3) |
| _quadruple = _ntuple(4) |
| |
| |
| def _list_with_default(out_size, defaults): |
| if isinstance(out_size, int): |
| return out_size |
| if len(defaults) <= len(out_size): |
| raise ValueError('Input dimension should be at least {}'.format(len(out_size) + 1)) |
| return [v if v is not None else d for v, d in zip(out_size, defaults[-len(out_size):])] |