blob: cf78cf1b899b7cca0ef1e7cbf68ccf3f53170089 [file] [log] [blame]
import torch
from . import _tensor_str
from ._utils import _type, _cuda, _range
from functools import reduce
from itertools import chain
import sys
import math
def _infer_sizes(sizes, total):
to_infer = -1
total_sizes = 1
for i, size in enumerate(sizes):
total_sizes *= size
if size == -1:
if to_infer >= 0:
raise RuntimeError
to_infer = i
if to_infer >= 0:
assert total % total_sizes == 0, "Can't make sizes have exactly %d elements" % total
sizes = list(sizes)
sizes[to_infer] = -total // total_sizes
return torch.Size(sizes)
return sizes
class _TensorBase(object):
#: bool: True if this is a CUDA tensor
is_cuda = False
def new(self, *args, **kwargs):
"""Constructs a new tensor of the same data type."""
return self.__class__(*args, **kwargs)
def type_as(self, tensor):
"""Returns this tensor cast to the type of the given tensor.
This is a no-op if the tensor is already of the correct type. This is
equivalent to::
self.type(tensor.type())
Params:
tensor (Tensor): the tensor which has the desired type
"""
return self.type(tensor.type())
def cpu(self):
"""Returns a CPU copy of this tensor if it's not already on the CPU"""
return self.type(getattr(torch, self.__class__.__name__))
def double(self):
"""Casts this tensor to double type"""
return self.type(type(self).__module__ + '.DoubleTensor')
def float(self):
"""Casts this tensor to float type"""
return self.type(type(self).__module__ + '.FloatTensor')
def half(self):
"""Casts this tensor to half-precision float type"""
return self.type(type(self).__module__ + '.HalfTensor')
def long(self):
"""Casts this tensor to long type"""
return self.type(type(self).__module__ + '.LongTensor')
def int(self):
"""Casts this tensor to int type"""
return self.type(type(self).__module__ + '.IntTensor')
def short(self):
"""Casts this tensor to short type"""
return self.type(type(self).__module__ + '.ShortTensor')
def char(self):
"""Casts this tensor to char type"""
return self.type(type(self).__module__ + '.CharTensor')
def byte(self):
"""Casts this tensor to byte type"""
return self.type(type(self).__module__ + '.ByteTensor')
def is_pinned(self):
"""Returns true if this tensor resides in pinned memory"""
storage = self.storage()
return storage.is_pinned() if storage else False
def pin_memory(self):
"""Copies the tensor to pinned memory, if it's not already pinned."""
if self.is_cuda:
raise TypeError("cannot pin '{0}' only CPU memory can be pinned"
.format(self.type()))
storage = self.storage()
if storage is None:
storage = (self.storage_type())()
return type(self)().set_(storage.pin_memory()).view_as(self)
def share_memory_(self):
"""Moves the underlying storage to shared memory.
This is a no-op if the underlying storage is already in shared memory
and for CUDA tensors. Tensors in shared memory cannot be resized.
"""
self.storage().share_memory_()
return self
def is_shared(self):
"""Checks if tensor is in shared memory.
This is always ``True`` for CUDA tensors.
"""
return self.storage().is_shared()
def __deepcopy__(self, _memo):
memo = _memo.setdefault('torch', {})
if self._cdata in memo:
return memo[self._cdata]
new_storage = self.storage().__deepcopy__(_memo)
new_tensor = self.new()
new_tensor.set_(new_storage, self.storage_offset(), self.size(), self.stride())
memo[self._cdata] = new_tensor
return new_tensor
def __reduce__(self):
return type(self), (self.tolist(),)
def __repr__(self):
return str(self)
def __str__(self):
# All strings are unicode in Python 3, while we have to encode unicode
# strings in Python2. If we can't, let python decide the best
# characters to replace unicode characters with.
if sys.version_info > (3,):
return _tensor_str._str(self)
else:
if hasattr(sys.stdout, 'encoding'):
return _tensor_str._str(self).encode(
sys.stdout.encoding or 'UTF-8', 'replace')
else:
return _tensor_str._str(self).encode('UTF-8', 'replace')
def __bool__(self):
if self.numel() == 0:
return False
raise RuntimeError("bool value of non-empty " + torch.typename(self) +
" objects is ambiguous")
__nonzero__ = __bool__
def __iter__(self):
return iter(map(lambda i: self.select(0, i), _range(self.size(0))))
def split(self, split_size, dim=0):
"""Splits this tensor into a list of tensors.
See :func:`torch.split`.
"""
return torch.split(self, split_size, dim)
def chunk(self, n_chunks, dim=0):
"""Splits this tensor into a list of tensors.
See :func:`torch.chunk`.
"""
return torch.chunk(self, n_chunks, dim)
def tolist(self):
"""Returns a nested list represenation of this tensor."""
dim = self.dim()
if dim == 1:
return [v for v in self]
elif dim > 0:
return [subt.tolist() for subt in self]
return []
def view(self, *args):
"""Returns a new tensor with the same data but different size.
The returned tensor shares the same data and must have the same number
of elements, but may have a different size. A tensor must be
:func:`contiguous` to be viewed.
Args:
args (torch.Size or int...): Desired size
Example:
>>> x = torch.randn(4, 4)
>>> x.size()
torch.Size([4, 4])
>>> y = x.view(16)
>>> y.size()
torch.Size([16])
>>> z = x.view(-1, 8) # the size -1 is inferred from other dimensions
>>> z.size()
torch.Size([2, 8])
"""
dst = self.new()
if len(args) == 1 and isinstance(args[0], torch.Size):
sizes = args[0]
else:
sizes = torch.Size(args)
sizes = _infer_sizes(sizes, self.nelement())
numel = reduce(lambda a, b: a * b, sizes) if len(sizes) > 0 else 0
if numel != self.nelement():
def format_size(size):
return 'x'.join(str(v) for v in size) if len(size) > 0 else '0'
raise ValueError(
"view of size '{0}' is invalid for input of size '{1}'"
.format(format_size(sizes), format_size(self.size())))
if not self.is_contiguous():
raise ValueError("input should be contiguous")
if self.storage() is not None:
dst.set_(self.storage(), self.storage_offset(), sizes)
return dst
def view_as(self, tensor):
"""Returns this tensor viewed as the size as the specified tensor.
This is equivalent to::
self.view(tensor.size())
"""
return self.view(tensor.size())
def permute(self, *dims):
"""Permute the dimensions of this tensor.
Args:
*dims (int...): The desired ordering of dimensions
Example:
>>> x = torch.randn(2, 3, 5)
>>> x.size()
torch.Size([2, 3, 5])
>>> x.permute(2, 0, 1).size()
torch.Size([5, 2, 3])
"""
perm = list(dims)
tensor = self
n_dims = tensor.dim()
assert len(perm) == n_dims, 'Invalid permutation'
for i, p in enumerate(perm):
if p != i and p != -1:
j = i
while True:
assert 0 <= perm[j] and perm[j] < n_dims, 'Invalid permutation'
tensor = tensor.transpose(j, perm[j])
perm[j], j = -1, perm[j]
if perm[j] == i:
break
perm[j] = -1
return tensor
def expand(self, *sizes):
"""Returns a new view of the tensor with singleton dimension expanded
to a larger size.
Expanding a tensor does not allocate new memory, but only creates a
new view on the existing tensor where a dimension of size one is
expanded to a larger size by setting the ``stride`` to 0. Any dimension
of size 1 can be expanded to an arbitrary value without allocating new
memory.
Args:
*sizes (torch.Size or int...): The desired expanded size
Example:
>>> x = torch.Tensor([[1], [2], [3]])
>>> x.size()
torch.Size([3, 1])
>>> x.expand(3, 4)
1 1 1 1
2 2 2 2
3 3 3 3
[torch.FloatTensor of size 3x4]
"""
result = self.new()
if len(sizes) == 1 and isinstance(sizes[0], torch.Size):
sizes = sizes[0]
else:
sizes = torch.Size(sizes)
src = self
src_dim = src.dim()
src_stride = list(src.stride())
src_size = list(src.size())
if len(sizes) != src_dim:
raise ValueError('the number of dimensions provided must equal tensor.dim()')
# create a new geometry for tensor:
for i, size in enumerate(src_size):
if size == 1:
src_size[i] = sizes[i]
src_stride[i] = 0
elif size != sizes[i]:
raise ValueError('incorrect size: only supporting singleton expansion (size=1)')
result.set_(src.storage(), src.storage_offset(), torch.Size(src_size),
tuple(src_stride))
return result
def expand_as(self, tensor):
"""Expands this tensor to the size of the specified tensor.
This is equivalent to::
self.expand(tensor.size())
"""
return self.expand(tensor.size())
def repeat(self, *sizes):
"""Repeats this tensor along the specified dimensions.
Unlike :meth:`expand`, this function copies the tensor's data.
Args:
*sizes (torch.Size or int...): The number of times to repeat this tensor along each dimension
Example:
>>> x = torch.Tensor([1, 2, 3])
>>> x.repeat(4, 2)
1 2 3 1 2 3
1 2 3 1 2 3
1 2 3 1 2 3
1 2 3 1 2 3
[torch.FloatTensor of size 4x6]
>>> x.repeat(4, 2, 1).size()
torch.Size([4, 2, 3])
"""
# If args == (torch.Size,), then we need to unpack the tuple
if len(sizes) == 1 and isinstance(sizes[0], torch.Size):
sizes = sizes[0]
repeats = list(sizes)
result = self.new()
src = self.contiguous()
if len(repeats) < src.dim():
raise ValueError('Number of dimensions of repeat dims can not be '
'smaller than number of dimensions of tensor')
xtensor = src.new().set_(src)
xsize = list(xtensor.size())
for i in _range(len(repeats) - src.dim()):
xsize = [1] + xsize
size = torch.Size([a * b for a, b in zip(xsize, repeats)])
xtensor.resize_(torch.Size(xsize))
result.resize_(size)
urtensor = result.new(result)
for i in _range(xtensor.dim()):
urtensor = urtensor.unfold(i, xtensor.size(i), xtensor.size(i))
for i in _range(urtensor.dim() - xtensor.dim()):
xsize = [1] + xsize
xtensor.resize_(torch.Size(xsize))
xxtensor = xtensor.expand_as(urtensor)
urtensor.copy_(xxtensor)
return result
def unsqueeze(self, dim):
"""Returns a new tensor with a dimension of size one inserted at the
specified position.
The returned tensor shares the same underlying data with this tensor.
Args:
dim (int): The index at which to insert the singleton dimension
Example:
>>> x = torch.Tensor([1, 2, 3, 4])
>>> x.unsqueeze(0)
1 2 3 4
[torch.FloatTensor of size 1x4]
>>> x.unsqueeze(1)
1
2
3
4
[torch.FloatTensor of size 4x1]
"""
return self.new(self).unsqueeze_(dim)
def unsqueeze_(self, dim):
"""In-place version of :meth:`unsqueeze`."""
sizes = list(self.size())
sizes.insert(dim, 1)
strides = list(self.stride())
strides.insert(dim, strides[dim] if len(strides) < dim else 1)
return self.set_(self.storage(), self.storage_offset(),
torch.Size(sizes), tuple(strides))
# TODO: add tests for operators
def __add__(self, other):
return self.add(other)
__radd__ = __add__
def __iadd__(self, other):
return self.add_(other)
def __sub__(self, other):
return self.sub(other)
def __rsub__(self, other):
return self.new().resize_as_(self).fill_(other).add_(-1, self)
def __isub__(self, other):
return self.sub_(other)
def __mul__(self, other):
return self.mul(other)
__rmul__ = __mul__
def __imul__(self, other):
return self.mul_(other)
def __matmul__(self, other):
dim_self = self.dim()
try:
dim_other = other.dim()
except AttributeError: # not a tensor
return NotImplemented
if dim_self == 1 and dim_other == 1:
return self.dot(other)
if dim_self == 2 and dim_other == 1:
return self.mv(other)
if dim_self == 1 and dim_other == 2:
return self.unsqueeze(0).mm(other).squeeze(0)
elif dim_self == 2 and dim_other == 2:
return self.mm(other)
raise ValueError("both arguments to __matmul__ need to be 1D or 2D, "
"but they are {}D and {}D".format(dim_self, dim_other))
def __pow__(self, other):
return self.pow(other)
def __ipow__(self, other):
return self.pow_(other)
def __div__(self, other):
return self.div(other)
__truediv__ = __div__
def __rdiv__(self, other):
return self.new().resize_as_(self).fill_(other).div_(self)
__rtruediv__ = __rdiv__
def __idiv__(self, other):
return self.div_(other)
def __mod__(self, other):
return self.remainder(other)
def __neg__(self):
return self.neg()
def __eq__(self, other):
return self.eq(other)
def __ne__(self, other):
return self.ne(other)
def __lt__(self, other):
return self.lt(other)
def __le__(self, other):
return self.le(other)
def __gt__(self, other):
return self.gt(other)
def __ge__(self, other):
return self.ge(other)
# TODO: add native add or and xor in the libs
def __and__(self, other):
if (type(self).__name__ != 'ByteTensor' or
type(other).__name__ != 'ByteTensor'):
raise RuntimeError('logical operations are supported on ByteTensors only')
return (self + other).eq(2)
def __or__(self, other):
if (type(self).__name__ != 'ByteTensor' or
type(other).__name__ != 'ByteTensor'):
raise RuntimeError('logical operations are supported on ByteTensors only')
return (self + other).gt(0)
def __xor__(self, other):
if (type(self).__name__ != 'ByteTensor' or
type(other).__name__ != 'ByteTensor'):
raise RuntimeError('logical operations are supported on ByteTensors only')
return (self + other).eq(1)
def __iand__(self, other):
if (type(self).__name__ != 'ByteTensor' or
type(other).__name__ != 'ByteTensor'):
raise RuntimeError('logical operations are supported on ByteTensors only')
return self.mul_(other)
def __ior__(self, other):
if (type(self).__name__ != 'ByteTensor' or
type(other).__name__ != 'ByteTensor'):
raise RuntimeError('logical operations are supported on ByteTensors only')
return self.copy_((self + other).gt(0))
def __ixor__(self, other):
if (type(self).__name__ != 'ByteTensor' or
type(other).__name__ != 'ByteTensor'):
raise RuntimeError('logical operations are supported on ByteTensors only')
return self.copy_((self + other).eq(1))
def __hash__(self):
return id(self)
_TensorBase.type = _type
_TensorBase.cuda = _cuda