blob: ac7ca9b55829723ce1bc8807f5342bee43084e43 [file] [log] [blame]
import os
import tempfile
from string import Template
import copy
import unittest
import warnings
import inspect
import torch
from torch._six import PY2
import common_utils as common
import common_nn
from common_cuda import TEST_CUDA
import torch.utils.cpp_extension
from cpp_api_parity import sample_module, torch_nn_modules, TorchNNTestParams, CppArg, parse_parity_tracker_table
parity_table_path = os.path.join(os.path.dirname(__file__), 'cpp_api_parity/parity-tracker.md')
parity_table = parse_parity_tracker_table(parity_table_path)
TORCH_NN_MODULE_COMMON_TEST_HARNESS = """\n
#include <torch/script.h>
const char * const parity_test_error_msg_prefix = "Parity test failed: ";
#define GENERATE_PARITY_TEST_ERROR_MSG(name, cpp_value, python_value) \
parity_test_error_msg_prefix, \
name, " in C++ has value: ", cpp_value, ", which does not match the corresponding value in Python: ", python_value \
bool check_tensor_equality(const torch::Tensor& tensor1, const torch::Tensor& tensor2) {
return tensor1.sizes().vec() == tensor2.sizes().vec() && \
tensor1.device() == tensor2.device() && \
tensor1.dtype() == tensor2.dtype() && \
tensor1.allclose(tensor2);
}
bool check_ivalue_equality(const c10::IValue& ivalue1, const c10::IValue& ivalue2) {
if (ivalue1.tagKind() != ivalue2.tagKind()) {
AT_ERROR("Value type mismatch: ", "ivalue1: ", ivalue1.tagKind(), ", ivalue2: ", ivalue2.tagKind());
}
if (ivalue1.isInt()) {
return ivalue1.toInt() == ivalue2.toInt();
} else if (ivalue1.isDouble()) {
return ivalue1.toDouble() == ivalue2.toDouble();
} else if (ivalue1.isBool()) {
return ivalue1.toBool() == ivalue2.toBool();
} else if (ivalue1.isString()) {
return ivalue1.toString() == ivalue2.toString();
} else if (ivalue1.isTensor()) {
return check_tensor_equality(ivalue1.toTensor(), ivalue2.toTensor());
} else {
AT_ERROR("Unsupported value type: ", ivalue1.tagKind());
}
}
"""
CHECK_MODULE_PARAM_EQUALITY = Template("""\
TORCH_CHECK(
check_tensor_equality(${script_module_prefix}.get_parameter("${param_name}"), ${cpp_module_prefix}->${param_name}),
GENERATE_PARITY_TEST_ERROR_MSG(
"`${cpp_module_prefix}->${param_name}`",
${cpp_module_prefix}->${param_name},
${script_module_prefix}.get_parameter("${param_name}")));
TORCH_CHECK(
${script_module_prefix}.get_parameter("${param_name}").requires_grad() == ${cpp_module_prefix}->${param_name}.requires_grad(),
GENERATE_PARITY_TEST_ERROR_MSG(
"`${cpp_module_prefix}->${param_name}.requires_grad()`",
${cpp_module_prefix}->${param_name}.requires_grad(),
${script_module_prefix}.get_parameter("${param_name}").requires_grad()));
""")
CHECK_MODULE_BUFFER_EQUALITY = Template("""\
TORCH_CHECK(
check_tensor_equality(${script_module_prefix}.get_buffer("${buffer_name}"), ${cpp_module_prefix}->${buffer_name}),
GENERATE_PARITY_TEST_ERROR_MSG(
"`${cpp_module_prefix}->${buffer_name}`",
${cpp_module_prefix}->${buffer_name},
${script_module_prefix}.get_buffer("${buffer_name}")));
""")
CHECK_MODULE_ATTR_EQUALITY = Template("""\
TORCH_CHECK(
check_ivalue_equality(
${script_module_prefix}.get_attribute("${attr_name}"), c10::IValue(${cpp_module_prefix}->${attr_name})),
GENERATE_PARITY_TEST_ERROR_MSG(
"`${cpp_module_prefix}->${attr_name}`",
${cpp_module_prefix}->${attr_name},
${script_module_prefix}.get_attribute("${attr_name}")));
""")
TORCH_NN_MODULE_TEST_CTOR_ARGS = Template("""\n
void ${module_name}_test_ctor_args() {
${module_qualified_name} m_init_by_cpp(${module_option});
}
""")
TORCH_NN_MODULE_TEST_INIT = Template("""\n
void ${module_variant_name}_test_init(
const std::string& saved_module_path,
const std::string& device) {
torch::jit::script::Module m_init_by_python = torch::jit::load(saved_module_path);
torch::manual_seed(2);
${module_qualified_name} m_init_by_cpp${cpp_constructor_args};
m_init_by_cpp->to(device);
${extra_stmts}
}
""")
TORCH_NN_MODULE_TEST_FORWARD = Template("""\n
void ${module_variant_name}_test_forward(
const std::string& saved_module_path,
const std::string& device,
torch::Tensor python_output,
${input_arg_declarations}) {
torch::manual_seed(2);
${module_qualified_name} module${cpp_constructor_args};
torch::load(module, saved_module_path);
module->to(device);
auto cpp_output = module(${input_args});
TORCH_CHECK(
check_tensor_equality(cpp_output, python_output),
GENERATE_PARITY_TEST_ERROR_MSG(
"forward output",
cpp_output,
python_output));
${extra_stmts}
}
""")
TORCH_NN_MODULE_TEST_BACKWARD = Template("""\n
void ${module_variant_name}_test_backward(
const std::string& saved_module_path,
const std::string& saved_grad_module_path,
const std::string& device,
${input_arg_declarations}) {
${module_qualified_name} python_grad_module${cpp_constructor_args};
torch::load(python_grad_module, saved_grad_module_path);
torch::manual_seed(2);
${module_qualified_name} module${cpp_constructor_args};
torch::load(module, saved_module_path);
module->to(device);
auto cpp_output = module(${input_args});
cpp_output.sum().backward();
for (size_t i = 0; i < module->parameters().size(); i++) {
auto named_param = module->named_parameters()[i];
auto grad = python_grad_module->parameters()[i];
TORCH_CHECK(
check_tensor_equality(named_param->grad(), grad),
GENERATE_PARITY_TEST_ERROR_MSG(
"gradient of `" + named_param.key() + "`",
named_param->grad(),
grad));
}
${extra_stmts}
}
""")
TORCH_NN_MODULE_IGNORED_ATTRS = {
'_backend', '_parameters', '_buffers', '_backward_hooks', '_forward_hooks', '_forward_pre_hooks',
'_state_dict_hooks', '_load_state_dict_pre_hooks', '_modules', 'training', 'has_parity',
}
class TestCppApiParity(common.TestCase):
def _python_arg_to_cpp_arg(self, python_arg):
if type(python_arg) == int:
return CppArg(type='int64_t', value=str(python_arg))
elif type(python_arg) == float:
return CppArg(type='double', value=str(python_arg))
elif type(python_arg) == bool:
return CppArg(type='bool', value=str(python_arg).lower())
elif type(python_arg) == str:
return CppArg(type='std::string', value='"{}"'.format(python_arg))
elif type(python_arg) == torch.Tensor:
return CppArg(
type='torch::Tensor',
value='torch::empty({})'.format(str(list(python_arg.shape)).replace('[', '{').replace(']', '}')))
else:
raise RuntimeError(
"{} is not a supported arg type for C++ module methods".format(type(python_default_value)))
def _compile_cpp_code_inline(self, name, cpp_sources, functions):
# Just-in-time compile the C++ test code
cpp_module = torch.utils.cpp_extension.load_inline(
name=name,
cpp_sources=cpp_sources,
functions=functions,
verbose=False,
)
return cpp_module
# This tests that Python and C++ torch.nn modules have matching constructor arg names and types.
def _test_torch_nn_module_ctor_args(self, module_name):
python_module_class = getattr(torch.nn, module_name)
module_metadata = torch_nn_modules.module_metadata_map[module_name]
cpp_default_constructor_args_str = module_metadata.cpp_default_constructor_args
if PY2:
init_arg_spec = inspect.getargspec(python_module_class.__init__)
else:
init_arg_spec = inspect.getfullargspec(python_module_class.__init__)
init_kwargs_defaults = init_arg_spec.defaults
python_default_constructor_arg_names = [x for x in init_arg_spec.args[1:-len(init_kwargs_defaults)] if x != 'has_parity']
cpp_default_constructor_arg_values = cpp_default_constructor_args_str.strip('()').split(',')
self.assertEqual(
len(cpp_default_constructor_arg_values),
len(python_default_constructor_arg_names),
"The constructor of `torch::nn::{}` in C++ ".format(module_name) +
"must take the exact same number of non-keyword arguments " +
"as the constructor of `torch.nn.{}` in Python. ".format(module_name) +
"However, currently the C++ constructor expects {} non-keyword argument(s) ".format(
len(cpp_default_constructor_arg_values)) +
"while the Python constructor expects {} non-keyword argument(s): {}".format(
len(python_default_constructor_arg_names),
python_default_constructor_arg_names))
cpp_module_option = 'torch::nn::{}Options{}'.format(module_name, cpp_default_constructor_args_str)
init_kwargs = init_arg_spec.args[-len(init_kwargs_defaults):]
for arg_name, python_default_value in zip(init_kwargs, init_kwargs_defaults):
cpp_module_option += '.{}({})'.format(arg_name, self._python_arg_to_cpp_arg(python_default_value).value)
cpp_sources = TORCH_NN_MODULE_COMMON_TEST_HARNESS + module_metadata.cpp_sources
cpp_sources += TORCH_NN_MODULE_TEST_CTOR_ARGS.substitute(
module_name=module_name,
module_qualified_name='torch::nn::{}'.format(module_name),
module_option=cpp_module_option)
cpp_test_name = module_name + '_test_ctor_args'
cpp_module = self._compile_cpp_code_inline(
name=cpp_test_name, cpp_sources=cpp_sources, functions=cpp_test_name)
getattr(cpp_module, cpp_test_name)()
def _test_torch_nn_module_variant(self, test_params):
def generate_test_cpp_sources(test_params, template, extra_stmts):
example_inputs = test_params.example_inputs
input_arg_types = [self._python_arg_to_cpp_arg(arg).type for arg in example_inputs]
input_args = ['arg{}'.format(str(i)) for i in range(len(input_arg_types))]
input_arg_declarations = ['{} {}'.format(arg_type, arg_name) for arg_type, arg_name in zip(input_arg_types, input_args)]
test_cpp_sources = template.substitute(
module_variant_name=test_params.module_variant_name,
module_qualified_name='torch::nn::{}'.format(test_params.module_name),
cpp_constructor_args=test_params.cpp_constructor_args,
input_arg_declarations=',\n'.join(input_arg_declarations),
input_args=',\n'.join(input_args),
extra_stmts=extra_stmts)
return test_cpp_sources
def setup_init_test(test_params):
# We are generating the attribute equality checks manually here,
# because it is not possible to have a `.attributes()` API that returns
# non-parameter / non-buffer attributes in a C++ torch::nn module.
def generate_attr_equality_checks(module,
script_module_prefix='m_init_by_python',
cpp_module_prefix='m_init_by_cpp'):
stmts = []
for name, sub_module in module.named_children():
sub_script_module_prefix = '{}.get_module("{}")'.format(script_module_prefix, name)
sub_cpp_module_prefix = '{}->{}'.format(cpp_module_prefix, name)
stmts = generate_attr_equality_checks(sub_module, sub_script_module_prefix, sub_cpp_module_prefix)
for name, param in module._parameters.items():
stmts.append(CHECK_MODULE_PARAM_EQUALITY.substitute(
script_module_prefix=script_module_prefix,
cpp_module_prefix=cpp_module_prefix,
param_name=name))
for name, buffer in module._buffers.items():
stmts.append(CHECK_MODULE_BUFFER_EQUALITY.substitute(
script_module_prefix=script_module_prefix,
cpp_module_prefix=cpp_module_prefix,
buffer_name=name))
for name, attr in module.__dict__.items():
if name not in TORCH_NN_MODULE_IGNORED_ATTRS:
stmts.append(CHECK_MODULE_ATTR_EQUALITY.substitute(
script_module_prefix=script_module_prefix,
cpp_module_prefix=cpp_module_prefix,
attr_name=name))
return stmts
device = test_params.device
python_module_class = test_params.python_module_class
python_constructor_args = test_params.python_constructor_args
example_inputs = test_params.example_inputs
torch.manual_seed(2)
module = python_module_class(*python_constructor_args).to(device)
extra_stmts = generate_attr_equality_checks(module)
assert len(extra_stmts) == test_params.num_attrs_recursive
extra_stmts_str = ''.join(extra_stmts)
return (([module], device),
generate_test_cpp_sources(
test_params=test_params, template=TORCH_NN_MODULE_TEST_INIT, extra_stmts=extra_stmts_str))
def setup_forward_test(test_params):
device = test_params.device
python_module_class = test_params.python_module_class
python_constructor_args = test_params.python_constructor_args
example_inputs = test_params.example_inputs
torch.manual_seed(2)
module = python_module_class(*python_constructor_args).to(device)
python_output = module(*example_inputs)
return (([module], device, python_output, example_inputs),
generate_test_cpp_sources(
test_params=test_params, template=TORCH_NN_MODULE_TEST_FORWARD, extra_stmts=''))
def setup_backward_test(test_params):
device = test_params.device
python_module_class = test_params.python_module_class
python_constructor_args = test_params.python_constructor_args
example_inputs = test_params.example_inputs
torch.manual_seed(2)
module = python_module_class(*python_constructor_args).to(device)
python_output = module(*example_inputs)
python_output.sum().backward()
# JIT tracing does not save a module's parameters' gradients into ScriptModule.
# Instead, we create another module `grad_module` with the same structure as `module`,
# and use `grad_module`'s parameters to save `module`'s corresponding parameters'
# gradients. Then, we trace both `module` and `grad_module`, serialize them and
# pass them into C++ for parity testing.
grad_module = copy.deepcopy(module)
for param, grad_param in zip(module.parameters(), grad_module.parameters()):
if param.grad is not None:
grad_param.data = param.grad
return (([module, grad_module], device, example_inputs),
generate_test_cpp_sources(
test_params=test_params, template=TORCH_NN_MODULE_TEST_BACKWARD, extra_stmts=''))
def trace_module(module, example_inputs):
# JIT tracing does not automatically save a module's non-parameter / non-buffer attributes
# into a ScriptModule's slots, which means we can't access them via `get_attributes()` in C++.
# Here, we manually register these attributes into the ScriptModule so that we can access them
# via `get_attributes()` in C++.
def register_attrs(module, script_module):
for sub_module, sub_script_module in zip(module.children(), script_module.children()):
register_attrs(sub_module, sub_script_module)
for key, value in module.__dict__.items():
if key not in TORCH_NN_MODULE_IGNORED_ATTRS:
script_module._c._register_attribute(
key, torch.jit.annotations.ann_to_type(type(value)), value)
# We use JIT tracing to serialize Python module state, so that we can load it into C++
traced_script_module = torch.jit.trace(module, example_inputs)
register_attrs(module, traced_script_module)
return traced_script_module
def serialize_module_into_file(script_module):
module_file = tempfile.NamedTemporaryFile(delete=False)
script_module.save(module_file.name)
module_file.close()
return module_file.name
def test_methods(test_params):
device = test_params.device
python_module_class = test_params.python_module_class
python_constructor_args = test_params.python_constructor_args
module_variant_name = test_params.module_variant_name
example_inputs = test_params.example_inputs
args_map = {}
cpp_sources = TORCH_NN_MODULE_COMMON_TEST_HARNESS + test_params.cpp_sources
torch_nn_test_methods = [
('init', setup_init_test),
('forward', setup_forward_test),
('backward', setup_backward_test),
]
for method_name, setup_test in torch_nn_test_methods:
args_map[method_name], test_cpp_sources = setup_test(test_params)
cpp_sources += test_cpp_sources
cpp_module = self._compile_cpp_code_inline(
name=test_params.module_variant_name,
cpp_sources=cpp_sources,
functions=['{}_test_{}'.format(
test_params.module_variant_name,
method_name) for method_name, _ in torch_nn_test_methods])
for method_name, _ in torch_nn_test_methods:
args = args_map[method_name]
modules = args[0]
script_modules = [trace_module(module, example_inputs) for module in modules]
module_file_names = [serialize_module_into_file(script_module) for script_module in script_modules]
cpp_args = module_file_names[:]
for arg in args[1:]:
if isinstance(arg, list):
cpp_args += arg
else:
cpp_args.append(arg)
try:
cpp_test_name = '{}_test_{}'.format(module_variant_name, method_name)
cpp_test_fn = getattr(cpp_module, cpp_test_name)
if not test_params.has_parity:
with self.assertRaisesRegex(RuntimeError, "Parity test failed"):
cpp_test_fn(*cpp_args)
else:
cpp_test_fn(*cpp_args)
finally:
# Ideally we would like to not have to manually delete the file, but NamedTemporaryFile
# opens the file, and it cannot be opened multiple times in Windows. To support Windows,
# we close the file after creation and try to remove it manually.
for module_file_name in module_file_names:
try:
os.remove(module_file_name)
except OSError as e:
warnings.warn("Unable to remove {}, got error: {}".format(module_file_name, str(e)))
test_methods(test_params)
def _compute_module_name(test_params_dict):
fullname = test_params_dict.get('fullname', None)
if fullname:
# NOTE: This doesn't work for some of the `wrap_functional` module tests such as "interpolate_nearest_1d",
# because in that case the module `interpolate` is not in `torch.nn` but rather in `torch.nn.functional`.
# We will fix this when we have parity tests for `torch.nn.functional` modules.
module_name = fullname.split('_')[0]
else:
module_name = test_params_dict.get('module_name')
return module_name
def _process_test_params(test_params_dict, module_metadata, device):
module_name = _compute_module_name(test_params_dict)
desc = test_params_dict.get('desc', None)
python_module_class = getattr(torch.nn, module_name)
test_params_dict['constructor'] = test_params_dict.get('constructor', python_module_class)
test = common_nn.TestBase(**test_params_dict)
module_variant_name = test.get_name()[5:] + (('_' + device) if device != 'cpu' else '')
example_inputs = test._get_input()
if type(example_inputs) == tuple:
example_inputs = list(example_inputs)
elif type(example_inputs) == torch.Tensor:
example_inputs = [example_inputs]
else:
raise RuntimeError("Unexpected input type: {}".format(type(example_inputs)))
if device != 'cuda' or TEST_CUDA:
example_inputs = [x.to(device) for x in example_inputs]
return TorchNNTestParams(
module_name=module_name,
module_variant_name=module_variant_name,
python_constructor_args=test.constructor_args,
cpp_constructor_args=test_params_dict.get('cpp_constructor_args'),
example_inputs=example_inputs,
has_parity=test_params_dict.get('has_parity', True),
python_module_class=python_module_class,
cpp_sources=module_metadata.cpp_sources,
num_attrs_recursive=module_metadata.num_attrs_recursive,
device=device,
)
def has_test(test_name):
return hasattr(TestCppApiParity, test_name)
def add_test(test_name, test_fn):
if has_test(test_name):
raise RuntimeError("Found two tests with the same name: " + test_name)
setattr(TestCppApiParity, test_name, test_fn)
devices = ['cpu', 'cuda']
torch_nn_test_params_map = {}
all_module_tests = sample_module.module_tests + \
common_nn.module_tests + \
common_nn.new_module_tests + \
common_nn.criterion_tests + \
common_nn.new_criterion_tests
for test_params_dict in all_module_tests:
# We skip all `torch.nn.functional` tests for now
if 'FunctionalModule' in str(test_params_dict.get('constructor', '')):
continue
module_name = _compute_module_name(test_params_dict)
assert hasattr(torch.nn, module_name), \
"`torch.nn` doesn't have module `{}`. ".format(module_name) + \
"If you are adding a new test, please set `fullname` using format `ModuleName_desc`, " + \
"or set `module_name` using format `ModuleName`."
module_full_name = 'torch.nn.' + module_name
if module_full_name not in parity_table['torch.nn']:
raise RuntimeError(
'Module `{}` is not found in Python / C++ API parity table. Please update parity table at {}.'.format(
module_full_name, parity_table_path))
has_impl_parity, _ = parity_table['torch.nn'][module_full_name]
def add_ctor_args_test_for_module(module_name, has_impl_parity):
ctor_args_test_name = 'test_torch_nn_{}_ctor_args'.format(module_name)
def ctor_args_test(self):
self._test_torch_nn_module_ctor_args(
module_name=self._testMethodName.replace('test_torch_nn_', '').replace('_ctor_args', ''))
if not has_impl_parity:
ctor_args_test = unittest.expectedFailure(ctor_args_test)
# We only run one constructor args test per module
if not has_test(ctor_args_test_name):
add_test(ctor_args_test_name, ctor_args_test)
def add_variant_test_for_module(module_name, test_params_dict, has_impl_parity):
module_metadata = torch_nn_modules.module_metadata_map[module_name]
for device in devices:
test_params = _process_test_params(
test_params_dict=test_params_dict,
module_metadata=module_metadata,
device=device)
test_name = 'test_torch_nn_{}'.format(test_params.module_variant_name)
torch_nn_test_params_map[test_name] = test_params
def test_fn(self):
self._test_torch_nn_module_variant(test_params=torch_nn_test_params_map[self._testMethodName])
if device == 'cuda':
test_fn = unittest.skipIf(not TEST_CUDA, "CUDA unavailable")(test_fn)
if not has_impl_parity:
test_fn = unittest.expectedFailure(test_fn)
add_test(test_name, test_fn)
add_ctor_args_test_for_module(module_name, has_impl_parity)
add_variant_test_for_module(module_name, test_params_dict, has_impl_parity)
# Assert that there exists auto-generated tests for SampleModule.
assert len([name for name in TestCppApiParity.__dict__ if 'SampleModule' in name]) == \
len(sample_module.module_tests) * len(devices) + 1
if __name__ == "__main__":
common.run_tests()