blob: ff9392b683b99983add2ac92acb355e37e9a08c3 [file] [log] [blame]
import torch
from ..function import Function
from .utils import maybe_unexpand, maybe_unexpand_or_view
# TODO: once Cpp-style functions are implemented we can detach a and b
# before calling forward.
class _CompareOp(Function):
@classmethod
def forward(cls, ctx, a, b):
ctx.a_size = a.size()
ctx.b_tensor = torch.is_tensor(b)
ctx.b_size = b.size() if ctx.b_tensor else None
ctx.input_type = type(a)
mask = getattr(a, cls.fn_name)(b)
ctx.mark_non_differentiable(mask)
return mask
@staticmethod
def backward(ctx, grad_output):
grad_input = (grad_output * 0).type(ctx.input_type)
return (maybe_unexpand(grad_input, ctx.a_size),
maybe_unexpand_or_view(grad_input, ctx.b_size) if ctx.b_tensor else None)
class Eq(_CompareOp):
fn_name = 'eq'
class Ne(_CompareOp):
fn_name = 'ne'
class Gt(_CompareOp):
fn_name = 'gt'
class Ge(_CompareOp):
fn_name = 'ge'
class Lt(_CompareOp):
fn_name = 'lt'
class Le(_CompareOp):
fn_name = 'le'