| import torch |
| from . import TensorPrinting |
| from functools import reduce |
| from itertools import chain |
| import sys |
| import math |
| |
| |
| def _infer_sizes(sizes, total): |
| to_infer = -1 |
| total_sizes = 1 |
| for i, size in enumerate(sizes): |
| total_sizes *= size |
| if size == -1: |
| if to_infer >= 0: |
| raise RuntimeError |
| to_infer = i |
| if to_infer >= 0: |
| assert total % total_sizes == 0, "Can't make sizes have exactly %d elements" % total |
| sizes[to_infer] = -total / total_sizes |
| return sizes |
| |
| |
| class _TensorBase(object): |
| def new(self, *args, **kwargs): |
| return self.__class__(*args, **kwargs) |
| |
| def type(self, t=None): |
| if isinstance(t, str) or t is None: |
| current = self.__module__ + '.' + self.__class__.__name__ |
| if t is None: |
| return current |
| if t == current: |
| return self |
| _, _, typename = t.partition('.') |
| return torch._import_dotted_name(t)(self.size()).copy_(self) |
| else: |
| if t == type(self): |
| return self |
| return t(self.size()).copy_(self) |
| |
| def typeAs(self, t): |
| return self.type(t.type()) |
| |
| def double(self): |
| return self.type(torch.DoubleTensor) |
| |
| def float(self): |
| return self.type(torch.FloatTensor) |
| |
| def long(self): |
| return self.type(torch.LongTensor) |
| |
| def int(self): |
| return self.type(torch.IntTensor) |
| |
| def short(self): |
| return self.type(torch.ShortTensor) |
| |
| def char(self): |
| return self.type(torch.CharTensor) |
| |
| def byte(self): |
| return self.type(torch.ByteTensor) |
| |
| def copy_(self, other): |
| torch._C._tensorCopy(self, other) |
| return self |
| |
| def __deepcopy__(self, _memo): |
| memo = _memo.setdefault('torch', {}) |
| if self._cdata in memo: |
| return memo[self._cdata] |
| new_storage = self.storage().__deepcopy__(_memo) |
| new_tensor = self.new() |
| new_tensor.set_(new_storage, self.storageOffset(), self.size(), self.stride()) |
| memo[self._cdata] = new_tensor |
| return new_tensor |
| |
| def __reduce__(self): |
| return type(self), (self.tolist(),) |
| |
| def __repr__(self): |
| return str(self) |
| |
| def __str__(self): |
| return TensorPrinting.printTensor(self) |
| |
| def __iter__(self): |
| return iter(map(lambda i: self.select(0, i), torch._pyrange(self.size(0)))) |
| |
| def split(self, split_size, dim=0): |
| result = [] |
| dim_size = self.size(dim) |
| num_splits = int(math.ceil(float(dim_size) / split_size)) |
| last_split_size = split_size - (split_size * num_splits - dim_size) |
| def get_split_size(i): |
| return split_size if i < num_splits-1 else last_split_size |
| return [self.narrow(int(dim), int(i*split_size), int(get_split_size(i))) for i |
| in torch._pyrange(0, num_splits)] |
| |
| def chunk(self, n_chunks, dim=0): |
| split_size = math.ceil(float(self.size(dim)) / n_chunks) |
| return self.split(split_size, dim) |
| |
| def tolist(self): |
| dim = self.dim() |
| if dim == 1: |
| return [v for v in self] |
| elif dim > 0: |
| return [subt.tolist() for subt in self] |
| return [] |
| |
| def view(self, *args): |
| dst = self.new() |
| if len(args) == 1 and torch.isStorage(args[0]): |
| sizes = args[0] |
| else: |
| sizes = torch.LongStorage(args) |
| sizes = _infer_sizes(sizes, self.nElement()) |
| |
| if reduce(lambda a,b: a * b, sizes) != self.nElement(): |
| raise RuntimeError('Invalid size for view. Input size: ' + |
| 'x'.join(map(lambda v: str(v), self.size())) + |
| ', output size: ' + |
| 'x'.join(map(lambda v: str(v), sizes)) + '.') |
| |
| assert self.isContiguous(), "expecting a contiguous tensor" |
| dst.set_(self.storage(), self.storageOffset(), sizes) |
| return dst |
| |
| def viewAs(self, tensor): |
| return self.view(tensor.size()) |
| |
| def permute(self, *args): |
| perm = list(args) |
| tensor = self |
| n_dims = tensor.dim() |
| assert len(perm) == n_dims, 'Invalid permutation' |
| for i, p in enumerate(perm): |
| if p != i and p != -1: |
| j = i |
| while True: |
| assert 0 <= perm[j] and perm[j] < n_dims, 'Invalid permutation' |
| tensor = tensor.transpose(j, perm[j]) |
| perm[j], j = -1, perm[j] |
| if perm[j] == i: |
| break |
| perm[j] = -1 |
| return tensor |
| |
| def expandAs(self, tensor): |
| return self.expand(tensor.size()) |
| |
| def expand(self, *args): |
| result = self.new() |
| sizes = args[0] if len(args) == 1 and torch.isLongStorage(args[0]) else torch.LongStorage(args) |
| src = self |
| |
| src_dim = src.dim() |
| src_stride = src.stride() |
| src_size = src.size() |
| |
| if sizes.size() != src_dim: |
| raise ValueError('the number of dimensions provided must equal tensor.dim()') |
| |
| # create a new geometry for tensor: |
| for i, size in enumerate(src_size): |
| if size == 1: |
| src_size[i] = sizes[i] |
| src_stride[i] = 0 |
| elif size != sizes[i]: |
| raise ValueError('incorrect size: only supporting singleton expansion (size=1)') |
| |
| result.set_(src.storage(), src.storageOffset(), |
| src_size, src_stride) |
| return result |
| |
| def repeatTensor(self, *args): |
| # If args == (torch.LongStorage,), then we need to unpack the tuple |
| if len(args) == 1 and isinstance(args[0], torch.LongStorage): |
| args = args[0] |
| repeats = list(args) |
| result = self.new() |
| src = self.contiguous() |
| |
| if len(repeats) < src.dim(): |
| raise ValueError('Number of dimensions of repeat dims can not be smaller than number of dimensions of tensor') |
| |
| xtensor = src.new().set_(src) |
| xsize = xtensor.size().tolist() |
| for i in torch._pyrange(len(repeats)-src.dim()): |
| xsize = [1] + xsize |
| |
| size = torch.LongStorage([a * b for a, b in zip(xsize, repeats)]) |
| xtensor.resize_(torch.LongStorage(xsize)) |
| result.resize_(size) |
| urtensor = result.new(result) |
| for i in torch._pyrange(xtensor.dim()): |
| urtensor = urtensor.unfold(i,xtensor.size(i),xtensor.size(i)) |
| for i in torch._pyrange(urtensor.dim()-xtensor.dim()): |
| xsize = [1] + xsize |
| xtensor.resize_(torch.LongStorage(xsize)) |
| xxtensor = xtensor.expandAs(urtensor) |
| urtensor.copy_(xxtensor) |
| return result |
| |
| #TODO: add tests for operators |
| def __add__(self, other): |
| return self.add(other) |
| __radd__ = __add__ |
| |
| def __iadd__(self, other): |
| return self.add_(other) |
| |
| def __sub__(self, other): |
| return self.sub(other) |
| |
| def __rsub__(self, other): |
| return self.new().resizeAs_(self).fill_(other).add_(-1, self) |
| |
| def __isub__(self, other): |
| return self.sub_(other) |
| |
| def __mul__(self, other): |
| return self.mul(other) |
| __rmul__ = __mul__ |
| |
| def __imul__(self, other): |
| return self.mul_(other) |
| |
| def __matmul__(self, other): |
| dim_self = self.dim() |
| dim_other = other.dim() |
| # TODO: should this really be dot product? |
| # if dim_self == 1 and dim_other == 1: |
| # return self.dot(other) |
| if dim_self == 2 and dim_other == 1: |
| return torch.mv(self, other) |
| elif dim_self == 2 and dim_other == 2: |
| return torch.mm(self, other) |
| |
| def __div__(self, other): |
| return self.div(other) |
| __truediv__ = __div__ |
| |
| def __rdiv__(self, other): |
| return self.new().resizeAs_(self).fill_(other).div_(self) |
| __rtruediv__ = __rdiv__ |
| |
| def __idiv__(self, other): |
| return self.div_(other) |
| |
| def __mod__(self, other): |
| return self.remainder(other) |
| |
| def __neg__(self): |
| return self.neg() |