blob: b92d71ce6694fbc7aecf70c6ff3e5f04b0d5fd64 [file] [log] [blame]
import torch
from torch.testing import FileCheck
from torch.testing._internal.common_utils import \
(run_tests)
from torch.testing._internal.jit_utils import JitTestCase
from torch.testing._internal.common_device_type import \
(instantiate_device_type_tests, skipCPUIfNoLapack, skipCUDAIfNoMagma, onlyCPU)
from collections.abc import Sequence
# Information for generating an alias test
# NOTE: ending the alias_name with an underscore will interpret the test
# as the test for an inplace method of that name
class AliasInfo(object):
__slots__ = ['alias_name', 'alias_op', 'original_name', 'original_op',
'get_input', 'get_args', 'decorators']
def __init__(self,
alias_name, # the name of the alias
alias_op, # the aliased op
original_name, # the name of the original function
original_op, # the original op
get_input, # callable (device)->tensor that returns the first tensor argument
*,
get_args=lambda d: (), # callable (device)->tuple that returns additional positional arguments
decorators=()): # decorators to apply to the test
self.alias_name = alias_name
self.alias_op = alias_op
self.original_name = original_name
self.original_op = original_op
self.get_input = get_input
self.get_args = get_args
self.decorators = decorators
alias_infos = (
AliasInfo('linalg_det', torch.linalg.det, 'det', torch.det,
lambda d: torch.randn(10, 10, device=d),
decorators=(skipCPUIfNoLapack, skipCUDAIfNoMagma)),
# NOTE: only runs on CPU because it leaks CUDA memory
# (see https://github.com/pytorch/pytorch/issues/43119)
AliasInfo('ger', torch.ger, 'outer', torch.outer,
lambda d: torch.randn(20, device=d), get_args=lambda d: (torch.randn(20, device=d),),
decorators=(onlyCPU,)),
AliasInfo('subtract', torch.subtract, 'sub', torch.sub,
lambda d: torch.randn(20, device=d),
get_args=lambda d: (torch.randn(20, device=d),),
decorators=(onlyCPU,)),
AliasInfo('subtract_', torch.Tensor.subtract_, 'sub_', torch.Tensor.sub_,
lambda d: torch.randn(20, device=d),
get_args=lambda d: (torch.randn(20, device=d),),
decorators=(onlyCPU,)),
AliasInfo('greater_equal', torch.greater_equal, 'ge', torch.ge,
lambda d: torch.randn(20, device=d),
get_args=lambda d: (torch.randn(20, device=d),),
decorators=(onlyCPU,)),
AliasInfo('greater_equal_', torch.Tensor.greater_equal_, 'ge_', torch.Tensor.ge_,
lambda d: torch.randn(20, device=d),
get_args=lambda d: (torch.randn(20, device=d),),
decorators=(onlyCPU,)),
AliasInfo('greater', torch.greater, 'gt', torch.gt,
lambda d: torch.randn(20, device=d),
get_args=lambda d: (torch.randn(20, device=d),),
decorators=(onlyCPU,)),
AliasInfo('greater_', torch.Tensor.greater_, 'gt_', torch.Tensor.gt_,
lambda d: torch.randn(20, device=d),
get_args=lambda d: (torch.randn(20, device=d),),
decorators=(onlyCPU,)),
AliasInfo('less_equal', torch.less_equal, 'le', torch.le,
lambda d: torch.randn(20, device=d),
get_args=lambda d: (torch.randn(20, device=d),),
decorators=(onlyCPU,)),
AliasInfo('less_equal_', torch.Tensor.less_equal_, 'le_', torch.Tensor.less_equal_,
lambda d: torch.randn(20, device=d),
get_args=lambda d: (torch.randn(20, device=d),),
decorators=(onlyCPU,)),
AliasInfo('less', torch.less, 'lt', torch.lt,
lambda d: torch.randn(20, device=d),
get_args=lambda d: (torch.randn(20, device=d),),
decorators=(onlyCPU,)),
AliasInfo('less_', torch.Tensor.less_, 'lt_', torch.Tensor.lt_,
lambda d: torch.randn(20, device=d),
get_args=lambda d: (torch.randn(20, device=d),),
decorators=(onlyCPU,)),
AliasInfo('not_equal', torch.not_equal, 'ne', torch.ne,
lambda d: torch.randn(20, device=d),
get_args=lambda d: (torch.randn(20, device=d),),
decorators=(onlyCPU,)),
AliasInfo('not_equal_', torch.Tensor.not_equal_, 'ne_', torch.Tensor.ne_,
lambda d: torch.randn(20, device=d),
get_args=lambda d: (torch.randn(20, device=d),),
decorators=(onlyCPU,)),
# NOTE: only runs on CPU because it leaks CUDA memory
# (see https://github.com/pytorch/pytorch/issues/43119)
AliasInfo('divide', torch.divide, 'div', torch.div,
lambda d: torch.randn(20, device=d), get_args=lambda d: (torch.rand(20, device=d) + .1,),
decorators=(onlyCPU,)),
AliasInfo('divide_', torch.Tensor.divide_, 'div_', torch.Tensor.div_,
lambda d: torch.randn(20, device=d), get_args=lambda d: (torch.rand(20, device=d) + .1,),
decorators=(onlyCPU,)),
# NOTE: only runs on CPU because it leaks CUDA memory
# (see https://github.com/pytorch/pytorch/issues/43119)
AliasInfo('multiply', torch.multiply, 'mul', torch.mul,
lambda d: torch.randn(20, device=d), get_args=lambda d: (torch.rand(20, device=d),),
decorators=(onlyCPU,)),
AliasInfo('multiply_', torch.Tensor.multiply_, 'mul_', torch.Tensor.mul_,
lambda d: torch.randn(20, device=d), get_args=lambda d: (torch.rand(20, device=d),),
decorators=(onlyCPU,)),
AliasInfo('true_divide', torch.true_divide, 'div', torch.div,
lambda d: torch.randn(20, device=d), get_args=lambda d: (torch.rand(20, device=d) + .1,),
decorators=(onlyCPU,)),
AliasInfo('true_divide_', torch.Tensor.true_divide_, 'div_', torch.Tensor.div_,
lambda d: torch.randn(20, device=d), get_args=lambda d: (torch.rand(20, device=d) + .1,),
decorators=(onlyCPU,)),
AliasInfo('swapdims', torch.swapdims, 'transpose', torch.transpose,
lambda d: torch.randn(20, 3, 2, 1, device=d), get_args=lambda d: (3, 1)),
AliasInfo('swapdims_', torch.Tensor.swapdims_, 'transpose_', torch.Tensor.transpose_,
lambda d: torch.randn(20, 3, 2, 1, device=d), get_args=lambda d: (3, 1)),
AliasInfo('swapaxes', torch.swapaxes, 'transpose', torch.transpose,
lambda d: torch.randn(20, 3, 2, 1, device=d), get_args=lambda d: (3, 1)),
AliasInfo('swapaxes_', torch.Tensor.swapaxes_, 'transpose_', torch.Tensor.transpose_,
lambda d: torch.randn(20, 3, 2, 1, device=d), get_args=lambda d: (3, 1)),
AliasInfo('row_stack', torch.row_stack, 'vstack', torch.vstack,
lambda d: ((torch.randn(20, device=d), torch.randn(20, device=d)))),
AliasInfo('moveaxis', torch.moveaxis, 'movedim', torch.movedim,
lambda d: torch.randn(20, 3, 2, 1, device=d), get_args=lambda d: (3, 1)),
)
# Placeholder test class for validating that aliases are correctly
# translated when scripted and traced
class TestOpNormalization(JitTestCase):
pass
# Clone input tensor and sequence of Tensors
def clone_inp(inp):
if isinstance(inp, Sequence):
return list(map(torch.clone, inp))
else:
return inp.clone()
# Generates alias tests and adds them to the specified class (cls)
def create_alias_tests(cls):
for info in alias_infos:
# Tests that the JIT remaps aliases to their original ops
def _test_jit_op_alias_normalization(self, device, info=info):
tensor = torch.tensor
op = info.alias_op
is_inplace = info.alias_name.endswith('_')
# Checks that scripting converts aliases
# NOTE: the code to test scripting must be generated since
# scripting does not support splatting args or directly
# calling torch.Tensor methods. The following
# splats args after the first tensor by inlining them as constants.
if is_inplace:
fn_template = '''
def _fn(t):
return t.{alias_name}({args})
'''
arg_string = ', '.join((str(arg) for arg in info.get_args(device)))
script = fn_template.format(alias_name=info.alias_name, args=arg_string)
else:
is_input_tensor_list = isinstance(info.get_input(device), Sequence)
# For sequence of Tensors, annotate the type to be List[Tensor]
if is_input_tensor_list:
fn_template = '''
def _fn(t: List[Tensor]):
return op(t{args})
'''
else:
fn_template = '''
def _fn(t):
return op(t{args})
'''
arg_string = ", " + ', '.join((str(arg) for arg in info.get_args(device)))
script = fn_template.format(args=arg_string)
# Compiles script
scripted = torch.jit.CompilationUnit(script)._fn
# Acquires and checks the graph remaps the alias
inp = info.get_input(device)
scripted(clone_inp(inp))
graph = scripted.graph_for(clone_inp(inp))
FileCheck().check(info.original_name).check_not(info.alias_name).run(graph)
# Checks that tracing converts aliases
# NOTE: tracing has no problem splatting args
args = info.get_args(device)
def _fn(t, info=info, args=args):
return info.alias_op(t, *args)
traced = torch.jit.trace(_fn, (clone_inp(inp),))
traced(clone_inp(inp))
graph = traced.graph_for(clone_inp(inp))
FileCheck().check(info.original_name).check_not(info.alias_name).run(graph)
# Applies decorators
for decorator in info.decorators:
_test_jit_op_alias_normalization = decorator(_test_jit_op_alias_normalization)
test_name = "test_jit_op_alias_normalization_" + info.alias_name
setattr(cls, test_name, _test_jit_op_alias_normalization)
# Tests that the alias functions perform the same operation as the original
def _test_alias_computation(self, device, info=info):
alias_op = info.alias_op
original_op = info.original_op
inp = info.get_input(device)
args = info.get_args(device)
alias_input = clone_inp(inp)
alias_result = alias_op(alias_input, *args)
original_input = clone_inp(inp)
original_result = alias_op(original_input, *args)
self.assertEqual(alias_input, original_input, atol=0, rtol=0)
self.assertEqual(alias_result, original_result, atol=0, rtol=0)
# Applies decorators
for decorator in info.decorators:
_test_alias_computation = decorator(_test_alias_computation)
test_name = "test_alias_computation_" + info.alias_name
setattr(cls, test_name, _test_alias_computation)
create_alias_tests(TestOpNormalization)
instantiate_device_type_tests(TestOpNormalization, globals())
if __name__ == '__main__':
run_tests()