| import torch |
| import warnings |
| from . import _tensor_str |
| from ._utils import _type, _cuda, _range, _rebuild_tensor |
| import sys |
| |
| |
| class _TensorBase(object): |
| #: bool: True if this is a CUDA tensor |
| is_cuda = False |
| is_sparse = False |
| |
| # NB: This implementation is CPU only; see THPTensor_(new) for the |
| # CUDA case, which handles constructing the tensor on the same GPU |
| # as this tensor. |
| def new(self, *args, **kwargs): |
| """Constructs a new tensor of the same data type.""" |
| return self.__class__(*args, **kwargs) |
| |
| def type_as(self, tensor): |
| """Returns this tensor cast to the type of the given tensor. |
| |
| This is a no-op if the tensor is already of the correct type. This is |
| equivalent to:: |
| |
| self.type(tensor.type()) |
| |
| Params: |
| tensor (Tensor): the tensor which has the desired type |
| """ |
| return self.type(tensor.type()) |
| |
| def cpu(self): |
| """Returns a CPU copy of this tensor if it's not already on the CPU""" |
| return self.type(getattr(torch, self.__class__.__name__)) |
| |
| def double(self): |
| """Casts this tensor to double type""" |
| return self.type(type(self).__module__ + '.DoubleTensor') |
| |
| def float(self): |
| """Casts this tensor to float type""" |
| return self.type(type(self).__module__ + '.FloatTensor') |
| |
| def half(self): |
| """Casts this tensor to half-precision float type""" |
| return self.type(type(self).__module__ + '.HalfTensor') |
| |
| def long(self): |
| """Casts this tensor to long type""" |
| return self.type(type(self).__module__ + '.LongTensor') |
| |
| def int(self): |
| """Casts this tensor to int type""" |
| return self.type(type(self).__module__ + '.IntTensor') |
| |
| def short(self): |
| """Casts this tensor to short type""" |
| return self.type(type(self).__module__ + '.ShortTensor') |
| |
| def char(self): |
| """Casts this tensor to char type""" |
| return self.type(type(self).__module__ + '.CharTensor') |
| |
| def byte(self): |
| """Casts this tensor to byte type""" |
| return self.type(type(self).__module__ + '.ByteTensor') |
| |
| def is_pinned(self): |
| """Returns true if this tensor resides in pinned memory""" |
| storage = self.storage() |
| return storage.is_pinned() if storage else False |
| |
| def pin_memory(self): |
| """Copies the tensor to pinned memory, if it's not already pinned.""" |
| if self.is_cuda: |
| raise TypeError("cannot pin '{0}' only CPU memory can be pinned" |
| .format(self.type())) |
| storage = self.storage() |
| if storage is None: |
| storage = (self.storage_type())() |
| return type(self)().set_(storage.pin_memory()).view_as(self) |
| |
| def share_memory_(self): |
| """Moves the underlying storage to shared memory. |
| |
| This is a no-op if the underlying storage is already in shared memory |
| and for CUDA tensors. Tensors in shared memory cannot be resized. |
| """ |
| self.storage().share_memory_() |
| return self |
| |
| def is_shared(self): |
| """Checks if tensor is in shared memory. |
| |
| This is always ``True`` for CUDA tensors. |
| """ |
| return self.storage().is_shared() |
| |
| def __deepcopy__(self, _memo): |
| memo = _memo.setdefault('torch', {}) |
| if self._cdata in memo: |
| return memo[self._cdata] |
| new_storage = self.storage().__deepcopy__(_memo) |
| new_tensor = self.new() |
| new_tensor.set_(new_storage, self.storage_offset(), self.size(), self.stride()) |
| memo[self._cdata] = new_tensor |
| return new_tensor |
| |
| def __reduce__(self): |
| # NOTE: _rebuild_tensor does not call __setstate__ |
| args = self.__getstate__() |
| return (_rebuild_tensor, args) |
| |
| def __getstate__(self): |
| return (self.storage(), |
| self.storage_offset(), |
| tuple(self.size()), |
| self.stride()) |
| |
| def __setstate__(self, state): |
| self.set_(*state) |
| |
| def __repr__(self): |
| return str(self) |
| |
| def __str__(self): |
| # All strings are unicode in Python 3, while we have to encode unicode |
| # strings in Python2. If we can't, let python decide the best |
| # characters to replace unicode characters with. |
| if sys.version_info > (3,): |
| return _tensor_str._str(self) |
| else: |
| if hasattr(sys.stdout, 'encoding'): |
| return _tensor_str._str(self).encode( |
| sys.stdout.encoding or 'UTF-8', 'replace') |
| else: |
| return _tensor_str._str(self).encode('UTF-8', 'replace') |
| |
| def __bool__(self): |
| if self.numel() == 0: |
| return False |
| raise RuntimeError("bool value of non-empty " + torch.typename(self) + |
| " objects is ambiguous") |
| |
| __nonzero__ = __bool__ |
| |
| def __iter__(self): |
| return iter(map(lambda i: self.select(0, i), _range(self.size(0)))) |
| |
| def split(self, split_size, dim=0): |
| """Splits this tensor into a tuple of tensors. |
| |
| See :func:`torch.split`. |
| """ |
| return torch.split(self, split_size, dim) |
| |
| def chunk(self, n_chunks, dim=0): |
| """Splits this tensor into a tuple of tensors. |
| |
| See :func:`torch.chunk`. |
| """ |
| return torch.chunk(self, n_chunks, dim) |
| |
| def matmul(self, other): |
| """Matrix product of two tensors. |
| |
| See :func:`torch.matmul`.""" |
| return torch.matmul(self, other) |
| |
| def tolist(self): |
| """Returns a nested list represenation of this tensor.""" |
| dim = self.dim() |
| if dim == 1: |
| return [v for v in self] |
| elif dim > 0: |
| return [subt.tolist() for subt in self] |
| return [] |
| |
| def view_as(self, tensor): |
| """Returns this tensor viewed as the size as the specified tensor. |
| |
| This is equivalent to:: |
| |
| self.view(tensor.size()) |
| """ |
| return self.view(tensor.size()) |
| |
| def permute(self, *dims): |
| """Permute the dimensions of this tensor. |
| |
| Args: |
| *dims (int...): The desired ordering of dimensions |
| |
| Example: |
| >>> x = torch.randn(2, 3, 5) |
| >>> x.size() |
| torch.Size([2, 3, 5]) |
| >>> x.permute(2, 0, 1).size() |
| torch.Size([5, 2, 3]) |
| """ |
| perm = list(dims) |
| tensor = self |
| n_dims = tensor.dim() |
| assert len(perm) == n_dims, 'Invalid permutation' |
| for i, p in enumerate(perm): |
| if p != i and p != -1: |
| j = i |
| while True: |
| assert 0 <= perm[j] and perm[j] < n_dims, 'Invalid permutation' |
| tensor = tensor.transpose(j, perm[j]) |
| perm[j], j = -1, perm[j] |
| if perm[j] == i: |
| break |
| perm[j] = -1 |
| return tensor |
| |
| def expand_as(self, tensor): |
| """Expands this tensor to the size of the specified tensor. |
| |
| This is equivalent to:: |
| |
| self.expand(tensor.size()) |
| """ |
| return self.expand(tensor.size()) |
| |
| def repeat(self, *sizes): |
| """Repeats this tensor along the specified dimensions. |
| |
| Unlike :meth:`expand`, this function copies the tensor's data. |
| |
| Args: |
| *sizes (torch.Size or int...): The number of times to repeat this tensor along each dimension |
| |
| Example: |
| >>> x = torch.Tensor([1, 2, 3]) |
| >>> x.repeat(4, 2) |
| 1 2 3 1 2 3 |
| 1 2 3 1 2 3 |
| 1 2 3 1 2 3 |
| 1 2 3 1 2 3 |
| [torch.FloatTensor of size 4x6] |
| >>> x.repeat(4, 2, 1).size() |
| torch.Size([4, 2, 3]) |
| """ |
| # If args == (torch.Size,), then we need to unpack the tuple |
| if len(sizes) == 1 and isinstance(sizes[0], torch.Size): |
| sizes = sizes[0] |
| repeats = list(sizes) |
| result = self.new() |
| src = self.contiguous() |
| |
| if len(repeats) < src.dim(): |
| raise ValueError('Number of dimensions of repeat dims can not be ' |
| 'smaller than number of dimensions of tensor') |
| |
| xtensor = src.new().set_(src) |
| xsize = list(xtensor.size()) |
| for i in _range(len(repeats) - src.dim()): |
| xsize = [1] + xsize |
| |
| size = torch.Size([a * b for a, b in zip(xsize, repeats)]) |
| xtensor.resize_(torch.Size(xsize)) |
| result.resize_(size) |
| urtensor = result.new(result) |
| for i in _range(xtensor.dim()): |
| urtensor = urtensor.unfold(i, xtensor.size(i), xtensor.size(i)) |
| for i in _range(urtensor.dim() - xtensor.dim()): |
| xsize = [1] + xsize |
| xtensor.resize_(torch.Size(xsize)) |
| xxtensor = xtensor.expand_as(urtensor) |
| urtensor.copy_(xxtensor) |
| return result |
| |
| def masked_copy_(self, *args, **kwargs): |
| warnings.warn("masked_copy_ is deprecated and renamed to masked_scatter_, and will be removed in v0.3") |
| return self.masked_scatter_(*args, **kwargs) |
| |
| # TODO: add tests for operators |
| def __add__(self, other): |
| return self.add(other) |
| __radd__ = __add__ |
| |
| def __iadd__(self, other): |
| return self.add_(other) |
| |
| def __sub__(self, other): |
| return self.sub(other) |
| |
| def __rsub__(self, other): |
| return self.new().resize_as_(self).fill_(other).add_(-1, self) |
| |
| def __isub__(self, other): |
| return self.sub_(other) |
| |
| def __mul__(self, other): |
| return self.mul(other) |
| __rmul__ = __mul__ |
| |
| def __imul__(self, other): |
| return self.mul_(other) |
| |
| def __matmul__(self, other): |
| if not torch.is_tensor(other): |
| return NotImplemented |
| return self.matmul(other) |
| |
| def __pow__(self, other): |
| return self.pow(other) |
| |
| def __ipow__(self, other): |
| return self.pow_(other) |
| |
| def __div__(self, other): |
| return self.div(other) |
| __truediv__ = __div__ |
| |
| def __rdiv__(self, other): |
| return self.new().resize_as_(self).fill_(other).div_(self) |
| __rtruediv__ = __rdiv__ |
| |
| def __idiv__(self, other): |
| return self.div_(other) |
| |
| def __mod__(self, other): |
| return self.remainder(other) |
| |
| def __neg__(self): |
| return self.neg() |
| |
| def __eq__(self, other): |
| return self.eq(other) |
| |
| def __ne__(self, other): |
| return self.ne(other) |
| |
| def __lt__(self, other): |
| return self.lt(other) |
| |
| def __le__(self, other): |
| return self.le(other) |
| |
| def __gt__(self, other): |
| return self.gt(other) |
| |
| def __ge__(self, other): |
| return self.ge(other) |
| |
| # TODO: add native add or and xor in the libs |
| def __invert__(self): |
| if type(self).__name__ != 'ByteTensor': |
| raise RuntimeError('logical operations are supported on ByteTensors only') |
| return (1 - self) |
| |
| def __hash__(self): |
| return id(self) |
| |
| |
| _TensorBase.type = _type |
| _TensorBase.cuda = _cuda |