blob: 56d8e0a3aea493adadd5f20b74eaeaa78fd2a4ee [file] [log] [blame]
import torch._C as _C
from collections import OrderedDict
from .functions import *
class Variable(_C._VariableBase):
_fallthrough_methods = {
'size',
'stride',
'nelement',
'ndimension',
'element_size',
'is_contiguous',
'is_same_size',
'is_set_to',
'is_size',
'is_signed',
'numel',
'dim',
'get_device',
'is_cuda',
}
@property
def grad(self):
if self.requires_grad:
# TODO: this won't have to be zeroed in the future
self._grad = self._grad or self.data.new(self.data.size()).zero_()
return self._grad
@property
def requires_grad(self):
return self._requires_grad
@requires_grad.setter
def requires_grad(self, value):
if self.creator is not None:
raise RuntimeError("you can only change requires_grad flags of "
"leaf variables")
self._requires_grad = value
def __getattr__(self, name):
if name in self._fallthrough_methods:
return getattr(self.data, name)
raise AttributeError(name)
def __getitem__(self, key):
if (isinstance(key, Variable) and
type(key.data).__name__ == 'ByteTensor'):
return MaskedSelect()(self, key)
return Index(key)(self)
def __setitem__(self, key, value):
if (isinstance(key, Variable) and
type(key.data).__name__ == 'ByteTensor'):
return MaskedFill(value, inplace=True)(self, key)
return SetValue(key, value)(self)
def __deepcopy__(self, memo):
if self.creator is None:
return Variable(self.data.clone(), requires_grad=self.requires_grad,
volatile=self.volatile)
raise RuntimeError("Only Variables created explicitly by the user "
"(graph leaves) support the deepcopy protocol at the moment")
def backward(self, gradient=None, retain_variables=False):
if self.volatile:
raise RuntimeError('calling backward on a volatile variable')
if not self.requires_grad:
raise RuntimeError("calling backward on a variable that doesn't require gradient")
if gradient is None:
if self.data.numel() != 1:
raise RuntimeError('backward should be called only on a scalar (i.e. 1-element tensor) or with gradient w.r.t. the variable')
gradient = self.data.new(1).fill_(1)
self._execution_engine.run_backward(self, gradient, retain_variables)
def __repr__(self):
return 'Variable containing:' + self.data.__repr__()
def register_hook(self, name, hook):
if self.volatile:
raise RuntimeError('registering hook on a volatile variable')
if not self.requires_grad:
raise RuntimeError("registering hook on a variable that doesn't require gradient")
if self.creator is not None:
self.creator.register_hook(name, lambda gi, go: hook(go[self.output_nr]))
else:
self.backward_hooks = self.backward_hooks or OrderedDict()
assert name not in self.backward_hooks, \
"Trying to register a second hook with name {}".format(name)
self.backward_hooks[name] = hook
def remove_hook(self, name):
if self.volatile:
raise RuntimeError("volatile variables don't support hooks")
if self.creator is not None:
self.creator.remove_hook(name)
else:
assert self.backward_hooks and name in self.backward_hooks, \
"Trying to remove an inexistent hook with name {}".format(name)
del self.backward_hooks[name]
def _do_backward(self, grad_output, retain_variables):
assert len(grad_output) == 1
assert self._version == 0 and self.creator is None, \
"leaf variable was used in an inplace operation"
if self.backward_hooks:
for hook in self.backward_hooks.values():
hook(grad_output[0])
self.grad.add_(grad_output[0])
return tuple()
def contiguous(self):
self.data = self.data.contiguous()
return self
def clone(self):
return Clone()(self)
def type(self, t):
if t != type(self.data):
return Type(t)(self)
return self
def _get_type(self, name):
module = torch._import_dotted_name(self.data.__module__)
return getattr(module, name)
def cuda(self, device_id=None):
return CudaTransfer(device_id)(self)
def cpu(self):
return self.type(getattr(torch, type(self.data).__name__))
def double(self):
return self.type(self._get_type('DoubleTensor'))
def float(self):
return self.type(self._get_type('FloatTensor'))
def long(self):
return self.type(self._get_type('LongTensor'))
def int(self):
return self.type(self._get_type('IntTensor'))
def short(self):
return self.type(self._get_type('ShortTensor'))
def char(self):
return self.type(self._get_type('CharTensor'))
def byte(self):
return self.type(self._get_type('ByteTensor'))
def _add(self, other, inplace):
if isinstance(other, Variable):
return Add(inplace)(self, other)
else:
assert not torch.is_tensor(other)
return AddConstant(other, inplace)(self)
def add(self, other):
return self._add(other, False)
def add_(self, other):
return self._add(other, True)
def _sub(self, other, inplace):
if isinstance(other, Variable):
return Sub(inplace=inplace)(self, other)
else:
assert not torch.is_tensor(other)
return SubConstant(other, inplace=inplace)(self)
def sub(self, other):
return self._sub(other, False)
def sub_(self, other):
return self._sub(other, True)
def mul(self, other):
if isinstance(other, Variable):
return Mul()(self, other)
else:
assert not torch.is_tensor(other)
return MulConstant(other)(self)
def mul_(self, other):
if not isinstance(other, Variable) and not torch.is_tensor(other):
return MulConstant(other, inplace=True)(self)
raise RuntimeError("mul_ only supports scalar multiplication")
def div(self, other):
if isinstance(other, Variable):
return Div()(self, other)
else:
assert not torch.is_tensor(other)
return DivConstant(other)(self)
def div_(self, other):
if not isinstance(other, Variable) and not torch.is_tensor(other):
return DivConstant(other, inplace=True)(self)
raise RuntimeError("div_ only supports scalar multiplication")
def pow(self, other):
if isinstance(other, Variable):
return Pow()(self, other)
else:
assert not torch.is_tensor(other)
return PowConstant(other)(self)
def exp(self):
return Exp()(self)
def exp_(self):
return Exp(inplace=True)(self)
def log(self):
return Log()(self)
def log1p(self):
return Log1p()(self)
def neg(self):
return Negate()(self)
def neg_(self):
return Negate(inplace=True)(self)
def tanh(self):
return Tanh()(self)
def tanh_(self):
return Tanh(True)(self)
def sigmoid(self):
return Sigmoid()(self)
def sigmoid_(self):
return Sigmoid(True)(self)
def sin(self):
return Sin()(self)
def cos(self):
return Cos()(self)
def tan(self):
return Tan()(self)
def asin(self):
return Asin()(self)
def acos(self):
return Acos()(self)
def atan(self):
return Atan()(self)
def sinh(self):
return Sinh()(self)
def cosh(self):
return Cosh()(self)
def abs(self):
return Abs()(self)
def clamp(self, min_val, max_val):
return Clamp(min_val, max_val)(self)
def cinv(self):
return Cinv()(self)
def cmax(self, other):
if isinstance(other, Variable):
return Cmax()(self, other)
else:
return CmaxConstant(other)(self)
def cmin(self, other):
if isinstance(other, Variable):
return Cmin()(self, other)
else:
return CminConstant(other)(self)
def floor(self):
return Floor()(self)
def ceil(self):
return Ceil()(self)
def frac(self):
return Frac()(self)
def sqrt(self):
return Sqrt()(self)
def round(self):
return Round()(self)
def sign(self):
return Sign()(self)
def trunc(self):
return Trunc()(self)
def floor(self):
return Floor()(self)
def ceil(self):
return Ceil()(self)
def fmod(self, value):
return Fmod(value)(self)
def remainder(self, value):
return Remainder(value)(self)
def lerp(self, tensor, weight):
return Lerp(weight)(self, tensor)
def rsqrt(self):
return Rsqrt()(self)
def sum(self, dim=None):
return Sum(dim)(self)
def prod(self, dim=None):
return Prod(dim)(self)
def mean(self, dim=None):
return Mean(dim)(self)
def max(self, dim=None):
return Max(dim)(self)
def min(self, dim=None):
return Min(dim)(self)
def mode(self, dim):
return Mode(dim)(self)
def median(self, dim):
return Median(dim)(self)
def kthvalue(self, dim):
return Kthvalue(dim)(self)
def sort(self, dim=None, descending=False):
return Sort(dim, descending)(self)
def topk(self, k, dim=None, largest=True, sorted=True):
return Topk(k, dim, largest, sorted)(self)
def view(self, *sizes):
return View(*sizes)(self)
def view_as(self, tensor):
return View(*tensor.size())(self)
@staticmethod
def _static_blas(cls, args, inplace):
num_args = len(args)
alpha = beta = 1
if num_args > 5:
raise RuntimeError("too many args")
if num_args == 5:
alpha, beta = args[1:3]
if num_args == 4:
alpha = args[1]
return cls(alpha, beta, inplace)(*(args[:1] + args[-2:]))
def _blas(self, cls, args, inplace):
return self._static_blas(cls, (self,) + args, inplace)
def mm(self, matrix):
output = Variable(self.data.new(self.data.size(0), matrix.data.size(1)))
return self._static_blas(Addmm, (output, 0, 1, self, matrix), False)
def bmm(self, batch):
output = Variable(self.data.new(self.data.size(0), self.data.size(1),
batch.data.size(2)))
return self._static_blas(Addbmm, (output, 0, 1, self, batch), False)
def mv(self, vector):
output = Variable(self.data.new(self.data.size(0)))
return self._static_blas(Addmv, (output, 0, 1, self, vector), False)
def ger(self, vector):
output = Variable(self.data.new(self.data.size(0), vector.data.size(0)))
return self._static_blas(Addr, (output, 0, 1, self, vector), False)
def resize(self, *sizes):
return Resize(*sizes)(self)
def resize_as(self, variable):
return Resize(*variable.size())(self)
def addmm(self, *args):
return self._blas(Addmm, args, False)
def addmm_(self, *args):
return self._blas(Addmm, args, True)
def addbmm(self, *args):
return self._blas(Addbmm, args, False)
def addbmm_(self, *args):
return self._blas(Addbmm, args, True)
def baddbmm(self, *args):
return self._blas(Baddbmm, args, False)
def baddbmm_(self, *args):
return self._blas(Baddbmm, args, True)
def addmv(self, *args):
return self._blas(Addmv, args, False)
def addmv_(self, *args):
return self._blas(Addmv, args, True)
def addr(self, *args):
return self._blas(Addr, args, False)
def addr(self, *args):
return self._blas(Addr, args, True)
def dot(self, other):
return Dot()(self, other)
def _addcop(self, op, args):
if len(args) == 3:
# scale, tensor1, tensor2
return op(args[0])(self, *args[1:])
else:
# tensor1, tensor2
return op()(self, *args)
def addcmul(self, *args):
return self._addcop(Addcmul, args)
def addcdiv(self, *args):
return self._addcop(Addcdiv, args)
def norm(self, norm_type=2, dim=None):
return Norm(norm_type, dim)(self)
def dist(self, tensor, norm_type=2):
return Norm(norm_type)(self - tensor)
def index_add(self, dim, index, tensor):
return IndexAdd(dim)(self, index, tensor)
def index_add_(self, dim, index, tensor):
return IndexAdd(dim, True)(self, index, tensor)
def index_copy(self, dim, index, tensor):
return IndexCopy(dim)(self, index, tensor)
def index_copy_(self, dim, index, tensor):
return IndexCopy(dim, True)(self, index, tensor)
def index_fill(self, dim, index, value):
return IndexFill(dim, value)(self, index)
def index_fill_(self, dim, index, value):
return IndexFill(dim, value, True)(self, index)
def index_select(self, dim, index):
return IndexSelect(dim)(self, index)
def masked_copy(self, mask, variable):
return MaskedCopy()(self, mask, variable)
def masked_copy_(self, mask, variable):
return MaskedCopy(True)(self, mask, variable)
def masked_fill(self, mask, value):
return MaskedFill(value)(self, mask)
def masked_fill_(self, mask, value):
return MaskedFill(value, True)(self, mask)
def masked_select(self, mask):
return MaskedSelect()(self, mask)
def expand(self, *sizes):
return Expand(*sizes)(self)
def expand_as(self, tensor):
return Expand(*tensor.size())(self)
def t(self):
return Transpose(0, 1)(self)
def transpose(self, dim1, dim2):
return Transpose(dim1, dim2)(self)
def cat(self, iterable, dim=0):
return Concat(dim)(*iterable)
def select(self, dim, _index):
index = tuple(slice(None, None) for _ in range(dim)) + (_index,)
return Index(index)(self)
def narrow(self, dim, start_index, length):
index = tuple(slice(None, None) for _ in range(dim)) + \
(slice(start_index, start_index+length),)
return Index(index)(self)
def chunk(self, num_chunks, dim=0):
return Chunk(num_chunks, dim)(self)
def squeeze(self, dim=None):
return Squeeze(dim)(self)
def unsqueeze(self, dim):
return Unsqueeze(dim)(self)
def permute(self, *permutation):
return Permute(*permutation)(self)
def diag(self, diagonal_idx=0):
return Diag(diagonal_idx)(self)
def tril(self, diagonal_idx=0):
return Tril(diagonal_idx)(self)
def triu(self, diagonal_idx=0):
return Triu(diagonal_idx)(self)
def __add__(self, other):
return self.add(other)
__radd__ = __add__
def __sub__(self, other):
return self.sub(other)
def __rsub__(self, other):
return SubConstant(other, sub_tensor=True)(self)
def __mul__(self, other):
return self.mul(other)
__rmul__ = __mul__
def __div__(self, other):
return self.div(other)
__truediv__ = __div__
def __rdiv__(self, other):
return DivConstant(other, div_by_tensor=True)(self)
__rtruediv__ = __rdiv__
def __pow__(self, other):
return self.pow(other)
def __rpow__(self, other):
return PowConstant(other, tensor_power=True)(self)
def __neg__(self):
return Negate()(self)
from .engine import ImperativeEngine
Variable._execution_engine = ImperativeEngine()