blob: 6dfdfabe617a2be223a1044e820f6719f3b88272 [file] [log] [blame]
import torch
from itertools import repeat
from ..._thnn import type2backend
from ..function import Function, InplaceFunction
from ..variable import Variable
from .utils import maybe_unexpand, maybe_unexpand_or_view
class Exp(InplaceFunction):
@staticmethod
def forward(ctx, i, inplace=False):
if inplace:
ctx.mark_dirty(i)
result = i.exp_()
else:
result = i.exp()
ctx.save_for_backward(result)
return result
@staticmethod
def backward(ctx, grad_output):
result, = ctx.saved_variables
return grad_output * result, None
class Log(Function):
@staticmethod
def forward(ctx, i):
ctx.save_for_backward(i)
return i.log()
@staticmethod
def backward(ctx, grad_output):
i, = ctx.saved_variables
return grad_output.div(i)
class Log1p(Function):
@staticmethod
def forward(ctx, i):
ctx.save_for_backward(i)
return i.log1p()
@staticmethod
def backward(ctx, grad_output):
i, = ctx.saved_variables
return grad_output.div(i.add(1))
class Tanh(InplaceFunction):
@staticmethod
def primspec(i, inplace=False):
if inplace:
return None
return torch.toffee.op("Tanh", i)
@staticmethod
def forward(ctx, i, inplace=False):
if inplace:
ctx.mark_dirty(i)
result = i.tanh_()
else:
result = i.tanh()
ctx.save_for_backward(result)
return result
@staticmethod
def backward(ctx, grad_output):
result, = ctx.saved_variables
if grad_output.volatile:
grad_input = Variable(grad_output.data.new(grad_output.size()), volatile=True)
backend = type2backend[type(result.data)]
backend.Tanh_updateGradInput(backend.library_state, None, grad_output.data,
grad_input.data, result.data)
else:
grad_input = grad_output * (1 - result * result)
return grad_input, None
class Sigmoid(InplaceFunction):
@staticmethod
def primspec(i, inplace=False):
if inplace:
return None
return torch.toffee.op("Sigmoid", i)
@staticmethod
def forward(ctx, i, inplace=False):
if inplace:
ctx.mark_dirty(i)
result = i.sigmoid_()
else:
result = i.sigmoid()
ctx.save_for_backward(result)
return result
@staticmethod
def backward(ctx, grad_output):
result, = ctx.saved_variables
if grad_output.volatile:
grad_input = Variable(grad_output.data.new(grad_output.size()), volatile=True)
backend = type2backend[type(result.data)]
backend.Sigmoid_updateGradInput(backend.library_state, None, grad_output.data,
grad_input.data, result.data)
else:
grad_input = grad_output * ((1 - result) * result)
return grad_input, None
class Sinh(Function):
@staticmethod
def forward(ctx, i):
ctx.save_for_backward(i)
return i.sinh()
@staticmethod
def backward(ctx, grad_output):
i, = ctx.saved_variables
return grad_output * i.cosh()
class Cosh(Function):
@staticmethod
def forward(ctx, i):
ctx.save_for_backward(i)
return i.cosh()
@staticmethod
def backward(ctx, grad_output):
i, = ctx.saved_variables
return grad_output * i.sinh()
class Abs(Function):
@staticmethod
def forward(ctx, i):
ctx.save_for_backward(i)
return i.abs()
@staticmethod
def backward(ctx, grad_output):
i, = ctx.saved_variables
return grad_output * i.sign()
class Clamp(Function):
@staticmethod
def forward(ctx, i, min_val, max_val):
ctx._mask = (i.ge(min_val) * i.le(max_val))
return i.clamp(min_val, max_val)
@staticmethod
def backward(ctx, grad_output):
mask = Variable(ctx._mask.type_as(grad_output.data))
return grad_output * mask, None, None
class Sqrt(Function):
@staticmethod
def forward(ctx, i):
ctx.save_for_backward(i)
return i.sqrt()
@staticmethod
def backward(ctx, grad_output):
i, = ctx.saved_variables
return grad_output.mul(i.pow(-0.5)).div_(2)
class Sin(Function):
@staticmethod
def forward(ctx, i):
ctx.save_for_backward(i)
return i.sin()
@staticmethod
def backward(ctx, grad_output):
i, = ctx.saved_variables
return grad_output * i.cos()
class Cos(Function):
@staticmethod
def forward(ctx, i):
ctx.save_for_backward(i)
return i.cos()
@staticmethod
def backward(ctx, grad_output):
i, = ctx.saved_variables
return grad_output.mul(i.sin()).neg_()
class Tan(Function):
@staticmethod
def forward(ctx, i):
ctx.save_for_backward(i)
return i.tan()
@staticmethod
def backward(ctx, grad_output):
i, = ctx.saved_variables
return grad_output.div(i.cos().pow(2))
class Asin(Function):
@staticmethod
def forward(ctx, i):
ctx.save_for_backward(i)
return i.asin()
@staticmethod
def backward(ctx, grad_output):
i, = ctx.saved_variables
return grad_output * (1 - i.mul(i)).sqrt().reciprocal()
class Acos(Function):
@staticmethod
def forward(ctx, i):
ctx.save_for_backward(i)
return i.acos()
@staticmethod
def backward(ctx, grad_output):
i, = ctx.saved_variables
return grad_output.mul((1 - i.mul(i)).sqrt().reciprocal()).neg_()
class Atan(Function):
@staticmethod
def forward(ctx, i):
ctx.save_for_backward(i)
return i.atan()
@staticmethod
def backward(ctx, grad_output):
i, = ctx.saved_variables
return grad_output * i.mul(i).add_(1).reciprocal()
class Atan2(Function):
@staticmethod
def forward(ctx, y, x):
ctx.save_for_backward(y, x)
return y.atan2(x)
@staticmethod
def backward(ctx, grad_output):
y, x, = ctx.saved_variables
denominator = y.mul(y).add(x.mul(x)).reciprocal()
return grad_output * x.mul(denominator), grad_output * y.neg().mul(denominator)
# TODO: make inplace and update grad formulas
class Reciprocal(Function):
@staticmethod
def forward(ctx, i):
result = i.reciprocal()
ctx.save_for_backward(result)
return result
@staticmethod
def backward(ctx, grad_output):
result, = ctx.saved_variables
return grad_output * result.mul(result).neg_()
class Cmax(Function):
@staticmethod
def forward(ctx, a, b):
ctx._a_size = a.size()
ctx._b_size = b.size()
ctx._mask = a.gt(b)
return a.max(b)
@staticmethod
def backward(ctx, grad_output):
mask = Variable(ctx._mask.type_as(grad_output.data))
return (
maybe_unexpand(grad_output * mask, ctx._a_size),
maybe_unexpand_or_view(grad_output * Variable(ctx._mask.eq(0).type_as(grad_output.data)), ctx._b_size)
)
class CmaxConstant(Function):
@staticmethod
def forward(ctx, i, constant):
ctx._mask = i.gt(constant)
return i.clamp(min=constant)
@staticmethod
def backward(ctx, grad_output):
mask = Variable(ctx._mask.type_as(grad_output.data))
return grad_output * mask, None
class Cmin(Function):
@staticmethod
def forward(ctx, a, b):
ctx._a_size = a.size()
ctx._b_size = b.size()
ctx._mask = a.lt(b).type_as(a)
return a.min(b)
@staticmethod
def backward(ctx, grad_output):
mask = Variable(ctx._mask.type_as(grad_output.data))
return (
maybe_unexpand(grad_output * mask, ctx._a_size),
maybe_unexpand_or_view(grad_output * Variable(ctx._mask.eq(0).type_as(grad_output.data)), ctx._b_size)
)
class CminConstant(Function):
@staticmethod
def forward(ctx, i, constant):
ctx._mask = i.lt(constant)
return i.clamp(max=constant)
@staticmethod
def backward(ctx, grad_output):
mask = Variable(ctx._mask.type_as(grad_output.data))
return grad_output * mask, None
class _ConstantGrad(Function):
grad_value = 0
@classmethod
def forward(cls, ctx, *args):
ctx._num_args = len(args)
ctx._args0_size = args[0].size()
return getattr(args[0], cls.__name__.lower())(*args[1:])
@classmethod
def backward(cls, ctx, grad_output):
return (maybe_unexpand(grad_output.mul(cls.grad_value), ctx._args0_size),) + (ctx._num_args - 1) * (None,)
class Floor(_ConstantGrad):
pass
class Ceil(_ConstantGrad):
pass
class Round(_ConstantGrad):
pass
class Sign(_ConstantGrad):
pass
class Trunc(_ConstantGrad):
pass
class Frac(_ConstantGrad):
grad_value = 1
class Fmod(_ConstantGrad):
grad_value = 1
class Remainder(_ConstantGrad):
grad_value = 1
class Lerp(Function):
@staticmethod
def forward(ctx, a, b, weight):
ctx._a_size = a.size()
ctx._b_size = b.size()
ctx._weight = float(weight)
return a.lerp(b, ctx._weight)
@staticmethod
def backward(ctx, grad_output):
return (maybe_unexpand(grad_output.mul(1 - ctx._weight), ctx._a_size),
maybe_unexpand_or_view(grad_output.mul(ctx._weight), ctx._b_size), None)
class Rsqrt(InplaceFunction):
@staticmethod
def forward(ctx, i, inplace=False):
if inplace:
ctx.mark_dirty(i)
result = i.rsqrt_()
else:
result = i.rsqrt()
ctx.save_for_backward(result)
return result
@staticmethod
def backward(ctx, grad_output):
result, = ctx.saved_variables
return result.pow(3).div_(-2).mul(grad_output), None
class Addcmul(InplaceFunction):
@staticmethod
def forward(ctx, add_tensor, mul_tensor1, mul_tensor2, scale=1.0, inplace=False):
ctx._scale = scale
ctx._add_tensor_size = add_tensor.size()
ctx.save_for_backward(mul_tensor1, mul_tensor2)
if inplace:
ctx.mark_dirty(add_tensor)
return add_tensor.addcmul_(scale, mul_tensor1, mul_tensor2)
else:
return add_tensor.addcmul(scale, mul_tensor1, mul_tensor2)
@staticmethod
def backward(ctx, grad_output):
grad_add = grad_mul1 = grad_mul2 = None
mul_tensor1, mul_tensor2 = ctx.saved_variables
if ctx.needs_input_grad[0]:
grad_add = maybe_unexpand(grad_output, ctx._add_tensor_size)
if ctx.needs_input_grad[1]:
grad_mul1 = maybe_unexpand_or_view(grad_output.mul(mul_tensor2).mul_(ctx._scale), mul_tensor1.size())
if ctx.needs_input_grad[2]:
grad_mul2 = maybe_unexpand_or_view(grad_output.mul(mul_tensor1).mul_(ctx._scale), mul_tensor2.size())
return grad_add, grad_mul1, grad_mul2, None, None
class Addcdiv(InplaceFunction):
@staticmethod
def forward(ctx, add_tensor, div_tensor1, div_tensor2, scale=1.0, inplace=False):
ctx._scale = scale
ctx._add_tensor_size = add_tensor.size()
ctx.save_for_backward(div_tensor1, div_tensor2)
if inplace:
ctx.mark_dirty(add_tensor)
return add_tensor.addcdiv_(ctx._scale, div_tensor1, div_tensor2)
else:
return add_tensor.addcdiv(ctx._scale, div_tensor1, div_tensor2)
@staticmethod
def backward(ctx, grad_output):
grad_add = grad_div1 = grad_div2 = None
div_tensor1, div_tensor2 = ctx.saved_variables
if ctx.needs_input_grad[0]:
grad_add = maybe_unexpand(grad_output, ctx._add_tensor_size)
if ctx.needs_input_grad[1]:
grad_div1 = maybe_unexpand_or_view(grad_output.div(div_tensor2).mul_(ctx._scale), div_tensor1.size())
if ctx.needs_input_grad[2]:
div_tensor2_sq = div_tensor2.mul(div_tensor2)
grad_div2 = maybe_unexpand_or_view(grad_output.mul(div_tensor1).div(div_tensor2_sq).mul(-ctx._scale),
div_tensor2.size())
return grad_add, grad_div1, grad_div2, None, None
# TODO: atan2 + inplace