blob: c14238d0db6e538875e44406f42338d076f49512 [file] [log] [blame]
from functools import partial, wraps, reduce
import warnings
import torch
from torch.testing import \
(FileCheck, floating_and_complex_types_and)
from torch.testing._internal.common_utils import \
(TestCase, run_tests, IS_SANDCASTLE, clone_input_helper, make_tensor)
from torch.testing._internal.common_methods_invocations import \
(op_db)
from torch.testing._internal.common_device_type import \
(instantiate_device_type_tests, ops, onlyOnCPUAndCUDA, skipCUDAIfRocm, OpDTypes)
from torch.testing._internal.common_jit import JitCommonTestCase, check_against_reference
from torch.autograd.gradcheck import gradcheck, gradgradcheck
from torch.testing._internal.jit_metaprogramming_utils import create_script_fn, create_traced_fn, \
check_alias_annotation
from torch.testing._internal.jit_utils import disable_autodiff_subgraph_inlining
# Tests that apply to all operators
class TestOpInfo(TestCase):
exact_dtype = True
# Verifies that ops have their unsupported dtypes
# registered correctly by testing that each claimed unsupported dtype
# throws a runtime error
@skipCUDAIfRocm
@onlyOnCPUAndCUDA
@ops(op_db, dtypes=OpDTypes.unsupported)
def test_unsupported_dtypes(self, device, dtype, op):
# sample_inputs can have a function for generating the input that doesn't work for specified dtype
# https://github.com/pytorch/pytorch/issues/49024
with self.assertRaises(RuntimeError):
samples = op.sample_inputs(device, dtype)
if len(samples) == 0:
self.skipTest("Skipped! No sample inputs!")
# NOTE: only tests on first sample
sample = samples[0]
op(*sample.input, *sample.args, **sample.kwargs)
# Verifies that ops have their supported dtypes
# registered correctly by testing that each claimed supported dtype
# does NOT throw a runtime error
@onlyOnCPUAndCUDA
@ops(op_db, dtypes=OpDTypes.supported)
def test_supported_dtypes(self, device, dtype, op):
samples = op.sample_inputs(device, dtype)
if len(samples) == 0:
self.skipTest("Skipped! No sample inputs!")
# NOTE: only tests on first sample
sample = samples[0]
op(*sample.input, *sample.args, **sample.kwargs)
# gradcheck requires double precision
_gradcheck_ops = partial(ops, dtypes=OpDTypes.supported,
allowed_dtypes=[torch.double, torch.cdouble])
class TestGradients(TestCase):
exact_dtype = True
# Copies inputs to inplace operations to avoid inplace modifications
# to leaves requiring gradient
def _get_safe_inplace(self, inplace_variant):
@wraps(inplace_variant)
def _fn(t, *args, **kwargs):
return inplace_variant(t.clone(), *args, **kwargs)
return _fn
def _check_helper(self, device, dtype, op, variant, check):
if variant is None:
self.skipTest("Skipped! Variant not implemented.")
if not op.supports_dtype(dtype, torch.device(device).type):
self.skipTest(f"Skipped! {op.name} does not support dtype {str(dtype)}")
samples = op.sample_inputs(device, dtype, requires_grad=True)
for sample in samples:
if sample.output_process_fn_grad is not None:
out_fn = sample.output_process_fn_grad
def variant_out_fn(*args, **kwargs):
return out_fn(variant(*args, **kwargs))
else:
variant_out_fn = variant
def fn(*inputs):
output = variant_out_fn(*inputs, **sample.kwargs)
return op.output_func(output)
if check == 'gradcheck':
self.assertTrue(gradcheck(fn, (*sample.input,) + sample.args,
check_batched_grad=op.check_batched_grad,
check_grad_dtypes=True))
elif check == 'gradgradcheck':
self.assertTrue(gradgradcheck(fn, (*sample.input,) + sample.args,
gen_non_contig_grad_outputs=False,
check_batched_grad=op.check_batched_gradgrad,
check_grad_dtypes=True))
self.assertTrue(gradgradcheck(fn, (*sample.input,) + sample.args,
gen_non_contig_grad_outputs=True,
check_batched_grad=op.check_batched_gradgrad,
check_grad_dtypes=True))
else:
self.assertTrue(False, msg="Unknown check requested!")
def _grad_test_helper(self, device, dtype, op, variant):
return self._check_helper(device, dtype, op, variant, 'gradcheck')
def _gradgrad_test_helper(self, device, dtype, op, variant):
return self._check_helper(device, dtype, op, variant, 'gradgradcheck')
def _skip_helper(self, op, dtype):
if not op.supports_autograd:
self.skipTest("Skipped! autograd not supported")
if not op.test_complex_grad and dtype.is_complex:
self.skipTest("Skipped! complex grad tests marked to skip.")
# Tests that gradients are computed correctly
@_gradcheck_ops(op_db)
def test_fn_grad(self, device, dtype, op):
self._skip_helper(op, dtype)
self._grad_test_helper(device, dtype, op, op.get_op())
# Method grad (and gradgrad, see below) tests are disabled since they're
# costly and redundant with function grad (and gradgad) tests
# @_gradcheck_ops(op_db)
# def test_method_grad(self, device, dtype, op):
# self._skip_helper(op, dtype)
# self._grad_test_helper(device, dtype, op, op.get_method())
@_gradcheck_ops(op_db)
def test_inplace_grad(self, device, dtype, op):
self._skip_helper(op, dtype)
if not op.test_inplace_grad:
self.skipTest("Skipped! Inplace gradcheck marked to skip.")
self._grad_test_helper(device, dtype, op, self._get_safe_inplace(op.get_inplace()))
# Test that gradients of gradients are computed correctly
@_gradcheck_ops(op_db)
def test_fn_gradgrad(self, device, dtype, op):
self._skip_helper(op, dtype)
self._gradgrad_test_helper(device, dtype, op, op.get_op())
# Method gradgrad (and grad, see above) tests are disabled since they're
# costly and redundant with function gradgrad (and grad) tests
# @_gradcheck_ops(op_db)
# def test_method_gradgrad(self, device, dtype, op):
# self._skip_helper(op, dtype)
# self._gradgrad_test_helper(device, dtype, op, op.get_method())
@_gradcheck_ops(op_db)
def test_inplace_gradgrad(self, device, dtype, op):
self._skip_helper(op, dtype)
if not op.test_inplace_grad:
self.skipTest("Skipped! Inplace gradgradcheck marked to skip.")
self._gradgrad_test_helper(device, dtype, op, self._get_safe_inplace(op.get_inplace()))
# Tests operators for consistency between JIT and eager, also checks
# correctness of JIT specific alias schemas and intended
# autodifferentiation behavior.
# Inherits from JitCommonTestCase instead of TestCase directly to share
# functionality with original test_jit.py method operator tests
class TestCommon(JitCommonTestCase):
exact_dtype = True
# Compares variant's backward
# NOTE: verifies it fails when the forward fails
def check_variant_backward(self, input, forward_result, expected_grad, expected_exception):
variant_exception_during_backwards = False
try:
forward_result.sum().backward()
variant_grad = input.grad
input.grad = None
except Exception as e:
if not expected_exception:
self.fail("Unexpected exception during backwards!")
variant_exception_during_backwards = True
if expected_exception != variant_exception_during_backwards:
self.fail("Unexpected success during backwards!")
if not expected_exception:
self.assertEqual(variant_grad, expected_grad)
# Tests that the forward and backward passes of operations produce the
# same values for the cross-product of op variants (method, inplace)
# against eager's gold standard op function variant
@ops(op_db)
def test_variant_consistency_eager(self, device, dtype, op):
test_backward = op.supports_autograd and (op.test_complex_grad or not dtype.is_complex)
samples = op.sample_inputs(device, dtype, requires_grad=test_backward)
if len(samples) == 0:
self.skipTest("Skipped! No sample inputs!")
for sample in samples:
# Acquires variants to test
method = op.get_method()
inplace = op.get_inplace()
inplace_ops = [inplace, ] # list of all inplace ops: inplace variant + alias inplace variants if exist
aliases = []
for a_op in op.aliases:
aliases.append(a_op.op)
aliases.append(a_op.method_variant)
aliases.append(a_op.inplace_variant)
inplace_ops.append(a_op.inplace_variant)
aliases = tuple(aliases)
inplace_ops = tuple(v for v in inplace_ops if v is not None)
variants = (v for v in (method, inplace) + aliases if v is not None)
# Computes expected forward
# below calls op's function variant
expected_forward = op(*sample.input, *sample.args, **sample.kwargs)
# Computes expected backward
# NOTE: backward may fail for some dtypes
exception_during_backwards = False
expected_grad = None
try:
expected_forward.sum().backward()
expected_grad = sample.input.grad
sample.input.grad = None
except Exception as e:
exception_during_backwards = True
# Test eager consistency
for variant in variants:
# Verifies that inplace operations that promote int->float fail
# on tensors with integer dtypes.
if (variant in inplace_ops and not torch.can_cast(expected_forward.dtype, dtype)):
try:
variant_forward = variant(*(clone_input_helper(input) for input in sample.input),
*sample.args,
**sample.kwargs)
except Exception as e:
continue
self.fail("Inplace operation on integer tensor that should be promoted to float didn't fail!")
# Compares variant's forward
# Note: copy the tensor-type inputs when testing inplace operation
variant_forward = variant(*(clone_input_helper(input) if variant in inplace_ops else input
for input in sample.input),
*sample.args,
**sample.kwargs)
self.assertEqual(variant_forward, expected_forward)
# Compares variant's backward
if test_backward and (variant not in inplace_ops or op.test_inplace_grad):
self.check_variant_backward(sample.input, variant_forward,
expected_grad, exception_during_backwards)
# Tests that the forward and backward passes of operations produce the
# same values for the cross-product of op variants (function, method, inplace)
# and runtimes (eager, traced, scripted).
# TODO WARNING: inplace x {traced, scripted} not currently tested
@ops(op_db)
def test_variant_consistency_jit(self, device, dtype, op):
test_backward = op.supports_autograd and (
(dtype.is_complex and op.test_complex_grad) or
(dtype.is_floating_point and (not op.skip_bfloat16_grad or dtype != torch.bfloat16)))
samples = op.sample_inputs(device, dtype, requires_grad=test_backward)
if len(samples) == 0:
self.skipTest("Skipped! No sample inputs!")
for sample in samples:
# Acquires variants to test
func = op.get_op()
method = op.get_method()
inplace = op.get_inplace()
variants = {
'function': func, 'method': method,
# TODO: inplace tests currently fail
# 'inplace': inplace,
}
# Test traced and scripted consistency
for func_type, variant in variants.items():
if variant is None:
continue
# Create accessor for script function variant
name = op.name + '_' if func_type == 'inplace' else op.name
# run with disable_autodiff_subgraph_inlining(True) to test
# autodiff support. Context manager forces the graph to contain
# DifferentiableGraph nodes if they are present
with disable_autodiff_subgraph_inlining():
# Check scripted forward, grad, and grad grad
script_fn = create_script_fn(self, name, func_type)
check_against_reference(self,
script_fn,
func,
op.output_func,
(*sample.input,) + sample.args,
sample.kwargs,
no_grad=not test_backward)
# Check traced forward, grad, and grad grad
traced_fn = create_traced_fn(self, variant)
check_against_reference(self,
traced_fn,
func,
op.output_func,
(*sample.input,) + sample.args,
sample.kwargs,
no_grad=not test_backward)
# Check alias annotation schema for correctness (make
# sure inputs that aren't supposed to be modified aren't)
# Note: only runs in float32 and int64 because schema isn't affected by dtype,
# so running it on all dtypes is would be excessive
if dtype in [torch.float32, torch.int32]:
check_alias_annotation(name, (*sample.input,) + sample.args, sample.kwargs,
func_type=func_type, aten_name=op.aten_name)
# Check autodifferentiation of nodes for traced and scripted graphs, only need to check once per sample
if dtype is torch.float32:
# Sandcastle doesn't fuse nodes
if IS_SANDCASTLE:
# fusible nodes are expected to be found in FusionGroups in the DifferentiableGraphs
nonfusible_nodes = op.autodiff_nonfusible_nodes + op.autodiff_fusible_nodes
fusible_nodes = []
else:
nonfusible_nodes = op.autodiff_nonfusible_nodes
fusible_nodes = op.autodiff_fusible_nodes
self.assertAutodiffNode(traced_fn.last_graph, op.assert_autodiffed, nonfusible_nodes, fusible_nodes)
self.assertAutodiffNode(script_fn.last_graph, op.assert_autodiffed, nonfusible_nodes, fusible_nodes)
@ops([op for op in op_db if op.aliases])
def test_jit_alias_remapping(self, device, dtype, op):
samples = op.sample_inputs(device, dtype, requires_grad=True)
if len(samples) == 0:
self.skipTest("Skipped! No sample inputs!")
# NOTE: only tests on first sample
sample = samples[0]
# Prepare data for test scripting
# Below we prepare strings of args/kwargs with and without type annotations.
# These strings are inserted into function template strings which is then torch scripted.
# - args string is ["t0", "t1", ...] corresponds to the input tensors required by the op
# - args_annot_kw is the string for the template function signature, for example,
# ["t0", "t1", "s0: float", "s1: bool", "max: float = 1.0", "min: float = 0.0"] ->
# def fn(t0, t1, s0: float, s1: bool, max: float = 1.0, min: float = 0.0)
# - args_kw is the string of args/kwargs used to call the op, same as args_annot_kw but
# without type annotations
args = [f"t{i}" for i in range(len(sample.input))]
args_annot_kw = args + \
[f"s{i}: {type(v).__name__}" for i, v in enumerate(sample.args)] + \
[f"{k}: {type(v).__name__} = {v}" for k, v in sample.kwargs.items()]
args_kw = args + \
[f"s{i}" for i in range(len(sample.args))] + \
[f"{k}={v}" for k, v in sample.kwargs.items()]
# Prepare data for test tracing
sample_args_kwargs = ()
if len(sample.args) > 0:
sample_args_kwargs += (sample.args, )
if len(sample.kwargs) > 0:
sample_args_kwargs += (sample.kwargs, )
original_name = op.name
original_name_inplace = original_name + "_"
expected_dtype = op(*sample.input, *sample.args, **sample.kwargs).dtype
for a_op in op.aliases:
inplace = a_op.inplace_variant
method_or_inplace = [a_op.inplace_variant, a_op.method_variant]
variants = (v for v in (a_op.op, a_op.method_variant, a_op.inplace_variant) if v is not None)
# Test scripting:
for variant in variants:
variant_name = variant.__name__
op_name = original_name_inplace if variant is inplace else original_name
if variant in method_or_inplace:
fn_template = '''
def _fn(t0{c}{args_annot_kw}):
return t0.{alias_name}({args_kw})
'''
# remove the first input tensor
script = fn_template.format(
c=", " if len(args_kw[1:]) > 1 else "",
args_annot_kw=", ".join(args_annot_kw[1:]),
args_kw=", ".join(args_kw[1:]),
alias_name=variant_name,
)
else:
fn_template = '''
def _fn({args_annot_kw}):
return variant({args_kw})
'''
script = fn_template.format(
args_annot_kw=", ".join(args_annot_kw),
args_kw=", ".join(args_kw),
)
scripted = torch.jit.CompilationUnit(script)._fn
if (variant is inplace and not torch.can_cast(expected_dtype, dtype)):
try:
inp = (clone_input_helper(input) for input in sample.input)
scripted(*inp, *sample.args, **sample.kwargs)
except Exception as e:
continue
self.fail("Inplace operation on integer tensor that should be promoted to float didn't fail!")
inp = (clone_input_helper(input) for input in sample.input)
scripted(*inp, *sample.args, **sample.kwargs)
inp = (clone_input_helper(input) for input in sample.input)
graph = scripted.graph_for(*inp, *sample.args, **sample.kwargs)
FileCheck().check(op_name).check_not(variant_name).run(graph)
# Test tracing:
for variant in variants:
variant_name = variant.__name__
op_name = original_name_inplace if variant is inplace else original_name
def _fn(*sample_args, **sample_kwargs):
return variant(*sample_args, **sample_kwargs)
inp = (*(clone_input_helper(input) for input in sample.input), ) + sample_args_kwargs
traced = torch.jit.trace(_fn, *inp)
inp = (*(clone_input_helper(input) for input in sample.input), ) + sample_args_kwargs
traced(*inp)
inp = (*(clone_input_helper(input) for input in sample.input), ) + sample_args_kwargs
graph = traced.graph_for(*inp)
FileCheck().check(op_name).check_not(variant_name).run(graph)
# Validates ops implement the correct out= behavior
# See https://github.com/pytorch/pytorch/wiki/Developer-FAQ#how-does-out-work-in-pytorch
# for a description of the correct behavior
# TODO: operations that support out= but don't support float
# are not covered by this test.
@ops(op_db, allowed_dtypes=(torch.float,))
def test_out(self, device, dtype, op):
# TODO: verify the op doesn't support the out= kwarg
if not op.supports_out:
self.skipTest("Skipped! Op doesn't support out= kwarg.")
# NOTE: only tests on first sample
samples = op.sample_inputs(device, dtype)
sample = samples[0]
# calls it normally to get the expected result
expected = op(*sample.input, *sample.args, **sample.kwargs)
op_out = partial(op, *sample.input, *sample.args, **sample.kwargs)
# Short-circuits if output is not a single tensor or an
# iterable of tensors
# Returns True if iterable is an iterable of tensors (includes empty iterables)
# and False o.w.
def _is_iterable_of_tensors(iterable):
try:
for t in iter(iterable):
if not isinstance(t, torch.Tensor):
return False
except TypeError as te:
return False
return True
if not isinstance(expected, torch.Tensor) and not _is_iterable_of_tensors(expected):
self.skipTest("Skipped! Only supports single tensor or iterable of tensor outputs.")
# A wrapper around map that works with single tensors and always
# instantiates the map. Used below to apply transforms to
# single tensor and iterable tensor outputs.
def _apply_out_transform(fn, out):
if isinstance(out, torch.Tensor):
return fn(out)
# assumes (see above) that out is an iterable of tensors
return tuple(map(fn, out))
# Case 0: out= with the correct shape, dtype, and device
# but NaN values for floating point and complex tensors, and
# maximum values for integer tensors.
# Expected behavior: out= values have no effect on the computation.
def _case_zero_transform(t):
try:
info = torch.iinfo(t.dtype)
return torch.full_like(t, info.max)
except TypeError as te:
# for non-integer types fills with NaN
return torch.full_like(t, float('nan'))
out = _apply_out_transform(_case_zero_transform, expected)
op_out(out=out)
self.assertEqual(expected, out)
# Case 1: out= with the correct shape, dtype, and device,
# but noncontiguous.
# Expected behavior: strides are respected.
def _case_one_transform(t):
return make_tensor(t.shape,
dtype=t.dtype,
device=t.device,
discontiguous=True)
# Extracts strides from a tensor or iterable of tensors into a tuple
def _extract_strides(out):
if isinstance(out, torch.Tensor):
return (out.stride(),)
# assumes (see above) that out is an iterable of tensors
return tuple(map(lambda t: t.stride(), out))
out = _apply_out_transform(_case_one_transform, expected)
original_strides = _extract_strides(out)
op_out(out=out)
final_strides = _extract_strides(out)
self.assertEqual(expected, out)
self.assertEqual(original_strides, final_strides)
# Case 2: out= with the correct dtype and device, but the wrong shape
# Expected behavior: resize with a warning.
def _case_two_transform(t):
wrong_shape = list(t.shape)
if len(wrong_shape) == 0:
# Handles scalar tensor case (empty list)
wrong_shape = [2]
else:
wrong_shape[-1] = wrong_shape[-1] + 1
return make_tensor(wrong_shape, dtype=t.dtype, device=t.device)
out = _apply_out_transform(_case_two_transform, expected)
with self.assertWarnsRegex(UserWarning, "An output with one or more elements"):
op_out(out=out)
self.assertEqual(expected, out)
# Case 3: out= with the correct dtype and device, but an empty
# tensor.
# Expected behavior: resize without warning.
def _case_three_transform(t):
return make_tensor((0,),
dtype=t.dtype,
device=t.device)
out = _apply_out_transform(_case_three_transform, expected)
with warnings.catch_warnings(record=True) as caught:
warnings.simplefilter("always")
op_out(out=out)
# Verifies no warning is a resize warning
for w in caught:
if "An output with one or more elements" in str(w.message):
self.fail("Resizing an out= argument with no elements threw a resize warning!")
self.assertEqual(expected, out)
# Case 4: out= with correct shape and dtype, but wrong device.
wrong_device = None
if torch.device(device).type != 'cpu':
wrong_device = 'cpu'
elif torch.cuda.is_available():
wrong_device = 'cuda'
if wrong_device is not None:
def _case_four_transform(t):
return make_tensor(t.shape, dtype=t.dtype, device=wrong_device)
out = _apply_out_transform(_case_four_transform, expected)
with self.assertRaises(RuntimeError):
op_out(out=out)
# Case 5: out= with correct shape and device, but a dtype
# that output cannot be "safely" cast to (long).
# Expected behavior: error.
# NOTE: this case is filtered by dtype since some ops produce
# bool tensors, for example, which can be safely cast to any
# dtype. It is applied when single tensors are floating point or complex
# dtypes, or if an op returns multiple tensors when at least one such
# tensor is a floating point or complex dtype.
_dtypes = floating_and_complex_types_and(torch.float16, torch.bfloat16)
if (isinstance(expected, torch.Tensor) and expected.dtype in _dtypes or
(not isinstance(expected, torch.Tensor) and
reduce(lambda cur, t: cur or t.dtype in _dtypes, expected, False))):
def _case_five_transform(t):
return make_tensor(t.shape, dtype=torch.long, device=t.device)
out = out = _apply_out_transform(_case_five_transform, expected)
with self.assertRaises(RuntimeError):
op_out(out=out)
instantiate_device_type_tests(TestOpInfo, globals())
instantiate_device_type_tests(TestGradients, globals())
instantiate_device_type_tests(TestCommon, globals())
if __name__ == '__main__':
run_tests()