blob: 22704607ddb045ef9cc567263b061928b2575e07 [file] [log] [blame]
import torch
from ._utils import _range
def split(tensor, split_size, dim=0):
if dim < 0:
dim += tensor.dim()
dim_size = tensor.size(dim)
num_splits = (dim_size + split_size - 1) // split_size
last_split_size = split_size - (split_size * num_splits - dim_size)
def get_split_size(i):
return split_size if i < num_splits-1 else last_split_size
return tuple(tensor.narrow(int(dim), int(i*split_size), int(get_split_size(i))) for i
in _range(0, num_splits))
def chunk(tensor, n_chunks, dim=0):
if dim < 0:
dim += tensor.dim()
split_size = (tensor.size(dim) + n_chunks - 1) // n_chunks
return split(tensor, split_size, dim)