blob: 98cdbfa1d5bdc633db6e018cbb12140f42c18539 [file] [log] [blame]
from itertools import repeat
from ..._thnn import type2backend
from ..function import Function, InplaceFunction
class Exp(InplaceFunction):
def forward(self, i):
if self.inplace:
self.mark_dirty(i)
result = i.exp_()
else:
result = i.exp()
self.save_for_backward(result)
return result
def backward(self, grad_output):
return self.saved_tensors[0] * grad_output
class Log(Function):
def forward(self, i):
self.save_for_backward(i)
return i.log()
def backward(self, grad_output):
return grad_output.div(self.saved_tensors[0])
class Log1p(Function):
def forward(self, i):
self.save_for_backward(i)
return i.log1p()
def backward(self, grad_output):
return grad_output.div(self.saved_tensors[0].add(1))
class Tanh(InplaceFunction):
def forward(self, i):
if self.inplace:
self.mark_dirty(i)
result = i.tanh_()
else:
result = i.tanh()
self.save_for_backward(result)
return result
def backward(self, grad_output):
result, = self.saved_tensors
grad_input = grad_output.new()
backend = type2backend[type(result)]
backend.Tanh_updateGradInput(backend.library_state, None, grad_output,
grad_input, result)
return grad_input
class Sigmoid(InplaceFunction):
def forward(self, i):
if self.inplace:
self.mark_dirty(i)
result = i.sigmoid_()
else:
result = i.sigmoid()
self.save_for_backward(result)
return result
def backward(self, grad_output):
result, = self.saved_tensors
grad_input = grad_output.new()
backend = type2backend[type(result)]
backend.Sigmoid_updateGradInput(backend.library_state, None, grad_output,
grad_input, result)
return grad_input
class Sinh(Function):
def forward(self, i):
self.save_for_backward(i)
return i.sinh()
def backward(self, grad_output):
i, = self.saved_tensors
return grad_output * i.cosh()
class Cosh(Function):
def forward(self, i):
self.save_for_backward(i)
return i.cosh()
def backward(self, grad_output):
i, = self.saved_tensors
return grad_output * i.sinh()
class Abs(Function):
def forward(self, i):
self.save_for_backward(i)
return i.abs()
def backward(self, grad_output):
i, = self.saved_tensors
return grad_output * i.sign()
class Clamp(Function):
def __init__(self, min_val, max_val):
super(Clamp, self).__init__()
self.min_val = min_val
self.max_val = max_val
def forward(self, i):
self.save_for_backward(i)
return i.clamp(self.min_val, self.max_val)
def backward(self, grad_output):
i, = self.saved_tensors
mask = i.ge(self.min_val) * i.le(self.max_val)
return grad_output * mask.type_as(grad_output)
class Sqrt(Function):
def forward(self, i):
self.save_for_backward(i)
return i.sqrt()
def backward(self, grad_output):
i, = self.saved_tensors
return grad_output.mul(i.pow(-0.5)).div(2)
class Sin(Function):
def forward(self, i):
self.save_for_backward(i)
return i.sin()
def backward(self, grad_output):
i, = self.saved_tensors
return grad_output * i.cos()
class Cos(Function):
def forward(self, i):
self.save_for_backward(i)
return i.cos()
def backward(self, grad_output):
i, = self.saved_tensors
return grad_output.mul(i.sin()).neg_()
class Tan(Function):
def forward(self, i):
self.save_for_backward(i)
return i.tan()
def backward(self, grad_output):
i, = self.saved_tensors
return grad_output.div(i.cos().pow(2))
class Asin(Function):
def forward(self, i):
self.save_for_backward(i)
return i.asin()
def backward(self, grad_output):
i, = self.saved_tensors
return grad_output * (1 - i.mul(i)).sqrt_().reciprocal_()
class Acos(Function):
def forward(self, i):
self.save_for_backward(i)
return i.acos()
def backward(self, grad_output):
i, = self.saved_tensors
return grad_output.mul((1 - i.mul(i)).sqrt_().reciprocal_()).neg_()
class Atan(Function):
def forward(self, i):
self.save_for_backward(i)
return i.atan()
def backward(self, grad_output):
i, = self.saved_tensors
return grad_output * i.mul(i).add_(1).reciprocal_()
class Reciprocal(Function):
def forward(self, i):
result = i.reciprocal()
self.save_for_backward(result)
return result
def backward(self, grad_output):
result, = self.saved_tensors
return grad_output * result.mul(result).neg_()
class Cmax(Function):
def forward(self, a, b):
self._max_buffer = a.gt(b).type_as(a)
return a.max(b)
def backward(self, grad_output):
return (
grad_output * self._max_buffer,
grad_output * self._max_buffer.eq(0).type_as(grad_output)
)
class CmaxConstant(Function):
def __init__(self, constant):
super(CmaxConstant, self).__init__()
self.constant = constant
def forward(self, i):
self._max_buffer = i.gt(self.constant).type_as(i)
return i.clamp(min=self.constant)
def backward(self, grad_output):
return grad_output * self._max_buffer
class Cmin(Function):
def forward(self, a, b):
self._min_buffer = a.lt(b).type_as(a)
return a.min(b)
def backward(self, grad_output):
return (
grad_output * self._min_buffer,
grad_output * self._min_buffer.eq(0).type_as(grad_output)
)
class CminConstant(Function):
def __init__(self, constant):
super(CminConstant, self).__init__()
self.constant = constant
def forward(self, i):
self._min_buffer = i.lt(self.constant).type_as(i)
return i.clamp(max=self.constant)
def backward(self, grad_output):
return grad_output * self._min_buffer
class _ConstantGrad(Function):
grad_value = 0
def __init__(self, *args):
super(_ConstantGrad, self).__init__()
self.args = args
def forward(self, i):
return getattr(i, type(self).__name__.lower())(*self.args)
def backward(self, grad_output):
grad_input = grad_output.new(*repeat(1, grad_output.dim()))
grad_input = grad_input.fill_(self.grad_value).expand_as(grad_output)
return grad_input.mul(grad_output)
class Floor(_ConstantGrad):
pass
class Ceil(_ConstantGrad):
pass
class Round(_ConstantGrad):
pass
class Sign(_ConstantGrad):
pass
class Trunc(_ConstantGrad):
pass
class Frac(_ConstantGrad):
grad_value = 1
class Fmod(_ConstantGrad):
grad_value = 1
class Remainder(_ConstantGrad):
grad_value = 1
class Lerp(Function):
def __init__(self, weight):
super(Lerp, self).__init__()
self.weight = float(weight)
def forward(self, a, b):
return a.lerp(b, self.weight)
def backward(self, grad_output):
return grad_output.mul(1 - self.weight), grad_output.mul(self.weight)
class Rsqrt(InplaceFunction):
def forward(self, input):
if self.inplace:
self.mark_dirty(input)
result = input.rsqrt_()
else:
result = input.rsqrt()
self.save_for_backward(result)
return result
def backward(self, grad_output):
result, = self.saved_tensors
return result.pow(3).div_(-2).mul_(grad_output)
class Addcmul(InplaceFunction):
def __init__(self, scale=1, inplace=False):
super(Addcmul, self).__init__(inplace)
self.scale = scale
def forward(self, add_tensor, mul_tensor1, mul_tensor2):
self.save_for_backward(mul_tensor1, mul_tensor2)
if self.inplace:
return add_tensor.addcmul_(self.scale, mul_tensor1, mul_tensor2)
else:
return add_tensor.addcmul(self.scale, mul_tensor1, mul_tensor2)
def backward(self, grad_output):
grad_add = grad_mul1 = grad_mul2 = None
mul_tensor1, mul_tensor2 = self.saved_tensors
if self.needs_input_grad[0]:
grad_add = grad_output
if self.needs_input_grad[1]:
grad_mul1 = grad_output.mul(mul_tensor2).mul(self.scale)
if self.needs_input_grad[2]:
grad_mul2 = grad_output.mul(mul_tensor1).mul(self.scale)
return grad_add, grad_mul1, grad_mul2
class Addcdiv(InplaceFunction):
def __init__(self, scale=1, inplace=False):
super(Addcdiv, self).__init__(inplace)
self.scale = scale
def forward(self, add_tensor, div_tensor1, div_tensor2):
self.save_for_backward(div_tensor1, div_tensor2)
if self.inplace:
return add_tensor.addcdiv_(self.scale, div_tensor1, div_tensor2)
else:
return add_tensor.addcdiv(self.scale, div_tensor1, div_tensor2)
def backward(self, grad_output):
grad_add = grad_div1 = grad_div2 = None
div_tensor1, div_tensor2 = self.saved_tensors
if self.needs_input_grad[0]:
grad_add = grad_output
if self.needs_input_grad[1]:
grad_div1 = grad_output.div(div_tensor2).mul(self.scale)
if self.needs_input_grad[2]:
div_tensor2_sq = div_tensor2.mul(div_tensor2)
grad_div2 = grad_output.mul(div_tensor1).div_(div_tensor2_sq)
grad_div2.neg_().mul_(self.scale)
return grad_add, grad_div1, grad_div2
# TODO: atan2 + inplace