| # Torch |
| from torch._six import PY2 |
| from torch.autograd import Variable |
| from torch.autograd.function import _nested_map |
| from torch.jit.annotations import BroadcastingList2, BroadcastingList3 # noqa: F401 |
| from torch.onnx import OperatorExportTypes |
| import torch |
| import torch.cuda |
| import torch.jit |
| import torch.jit._logging |
| import torch.jit.frontend |
| import torch.jit.quantized |
| import zipfile |
| import functools |
| |
| # Testing utils |
| from common_utils import TestCase, IS_WINDOWS, \ |
| freeze_rng_state, TemporaryFileName, enable_profiling_mode, ProfilingMode |
| |
| # Standard library |
| from contextlib import contextmanager |
| from functools import reduce |
| from itertools import chain |
| from torch._six import StringIO |
| |
| import inspect |
| import io |
| import math |
| import os |
| import pickle |
| import sys |
| import tempfile |
| import textwrap |
| |
| RUN_CUDA = torch.cuda.is_available() |
| RUN_CUDA_MULTI_GPU = RUN_CUDA and torch.cuda.device_count() > 1 |
| |
| def execWrapper(code, glob, loc): |
| if PY2: |
| exec(code) in glob, loc |
| else: |
| exec(code, glob, loc) |
| |
| def do_input_map(fn, input): |
| return _nested_map(lambda t: isinstance(t, torch.Tensor), fn)(input) |
| |
| def clear_class_registry(): |
| torch._C._jit_clear_class_registry() |
| torch.jit._recursive.concrete_type_store = torch.jit._recursive.ConcreteTypeStore() |
| |
| |
| class JitTestCase(TestCase): |
| _do_cuda_memory_leak_check = True |
| _restored_warnings = False |
| |
| class capture_stdout(list): |
| """ |
| Replace sys.stdout with a temporary StringIO |
| """ |
| def __enter__(self): |
| self.sys_stdout = sys.stdout |
| self.stringio = StringIO() |
| sys.stdout = self.stringio |
| return self |
| |
| def __exit__(self, *args): |
| self.append(str(self.stringio.getvalue())) |
| del self.stringio |
| sys.stdout = self.sys_stdout |
| |
| def setHooks(self): |
| torch._C._jit_set_emit_hooks(self.emitModuleHook, self.emitFunctionHook) |
| |
| def clearHooks(self): |
| torch._C._jit_set_emit_hooks(None, None) |
| |
| def setUp(self): |
| super(JitTestCase, self).setUp() |
| # unittest overrides all warning filters and forces all of them to show up |
| # after we install our own to silence those coming from inside PyTorch. |
| # This will ensure that our filter still takes precedence. |
| if not JitTestCase._restored_warnings: |
| torch.jit.TracerWarning.ignore_lib_warnings() |
| JitTestCase._restored_warnings = True |
| self.setHooks() |
| |
| def tearDown(self): |
| super(JitTestCase, self).tearDown() |
| # needs to be cleared because python might be unloaded before |
| # the callback gets destucted |
| self.clearHooks() |
| clear_class_registry() |
| |
| def _isHookExceptionOk(self, e): |
| se = str(e) |
| allowed = ("Could not export Python function", |
| "closures are not exportable") |
| for a in allowed: |
| if a in se: |
| return True |
| return False |
| |
| def _compared_saved_loaded(self, m): |
| if PY2: |
| # Disable for Python 2, which does not allow manipulation of multiple objects |
| # returned by zipfile.open(). |
| # See: https://docs.python.org/2.7/library/zipfile.html#zipfile.ZipFile.open |
| return |
| |
| def extract_files(buffer): |
| # crack open the zip format to get at the main module code |
| archive = zipfile.ZipFile(buffer) |
| # check that we have no duplicate names |
| self.assertEqual(len(set(archive.namelist())), len(archive.namelist())) |
| files = list(filter(lambda x: x.startswith('archive/code/'), archive.namelist())) |
| # unwrap all the code files into strings |
| code_files = filter(lambda x: x.endswith('.py'), files) |
| code_files = map(lambda f: archive.open(f), code_files) |
| code_files = map(lambda file: "".join([line.decode() for line in file]), code_files) |
| |
| # unpickled all the debug files |
| debug_files = filter(lambda f: f.endswith('.debug_pkl'), files) |
| debug_files = map(lambda f: archive.open(f), debug_files) |
| debug_files = map(lambda f: pickle.load(f), debug_files) |
| return code_files, debug_files |
| |
| # disable the hook while we parse code, otherwise we will re-enter the hook |
| with torch.jit._disable_emit_hooks(): |
| try: |
| # short-circuit if this is an empty function or module |
| if len(m.code) == 0: |
| return |
| if isinstance(m, torch._C.ScriptModule): |
| if len(m._method_names()) == 0: |
| return |
| |
| # save the module to a buffer |
| buffer = io.BytesIO() |
| torch.jit.save(m, buffer) |
| # copy the data in the buffer so we can restore it later. This |
| # is because py2 and py3 have different semantics with zipfile |
| # and it's easier to just work with a fresh copy each time. |
| buffer_copy = buffer.getvalue() |
| |
| code_files, debug_files = extract_files(buffer) |
| |
| except RuntimeError as e: |
| if not self._isHookExceptionOk(e): |
| raise |
| else: |
| return |
| |
| # import the model again (from a the copy we made of the original) |
| buffer2 = io.BytesIO(buffer_copy) |
| imported = torch.jit.load(buffer2) |
| |
| # save it again |
| saved_module_buffer_2 = io.BytesIO() |
| torch.jit.save(imported, saved_module_buffer_2) |
| |
| saved_module_buffer_2.seek(0) |
| code_files_2, debug_files_2 = extract_files(saved_module_buffer_2) |
| |
| for a, b in zip(code_files, code_files_2): |
| self.assertMultiLineEqual(a, b) |
| |
| if isinstance(m, torch._C.ScriptModule): |
| self.assertTrue(torch._C._ivalue_tags_match(m, imported._c)) |
| |
| |
| def emitFunctionHook(self, func): |
| # func has invalid names for export, skip the jitter check |
| inline_everything = torch._C._jit_get_inline_everything_mode() |
| if func.name == "<lambda>" or "aten::" in func.name or not inline_everything: |
| return |
| self._compared_saved_loaded(func) |
| |
| def emitModuleHook(self, module): |
| self._compared_saved_loaded(module) |
| |
| |
| def getExportImportCopy(self, m, also_test_file=True, map_location=None): |
| buffer = io.BytesIO() |
| torch.jit.save(m, buffer) |
| buffer.seek(0) |
| imported = torch.jit.load(buffer, map_location=map_location) |
| |
| if not also_test_file: |
| return imported |
| |
| with TemporaryFileName() as fname: |
| torch.jit.save(imported, fname) |
| return torch.jit.load(fname, map_location=map_location) |
| |
| def getExportImportCopyWithPacking(self, m, also_test_file=True, map_location=None): |
| buffer = io.BytesIO() |
| m.apply(lambda s: s._pack() if s._c._has_method('_pack') else None) |
| torch.jit.save(m, buffer) |
| m.apply(lambda s: s._unpack() if s._c._has_method('_unpack') else None) |
| buffer.seek(0) |
| imported = torch.jit.load(buffer, map_location=map_location) |
| imported.apply(lambda s: s._unpack() if s._c._has_method('_unpack') else None) |
| |
| if not also_test_file: |
| return imported |
| |
| # Ideally we would like to not have to manually delete the file, but NamedTemporaryFile |
| # opens the file, and it cannot be opened multiple times in Windows. To support Windows, |
| # close the file after creation and try to remove it manually |
| f = tempfile.NamedTemporaryFile(delete=False) |
| try: |
| f.close() |
| imported.save(f.name) |
| result = torch.jit.load(f.name, map_location=map_location) |
| finally: |
| os.unlink(f.name) |
| |
| result.apply(lambda s: s._unpack() if s._c._has_method('_unpack') else None) |
| return result |
| |
| def assertGraphContains(self, graph, kind): |
| self.assertTrue(any(n.kind() == kind for n in graph.nodes())) |
| |
| def assertGraphContainsExactly(self, graph, kind, num_kind_nodes, consider_subgraphs=False): |
| def perform_assert(graph, kind, actual, expected, consider_subgraphs): |
| if actual == expected: |
| return |
| subgraph = 'including' if consider_subgraphs else 'excluding' |
| raise AssertionError( |
| '{}\nError: graph contains {} {} nodes ({} subgraphs) but expected {}'.format( |
| graph, actual, kind, subgraph, expected)) |
| |
| if consider_subgraphs: |
| strgraph = str(graph) |
| count = strgraph.count(kind) - strgraph.count('with {}'.format(kind)) |
| perform_assert(graph, kind, count, num_kind_nodes, |
| consider_subgraphs) |
| return |
| |
| nodes = [node for node in graph.nodes() |
| if node.kind() == kind] |
| perform_assert(graph, kind, len(nodes), num_kind_nodes, |
| consider_subgraphs) |
| |
| def assertExpectedONNXGraph(self, g, *args, **kwargs): |
| g = torch.onnx._optimize_trace(g, operator_export_type=OperatorExportTypes.ONNX) |
| self.assertExpectedGraph(g, *args, **kwargs) |
| |
| def assertExpectedGraph(self, trace, *args, **kwargs): |
| if isinstance(trace, torch._C.Graph): |
| graph = trace |
| else: |
| graph = trace.graph() |
| |
| torch._C._jit_pass_lint(graph) |
| torch._C._jit_pass_dce(graph) |
| torch._C._jit_pass_lint(graph) |
| graph = torch._C._jit_pass_canonicalize(graph) |
| torch._C._jit_pass_lint(graph) |
| self.assertExpected(str(graph), *args, **kwargs) |
| |
| def assertAutodiffNode(self, graph, should_autodiff_node, nonfusible_nodes, fusible_nodes): |
| diff_nodes = graph.findAllNodes('prim::DifferentiableGraph') |
| diff_subgraphs = [node.g('Subgraph') for node in diff_nodes] |
| |
| # For any non-fusible node, it must show up in one of the DifferentiableGraph. |
| found_all_nonfusible_nodes = (len(diff_subgraphs) == 0 and len(nonfusible_nodes) == 0)\ |
| or all([any(g.findNode(n) is not None for g in diff_subgraphs) for n in nonfusible_nodes]) |
| |
| # For any fusible node, it must show up in one of the FusionGroup in the DifferentiableGraph. |
| fusion_nodes = list(chain.from_iterable([g.findAllNodes('prim::FusionGroup') for g in diff_subgraphs])) |
| fusion_subgraphs = [node.g('Subgraph') for node in fusion_nodes] |
| found_all_fusible_nodes = (len(fusion_nodes) == 0 and len(fusible_nodes) == 0)\ |
| or all([any(g.findNode(n) is not None for g in fusion_subgraphs) for n in fusible_nodes]) |
| |
| self.assertEqual(should_autodiff_node, found_all_nonfusible_nodes and found_all_fusible_nodes) |
| |
| def run_pass(self, name, trace): |
| if isinstance(trace, torch._C.Graph): |
| graph = trace |
| set_graph = False |
| else: |
| set_graph = True |
| graph = trace.graph() |
| |
| torch._C._jit_pass_lint(graph) |
| result = getattr(torch._C, '_jit_pass_' + name)(graph) |
| if result is not None: |
| graph = result |
| torch._C._jit_pass_lint(graph) |
| |
| if set_graph: |
| trace.set_graph(graph) |
| return graph |
| |
| def get_frame_vars(self, frames_up): |
| frame = inspect.currentframe() |
| i = 0 |
| while i < frames_up + 1: |
| frame = frame.f_back |
| i += 1 |
| defined_vars = {} |
| defined_vars.update(frame.f_locals) |
| defined_vars.update(frame.f_globals) |
| return defined_vars |
| |
| def checkScriptRaisesRegex(self, script, inputs, exception, regex, |
| outputs=None, capture_output=False, profiling=ProfilingMode.PROFILING): |
| """ |
| Checks that a given function will throw the correct exception, |
| when executed with normal python, the string frontend, and the AST frontend |
| """ |
| |
| with enable_profiling_mode(): |
| # normal python |
| with self.assertRaisesRegex(exception, regex): |
| script(*inputs) |
| # string frontend |
| with self.assertRaisesRegex(exception, regex): |
| source = textwrap.dedent(inspect.getsource(script)) |
| cu = torch.jit.CompilationUnit(source) |
| ge = getattr(cu, script.__name__) |
| # profiling run |
| with self.assertRaisesRegex(exception, regex): |
| ge(*inputs) |
| # optimized run |
| ge(*inputs) |
| # python AST frontend |
| with self.assertRaisesRegex(exception, regex): |
| ge = torch.jit.script(script) |
| # profiling run |
| with self.assertRaisesRegex(exception, regex): |
| ge(*inputs) |
| # optimized run |
| ge(*inputs) |
| |
| def checkScript(self, |
| script, |
| inputs, |
| name='func', |
| optimize=True, |
| inputs_requires_grad=False, |
| capture_output=False, |
| frames_up=1, |
| profiling=ProfilingMode.PROFILING): |
| with torch.jit.optimized_execution(optimize): |
| with enable_profiling_mode(): |
| if isinstance(script, str): |
| # Compile the string to a Script function |
| # with enable_profiling_mode(): |
| cu = torch.jit.CompilationUnit(script, _frames_up=frames_up) |
| |
| # Execute the Python function so we can run it later and get its |
| # outputs |
| |
| frame = self.get_frame_vars(frames_up) |
| the_locals = {} |
| execWrapper(script, glob=frame, loc=the_locals) |
| frame.update(the_locals) |
| |
| python_fn = frame[name] |
| scripted_fn = getattr(cu, name) |
| else: |
| |
| # Check the string frontend first |
| source = textwrap.dedent(inspect.getsource(script)) |
| self.checkScript( |
| source, |
| inputs, |
| script.__name__, |
| capture_output, |
| profiling=profiling, |
| frames_up=2) |
| |
| # Continue checking the Python frontend |
| scripted_fn = torch.jit.script(script, _frames_up=1) |
| python_fn = script |
| |
| if inputs_requires_grad: |
| recording_inputs = do_input_map(lambda t: t.detach().requires_grad_(), inputs) |
| else: |
| recording_inputs = inputs |
| |
| if capture_output: |
| with self.capture_stdout() as script_stdout: |
| script_outputs = scripted_fn(*recording_inputs) |
| with self.capture_stdout() as opt_script_stdout: |
| opt_script_outputs = scripted_fn(*recording_inputs) |
| with self.capture_stdout() as _python_stdout: |
| python_outputs = python_fn(*inputs) |
| if not IS_WINDOWS: |
| self.assertExpected(script_stdout[0], subname='stdout') |
| self.assertEqual(python_outputs, opt_script_outputs) |
| else: |
| # profiling run |
| script_outputs = scripted_fn(*recording_inputs) |
| # optimized run |
| opt_script_outputs = scripted_fn(*recording_inputs) |
| python_outputs = python_fn(*inputs) |
| self.assertEqual(python_outputs, script_outputs) |
| self.assertEqual(script_outputs, opt_script_outputs) |
| return scripted_fn |
| |
| def checkTrace(self, func, reference_tensors, input_tensors=None, |
| drop=None, allow_unused=False, verbose=False, |
| inputs_require_grads=True, check_tolerance=1e-5, export_import=True, |
| _force_outplace=False): |
| |
| # TODO: check gradients for parameters, not just inputs |
| def allSum(vs): |
| # drop allows us to remove some values from ever being used |
| # to test unused outputs |
| if drop is not None: |
| vs = vs[:-drop] |
| # we don't want all the grad for all the outputs to be the same |
| # so we multiply each by a constant |
| return sum(math.log(i + 2) * v.sum() for i, v in enumerate(vs) if v is not None) |
| if input_tensors is None: |
| input_tensors = reference_tensors |
| |
| def flatten_inputs(inputs): |
| def input_reduce(input, fn, acc): |
| if isinstance(input, torch.Tensor): |
| fn(input, acc) |
| elif isinstance(input, dict): |
| reduce(lambda acc, key: input_reduce(input[key], fn, acc), input, acc) |
| else: |
| reduce(lambda acc, val: input_reduce(val, fn, acc), input, acc) |
| return acc |
| return tuple(input_reduce(recording_inputs, lambda t, acc: acc.append(t), [])) |
| |
| nograd_inputs = reference_tensors |
| if inputs_require_grads: |
| recording_inputs = do_input_map(lambda t: t.clone().requires_grad_(), reference_tensors) |
| flattened_recording_inputs = flatten_inputs(recording_inputs) |
| else: |
| recording_inputs = reference_tensors |
| |
| # `check_trace` is set to False because check_trace is run with @no_grad |
| # Also, `checkTrace` already does all the checks |
| # against python function |
| ge = torch.jit.trace(func, input_tensors, check_tolerance=check_tolerance, |
| _force_outplace=_force_outplace, check_trace=False) |
| |
| if export_import: |
| ge = self.getExportImportCopy(ge) |
| |
| if verbose: |
| print(ge.graph) |
| |
| # test no gradients case |
| outputs = func(*nograd_inputs) |
| outputs_ge = ge(*nograd_inputs) |
| self.assertEqual(outputs, outputs_ge) |
| |
| # test gradients case |
| outputs = func(*recording_inputs) |
| if inputs_require_grads: |
| grads = torch.autograd.grad(allSum(outputs), flattened_recording_inputs, |
| allow_unused=allow_unused) |
| |
| outputs_ge = ge(*recording_inputs) |
| if inputs_require_grads: |
| grads_ge = torch.autograd.grad(allSum(outputs_ge), flattened_recording_inputs, |
| allow_unused=allow_unused) |
| self.assertEqual(outputs, outputs_ge) |
| if inputs_require_grads: |
| self.assertEqual(grads, grads_ge) |
| |
| self.assertEqual(outputs, outputs_ge) |
| if inputs_require_grads: |
| self.assertEqual(grads, grads_ge) |
| |
| # test the grad grad case |
| outputs = func(*recording_inputs) |
| l1 = allSum(outputs) |
| if inputs_require_grads: |
| grads = torch.autograd.grad(l1, flattened_recording_inputs, create_graph=True, |
| allow_unused=allow_unused) |
| if inputs_require_grads: |
| l2 = (allSum(grads) * l1) |
| grads2 = torch.autograd.grad(l2, flattened_recording_inputs, allow_unused=allow_unused) |
| |
| if inputs_require_grads: |
| recording_inputs = do_input_map(lambda t: Variable(t, requires_grad=True), reference_tensors) |
| flattened_recording_inputs = flatten_inputs(recording_inputs) |
| |
| outputs_ge = ge(*recording_inputs) |
| l1_ge = allSum(outputs_ge) |
| if inputs_require_grads: |
| grads_ge = torch.autograd.grad( |
| l1_ge, flattened_recording_inputs, create_graph=True, allow_unused=allow_unused) |
| |
| if inputs_require_grads: |
| l2_ge = (allSum(grads_ge) * l1_ge) |
| grads2_ge = torch.autograd.grad(l2_ge, flattened_recording_inputs, allow_unused=allow_unused) |
| |
| self.assertEqual(outputs, outputs_ge) |
| if inputs_require_grads: |
| self.assertEqual(grads, grads_ge) |
| for g2, g2_ge in zip(grads2, grads2_ge): |
| if g2 is None and g2_ge is None: |
| continue |
| self.assertTrue(torch.allclose(g2, g2_ge, atol=8e-4, rtol=8e-4)) |
| |
| return ge |
| |
| def createFunctionFromGraph(self, trace): |
| graph = trace if isinstance(trace, torch._C.Graph) else trace.graph() |
| return torch._C._create_function_from_graph("forward", graph) |
| |
| def assertExportImport(self, trace, inputs): |
| m = self.createFunctionFromGraph(trace) |
| self.assertExportImportModule(m, inputs) |
| |
| def assertExportImportModule(self, m, inputs): |
| m_import = self.getExportImportCopy(m) |
| a = self.runAndSaveRNG(m, inputs) |
| b = self.runAndSaveRNG(m_import, inputs) |
| self.assertEqual(a, b) |
| |
| def runAndSaveRNG(self, func, inputs, kwargs=None): |
| kwargs = kwargs if kwargs else {} |
| with freeze_rng_state(): |
| results = func(*inputs, **kwargs) |
| return results |
| |
| def checkModule(self, nn_module, args): |
| """ |
| Check that a nn.Module's results in Script mode match eager and that it |
| can be exported |
| """ |
| sm = torch.jit.script(nn_module) |
| |
| with freeze_rng_state(): |
| eager_out = nn_module(*args) |
| |
| with freeze_rng_state(): |
| script_out = sm(*args) |
| |
| self.assertEqual(eager_out, script_out) |
| self.assertExportImportModule(sm, args) |
| |
| return sm |
| |
| @contextmanager |
| def inline_everything_mode(should_inline): |
| old = torch._C._jit_get_inline_everything_mode() |
| torch._C._jit_set_inline_everything_mode(should_inline) |
| try: |
| yield |
| finally: |
| torch._C._jit_set_inline_everything_mode(old) |
| |
| |
| # note: not re-entrant, use unnested only |
| @contextmanager |
| def disable_autodiff_subgraph_inlining(enabled=True): |
| torch._C._debug_set_autodiff_subgraph_inlining(not enabled) |
| try: |
| yield |
| finally: |
| torch._C._debug_set_autodiff_subgraph_inlining(True) |
| |
| def _inline_everything(fn): |
| @functools.wraps(fn) |
| def wrapper(*args, **kwargs): |
| with inline_everything_mode(True): |
| fn(*args, **kwargs) |
| return wrapper |
| |
| # this exists for forward compatibility reasons temporarily. |
| # TODO(suo) remove |
| def _tmp_donotuse_dont_inline_everything(fn): |
| @functools.wraps(fn) |
| def wrapper(*args, **kwargs): |
| with inline_everything_mode(False): |
| fn(*args, **kwargs) |
| return wrapper |
| |
| # make it easy to quicky define/trace a function for these tests |
| def _trace(*args, **kwargs): |
| def wrapper(func): |
| return torch.jit.trace(func, args, **kwargs) |
| return wrapper |
| |
| |
| def enable_cpu_fuser(fn): |
| def wrapper(*args, **kwargs): |
| torch._C._jit_override_can_fuse_on_cpu(True) |
| try: |
| fn(*args, **kwargs) |
| finally: |
| torch._C._jit_override_can_fuse_on_cpu(False) |
| return wrapper |
| |
| |
| def enable_cpu_fuser_if(cond): |
| if cond: |
| return enable_cpu_fuser |
| else: |
| def noop_fuser(fn): |
| def wrapper(*args, **kwargs): |
| return fn(*args, **kwargs) |
| return wrapper |
| return noop_fuser |
| |
| def get_forward(c): |
| return c._get_method('forward') |
| |
| def get_forward_graph(c): |
| return c._get_method('forward').graph |
| |
| def get_module_method(m, module, method): |
| return m._c.getattr(module)._get_method(method) |