blob: 594de7abe99b914d672601dd28f0aa9ee8c4c551 [file] [log] [blame]
from functools import reduce
from ..function import Function
from ..variable import Variable
class Sum(Function):
@staticmethod
def forward(ctx, input, dim=None, keepdim=False):
ctx.dim = dim
ctx.keepdim = keepdim
ctx.input_size = input.size()
if dim is None:
return input.new((input.sum(),))
else:
return input.sum(dim, keepdim)
@staticmethod
def backward(ctx, grad_output):
if ctx.dim is None:
return grad_output.expand(ctx.input_size), None, None
else:
if ctx.keepdim is False and len(ctx.input_size) != 1:
grad_output = grad_output.unsqueeze(ctx.dim)
repeats = [1 for _ in ctx.input_size]
repeats[ctx.dim] = ctx.input_size[ctx.dim]
return grad_output.repeat(*repeats), None, None
class Prod(Function):
@staticmethod
def forward(ctx, input, dim=None, keepdim=False):
ctx.dim = dim
ctx.keepdim = keepdim
ctx.input_size = input.size()
if dim is None:
ctx.result = input.prod()
ctx.save_for_backward(input)
return input.new((ctx.result,))
else:
output = input.prod(dim, keepdim)
ctx.save_for_backward(input, output)
return output
@staticmethod
def backward(ctx, grad_output):
if ctx.dim is None:
input, = ctx.saved_variables
zero_idx = (input.data == 0).nonzero()
if zero_idx.dim() == 0:
return grad_output.mul(ctx.result).expand_as(input).div(input), None, None
elif zero_idx.size(0) > 1:
return (grad_output * 0).expand_as(input), None, None
else:
grad_input = Variable(grad_output.data.new(ctx.input_size).zero_())
zero_idx = tuple(zero_idx[0].cpu())
to_add = input.data.new(ctx.input_size).zero_()
to_add[zero_idx] = 1.
grad_input[zero_idx] = grad_output * (input + Variable(to_add)).prod()
return grad_input, None, None
else:
input, output = ctx.saved_variables
dim = ctx.dim if ctx.dim >= 0 else ctx.dim + input.dim()
if ctx.keepdim is False and len(ctx.input_size) != 1:
grad_output = grad_output.unsqueeze(dim)
output = output.unsqueeze(dim)
zero_mask = input == 0
slice_zero_count = zero_mask.sum(dim, True)
total_zeros = slice_zero_count.data.sum()
grad_input = grad_output.mul(output).expand_as(input).div(input)
if total_zeros == 0:
return grad_input, None, None
some_zeros = slice_zero_count.gt(0).expand_as(grad_input)
grad_input[some_zeros] = 0
single_zero_idx = slice_zero_count.data.eq(1).nonzero()
if len(single_zero_idx) == 0:
return grad_input, None, None
for idx in single_zero_idx:
idx_tuple = tuple(idx.cpu())
input_idx_tuple = idx_tuple[:dim] + (slice(0, None),) + idx_tuple[dim + 1:]
# slice_mask and input_copy are 1D
slice_mask = zero_mask[input_idx_tuple]
input_copy = input[input_idx_tuple].clone()
zero_idx = slice_mask.data.nonzero()[0, 0]
input_copy[zero_idx] = 1.
grad_idx_tuple = idx_tuple[:dim] + (zero_idx,) + idx_tuple[dim + 1:]
grad_input[grad_idx_tuple] = grad_output[idx_tuple] * input_copy.prod()
return grad_input, None, None
class Mean(Function):
@staticmethod
def forward(ctx, input, dim=None, keepdim=False):
ctx.dim = dim
ctx.keepdim = keepdim
ctx.input_size = input.size()
if dim is None:
return input.new((input.mean(),))
else:
return input.mean(dim, keepdim)
@staticmethod
def backward(ctx, grad_output):
if ctx.dim is None:
grad_input_val = grad_output / reduce(lambda x, y: x * y, ctx.input_size, 1)
return grad_input_val.expand(ctx.input_size), None, None
else:
if ctx.keepdim is False and len(ctx.input_size) != 1:
grad_output = grad_output.unsqueeze(ctx.dim)
repeats = [1 for _ in ctx.input_size]
dim_size = ctx.input_size[ctx.dim]
repeats[ctx.dim] = dim_size
return grad_output.repeat(*repeats).div_(dim_size), None, None
class _SelectionFunction(Function):
has_all_reduce = True
# additional_args is prepended before dim when calling the tensor
# function. It's a no-op for subclasses other than kthvalue.
# kthvalue not only requires us to pass a dim, but also preceed it with k.
@classmethod
def forward(cls, ctx, input, dim=None, keepdim=False, additional_args=tuple()):
fn = getattr(input, cls.__name__.lower())
ctx.dim = dim
ctx.keepdim = keepdim
ctx.additional_args = additional_args
ctx.input_size = input.size()
if ctx.dim is None and cls.has_all_reduce:
value = fn(*additional_args)
ctx.indices_tuple = tuple(input.eq(value).nonzero()[0])
return input.new((value,))
else:
if ctx.dim is None:
dim = input.dim() - 1
else:
dim = ctx.dim
args = (dim, keepdim)
if additional_args:
args = additional_args + args
output, indices = fn(*args)
ctx.save_for_backward(indices)
ctx.mark_non_differentiable(indices)
return output, indices
@classmethod
def backward(cls, ctx, grad_output, grad_indices=None):
grad_input = Variable(grad_output.data.new(*ctx.input_size).zero_())
if ctx.dim is None and cls.has_all_reduce:
grad_input[ctx.indices_tuple] = grad_output
else:
if ctx.dim is None:
dim = len(ctx.input_size) - 1
else:
dim = ctx.dim
indices, = ctx.saved_variables
if ctx.keepdim is False and len(ctx.input_size) != 1:
grad_output = grad_output.unsqueeze(dim)
grad_indices = grad_indices.unsqueeze(dim)
indices = indices.unsqueeze(dim)
grad_input.scatter_(dim, indices, grad_output)
return grad_input, None, None, None
class Max(_SelectionFunction):
pass
class Min(_SelectionFunction):
pass
class Mode(_SelectionFunction):
has_all_reduce = False
class Median(_SelectionFunction):
has_all_reduce = False
class Kthvalue(_SelectionFunction):
has_all_reduce = False
@classmethod
def forward(cls, ctx, input, k, dim=None, keepdim=False):
return super(Kthvalue, cls).forward(ctx, input, dim, keepdim, (k,))
class Norm(Function):
@staticmethod
def forward(ctx, input, p=2, dim=None, keepdim=False):
ctx.p = p
ctx.dim = dim
ctx.keepdim = keepdim
if dim is None:
ctx.norm = input.norm(p)
ctx.save_for_backward(input)
return input.new((ctx.norm,))
else:
output = input.norm(p, dim, keepdim)
ctx.save_for_backward(input, output)
return output
@staticmethod
def backward(ctx, grad_output):
if ctx.dim is None:
input, = ctx.saved_variables
if ctx.p == 2:
scale_v = (grad_output / ctx.norm).expand_as(input)
return input.mul(scale_v), None, None, None
else:
pow = input.abs().pow(ctx.p - 2)
scale_v = (grad_output / ctx.norm ** (ctx.p - 1)).expand_as(input)
return input.mul(pow).mul(scale_v), None, None, None
else:
input, output = ctx.saved_variables
if ctx.keepdim is False and input.dim() != 1:
grad_output = grad_output.unsqueeze(ctx.dim)
output = output.unsqueeze(ctx.dim)
big_grad_output = grad_output.expand_as(input)
if ctx.p == 2:
big_output = output.expand_as(input)
return input.mul(big_grad_output).div(big_output), None, None, None
else:
pow = input.abs().pow(ctx.p - 2)
big_output = output.pow(ctx.p - 1).expand_as(input)
return input.mul(pow).mul(big_grad_output).div(big_output), None, None, None
# TODO: renorm
# TODO: std
# TODO: var