| from ._utils import _range |
| def split(tensor, split_size, dim=0): |
| dim_size = tensor.size(dim) |
| num_splits = (dim_size + split_size - 1) // split_size |
| last_split_size = split_size - (split_size * num_splits - dim_size) |
| return split_size if i < num_splits-1 else last_split_size |
| return tuple(tensor.narrow(int(dim), int(i*split_size), int(get_split_size(i))) for i |
| in _range(0, num_splits)) |
| def chunk(tensor, n_chunks, dim=0): |
| split_size = (tensor.size(dim) + n_chunks - 1) // n_chunks |
| return split(tensor, split_size, dim) |