blob: a54baba7fab639d63c9ec2661d3a3a5858a237ed [file] [log] [blame]
# Copyright (c) Facebook, Inc. and its affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.
import itertools
import torch
from functorch import vmap
import torch.utils._pytree as pytree
from functorch_lagging_op_db import functorch_lagging_op_db
from functorch_additional_op_db import additional_op_db
from torch.testing._internal.common_methods_invocations import SkipInfo
"""
Usage:
class MyTestCase(TestCase):
@parameterized('param', {'abs': torch.abs, 'cos': torch.cos})
def test_single_param(self, param):
pass
@parameterized('param1', {'sin': torch.sin, 'tan': torch.tan})
@parameterized('param2', {'abs': torch.abs, 'cos': torch.cos})
def test_multiple_param(self, param1, param2):
pass
# The following creates:
# - MyTestCase.test_single_param_abs
# - MyTestCase.test_single_param_cos
# - MyTestCase.test_multiple_param_abs_sin
# - MyTestCase.test_multiple_param_cos_sin
# - MyTestCase.test_multiple_param_abs_tan
# - MyTestCase.test_multiple_param_cos_tan
instantiate_parameterized_methods(MyTestCase)
# This is also composable with PyTorch testing's instantiate_device_type_tests
# Make sure the param is after the device arg
class MyDeviceSpecificTest(TestCase):
@parameterized('param', {'abs': torch.abs, 'cos': torch.cos})
def test_single_param(self, device, param):
pass
# The following creates:
# - MyDeviceSpecificTestCPU.test_single_param_abs_cpu
# - MyDeviceSpecificTestCPU.test_single_param_cos_cpu
# - MyDeviceSpecificTestCUDA.test_single_param_abs_cuda
# - MyDeviceSpecificTestCUDA.test_single_param_cos_cpu
instantiate_parameterized_methods(MyDeviceSpecificTest)
instantiate_device_type_tests(MyDeviceSpecificTest, globals())
# !!!!! warning !!!!!
# 1. The method being parameterized over MUST NOT HAVE A DOCSTRING. We'll
# error out nicely if this happens.
# 2. All other decorators MUST USE functools.wraps (they must propagate the docstring)
# `@parameterized` works by storing some metadata in place of the docstring.
# This takes advantage of how other decorators work (other decorators usually
# propagate the docstring via functools.wrap).
# 3. We might not compose with PyTorch testing's @dtypes and @precision
# decorators. But that is easily fixable. TODO.
# I think this composes with PyTorch testing's instantiate_device_type_tests.
"""
PARAM_META = '_torch_parameterized_meta'
class ParamMeta():
def __init__(self):
self.stack = []
def push(self, elt):
self.stack.append(elt)
def pop(self, elt):
return self.stack.pop()
def has_param_meta(method):
param_meta = getattr(method, '__doc__', None)
return param_meta is not None and isinstance(param_meta, ParamMeta)
def get_param_meta(method):
param_meta = getattr(method, '__doc__', None)
if param_meta is None:
method.__doc__ = ParamMeta()
if not isinstance(method.__doc__, ParamMeta):
raise RuntimeError('Tried to use @parameterized on a method that has '
'a docstring. This is not supported. Please remove '
'the docstring.')
return method.__doc__
def parameterized(arg_name, case_dict):
def decorator(fn):
param_meta = get_param_meta(fn)
param_meta.push((arg_name, case_dict))
return fn
return decorator
def parameterized_with_device(arg_name, case_dict):
def decorator(fn):
param_meta = get_param_meta(fn)
param_meta.push((arg_name, case_dict))
fn._has_device = True
return fn
return decorator
def _set_parameterized_method(test_base, fn, instantiated_cases, extension_name):
new_name = f'{fn.__name__}_{extension_name}'
def wrapped_no_device(self, *args, **kwargs):
for arg_name, case in instantiated_cases:
kwargs[arg_name] = case
return fn(self, *args, **kwargs)
def wrapped_with_device(self, device, *args, **kwargs):
for arg_name, case in instantiated_cases:
kwargs[arg_name] = case
return fn(self, device, *args, **kwargs)
if getattr(fn, '_has_device', False):
wrapped = wrapped_with_device
else:
wrapped = wrapped_no_device
wrapped.__name__ = new_name
setattr(test_base, new_name, wrapped)
def to_tuples(dct):
return [(k, v) for k, v in dct.items()]
def instantiate_parameterized_methods(test_base):
allattrs = tuple(dir(test_base))
for attr_name in allattrs:
attr = getattr(test_base, attr_name)
if not has_param_meta(attr):
continue
param_meta = get_param_meta(attr)
arg_names, case_dicts = zip(*param_meta.stack)
case_dicts = [to_tuples(cd) for cd in case_dicts]
for list_of_name_and_case in itertools.product(*case_dicts):
case_names, cases = zip(*list_of_name_and_case)
extension_name = '_'.join(case_names)
instantiated_cases = list(zip(arg_names, cases))
_set_parameterized_method(test_base, attr, instantiated_cases, extension_name)
# Remove the base fn from the testcase
delattr(test_base, attr_name)
def loop(op, in_dims, out_dim, batch_size, *batched_args, **kwarg_values):
outs = []
for idx in range(batch_size):
idx_args = []
idx_kwargs = {}
for a, in_dim in zip(batched_args, in_dims):
idx_args.append(a.select(in_dim, idx) if in_dim is not None else a)
out = op(*idx_args, **kwarg_values)
outs.append(out)
loop_out = []
if isinstance(outs[0], torch.Tensor):
loop_out = torch.stack(outs)
else:
for idx in range(len(outs[0])):
loop_out.append(torch.stack([i[idx] for i in outs], out_dim))
return loop_out
def get_exhaustive_batched_inputs(arg_values, kwarg_values, batch_size=3):
def add_batch_dim(arg, bdim, batch_size=3):
if isinstance(arg, torch.Tensor):
shape = [1] * len(arg.shape)
shape.insert(bdim, batch_size)
return (arg.repeat(shape), bdim)
else:
return (arg, None)
batch_choices = []
for a in arg_values:
if isinstance(a, torch.Tensor):
batched_val = add_batch_dim(a, 0, batch_size)
batch_choices.append((batched_val, (a, None)))
else:
batch_choices.append(((a, None),))
for batched_values in itertools.product(*batch_choices):
batched_args, in_dims = zip(*batched_values)
if all([i is None for i in in_dims]):
continue
yield batched_args, in_dims, kwarg_values
def get_fallback_and_vmap_exhaustive(op, arg_values, kwarg_values):
out_dim = 0
batch_size = 3
generator = get_exhaustive_batched_inputs(arg_values, kwarg_values, batch_size)
for batched_args, in_dims, kwarg_values in generator:
loop_out = loop(op, in_dims, out_dim, batch_size, *batched_args, **kwarg_values)
# Used for debugging the resulting operations
# from functorch import make_fx
# def f(a):
# return op(a)
# t = make_fx(vmap(f, in_dims=in_dims, out_dims=out_dim))(*batched_args, **kwarg_values)
# import pdb; pdb.set_trace()
batched_out = vmap(op, in_dims=in_dims, out_dims=out_dim)(*batched_args, **kwarg_values)
yield (loop_out, batched_out)
# Tests case where we dispatch to a batching rule with no bdims
# Should now be covered by https://github.com/facebookresearch/functorch/pull/63
def f(x, *args, **kwargs):
out = op(*args, **kwargs)
if isinstance(out, torch.Tensor):
return out + x.to(out.device)
out = list(out)
for idx in range(len(out)):
out[idx] = out[idx] + x.to(out[idx].device)
return out
vmap1_dims = tuple([0] + [None] * len(in_dims))
vmap2_dims = tuple([None] + list(in_dims))
loop_out = pytree.tree_map(lambda v: torch.ones(3, *v.shape, device=v.device) + v, loop_out)
batched_out = vmap(vmap(f, in_dims=vmap1_dims), in_dims=vmap2_dims)(torch.ones(3), *batched_args, **kwarg_values)
yield (loop_out, batched_out)
def opinfo_in_dict(opinfo, d):
return (opinfo.name in d) or (f'{opinfo.name}.{opinfo.variant_test_name}' in d)
def xfail(op_name, variant_name=None, *, device_type=None, dtypes=None, expected_failure=True):
return (op_name, variant_name, device_type, dtypes, expected_failure)
def skipOps(test_case_name, base_test_name, to_skip):
all_opinfos = functorch_lagging_op_db + additional_op_db
for xfail in to_skip:
op_name, variant_name, device_type, dtypes, expected_failure = xfail
if variant_name is None:
# match all variants
matching_opinfos = [o for o in all_opinfos if o.name == op_name]
assert len(matching_opinfos) >= 1, f"Couldn't find OpInfo for {xfail}"
else:
matching_opinfos = [o for o in all_opinfos
if o.name == op_name and o.variant_test_name == variant_name]
assert len(matching_opinfos) >= 1, f"Couldn't find OpInfo for {xfail}"
for opinfo in matching_opinfos:
decorators = list(opinfo.decorators)
decorators.append(SkipInfo(test_case_name, base_test_name,
device_type=device_type, dtypes=dtypes,
expected_failure=True))
opinfo.decorators = tuple(decorators)
# This decorator doesn't modify fn in any way
def wrapped(fn):
return fn
return wrapped