blob: ffc0d124f99d2033a3700855fba2f2c5965da8b9 [file] [log] [blame]
import torch
from ..function import Function
class _CompareOp(Function):
def __init__(self, scalar=None):
super(_CompareOp, self).__init__()
self.scalar = scalar
def forward(self, tensor1, tensor2=None):
other = tensor2 if tensor2 is not None else self.scalar
mask = getattr(tensor1, self.fn_name)(other)
self.mark_non_differentiable(mask)
return mask
class Eq(_CompareOp):
fn_name = 'eq'
class Ne(_CompareOp):
fn_name = 'ne'
class Gt(_CompareOp):
fn_name = 'gt'
class Ge(_CompareOp):
fn_name = 'ge'
class Lt(_CompareOp):
fn_name = 'lt'
class Le(_CompareOp):
fn_name = 'le'