| # -*- coding: utf-8 -*- | 
 | # Owner(s): ["oncall: jit"] | 
 |  | 
 | import torch | 
 |  | 
 | # This is how we include tests located in test/jit/... | 
 | # They are included here so that they are invoked when you call `test_jit.py`, | 
 | # do not run these test files directly. | 
 | from jit.test_tracer import TestTracer, TestMixTracingScripting  # noqa: F401 | 
 | from jit.test_recursive_script import TestRecursiveScript  # noqa: F401 | 
 | from jit.test_type_sharing import TestTypeSharing  # noqa: F401 | 
 | from jit.test_logging import TestLogging  # noqa: F401 | 
 | from jit.test_backends import TestBackends, TestBackendsWithCompiler  # noqa: F401 | 
 | from jit.test_backend_nnapi import TestNnapiBackend  # noqa: F401 | 
 | from jit.test_list_dict import TestList, TestDict, TestNamedTuple, TestScriptDict, TestScriptList  # noqa: F401 | 
 | from jit.test_async import TestAsync  # noqa: F401 | 
 | from jit.test_data_parallel import TestDataParallel  # noqa: F401 | 
 | from jit.test_models import TestModels  # noqa: F401 | 
 | from jit.test_modules import TestModules  # noqa: F401 | 
 | from jit.test_autodiff import TestAutodiffJit  # noqa: F401 | 
 | from jit.test_autodiff_subgraph_slicing import TestAutodiffSubgraphSlicing  # noqa: F401 | 
 | from jit.test_custom_operators import TestCustomOperators  # noqa: F401 | 
 | from jit.test_export_modes import TestExportModes  # noqa: F401 | 
 | from jit.test_graph_rewrite_passes import TestGraphRewritePasses  # noqa: F401 | 
 | from jit.test_class_type import TestClassType  # noqa: F401 | 
 | from jit.test_builtins import TestBuiltins, TestTensorBuiltins  # noqa: F401 | 
 | from jit.test_ignore_context_manager import TestIgnoreContextManager  # noqa: F401 | 
 | from jit.test_symbolic_shape_analysis import TestSymbolicShapeAnalysis  # noqa: F401 | 
 | from jit.test_op_decompositions import TestOpDecompositions  # noqa: F401 | 
 | from jit.test_unsupported_ops import TestUnsupportedOps  # noqa: F401 | 
 | from jit.test_freezing import TestFreezing, TestFrozenOptimizations, TestMKLDNNReinplacing  # noqa: F401 | 
 | from jit.test_peephole import TestPeephole  # noqa: F401 | 
 | from jit.test_alias_analysis import TestAliasAnalysis  # noqa: F401 | 
 | from jit.test_save_load import TestSaveLoad, TestSaveLoadFlatbuffer  # noqa: F401 | 
 | from jit.test_save_load_for_op_version import TestSaveLoadForOpVersion  # noqa: F401 | 
 | from jit.test_module_containers import TestModuleContainers  # noqa: F401 | 
 | from jit.test_python_bindings import TestPythonBindings  # noqa: F401 | 
 | from jit.test_python_ir import TestPythonIr  # noqa: F401 | 
 | from jit.test_functional_blocks import TestFunctionalBlocks  # noqa: F401 | 
 | from jit.test_remove_mutation import TestRemoveMutation  # noqa: F401 | 
 | from jit.test_torchbind import TestTorchbind  # noqa: F401 | 
 | from jit.test_module_interface import TestModuleInterface  # noqa: F401  # noqa: F401 | 
 | from jit.test_with import TestWith  # noqa: F401 | 
 | from jit.test_enum import TestEnum  # noqa: F401 | 
 | from jit.test_string_formatting import TestStringFormatting  # noqa: F401 | 
 | from jit.test_profiler import TestProfiler  # noqa: F401 | 
 | from jit.test_slice import TestSlice  # noqa: F401 | 
 | from jit.test_ignorable_args import TestIgnorableArgs  # noqa: F401 | 
 | from jit.test_hooks import TestHooks  # noqa: F401 | 
 | from jit.test_warn import TestWarn  # noqa: F401 | 
 | from jit.test_isinstance import TestIsinstance  # noqa: F401 | 
 | from jit.test_cuda import TestCUDA  # noqa: F401 | 
 | from jit.test_python_builtins import TestPythonBuiltinOP  # noqa: F401 | 
 | from jit.test_typing import TestTyping  # noqa: F401 | 
 | from jit.test_hash import TestHash  # noqa: F401 | 
 | from jit.test_complex import TestComplex  # noqa: F401 | 
 | from jit.test_jit_utils import TestJitUtils  # noqa: F401 | 
 | from jit.test_scriptmod_ann import TestScriptModuleInstanceAttributeTypeAnnotation  # noqa: F401 | 
 | from jit.test_types import TestTypesAndAnnotation  # noqa: F401 | 
 | from jit.test_misc import TestMisc  # noqa: F401 | 
 | from jit.test_upgraders import TestUpgraders  # noqa: F401 | 
 | from jit.test_pdt import TestPDT  # noqa: F401 | 
 | from jit.test_tensor_creation_ops import TestTensorCreationOps  # noqa: F401 | 
 | from jit.test_module_apis import TestModuleAPIs  # noqa: F401 | 
 | from jit.test_script_profile import TestScriptProfile  # noqa: F401 | 
 | from jit.test_convert_activation import TestFunctionalToInplaceActivation, TestInplaceToFunctionalActivation  # noqa: F401 | 
 | from jit.test_parametrization import TestParametrization  # noqa: F401 | 
 | from jit.test_attr import TestGetDefaultAttr  # noqa: F401 | 
 | from jit.test_aten_pow import TestAtenPow  # noqa: F401 | 
 | from jit.test_optimize_for_mobile_preserve_debug_info import TestOptimizeForMobilePreserveDebugInfo  # noqa: F401 | 
 | from jit.test_union import TestUnion  # noqa: F401 | 
 | from jit.test_batch_mm import TestBatchMM  # noqa: F401 | 
 | from jit.test_dtype_analysis import TestDtypeAnalysis, TestDtypeCustomRulesCPU  # noqa: F401 | 
 | from jit.test_device_analysis import TestDeviceAnalysis  # noqa: F401 | 
 | from jit.test_dce import TestDCE  # noqa: F401 | 
 | from jit.test_sparse import TestSparse  # noqa: F401 | 
 | from jit.test_tensor_methods import TestTensorMethods  # noqa: F401 | 
 | from jit.test_dataclasses import TestDataclasses  # noqa: F401 | 
 |  | 
 | # Torch | 
 | from torch import Tensor | 
 | from torch._C import TensorType, BoolType, parse_ir, _propagate_shapes | 
 | from torch.autograd import Variable | 
 | from torch.jit.annotations import BroadcastingList2, BroadcastingList3, Any  # noqa: F401 | 
 | from torch.nn.utils.rnn import PackedSequence | 
 | from torch.testing import FileCheck, make_tensor | 
 | import torch.autograd.profiler | 
 | import torch.cuda | 
 | import torch.jit | 
 | import torch.jit._logging | 
 | import torch.jit.frontend | 
 | import torch.nn as nn | 
 | import torch.nn.functional as F | 
 |  | 
 | # Testing utils | 
 | from torch.testing._internal import jit_utils | 
 | from torch.testing._internal.common_jit import check_against_reference | 
 | from torch.testing._internal.common_utils import run_tests, IS_WINDOWS, TEST_WITH_UBSAN, \ | 
 |     suppress_warnings, BUILD_WITH_CAFFE2, IS_SANDCASTLE, GRAPH_EXECUTOR, ProfilingMode, TestCase, \ | 
 |     freeze_rng_state, slowTest, TemporaryFileName, skipIfCompiledWithoutNumpy, \ | 
 |     enable_profiling_mode_for_profiling_tests, TEST_MKL, set_default_dtype, num_profiled_runs, \ | 
 |     skipIfCrossRef, IS_MACOS, skipIfTorchDynamo | 
 | from torch.testing._internal.jit_utils import JitTestCase, enable_cpu_fuser, disable_autodiff_subgraph_inlining, \ | 
 |     _trace, do_input_map, get_execution_plan, make_global, \ | 
 |     execWrapper, _inline_everything, _tmp_donotuse_dont_inline_everything, \ | 
 |     RUN_CUDA | 
 | from torch.testing._internal.jit_metaprogramming_utils import ( | 
 |     get_script_args, | 
 |     create_input, unpack_variables, | 
 |     additional_module_tests, EXCLUDE_SCRIPT_MODULES, | 
 |     get_nn_module_name_from_kwargs, get_nn_mod_test_name, script_method_template) | 
 |  | 
 | from torch.testing._internal.common_nn import module_tests, new_module_tests, criterion_tests | 
 |  | 
 | # For testing truediv in python 2 | 
 | from torch.testing._internal.test_module.future_div import div_int_future, div_float_future | 
 | from torch.testing._internal.test_module.no_future_div import div_int_nofuture, div_float_nofuture | 
 |  | 
 | # Standard library | 
 | from collections import defaultdict, namedtuple, OrderedDict | 
 | from copy import deepcopy | 
 | from itertools import product | 
 | from textwrap import dedent | 
 | from typing import List, Dict, NamedTuple, Optional, Tuple, Union | 
 | import copy | 
 | import functools | 
 | import inspect | 
 | import io | 
 | import itertools | 
 | import math | 
 | import numpy as np | 
 | import os | 
 | import pickle | 
 | import pickletools | 
 | import random | 
 | import re | 
 | import shutil | 
 | import string | 
 | import sys | 
 | import tempfile | 
 | import types | 
 | import typing | 
 | import unittest | 
 | import warnings | 
 | import zipfile | 
 |  | 
 |  | 
 | def canonical(graph): | 
 |     return torch._C._jit_pass_canonicalize(graph).str(False) | 
 |  | 
 | def LSTMCellF(input, hx, cx, *params): | 
 |     return LSTMCell(input, (hx, cx), *params) | 
 |  | 
 | def doAutodiffCheck(testname): | 
 |     # TODO: setting false on test itself is not working | 
 |     if "test_t_" in testname or testname == "test_t": | 
 |         return False | 
 |  | 
 |     if GRAPH_EXECUTOR == ProfilingMode.SIMPLE: | 
 |         return False | 
 |  | 
 |     if GRAPH_EXECUTOR == ProfilingMode.LEGACY: | 
 |         return True | 
 |  | 
 |  | 
 |     # these tests are disabled because BailOut nodes | 
 |     # inserted by ProfilingExecutor interfere with | 
 |     # subgraph slicing of Differentiable Graphs | 
 |     test_exceptions = [ | 
 |         # functional | 
 |         'test_nn_dropout', | 
 |         'test_nn_log_softmax', | 
 |         'test_nn_relu', | 
 |         'test_nn_softmax', | 
 |         'test_nn_threshold', | 
 |         'test_nn_lp_pool2d', | 
 |         'test_nn_lp_pool1d', | 
 |         'test_nn_gumbel_softmax_hard', | 
 |         'test_nn_gumbel_softmax', | 
 |         'test_nn_multilabel_soft_margin_loss', | 
 |         'test_nn_batch_norm', | 
 |         'test_nn_max_pool2d_with_indices', | 
 |         # AutogradJitGenerated | 
 |         'test___rdiv___constant', | 
 |         'test___rdiv___scalar_constant', | 
 |         'test_split', | 
 |         'test_split_dim', | 
 |         'test_split_dim_neg0', | 
 |         'test_split_size_list', | 
 |         'test_split_size_list_dim', | 
 |         'test_split_size_list_dim_neg0', | 
 |         'test_split_with_sizes', | 
 |         'test_split_with_sizes_dim', | 
 |         'test_split_with_sizes_dim_neg0', | 
 |         'test_split_with_sizes_size_0', | 
 |         'test_nn_max_pool2d_with_indices', | 
 |     ] | 
 |  | 
 |     if testname in test_exceptions: | 
 |         return False | 
 |     return True | 
 |  | 
 |  | 
 | # TODO: enable TE in PE when all tests are fixed | 
 | torch._C._jit_set_texpr_fuser_enabled(GRAPH_EXECUTOR == ProfilingMode.PROFILING) | 
 | torch._C._jit_set_profiling_executor(GRAPH_EXECUTOR != ProfilingMode.LEGACY) | 
 |  | 
 | def LSTMCell(input, hidden, w_ih, w_hh, b_ih=None, b_hh=None): | 
 |     hx, cx = hidden | 
 |     gates = F.linear(input, w_ih, b_ih) + F.linear(hx, w_hh, b_hh) | 
 |  | 
 |     ingate, forgetgate, cellgate, outgate = gates.chunk(4, 1) | 
 |     ingate = torch.sigmoid(ingate) | 
 |     forgetgate = torch.sigmoid(forgetgate) | 
 |     cellgate = torch.tanh(cellgate) | 
 |     outgate = torch.sigmoid(outgate) | 
 |  | 
 |     cy = (forgetgate * cx) + (ingate * cellgate) | 
 |     hy = outgate * torch.tanh(cy) | 
 |     return hy, cy | 
 |  | 
 |  | 
 | def LSTMCellC(*args, **kwargs): | 
 |     hy, cy = LSTMCellF(*args, **kwargs) | 
 |     return torch.cat((hy, cy)) | 
 |  | 
 |  | 
 | def LSTMCellS(x, hx, cx, w_ih, w_hh, b_ih, b_hh): | 
 |     gates = x.mm(w_ih.t()) + hx.mm(w_hh.t()) + b_ih + b_hh | 
 |     ingate, forgetgate, cellgate, outgate = gates.chunk(4, 1) | 
 |     ingate = torch.sigmoid(ingate) | 
 |     forgetgate = torch.sigmoid(forgetgate) | 
 |     cellgate = torch.tanh(cellgate) | 
 |     outgate = torch.sigmoid(outgate) | 
 |     cy = (forgetgate * cx) + (ingate * cellgate) | 
 |     hy = outgate * torch.tanh(cy) | 
 |     return hy, cy | 
 |  | 
 |  | 
 | # Code reference: https://github.com/pytorch/translate/blob/master/pytorch_translate/rnn_cell.py#L27:44 | 
 | def MiLSTMCell(x, hx, cx, w_ih, w_hh, alpha, beta_i, beta_h, bias): | 
 |     Wx = x.mm(w_ih.t()) | 
 |     Uz = hx.mm(w_hh.t()) | 
 |     # Section 2.1 in https://arxiv.org/pdf/1606.06630.pdf | 
 |     gates = alpha * Wx * Uz + beta_i * Wx + beta_h * Uz + bias | 
 |     # Same as LSTMCell after this point | 
 |     ingate, forgetgate, cellgate, outgate = gates.chunk(4, 1) | 
 |     ingate = ingate.sigmoid() | 
 |     forgetgate = forgetgate.sigmoid() | 
 |     cellgate = cellgate.tanh() | 
 |     outgate = outgate.sigmoid() | 
 |     cy = (forgetgate * cx) + (ingate * cellgate) | 
 |     hy = outgate * cy.tanh() | 
 |     return hy, cy | 
 |  | 
 |  | 
 |  | 
 | def get_lstm_inputs(device, training=False, seq_length=None): | 
 |     input_shape = (3, 10) if seq_length is None else (seq_length, 3, 10) | 
 |     input = torch.randn(*input_shape, dtype=torch.float, device=device, requires_grad=training) | 
 |     hx = torch.randn(3, 20, dtype=torch.float, device=device, requires_grad=training) | 
 |     cx = torch.randn(3, 20, dtype=torch.float, device=device, requires_grad=training) | 
 |     module = nn.LSTMCell(10, 20).to(device, torch.float)  # Just to allocate weights with correct sizes | 
 |     if training: | 
 |         params = tuple(module.parameters()) | 
 |     else: | 
 |         params = tuple(p.requires_grad_(False) for p in module.parameters()) | 
 |     return (input, hx, cx) + params | 
 |  | 
 |  | 
 | def get_milstm_inputs(device, training=False): | 
 |     minibatch = 3 | 
 |     input_size = 10 | 
 |     hidden_size = 20 | 
 |     x = torch.randn(minibatch, input_size, device=device, dtype=torch.float) | 
 |     hx = torch.randn(minibatch, hidden_size, device=device, dtype=torch.float) | 
 |     cx = torch.randn(minibatch, hidden_size, device=device, dtype=torch.float) | 
 |  | 
 |     ih = torch.randn(4 * hidden_size, input_size, device=device, dtype=torch.float, requires_grad=training) | 
 |     hh = torch.randn(4 * hidden_size, hidden_size, device=device, dtype=torch.float, requires_grad=training) | 
 |     alpha = torch.randn(4 * hidden_size, dtype=torch.float, device=device, requires_grad=training) | 
 |     ibeta = torch.randn(4 * hidden_size, dtype=torch.float, device=device, requires_grad=training) | 
 |     hbeta = torch.randn(4 * hidden_size, dtype=torch.float, device=device, requires_grad=training) | 
 |     bias = torch.randn(4 * hidden_size, dtype=torch.float, device=device, requires_grad=training) | 
 |     return x, hx, cx, ih, hh, alpha, ibeta, hbeta, bias | 
 |  | 
 |  | 
 | def get_fn(file_name, script_path): | 
 |     import importlib.util | 
 |     spec = importlib.util.spec_from_file_location(file_name, script_path) | 
 |     module = importlib.util.module_from_spec(spec) | 
 |     spec.loader.exec_module(module) | 
 |     fn = module.fn | 
 |     return fn | 
 |  | 
 | def get_grad_executor(plan_state, diff_graph_idx=None, skip_check=False): | 
 |     if diff_graph_idx is None: | 
 |         nodes = list(plan_state.graph.nodes()) | 
 |  | 
 |         if not skip_check: | 
 |             nodes = list(filter(lambda n : n.kind() != "prim::BailOut" and n.kind() != "prim::BailoutTemplate", nodes)) | 
 |             if len(nodes) == 1 or (len(nodes) == 2 and nodes[1].kind() == "prim::TupleConstruct"): | 
 |                 pass | 
 |             elif len(nodes) == 2 and nodes[0].kind() == "prim::RequiresGradCheck" and nodes[1].kind() == "prim::If": | 
 |                 pass | 
 |             else: | 
 |                 raise RuntimeError("Can't get a grad_executor for a non-differentiable graph") | 
 |     grad_executors = list(plan_state.code.grad_executor_states()) | 
 |     return grad_executors[diff_graph_idx or 0] | 
 |  | 
 |  | 
 | def all_backward_graphs(script_module, diff_graph_idx=None): | 
 |     # Note: for Python 2 the order seems to be unstable | 
 |     ge_state = script_module.get_debug_state() | 
 |     fwd_plan = get_execution_plan(ge_state) | 
 |     grad_executor_state = get_grad_executor(fwd_plan, diff_graph_idx=diff_graph_idx) | 
 |     bwd_plans = list(grad_executor_state.execution_plans.values()) | 
 |     return [p.graph.copy() for p in bwd_plans] | 
 |  | 
 |  | 
 | def backward_graph(script_module, diff_graph_idx=None, skip_check=False): | 
 |     ge_state = script_module.get_debug_state() | 
 |     fwd_plan = get_execution_plan(ge_state) | 
 |     grad_executor_state = get_grad_executor(fwd_plan, diff_graph_idx=diff_graph_idx, skip_check=skip_check) | 
 |     bwd_plan = get_execution_plan(grad_executor_state) | 
 |     # Running JIT passes requires that we own the graph (with a shared_ptr). | 
 |     # The debug state struct does not own its graph so we make a copy of it. | 
 |     return bwd_plan.graph.copy() | 
 |  | 
 |  | 
 | # helper function to get sum of List[Tensor] | 
 | def _sum_of_list(tensorlist): | 
 |     s = 0 | 
 |     for t in tensorlist: | 
 |         s += t.sum() | 
 |     return s | 
 |  | 
 |  | 
 | # has to be at top level or Pickle complains | 
 | class FooToPickle(torch.nn.Module): | 
 |     def __init__(self): | 
 |         super(FooToPickle, self).__init__() | 
 |         self.bar = torch.jit.ScriptModule() | 
 |  | 
 | class TestJit(JitTestCase): | 
 |     @unittest.skip("Requires a lot of RAM") | 
 |     def test_big(self): | 
 |         m = torch.jit.ScriptModule() | 
 |         gig = int(1024 * 1024 * 1024 / 4) | 
 |         # a small tensor in the first 4GB | 
 |         m.v0 = nn.Parameter(torch.full((2,), 1, dtype=torch.float)) | 
 |         # a large tensor in the first 4GB that ends outside of it | 
 |         m.v1 = nn.Parameter(torch.full((5, gig), 2, dtype=torch.float)) | 
 |         # a small tensor in >4GB space | 
 |         m.v2 = nn.Parameter(torch.full((2,), 3, dtype=torch.float)) | 
 |         # s large tensor in the > 4GB space | 
 |         m.v3 = nn.Parameter(torch.full((5, gig), 4, dtype=torch.float)) | 
 |  | 
 |         m2 = self.getExportImportCopy(m) | 
 |  | 
 |         self.assertEqual(tuple(m.parameters()), tuple(m2.parameters())) | 
 |  | 
 |     def test_inferred_as_tensor(self): | 
 |         with self.assertRaisesRegex(RuntimeError, "Inferred the value for argument 'dim' to be of type 'Tensor' " | 
 |                                                   "because it was not annotated with an explicit type"): | 
 |             @torch.jit.script | 
 |             def dot(points, query, dim): | 
 |                 return (points * query).sum(dim) | 
 |  | 
 |     def test_constants_pkl(self): | 
 |         # This test asserts that the serialization archive includes a `constants.pkl` | 
 |         # file. This file is used by `torch.load` to determine whether a zip file | 
 |         # is a normal eager-mode serialization zip or a jit serialization zip. If | 
 |         # you are deleting `constants.pkl`, make sure to update `torch.serialization.load` | 
 |         # so it is still able to figure out which is which. | 
 |         @torch.jit.script | 
 |         def fn(x): | 
 |             return x | 
 |  | 
 |         buf = io.BytesIO() | 
 |         torch.jit.save(fn, buf) | 
 |         buf.seek(0) | 
 |  | 
 |         files = zipfile.ZipFile(buf).filelist | 
 |         self.assertTrue(any(['archive/constants.pkl' == f.filename for f in files])) | 
 |  | 
 |     def test_script_fn_pkl(self): | 
 |         with self.assertRaisesRegex(pickle.PickleError, "ScriptFunction cannot be pickled"): | 
 |  | 
 |             @torch.jit.script | 
 |             def fn(x: torch.Tensor) -> torch.Tensor: | 
 |                 return x | 
 |  | 
 |             pkl_fn = pickle.dumps(fn, protocol=0) | 
 |  | 
 |     def test_restore_device(self): | 
 |         class M(torch.jit.ScriptModule): | 
 |             def __init__(self, cpu_device_str): | 
 |                 super(M, self).__init__() | 
 |                 self.p0 = nn.Parameter(torch.tensor([0.3], dtype=torch.float, | 
 |                                                     device=cpu_device_str)) | 
 |                 self.b0 = torch.tensor([0.9], dtype=torch.float, | 
 |                                        device=cpu_device_str) | 
 |  | 
 |         # main purpose is checking map_location works | 
 |         m = M("cpu") | 
 |         m2 = self.getExportImportCopy(m) | 
 |         self.assertEqual(tuple(m.parameters()), tuple(m2.parameters())) | 
 |         self.assertEqual(tuple(m.buffers()), tuple(m2.buffers())) | 
 |         self.assertFalse(m2.p0.is_cuda) | 
 |         self.assertFalse(m2.b0.is_cuda) | 
 |  | 
 |     @unittest.skipIf(not RUN_CUDA, "restore device requires CUDA") | 
 |     def test_restore_device_cuda(self): | 
 |         class MyModule(torch.jit.ScriptModule): | 
 |             def __init__(self): | 
 |                 super(MyModule, self).__init__() | 
 |                 self.register_buffer('b0', torch.randn(1, 3)) | 
 |                 self.p0 = nn.Parameter(torch.randn(2, 3)) | 
 |  | 
 |             @torch.jit.script_method | 
 |             def forward(self, x): | 
 |                 return x + self.b0 + self.p0 | 
 |  | 
 |         m = MyModule() | 
 |         m.cuda(torch.cuda.device_count() - 1) | 
 |         cuda_device_str = 'cuda:' + str(torch.cuda.device_count() - 1) | 
 |  | 
 |         self.assertTrue(m.p0.is_cuda) | 
 |         self.assertTrue(m.b0.is_cuda) | 
 |  | 
 |         # restore to the saved devices | 
 |         m2 = self.getExportImportCopy(m) | 
 |         self.assertEqual(tuple(m.parameters()), tuple(m2.parameters())) | 
 |         self.assertEqual(tuple(m.buffers()), tuple(m2.buffers())) | 
 |         self.assertEqual(str(m2.p0.device), cuda_device_str) | 
 |         self.assertEqual(str(m2.b0.device), cuda_device_str) | 
 |  | 
 |         # restore all to cpu using string | 
 |         cpu_device_str = 'cpu' | 
 |         m3 = self.getExportImportCopy(m, map_location=cpu_device_str) | 
 |         self.assertEqual(str(m3.p0.device), cpu_device_str) | 
 |         self.assertEqual(str(m3.b0.device), cpu_device_str) | 
 |  | 
 |         # restore all to first gpu using device | 
 |         m4 = self.getExportImportCopy( | 
 |             m3, map_location=torch.device('cuda:0')) | 
 |         self.assertEqual(str(m4.p0.device), 'cuda:0') | 
 |         self.assertEqual(str(m4.b0.device), 'cuda:0') | 
 |  | 
 |         # compute and compare the results | 
 |         input = torch.rand(2, 3).cuda(torch.cuda.device_count() - 1) | 
 |         origin_result = m(input) | 
 |         self.assertEqual(origin_result, m2(input)) | 
 |         self.assertEqual(origin_result, m3(input.cpu())) | 
 |         self.assertEqual(origin_result, m4(input.cuda(0))) | 
 |  | 
 |     def test_trace_retains_train(self): | 
 |         class M(torch.nn.Module): | 
 |             def forward(self, x): | 
 |                 return x | 
 |         m = M() | 
 |         m.eval() | 
 |         tm = torch.jit.trace(m, (torch.rand(3))) | 
 |         self.assertEqual(tm.training, m.training) | 
 |  | 
 |     @unittest.skipIf(not RUN_CUDA, "restore device requires CUDA") | 
 |     def test_restore_shared_storage_on_cuda(self): | 
 |         class Foo(torch.jit.ScriptModule): | 
 |             def __init__(self): | 
 |                 super(Foo, self).__init__() | 
 |                 whole_tensor = torch.randn(4, 5, dtype=torch.float, device='cpu') | 
 |                 self.p0 = nn.Parameter(whole_tensor.narrow(0, 0, 1)) | 
 |                 self.register_buffer('b0', whole_tensor.narrow(0, 3, 1)) | 
 |  | 
 |         m = Foo() | 
 |         m2 = self.getExportImportCopy(m, map_location=torch.device('cuda:0')) | 
 |         self.assertEqual(tuple(m.parameters()), tuple(m2.parameters())) | 
 |         self.assertEqual(tuple(m.buffers()), tuple(m2.buffers())) | 
 |         self.assertTrue(m2.p0.is_cuda) | 
 |         self.assertTrue(m2.b0.is_cuda) | 
 |         self.assertTrue(m2.p0.is_shared()) | 
 |         self.assertTrue(m2.b0.is_shared()) | 
 |         self.assertEqual(m2.b0.storage().data_ptr(), m2.p0.storage().data_ptr()) | 
 |  | 
 |     def test_add_relu_fusion(self): | 
 |         class M(torch.nn.Module): | 
 |             def __init__(self, relu_op): | 
 |                 super(M, self).__init__() | 
 |                 self.relu_op = relu_op | 
 |  | 
 |             def forward(self, a, b, c): | 
 |                 tmp = torch.add(a, b) | 
 |                 x = self.relu_op(tmp) | 
 |                 d = torch.add(a, c) | 
 |                 return x + d | 
 |         a = torch.rand((7, 11)) | 
 |         a = a * -10 | 
 |         a = a + 5 | 
 |         b = torch.rand((7, 11)) | 
 |         c = torch.rand((7, 11)) | 
 |         m = torch.jit.script(M(torch.relu)) | 
 |         orig_res = m(a, b, c) | 
 |         torch._C._jit_pass_fuse_add_relu(m.graph) | 
 |         buffer = io.BytesIO() | 
 |         torch.jit.save(m, buffer) | 
 |         buffer.seek(0) | 
 |         m = torch.jit.load(buffer) | 
 |         new_res = m(a, b, c) | 
 |         FileCheck().check_not("aten::relu(") \ | 
 |             .check("aten::_add_relu(") \ | 
 |             .run(m.graph) | 
 |         torch.testing.assert_close(orig_res, new_res) | 
 |  | 
 |         # add, relu_ | 
 |         a = torch.rand((7, 11)) | 
 |         a = a * -10 | 
 |         a = a + 5 | 
 |         b = torch.rand((7, 11)) | 
 |         c = torch.rand((7, 11)) | 
 |         m = torch.jit.script(M(torch.relu_)) | 
 |         orig_res = m(a, b, c) | 
 |         torch._C._jit_pass_fuse_add_relu(m.graph) | 
 |         buffer = io.BytesIO() | 
 |         torch.jit.save(m, buffer) | 
 |         buffer.seek(0) | 
 |         m = torch.jit.load(buffer) | 
 |         new_res = m(a, b, c) | 
 |         FileCheck().check_not("aten::relu_(") \ | 
 |             .check("aten::_add_relu(") \ | 
 |             .run(m.graph) | 
 |         torch.testing.assert_close(orig_res, new_res) | 
 |  | 
 |         class Madd_(torch.nn.Module): | 
 |             def __init__(self, relu_op): | 
 |                 super(Madd_, self).__init__() | 
 |                 self.relu_op = relu_op | 
 |  | 
 |             def forward(self, a, b): | 
 |                 x = a.add_(b) | 
 |                 x = self.relu_op(x) | 
 |                 return x | 
 |  | 
 |         # add_, relu_ | 
 |         a = torch.rand((7, 11)) | 
 |         a = a * -10 | 
 |         a = a + 5 | 
 |         b = torch.rand((7, 11)) | 
 |         # Because in place add_ will overwrite a | 
 |         a_copy = a.clone() | 
 |         m = torch.jit.script(Madd_(torch.relu_)) | 
 |         orig_res = m(a, b) | 
 |         torch._C._jit_pass_fuse_add_relu(m.graph) | 
 |         buffer = io.BytesIO() | 
 |         torch.jit.save(m, buffer) | 
 |         buffer.seek(0) | 
 |         m = torch.jit.load(buffer) | 
 |         new_res = m(a_copy, b) | 
 |         FileCheck().check_not("aten::add_(") \ | 
 |             .check_not("aten::relu_(") \ | 
 |             .check("aten::_add_relu_(") \ | 
 |             .run(m.graph) | 
 |         torch.testing.assert_close(orig_res, new_res) | 
 |         # Since _add_relu_ does inplace mutation ensure | 
 |         # a_copy is modified | 
 |         torch.testing.assert_close(orig_res, a_copy) | 
 |  | 
 |         class Madd_out(torch.nn.Module): | 
 |             def __init__(self, relu_op): | 
 |                 super(Madd_out, self).__init__() | 
 |                 self.relu_op = relu_op | 
 |  | 
 |             def forward(self, a, b): | 
 |                 x = torch.add(a, b, out=a) | 
 |                 x = self.relu_op(x) | 
 |                 return x | 
 |         a = torch.rand((7, 11)) | 
 |         a = a * -10 | 
 |         a = a + 5 | 
 |         b = torch.rand((7, 11)) | 
 |  | 
 |         # add_out, relu_ | 
 |         a = torch.rand((7, 11)) | 
 |         a = a * -10 | 
 |         a = a + 5 | 
 |         b = torch.rand((7, 11)) | 
 |         # Because in place add_ will overwrite a | 
 |         a_copy = a.clone() | 
 |         m = torch.jit.script(Madd_out(torch.relu_)) | 
 |         orig_res = m(a, b) | 
 |         torch._C._jit_pass_fuse_add_relu(m.graph) | 
 |         buffer = io.BytesIO() | 
 |         torch.jit.save(m, buffer) | 
 |         buffer.seek(0) | 
 |         m = torch.jit.load(buffer) | 
 |         new_res = m(a_copy, b) | 
 |         FileCheck().check_not("aten::add(") \ | 
 |             .check_not("aten::relu_(") \ | 
 |             .check("aten::_add_relu(") \ | 
 |             .run(m.graph) | 
 |         torch.testing.assert_close(orig_res, new_res) | 
 |         # Since _add_relu_ with out=a does inplace mutation ensure | 
 |         # a_copy is modified | 
 |         torch.testing.assert_close(orig_res, a_copy) | 
 |  | 
 |     def test_repeat_interleave_script(self): | 
 |         def fn(input: torch.Tensor, repeats: torch.Tensor) -> torch.Tensor: | 
 |             output = input.repeat_interleave(repeats) | 
 |             return output | 
 |         fn_scripted = torch.jit.script(fn) | 
 |  | 
 |         input = torch.tensor([5, 7], dtype=torch.int64) | 
 |         repeats = torch.tensor([3, 6], dtype=torch.int64) | 
 |  | 
 |         output = fn(input, repeats) | 
 |         output_scripted = fn_scripted(input, repeats) | 
 |         self.assertEqual(output_scripted, output) | 
 |  | 
 |     @unittest.skipIf(GRAPH_EXECUTOR != ProfilingMode.LEGACY, "Simple executor doesn't have shape information") | 
 |     def test_peephole_optimize_shape_ops(self): | 
 |         def test_input(func, input, result): | 
 |             # if result == 2 we will trigger a bailout and | 
 |             # the unprofiled graph should return the correct result | 
 |             self.assertEqual(func(input, profile_and_replay=True), result) | 
 |             gre = func.graph_for(input) | 
 |             FileCheck().check_not("prim::If").run(gre) | 
 |  | 
 |         def test_dim(): | 
 |             @torch.jit.script | 
 |             def func(x): | 
 |                 if x.dim() == 1: | 
 |                     return 1 | 
 |                 else: | 
 |                     return 2 | 
 |  | 
 |             test_input(func, torch.tensor([0.5]), 1) | 
 |             test_input(func, torch.tensor([[0.5]]), 2) | 
 |         test_dim() | 
 |  | 
 |         def test_size_index(): | 
 |             @torch.jit.script | 
 |             def func(x): | 
 |                 if x.size(0) == 1: | 
 |                     return 1 | 
 |                 else: | 
 |                     return 2 | 
 |  | 
 |             test_input(func, torch.rand([1, 2]), 1) | 
 |             test_input(func, torch.rand([1, 3]), 1) | 
 |  | 
 |             @torch.jit.script | 
 |             def neg_index(x): | 
 |                 if x.size(-2) == 1: | 
 |                     return 1 | 
 |                 else: | 
 |                     return 2 | 
 |  | 
 |             test_input(neg_index, torch.rand([1, 2]), 1) | 
 |             test_input(neg_index, torch.rand([1, 3]), 1) | 
 |  | 
 |         if GRAPH_EXECUTOR == ProfilingMode.PROFILING: | 
 |             test_size_index() | 
 |  | 
 |         def test_dtype(): | 
 |             @torch.jit.script | 
 |             def func(x): | 
 |                 if x.dtype == torch.float32: | 
 |                     return 1 | 
 |                 else: | 
 |                     return 2 | 
 |  | 
 |             test_input(func, torch.tensor(0.5, dtype=torch.float32), 1) | 
 |             test_input(func, torch.tensor(0.5, dtype=torch.int64), 2) | 
 |         test_dtype() | 
 |  | 
 |         def test_is_floating_poiint(): | 
 |             @torch.jit.script | 
 |             def func(x): | 
 |                 if x.is_floating_point(): | 
 |                     return 1 | 
 |                 else: | 
 |                     return 2 | 
 |  | 
 |             test_input(func, torch.tensor(0.5, dtype=torch.float32), 1) | 
 |             test_input(func, torch.tensor(0.5, dtype=torch.int64), 2) | 
 |         test_is_floating_poiint() | 
 |  | 
 |         def test_device(): | 
 |             @torch.jit.script | 
 |             def func_1(x): | 
 |                 if x.device == torch.device('cuda:0'): | 
 |                     a = 0 | 
 |                 else: | 
 |                     a = 1 | 
 |                 return a | 
 |  | 
 |             @torch.jit.script | 
 |             def func_2(x): | 
 |                 if x.is_cuda: | 
 |                     a = 0 | 
 |                 else: | 
 |                     a = 1 | 
 |                 return a | 
 |  | 
 |             test_input(func_1, torch.tensor(0.5), 1) | 
 |             test_input(func_2, torch.tensor(0.5), 1) | 
 |  | 
 |             if RUN_CUDA: | 
 |                 test_input(func_1, torch.tensor(0.5, device="cuda:0"), 0) | 
 |                 test_input(func_2, torch.tensor(0.5, device="cuda:0"), 0) | 
 |  | 
 |         test_device() | 
 |  | 
 |     def test_attrs(self): | 
 |         def foo(x): | 
 |             return ( | 
 |                 # x.dtype, TODO: dtype long -> instance conversion | 
 |                 x.device, | 
 |                 x.shape, | 
 |                 x.is_cuda, | 
 |                 x.is_mkldnn, | 
 |                 x.is_quantized, | 
 |                 x.requires_grad, | 
 |                 x.T, | 
 |                 x.mT, | 
 |                 x.H, | 
 |                 x.mH | 
 |                 # x.layout TODO: layout long -> instance conversion | 
 |             ) | 
 |  | 
 |         scripted = torch.jit.script(foo) | 
 |         x = torch.rand(3, 4) | 
 |         self.assertEqual(scripted(x), foo(x)) | 
 |  | 
 |     def test_layout(self): | 
 |         @torch.jit.script | 
 |         def check(x, y): | 
 |             return x.layout == y.layout | 
 |  | 
 |         x = torch.rand(3, 4) | 
 |         y = torch.rand(3, 4) | 
 |  | 
 |         self.assertTrue(check(x, y)) | 
 |  | 
 |     def test_matrix_transpose(self): | 
 |         @torch.jit.script | 
 |         def check(x): | 
 |             return torch.equal(x.mT, x.transpose(-2, -1)) | 
 |  | 
 |         x = torch.rand(3, 4) | 
 |         self.assertTrue(check(x)) | 
 |  | 
 |     def test_transpose(self): | 
 |         @torch.jit.script | 
 |         def check(x): | 
 |             return torch.equal(x.T, x.t()) | 
 |  | 
 |         x = torch.rand(3, 4) | 
 |         self.assertTrue(check(x)) | 
 |  | 
 |     def test_matrix_conj_transpose(self): | 
 |         @torch.jit.script | 
 |         def check(x): | 
 |             return torch.equal(x.mH, x.transpose(-2, -1).conj()) | 
 |  | 
 |         x = torch.rand(3, 4) | 
 |         self.assertTrue(check(x)) | 
 |  | 
 |         x = make_tensor((3, 4), device="cpu", dtype=torch.complex64) | 
 |         self.assertTrue(check(x)) | 
 |  | 
 |     def test_conj_transpose(self): | 
 |         @torch.jit.script | 
 |         def check(x): | 
 |             return torch.equal(x.H, x.t().conj()) | 
 |  | 
 |         x = torch.rand(3, 4) | 
 |         self.assertTrue(check(x)) | 
 |  | 
 |         x = make_tensor((3, 4), device="cpu", dtype=torch.complex64) | 
 |         self.assertTrue(check(x)) | 
 |  | 
 |     def test_T_mT_H_mH(self): | 
 |         def T(x): | 
 |             return x.mT | 
 |  | 
 |         def mT(x): | 
 |             return x.mT | 
 |  | 
 |         def H(x): | 
 |             return x.H | 
 |  | 
 |         def mH(x): | 
 |             return x.mH | 
 |  | 
 |         x = torch.rand(3, 4) | 
 |         y = make_tensor((3, 4), device="cpu", dtype=torch.complex64) | 
 |  | 
 |         self.checkScript(T, (x, )) | 
 |         self.checkScript(mT, (x, )) | 
 |         self.checkScript(H, (x, )) | 
 |         self.checkScript(mH, (x, )) | 
 |         self.checkScript(T, (y, )) | 
 |         self.checkScript(mT, (y, )) | 
 |         self.checkScript(H, (y, )) | 
 |         self.checkScript(mH, (y, )) | 
 |  | 
 |     def test_nn_conv(self): | 
 |         class Mod(nn.Module): | 
 |             def __init__(self, conv): | 
 |                 super().__init__() | 
 |                 self.conv = conv | 
 |  | 
 |             def forward(self, input): | 
 |                 return self.conv(input) | 
 |  | 
 |         inputs = [ | 
 |             # Conv | 
 |             (Mod(nn.Conv1d(16, 33, 3, stride=2)), torch.randn(20, 16, 5)), | 
 |             (Mod(nn.Conv2d(16, 33, 3, stride=2)), torch.randn(20, 16, 5, 10)), | 
 |             (Mod(nn.Conv3d(16, 33, 3, stride=2)), torch.randn(20, 16, 3, 5, 4)), | 
 |             # ConvTransposed | 
 |             (Mod(nn.ConvTranspose1d(16, 33, 3, stride=2)), torch.randn(20, 16, 5)), | 
 |             (Mod(nn.ConvTranspose2d(16, 33, 3, stride=2)), torch.randn(20, 16, 5, 10)), | 
 |             (Mod(nn.ConvTranspose3d(16, 33, 3, stride=2)), torch.randn(20, 16, 3, 5, 4)), | 
 |         ] | 
 |  | 
 |         for m, inp in inputs: | 
 |             self.checkModule(m, (inp,)) | 
 |  | 
 |     @unittest.skipIf(GRAPH_EXECUTOR != ProfilingMode.PROFILING, 'Not implemented for Simple or Legacy') | 
 |     def test_debug_flush_compilation_cache(self): | 
 |         def foo(x): | 
 |             return x + 2 | 
 |  | 
 |         class Mod(nn.Module): | 
 |             def __init__(self): | 
 |                 super(Mod, self).__init__() | 
 |  | 
 |             def forward(self, t): | 
 |                 return t + 2 | 
 |  | 
 |         m = torch.jit.script(Mod()) | 
 |         x = torch.rand(1, 10) | 
 |  | 
 |         with enable_profiling_mode_for_profiling_tests(): | 
 |             jitted = self.checkScript(foo, (x,)) | 
 |             # shouldn't throw | 
 |             states = jitted.get_debug_state() | 
 |  | 
 |             # after flushing there shouldn't be | 
 |             # no opt plan | 
 |             jitted._debug_flush_compilation_cache() | 
 |             with self.assertRaisesRegex(RuntimeError, "INTERNAL ASSERT FAILED"): | 
 |                 states = jitted.get_debug_state() | 
 |  | 
 |             NUM_RUNS = 1 | 
 |             with num_profiled_runs(NUM_RUNS): | 
 |                 m(x) | 
 |                 m(x) | 
 |                 fwd = m._c._get_method("forward") | 
 |                 states = m.get_debug_state() | 
 |  | 
 |                 # after flushing there shouldn't be | 
 |                 # no opt plan | 
 |                 fwd._debug_flush_compilation_cache() | 
 |                 with self.assertRaisesRegex(RuntimeError, "INTERNAL ASSERT FAILED"): | 
 |                     states = m.get_debug_state() | 
 |  | 
 |     def test_numel(self): | 
 |         @torch.jit.script | 
 |         def get_numel_script(x): | 
 |             return x.numel() | 
 |  | 
 |         x = torch.rand(3, 4) | 
 |         numel = get_numel_script(x) | 
 |         self.assertEqual(numel, x.numel()) | 
 |  | 
 |     def test_element_size(self): | 
 |         @torch.jit.script | 
 |         def get_element_size_script(x): | 
 |             return x.element_size() | 
 |  | 
 |         x = torch.rand(3, 4) | 
 |         element_size = get_element_size_script(x) | 
 |         self.assertEqual(element_size, x.element_size()) | 
 |  | 
 |     def test_Sequential(self): | 
 |         class Seq(nn.Module): | 
 |             def __init__(self): | 
 |                 super(Seq, self).__init__() | 
 |                 self.seq = nn.Sequential(nn.Linear(10, 20), nn.Linear(20, 30)) | 
 |  | 
 |             @torch.jit.script_method | 
 |             def forward(self, x): | 
 |                 for l in self.seq: | 
 |                     x = l(x) | 
 |                 return x | 
 |  | 
 |         m = torch.jit.script(Seq()) | 
 |         assert m.graph  # ensure jit was able to compile | 
 |  | 
 |     def test_ModuleList(self): | 
 |         class Mod(nn.Module): | 
 |             def __init__(self): | 
 |                 super(Mod, self).__init__() | 
 |                 self.model = nn.ModuleList([nn.Linear(10, 10) for _ in range(10)]) | 
 |                 self.model += (nn.Linear(10, 20),) | 
 |                 self.model.append(nn.Linear(20, 30)) | 
 |                 self.model.extend([nn.Linear(30, 40), nn.Linear(40, 50)]) | 
 |  | 
 |             def forward(self, v): | 
 |                 for m in self.model: | 
 |                     v = m(v) | 
 |                 return v | 
 |  | 
 |         m = torch.jit.script(Mod()) | 
 |         assert m.graph  # ensure jit was able to compile | 
 |  | 
 |     def test_disabled(self): | 
 |         torch.jit._state.disable() | 
 |         try: | 
 |             def f(x, y): | 
 |                 return x + y | 
 |  | 
 |             self.assertIs(torch.jit.trace(f, (torch.randn(2, 2), torch.randn(2, 2))), f) | 
 |             self.assertIs(torch.jit.script(f), f) | 
 |  | 
 |             class MyModule(torch.jit.ScriptModule): | 
 |                 @torch.jit.script_method | 
 |                 def method(self, x): | 
 |                     return x | 
 |  | 
 |             # XXX: Unfortunately ScriptModule won't simply become Module now, | 
 |             # because that requires disabling the JIT at startup time, which | 
 |             # we can't do in here. | 
 |             # We need to or those two conditions to make it work with all versions of Python | 
 |             self.assertTrue(inspect.ismethod(MyModule.method) or inspect.isfunction(MyModule.method)) | 
 |         finally: | 
 |             torch.jit._state.enable() | 
 |  | 
 |     def test_train_eval(self): | 
 |         class Sub(nn.Module): | 
 |             def forward(self, input): | 
 |                 if self.training: | 
 |                     return input | 
 |                 else: | 
 |                     return -input | 
 |  | 
 |         class MyModule(torch.jit.ScriptModule): | 
 |             def __init__(self, module): | 
 |                 super(MyModule, self).__init__() | 
 |                 self.module = module | 
 |  | 
 |             @torch.jit.script_method | 
 |             def forward(self, input): | 
 |                 return self.module(input) + 1 | 
 |  | 
 |         m = MyModule(Sub()) | 
 |         input = torch.rand(3, 4) | 
 |         self.assertEqual(input + 1, m(input)) | 
 |         m.eval() | 
 |         self.assertEqual(-input + 1, m(input)) | 
 |  | 
 |         # test batchnorm and dropout train/eval | 
 |         input = torch.randn(6, 10) | 
 |         batchnorm = nn.BatchNorm1d(10) | 
 |         dropout = nn.Dropout(p=0.2) | 
 |  | 
 |         m_batchnorm = MyModule(batchnorm) | 
 |         self.assertEqual(batchnorm(input) + 1, m_batchnorm(input)) | 
 |         batchnorm.eval() | 
 |         m_batchnorm.eval() | 
 |         self.assertEqual(batchnorm(input) + 1, m_batchnorm(input)) | 
 |  | 
 |         m_dropout = MyModule(dropout) | 
 |         dropout.eval() | 
 |         m_dropout.eval() | 
 |         self.assertEqual(dropout(input) + 1, m_dropout(input)) | 
 |  | 
 |     def test_nn_lp_pool2d(self): | 
 |         class Mod(torch.nn.Module): | 
 |             def __init__(self): | 
 |                 super().__init__() | 
 |                 self.l = torch.nn.LPPool2d(2, 3) | 
 |                 self.n = torch.nn.LPPool2d(2, (7, 1)) | 
 |  | 
 |             def forward(self, x): | 
 |                 return (self.l(x), | 
 |                         self.n(x), | 
 |                         torch.nn.functional.lp_pool2d(x, float(2), 3), | 
 |                         torch.nn.functional.lp_pool2d(x, 2, 3), | 
 |                         torch.nn.functional.lp_pool2d(x, float(2), (7, 1))) | 
 |  | 
 |         self.checkModule(Mod(), (torch.rand(1, 3, 7, 7),)) | 
 |  | 
 |     def test_nn_lp_pool1d(self): | 
 |         class Mod(torch.nn.Module): | 
 |             def __init__(self): | 
 |                 super().__init__() | 
 |                 self.l = torch.nn.LPPool1d(2, 3) | 
 |                 self.n = torch.nn.LPPool1d(2, 7) | 
 |  | 
 |             def forward(self, x): | 
 |                 return (self.l(x), | 
 |                         self.n(x), | 
 |                         torch.nn.functional.lp_pool1d(x, float(2), 3), | 
 |                         torch.nn.functional.lp_pool1d(x, 2, 3), | 
 |                         torch.nn.functional.lp_pool1d(x, float(2), 7)) | 
 |  | 
 |         self.checkModule(Mod(), (torch.rand(1, 3, 7),)) | 
 |  | 
 |     def test_nn_padding_functional(self): | 
 |         class Mod(nn.Module): | 
 |             def __init__(self, *pad): | 
 |                 super().__init__() | 
 |                 self.pad = pad | 
 |  | 
 |             def forward(self, x): | 
 |                 return F.pad(x, self.pad, mode='constant', value=3.5) | 
 |  | 
 |         inputs = [ | 
 |             (Mod(1, 2), torch.randn(1, 3, 4)),  # 1D | 
 |             (Mod(1, 2, 3, 4), torch.randn(1, 3, 4)),  # 2D | 
 |             (Mod(1, 2, 3, 4, 5, 6), torch.randn(1, 3, 4)),  # 3D | 
 |         ] | 
 |  | 
 |         for m, inp in inputs: | 
 |             self.checkModule(m, (inp,)) | 
 |  | 
 |     def test_nn_padding(self): | 
 |         class Mod(nn.Module): | 
 |             def __init__(self, padding): | 
 |                 super().__init__() | 
 |                 self.padding = padding | 
 |  | 
 |             def forward(self, input): | 
 |                 return self.padding(input) | 
 |  | 
 |         inputs = [ | 
 |             (Mod(nn.ConstantPad1d(2, 3.5)), torch.randn(1, 2, 4)), | 
 |             (Mod(nn.ConstantPad2d(2, 3.5)), torch.randn(1, 2, 2)), | 
 |             (Mod(nn.ConstantPad3d(3, 3.5)), torch.randn(16, 3, 10, 20, 30)), | 
 |             (Mod(nn.ReflectionPad1d(2)), torch.arange(8, dtype=torch.float).reshape(1, 2, 4)), | 
 |             (Mod(nn.ReflectionPad2d(2)), torch.arange(9, dtype=torch.float).reshape(1, 1, 3, 3)), | 
 |             (Mod(nn.ReflectionPad3d(3)), torch.randn(16, 3, 8, 32, 48)), | 
 |             (Mod(nn.ReplicationPad1d(2)), torch.arange(8, dtype=torch.float).reshape(1, 2, 4)), | 
 |             (Mod(nn.ReplicationPad2d(2)), torch.arange(9, dtype=torch.float).reshape(1, 1, 3, 3)), | 
 |             (Mod(nn.ReplicationPad3d(3)), torch.randn(16, 3, 8, 32, 48)), | 
 |             (Mod(nn.ZeroPad2d(2)), torch.randn(1, 1, 3, 3)) | 
 |         ] | 
 |  | 
 |         for m, inp in inputs: | 
 |             self.checkModule(m, (inp,)) | 
 |  | 
 |     def test_script_autograd_grad(self): | 
 |         def test_simple_grad(x, y): | 
 |             # type: (Tensor, Tensor) -> List[Optional[Tensor]] | 
 |             z = x + 2 * y + x * y | 
 |             return torch.autograd.grad((z.sum(), ), (x, y)) | 
 |  | 
 |         def test_simple_grad_with_grad_outputs(x, y): | 
 |             # type: (Tensor, Tensor) -> List[Optional[Tensor]] | 
 |             z = x + 2 * y + x * y | 
 |             grad_outputs = torch.jit.annotate(List[Optional[torch.Tensor]], [torch.ones((2, 2)), ]) | 
 |             return torch.autograd.grad((z, ), (x, y), grad_outputs) | 
 |  | 
 |         def test_one_output_not_requires_grad(x, y): | 
 |             # type: (Tensor, Tensor) -> List[Optional[Tensor]] | 
 |             z = 2 * y + y | 
 |             return torch.autograd.grad((z.sum(),), (x, y), allow_unused=True) | 
 |  | 
 |         def test_retain_graph(x, y): | 
 |             # type: (Tensor, Tensor) -> None | 
 |             z = x + 2 * y + x * y | 
 |             torch.autograd.grad((z.sum(), ), (x, y), retain_graph=True) | 
 |             torch.autograd.grad((z.sum(), ), (x, y)) | 
 |  | 
 |         x = torch.randn(2, 2, requires_grad=True) | 
 |         y = torch.randn(2, 2, requires_grad=True) | 
 |         self.checkScript(test_simple_grad, (x, y), inputs_requires_grad=True) | 
 |         self.checkScript(test_simple_grad_with_grad_outputs, (x, y), inputs_requires_grad=True) | 
 |         self.checkScript(test_one_output_not_requires_grad, (x, y), inputs_requires_grad=True) | 
 |         self.checkScript(test_retain_graph, (x, y), inputs_requires_grad=True) | 
 |  | 
 |     def test_script_backward(self): | 
 |         def checkBackwardScript(fn, inputs): | 
 |             scripted_fn = torch.jit.script(fn) | 
 |             FileCheck().check("torch.autograd.backward").run(scripted_fn.code) | 
 |             recording_inputs = do_input_map(lambda t: t.detach().requires_grad_(), inputs) | 
 |  | 
 |             fn(*inputs) | 
 |             scripted_fn(*recording_inputs) | 
 |  | 
 |             for inp1, inp2 in zip(inputs, recording_inputs): | 
 |                 self.assertEqual(inp1.grad, inp2.grad) | 
 |  | 
 |         def test_tensor_backward(input): | 
 |             # type: (Tensor) -> None | 
 |             output = torch.relu(input) | 
 |             output = output.softmax(0) | 
 |             sum_out = output.sum() | 
 |             sum_out.backward() | 
 |  | 
 |         def test_torch_autograd_backward(input): | 
 |             # type: (Tensor) -> None | 
 |             output = torch.relu(input) | 
 |             output = output.softmax(0) | 
 |             torch.autograd.backward(output.sum()) | 
 |  | 
 |         def test_torch_autograd_backward_with_grad_tensors(input): | 
 |             # type: (Tensor) -> None | 
 |             output = torch.relu(input) | 
 |             output = output.softmax(0) | 
 |             grad_outputs = torch.jit.annotate(List[Optional[torch.Tensor]], [torch.ones((2, 2)), ]) | 
 |             torch.autograd.backward((output,), grad_outputs) | 
 |  | 
 |         inp = torch.randn(2, 2, requires_grad=True) | 
 |         checkBackwardScript(test_tensor_backward, (inp,)) | 
 |         checkBackwardScript(test_torch_autograd_backward, (inp,)) | 
 |         checkBackwardScript(test_torch_autograd_backward_with_grad_tensors, (inp,)) | 
 |  | 
 |     def test_script_backward_twice(self): | 
 |         def checkBackwardTwiceScript(fn, inputs, retain_graph_=False): | 
 |             torch._C._jit_set_profiling_executor(False) | 
 |  | 
 |             with torch.jit.optimized_execution(True): | 
 |                 scripted_fn = torch.jit.script(fn, inputs) | 
 |                 FileCheck().check("prim::DifferentiableGraph").run(scripted_fn.graph_for(*inputs)) | 
 |  | 
 |                 result = scripted_fn(*inputs) | 
 |                 result.sum().backward(retain_graph=retain_graph_) | 
 |                 if not retain_graph_: | 
 |                     self.assertRaisesRegex(RuntimeError, 'Specify retain_graph=True', | 
 |                                            lambda: result.sum().backward()) | 
 |                 else: | 
 |                     result.sum().backward() | 
 |  | 
 |         def test_script_backward_twice_with_saved_values(input1, input2): | 
 |             # type: (Tensor, Tensor) -> Tensor | 
 |             tmp1 = torch.mul(input1, input2) | 
 |             tmp2 = torch.abs(tmp1) | 
 |             if torch.equal(input1, input2): | 
 |                 tmp2 = torch.acos(tmp2) | 
 |             else: | 
 |                 tmp2 = torch.atan(tmp2) | 
 |             result = torch.add(tmp2, input2) | 
 |             return result | 
 |  | 
 |         inp1 = torch.randn(2, 2, requires_grad=True) | 
 |         inp2 = torch.randn(2, 2, requires_grad=True) | 
 |         checkBackwardTwiceScript(test_script_backward_twice_with_saved_values, (inp1, inp2), False) | 
 |         checkBackwardTwiceScript(test_script_backward_twice_with_saved_values, (inp1, inp2), True) | 
 |  | 
 |     def test_diff_subgraph_clones_constants(self): | 
 |         @torch.jit.script | 
 |         def f(x, y): | 
 |             return x + x + y + x + y + x + y + x + y + x | 
 |  | 
 |         def count_constants(graph): | 
 |             return sum(node.kind() == 'prim::Constant' for node in graph.nodes()) | 
 |  | 
 |         graph = f.graph.copy() | 
 |         self.run_pass('cse', graph) | 
 |         self.run_pass('create_autodiff_subgraphs', graph) | 
 |         nodes = list(graph.nodes()) | 
 |         self.assertEqual(count_constants(graph), 1) | 
 |         self.assertEqual(count_constants(nodes[1].g('Subgraph')), 1) | 
 |  | 
 |     # TODO: adapt this test to check that GraphExecutor treats them differently | 
 |     @unittest.skip("Need to be adjusted to Graph Executor") | 
 |     def test_arg_configurations(self): | 
 |         """Different arg configurations should trigger different traces""" | 
 |         x = Variable(torch.FloatTensor(4, 4).uniform_()) | 
 |         x_double = Variable(x.data.double()) | 
 |         x_grad = Variable(x.data.clone(), requires_grad=True) | 
 |         y = Variable(torch.randn(4)) | 
 |  | 
 |         configurations = [ | 
 |             (x,), | 
 |             (x_double,), | 
 |             (x_grad,), | 
 |             (y,), | 
 |             ([x, x],), | 
 |             ([x, y],), | 
 |         ] | 
 |         if torch.cuda.is_available(): | 
 |             x_cuda = Variable(x.data.cuda()) | 
 |             configurations += [ | 
 |                 (x_cuda,), | 
 |                 ([x, x_cuda],), | 
 |                 ([x_cuda, x],), | 
 |                 ([[x_cuda, x]],), | 
 |             ] | 
 |             if torch.cuda.device_count() > 1: | 
 |                 x_cuda_1 = Variable(x.data.cuda(1)) | 
 |                 configurations += [ | 
 |                     (x_cuda_1,), | 
 |                     ([x_cuda, x_cuda_1],), | 
 |                 ] | 
 |  | 
 |         @torch.jit.compile(nderivs=0) | 
 |         def fn(*args): | 
 |             in_vars, _ = torch._C._jit_flatten(args) | 
 |             return in_vars[0] + 1 | 
 |  | 
 |         for i, config in enumerate(configurations): | 
 |             self.assertFalse(fn.has_trace_for(*config)) | 
 |             fn(*config) | 
 |             self.assertTrue(fn.has_trace_for(*config)) | 
 |             for unk_config in configurations[i + 1:]: | 
 |                 self.assertFalse(fn.has_trace_for(*unk_config)) | 
 |         self.assertEqual(fn.hits, 0) | 
 |  | 
 |     def test_torch_sum(self): | 
 |         def fn(x): | 
 |             return torch.sum(x) | 
 |  | 
 |         def fn1(x, dim: int): | 
 |             return torch.sum(x, dim) | 
 |  | 
 |         x = torch.randn(3, 4) | 
 |         self.checkScript(fn, (x, )) | 
 |         self.checkScript(fn1, (x, 1, )) | 
 |         self.checkScript(fn1, (x, 0, )) | 
 |  | 
 |     def test_cse(self): | 
 |         x = torch.tensor([0.4, 0.3], requires_grad=True) | 
 |         y = torch.tensor([0.7, 0.5], requires_grad=True) | 
 |  | 
 |         def fn(x, y): | 
 |             w = (x + y) * (x + y) * (x + y) | 
 |             t = torch.tanh(w) + torch.tanh(w) | 
 |             z = (x + y) * (x + y) * (x + y) + t | 
 |             return z | 
 |  | 
 |         g, _ = torch.jit._get_trace_graph(fn, (x, y)) | 
 |         self.run_pass('cse', g) | 
 |         do_exactly = True | 
 |         FileCheck().check_count("add", 1).check_count("mul", 2, do_exactly) \ | 
 |             .check_count("tanh", 1, do_exactly).check_count("add", 2, do_exactly).check_next("return")  \ | 
 |             .run(str(g)) | 
 |  | 
 |         self.assertExportImport(g, (x, y)) | 
 |  | 
 |     def test_cse_not_introduce_aliasing(self): | 
 |         @torch.jit.script | 
 |         def tensor_alias_outputs(x): | 
 |             return x + x, x + x | 
 |  | 
 |         self.run_pass('cse', tensor_alias_outputs.graph) | 
 |         FileCheck().check_count("aten::add", 2).run(tensor_alias_outputs.graph) | 
 |  | 
 |         @torch.jit.script | 
 |         def ints_alias_outputs(x): | 
 |             # type: (int) -> Tuple[int, int] | 
 |             return x + x, x + x | 
 |  | 
 |         # non-aliasing types can be CSEd | 
 |         self.run_pass('cse', ints_alias_outputs.graph) | 
 |         FileCheck().check_count("aten::add", 1, exactly=True).run(ints_alias_outputs.graph) | 
 |  | 
 |     def test_recursive_cse(self): | 
 |         input_str = """ | 
 | graph(%x : Tensor, | 
 |       %y : Tensor, | 
 |       %20 : int): | 
 |   %2 : int = prim::Constant[value=1]() | 
 |   %3 : Tensor = aten::add(%x, %y, %2) | 
 |   %4 : int = aten::add(%2, %20) | 
 |   %5 : bool = aten::Bool(%4) | 
 |   %z : int = prim::If(%5) | 
 |     # CHECK: block | 
 |     block0(): | 
 |       # CHECK-NOT: aten::add | 
 |       %z.1 : int = aten::add(%2, %20) | 
 |       -> (%z.1) | 
 |     block1(): | 
 |       -> (%2) | 
 |   return (%z) | 
 | """ | 
 |         graph = parse_ir(input_str) | 
 |         self.run_pass('cse', graph) | 
 |         FileCheck().run(input_str, graph) | 
 |  | 
 |     def test_pattern_based_rewrite(self): | 
 |         # mul(mul(mul(mul(x,y),z),x),y) --> mul(mul(mulmul(x,y,z), x), y) --> | 
 |         # --> mulmul(mulmul(x,y,z), x, y) | 
 |         input_str = """ | 
 | graph(%x, %y, %z): | 
 |     # CHECK-NOT: aten::mul | 
 |     # CHECK: my::fused_mulmul | 
 |     %t = aten::mul(%x, %y) | 
 |     %p = aten::mul(%t, %z) | 
 |     # CHECK: my::fused_mulmul | 
 |     %u = aten::mul(%p, %x) | 
 |     %o = aten::mul(%u, %y) | 
 |     return (%o)""" | 
 |         graph = parse_ir(input_str) | 
 |         torch._C._jit_pass_custom_pattern_based_rewrite_graph(""" | 
 | graph(%a, %b, %c): | 
 |   %q = aten::mul(%a, %b) | 
 |   %r = aten::mul(%q, %c) | 
 |   return (%r)""", """ | 
 | graph(%a, %b, %c): | 
 |   %r = my::fused_mulmul(%a, %b, %c) | 
 |   return (%r)""", graph) | 
 |         FileCheck().run(input_str, graph) | 
 |  | 
 |         # Check that overlapping matches are handled correctly | 
 |         # mul(mul(mul(x,y),z),x) --> mul(mulmul(x,y,z), x) | 
 |         input_str = """ | 
 | graph(%x, %y, %z): | 
 |     # CHECK-NOT: aten::mul | 
 |     # CHECK: my::fused_mulmul | 
 |     %t = aten::mul(%x, %y) | 
 |     %p = aten::mul(%t, %z) | 
 |     # CHECK-NEXT: aten::mul | 
 |     %u = aten::mul(%p, %x) | 
 |     return (%u)""" | 
 |         graph = parse_ir(input_str) | 
 |         torch._C._jit_pass_custom_pattern_based_rewrite_graph(""" | 
 | graph(%a, %b, %c): | 
 |   %q = aten::mul(%a, %b) | 
 |   %r = aten::mul(%q, %c) | 
 |   return (%r)""", """ | 
 | graph(%a, %b, %c): | 
 |   %r = my::fused_mulmul(%a, %b, %c) | 
 |   return (%r)""", graph) | 
 |         FileCheck().run(input_str, graph) | 
 |  | 
 |         # Check add(mul(x,y),z) --> muladd(x,y,z) replacement | 
 |         input_str = """ | 
 | graph(%x, %y, %z): | 
 |     # CHECK-NOT: aten::mul | 
 |     # CHECK-NOT: aten::add | 
 |     %c = prim::Const[value=1]() | 
 |     %t = aten::mul(%x, %y) | 
 |     %p = aten::add(%t, %z, %c) | 
 |     # CHECK: my::muladd | 
 |     # CHECK-NEXT: return | 
 |     return (%p)""" | 
 |         graph = parse_ir(input_str) | 
 |         torch._C._jit_pass_custom_pattern_based_rewrite_graph(""" | 
 | graph(%a, %b, %c, %d): | 
 |   %q = aten::mul(%a, %b) | 
 |   %r = aten::add(%q, %c, %d) | 
 |   return (%r)""", """ | 
 | graph(%a, %b, %c, %d): | 
 |   %r = my::muladd(%a, %b, %c, %d) | 
 |   return (%r)""", graph) | 
 |         FileCheck().run(input_str, graph) | 
 |  | 
 |         # Check add(mul(x,y),z) --> sub(add(x,y),z) replacement | 
 |         input_str = """ | 
 | graph(%x, %y, %z): | 
 |     # CHECK-NOT: aten::mul | 
 |     %c = prim::Const[value=1]() | 
 |     # CHECK: aten::add | 
 |     %t = aten::mul(%x, %y) | 
 |     # CHECK-NEXT: aten::sub | 
 |     %p = aten::add(%t, %z, %c) | 
 |     # CHECK-NOT: aten::add | 
 |     # CHECK-NEXT: return | 
 |     return (%p)""" | 
 |         graph = parse_ir(input_str) | 
 |         torch._C._jit_pass_custom_pattern_based_rewrite_graph(""" | 
 | graph(%a, %b, %c, %d): | 
 |   %q = aten::mul(%a, %b) | 
 |   %r = aten::add(%q, %c, %d) | 
 |   return (%r)""", """ | 
 | graph(%a, %b, %c, %d): | 
 |   %q = aten::add(%a, %b, %d) | 
 |   %r = aten::sub(%q, %c, %d) | 
 |   return (%r)""", graph) | 
 |         FileCheck().run(input_str, graph) | 
 |  | 
 |         # Check mul(x,y) --> x replacement | 
 |         input_str = """ | 
 | graph(%x, %y, %z): | 
 |     %c = prim::Const[value=1]() | 
 |     # CHECK-NOT: aten::mul | 
 |     %t = aten::mul(%x, %y) | 
 |     # CHECK: aten::add(%x, %z | 
 |     %p = aten::add(%t, %z, %c) | 
 |     # CHECK-NEXT: return | 
 |     return (%p)""" | 
 |         graph = parse_ir(input_str) | 
 |         torch._C._jit_pass_custom_pattern_based_rewrite_graph(""" | 
 | graph(%Pa, %Pb): | 
 |   %Pq = aten::mul(%Pa, %Pb) | 
 |   return (%Pq)""", """ | 
 | graph(%Ra, %Rb): | 
 |   return (%Ra)""", graph) | 
 |         FileCheck().run(input_str, graph) | 
 |  | 
 |     @_tmp_donotuse_dont_inline_everything | 
 |     def test_pattern_based_module_rewrite(self): | 
 |         # Check match::module behavior | 
 |         class Test(torch.nn.Module): | 
 |             def __init__(self): | 
 |                 super(Test, self).__init__() | 
 |                 self.conv = torch.nn.Conv2d(1, 20, 5, 1) | 
 |                 self.bn = torch.nn.BatchNorm2d(num_features=20) | 
 |  | 
 |             def forward(self, x): | 
 |                 x = self.conv(x) | 
 |                 x = self.bn(x) | 
 |                 return x | 
 |         m = torch.jit.script(Test()) | 
 |         torch._C._jit_pass_custom_pattern_based_rewrite_graph(""" | 
 |         graph(%self, %x): | 
 |                 %conv = match::module[name="Conv2d"](%self) | 
 |                 %y = prim::CallMethod[name="forward"](%conv, %x) | 
 |                 %bn = match::module[name="BatchNorm2d"](%self) | 
 |                 %z = prim::CallMethod[name="forward"](%bn, %y) | 
 |                 return (%z)""", """ | 
 |         graph(%self, %x): | 
 |           %z = my::matched_conv_bn(%self, %x) | 
 |           return (%z)""", m._c._get_method("forward").graph) | 
 |  | 
 |         FileCheck().check("my::matched_conv_bn").run(m._c._get_method("forward").graph) | 
 |  | 
 |     def test_pattern_based_rewrite_with_source_range_preserved(self): | 
 |         class TestModule1(torch.nn.Module): | 
 |             def __init__(self): | 
 |                 super(TestModule1, self).__init__() | 
 |  | 
 |             def forward(self, x, y, z, w): | 
 |                 x = x + y | 
 |                 x = x * z | 
 |                 return w - x | 
 |  | 
 |         input_pattern = """ | 
 |         graph(%x, %y, %z, %const): | 
 |             %t = aten::add(%x, %y, %const) | 
 |             %o = aten::mul(%t, %z) | 
 |             return (%o)""" | 
 |         replacement_pattern = """ | 
 |         graph(%x, %y, %z, %const): | 
 |             %o = my::add_mul(%x, %y, %z, %const) | 
 |             return (%o)""" | 
 |         scripted_model = torch.jit.script(TestModule1()) | 
 |         graph = scripted_model.graph | 
 |         value_mappings = [("o", "t")] | 
 |         for node in graph.nodes(): | 
 |             if node.kind() == "aten::add": | 
 |                 source_range_1 = node.sourceRange() | 
 |         torch._C._jit_pass_custom_pattern_based_rewrite_graph( | 
 |             input_pattern, replacement_pattern, scripted_model.graph, value_name_pairs=value_mappings) | 
 |         graph = scripted_model.graph | 
 |         for node in graph.nodes(): | 
 |             if node.kind() == "my::add_mul": | 
 |                 source_range_2 = node.sourceRange() | 
 |         self.assertTrue(source_range_1 == source_range_2) | 
 |  | 
 |         class TestModule2(torch.nn.Module): | 
 |             def __init__(self): | 
 |                 super(TestModule2, self).__init__() | 
 |  | 
 |             def forward(self, x, y, z, w): | 
 |                 x = x + y | 
 |                 x = x + z | 
 |                 x = x * z | 
 |                 x = x * w | 
 |                 return x - 2 | 
 |  | 
 |         # Check source range preservation for two node transforms add -> my_add | 
 |         input_pattern = """ | 
 |         graph(%x, %y, %const): | 
 |             %o = aten::add(%x, %y, %const) | 
 |             return (%o)""" | 
 |         replacement_pattern = """ | 
 |         graph(%x, %y, %const): | 
 |             %o = my::add(%x, %y, %const) | 
 |             return (%o)""" | 
 |         scripted_model = copy.deepcopy(torch.jit.script(TestModule2())) | 
 |         graph_copy = scripted_model.graph.copy() | 
 |         value_mappings = [("o", "o")] | 
 |         source_range_add_1 = None | 
 |         for node in graph_copy.nodes(): | 
 |             if source_range_add_1 is None and node.kind() == "aten::add": | 
 |                 source_range_add_1 = node.sourceRange() | 
 |             if source_range_add_1 is not None and node.kind() == "aten::add": | 
 |                 source_range_add_2 = node.sourceRange() | 
 |         torch._C._jit_pass_custom_pattern_based_rewrite_graph( | 
 |             input_pattern, replacement_pattern, graph_copy, value_name_pairs=value_mappings) | 
 |         source_range_my_add_1 = None | 
 |         for node in graph_copy.nodes(): | 
 |             if source_range_my_add_1 is None and node.kind() == "my::add": | 
 |                 source_range_my_add_1 = node.sourceRange() | 
 |             if source_range_my_add_1 is not None and node.kind() == "my::add": | 
 |                 source_range_my_add_2 = node.sourceRange() | 
 |         self.assertTrue(source_range_add_1 == source_range_my_add_1) | 
 |         self.assertTrue(source_range_add_2 == source_range_my_add_2) | 
 |  | 
 |         # Check source range preservation for add-add -> double_add transform | 
 |         # fuse nodes | 
 |         input_pattern = """ | 
 |         graph(%x, %y, %z, %const): | 
 |             %t = aten::add(%x, %y, %const) | 
 |             %o = aten::add(%t, %z, %const) | 
 |             return (%o)""" | 
 |         replacement_pattern = """ | 
 |         graph(%x, %y, %z, %const): | 
 |             %o = my::double_add(%x, %y, %z, %const) | 
 |             return (%o)""" | 
 |         scripted_model = torch.jit.script(TestModule2()) | 
 |         graph_copy = scripted_model.graph.copy() | 
 |         value_mappings = [("o", "t")] | 
 |         source_range_1 = None | 
 |         source_range_2 = None | 
 |         for node in graph_copy.nodes(): | 
 |             if node.kind() == "aten::add": | 
 |                 source_range_1 = node.sourceRange() | 
 |                 break | 
 |         torch._C._jit_pass_custom_pattern_based_rewrite_graph( | 
 |             input_pattern, replacement_pattern, graph_copy, value_name_pairs=value_mappings) | 
 |         for node in graph_copy.nodes(): | 
 |             if node.kind() == "my::double_add": | 
 |                 source_range_2 = node.sourceRange() | 
 |         self.assertTrue(source_range_1 == source_range_2) | 
 |  | 
 |         # Check source range preservation for mul -> add + add transform | 
 |         # split node | 
 |         input_pattern = """ | 
 |         graph(%x, %y): | 
 |             %t = aten::mul(%x, %y) | 
 |             return (%t)""" | 
 |         replacement_pattern = """ | 
 |         graph(%x, %y): | 
 |             %t = my::add(%x, %y) | 
 |             %o = my::add(%t, %y) | 
 |             return (%o)""" | 
 |         scripted_model = torch.jit.script(TestModule2()) | 
 |         graph_copy = scripted_model.graph.copy() | 
 |         value_mappings = [("t", "t"), ("o", "t")] | 
 |         source_range_mul_1 = None | 
 |         for node in graph_copy.nodes(): | 
 |             if source_range_mul_1 is None and node.kind() == "aten::mul": | 
 |                 source_range_mul_1 = node.sourceRange() | 
 |             if source_range_mul_1 is not None and node.kind() == "aten::mul": | 
 |                 source_range_mul_2 = node.sourceRange() | 
 |         torch._C._jit_pass_custom_pattern_based_rewrite_graph( | 
 |             input_pattern, replacement_pattern, graph_copy, value_name_pairs=value_mappings) | 
 |         source_range_add_1 = None | 
 |         for node in graph_copy.nodes(): | 
 |             if source_range_add_1 is None and node.kind() == "my::add": | 
 |                 source_range_add_1 = node.sourceRange() | 
 |             if source_range_add_1 is not None and node.kind() == "my::add": | 
 |                 source_range_add_2 = node.sourceRange() | 
 |         self.assertTrue(source_range_mul_1 == source_range_add_1) | 
 |         self.assertTrue(source_range_mul_2 == source_range_add_2) | 
 |  | 
 |         # Check lack of source range preservation for mul-mul-> double_mul transform | 
 |         input_pattern = """ | 
 |         graph(%x, %y, %z): | 
 |             %t = aten::mul(%x, %y) | 
 |             %o = aten::mul(%t, %z) | 
 |             return (%o)""" | 
 |         replacement_pattern = """ | 
 |         graph(%x, %y, %z): | 
 |             %o = my::double_mul(%x, %y, %z) | 
 |             return (%o)""" | 
 |         scripted_model = torch.jit.script(TestModule2()) | 
 |         graph_copy = scripted_model.graph.copy() | 
 |         for node in graph_copy.nodes(): | 
 |             if node.kind() == "aten::mul": | 
 |                 source_range_1 = node.sourceRange() | 
 |         torch._C._jit_pass_custom_pattern_based_rewrite_graph(input_pattern, replacement_pattern, graph_copy) | 
 |         for node in graph_copy.nodes(): | 
 |             if node.kind() == "my::double_mul": | 
 |                 source_range_2 = node.sourceRange() | 
 |         self.assertFalse(source_range_1 == source_range_2) | 
 |  | 
 |     def test_expand_quantlint(self): | 
 |         pass | 
 |  | 
 |     def test_expand_fold_quant_inputs(self): | 
 |         pass | 
 |  | 
 |     def test_shape_analysis_broadcast(self): | 
 |         def broadcast(a, b): | 
 |             return a + b | 
 |  | 
 |         x = torch.randn(3, 1, 5, requires_grad=True) | 
 |         y = torch.randn(4, 1, 8, 5, requires_grad=True) | 
 |  | 
 |         graph = torch.jit.script(broadcast).graph | 
 |         torch._C._jit_pass_complete_shape_analysis(graph, (x, y), False) | 
 |         FileCheck().check("Double(4, 3, 8, 5, strides=[120, 40, 5, 1], device=cpu)").run(str(graph)) | 
 |  | 
 |     def test_shape_analysis_unsqueeze_in_loop(self): | 
 |         input_str = """graph(%x.1 : Tensor): | 
 |           %4 : bool = prim::Constant[value=1]() | 
 |           %1 : int = prim::Constant[value=2]() | 
 |           %7 : int = prim::Constant[value=0]() | 
 |           # CHECK: FloatTensor(requires_grad=0, device=cpu) = prim::Loop | 
 |           %x : Tensor = prim::Loop(%1, %4, %x.1) | 
 |             # CHECK: : FloatTensor(requires_grad=0, device=cpu)): | 
 |             block0(%i : int, %x.6 : Tensor): | 
 |               # CHECK: FloatTensor(requires_grad=0, device=cpu) = aten::unsqueeze | 
 |               %x.3 : Tensor = aten::unsqueeze(%x.6, %7) | 
 |               -> (%4, %x.3) | 
 |           return (%x)""" | 
 |         graph = parse_ir(input_str) | 
 |         torch._C._jit_pass_complete_shape_analysis(graph, (torch.zeros(2, 2, dtype=torch.float32),), False) | 
 |         FileCheck().run(input_str, graph) | 
 |  | 
 |     def test_script_tensor_type(self): | 
 |         def foo(x, t: torch.dtype): | 
 |             return x.type(t) | 
 |         scr = torch.jit.script(foo) | 
 |         x = torch.rand(3, 4) | 
 |         for t in [torch.int8, torch.float64, torch.float32, | 
 |                   torch.bfloat16, torch.complex64, torch.complex128, torch.bool]: | 
 |             self.assertEqual(scr(x, t), foo(x, t)) | 
 |  | 
 |     def test_shape_analysis_masked_select(self): | 
 |         input_str = """graph(%0 : Float(), | 
 |           %1 : Bool()): | 
 |           # CHECK: Float(*, requires_grad=0, device=cpu) = aten::masked_select | 
 |           %2 : Tensor = aten::masked_select(%0, %1) # test/test_jit.py:15261:0 | 
 |           return (%2)""" | 
 |         graph = parse_ir(input_str) | 
 |         x = torch.ones(1, dtype=torch.float32)[0] | 
 |         mask = x.ge(0.5) | 
 |         torch._C._jit_pass_complete_shape_analysis(graph, (x, mask), False) | 
 |         FileCheck().run(input_str, graph) | 
 |  | 
 |     # TODO: update verify to work with GraphExecutors | 
 |     @unittest.skip("verify needs to be updated to work with GraphExecutors") | 
 |     def test_verify(self): | 
 |         x = torch.tensor([0.4], requires_grad=True) | 
 |         y = torch.tensor([0.7], requires_grad=True) | 
 |  | 
 |         @torch.jit.compile | 
 |         def f(x, y): | 
 |             z = torch.sigmoid(x * (x + y)) | 
 |             w = torch.abs(x * x * x + y) + Variable(torch.ones(1)) | 
 |             return z, w | 
 |  | 
 |         torch.jit.verify(f, (x, y), loss_fn=lambda z, w: z * w, devices=[]) | 
 |  | 
 |     # TODO: adapt to a GraphExecutor test | 
 |     @unittest.skip("Need to instrument GraphExecutors a bit more") | 
 |     def test_flags(self): | 
 |         x, y = torch.randn(2, 2) | 
 |         y = Variable(torch.randn(2, 2)) | 
 |  | 
 |         @torch.jit.compile | 
 |         def fn(x, y): | 
 |             return (x * x + y * y + x * y).sum() | 
 |  | 
 |         grads = {} | 
 |         for rx, ry in product((True, False), repeat=2): | 
 |             x.requires_grad = rx | 
 |             y.requires_grad = ry | 
 |  | 
 |             self.assertFalse(fn.has_trace_for(x, y)) | 
 |             out = fn(x, y) | 
 |  | 
 |             self.assertFalse(fn.has_trace_for(x, y)) | 
 |             for v, name, compute in [(x, 'x', rx), (y, 'y', ry)]: | 
 |                 if not compute: | 
 |                     continue | 
 |                 grad_v, = torch.autograd.grad(out, v, retain_graph=True) | 
 |                 expected_grad = grads.setdefault(name, grad_v) | 
 |                 self.assertEqual(grad_v, expected_grad) | 
 |             self.assertEqual(fn.has_trace_for(x, y), rx or ry) | 
 |  | 
 |     def test_python_ir(self): | 
 |         x = torch.tensor([0.4], requires_grad=True) | 
 |         y = torch.tensor([0.7], requires_grad=True) | 
 |  | 
 |         def doit(x, y): | 
 |             return torch.sigmoid(torch.tanh(x * (x + y))) | 
 |  | 
 |         g, _ = torch.jit._get_trace_graph(doit, (x, y)) | 
 |         self.run_pass('dce', g) | 
 |         self.run_pass('canonicalize', g) | 
 |         g2 = torch._C.Graph() | 
 |         g_to_g2 = {} | 
 |         for node in g.inputs(): | 
 |             g_to_g2[node] = g2.addInput() | 
 |         for node in g.nodes(): | 
 |             n_ = g2.createClone(node, lambda x: g_to_g2[x]) | 
 |             g2.appendNode(n_) | 
 |             for o, no in zip(node.outputs(), n_.outputs()): | 
 |                 g_to_g2[o] = no | 
 |  | 
 |         for node in g.outputs(): | 
 |             g2.registerOutput(g_to_g2[node]) | 
 |  | 
 |         t_node = g2.create("prim::TensorTest").t_("a", torch.ones([2, 2])) | 
 |         self.assertEqual(t_node.attributeNames(), ["a"]) | 
 |         g2.appendNode(t_node) | 
 |         self.assertTrue(torch.equal(torch.ones(2, 2), t_node.t("a"))) | 
 |         for node in g.nodes(): | 
 |             self.assertTrue(g2.findNode(node.kind()) is not None) | 
 |  | 
 |     def test_permute_inputs_binding(self): | 
 |         @torch.jit.script | 
 |         def foo(i, j, k): | 
 |             pass | 
 |  | 
 |         g = foo.graph | 
 |  | 
 |         idxs = [] | 
 |         for i, inp in enumerate(g.inputs()): | 
 |             inp.setDebugName(f"inp{i}") | 
 |             idxs.append(i) | 
 |  | 
 |         permuted_idxs = list(np.random.permutation(idxs)) | 
 |         g.permuteInputs(permuted_idxs) | 
 |         for i, inp in enumerate(g.inputs()): | 
 |             self.assertEqual(f"inp{permuted_idxs[i]}", inp.debugName()) | 
 |  | 
 |     @unittest.skipIf(IS_MACOS, "Failing on MacOS only") | 
 |     def test_python_ir_utils(self): | 
 |         @torch.jit.script | 
 |         def foo(inp): | 
 |             x = inp + 1 | 
 |             y = x / 2 | 
 |             z = y * y | 
 |             return z | 
 |  | 
 |         add_node = foo.graph.findNode("aten::add") | 
 |         div_node = foo.graph.findNode("aten::div") | 
 |  | 
 |         with foo.graph.insert_point_guard(add_node): | 
 |             with foo.graph.insert_point_guard(div_node): | 
 |                 foo.graph.insertConstant("goodbye") | 
 |             foo.graph.insertConstant("hello") | 
 |         with foo.graph.insert_point_guard(foo.graph.findNode("aten::mul")): | 
 |             foo.graph.insertConstant("hello") | 
 |         FileCheck().check("hello").check("goodbye").check("hello").run(foo.graph) | 
 |  | 
 |         self.assertTrue(add_node.matches(add_node.schema())) | 
 |         self.assertFalse(add_node.matches(div_node.schema())) | 
 |  | 
 |     def test_python_ir_utils_graph(self): | 
 |         @torch.jit.script | 
 |         def unrolled_mul(x: torch.Tensor, y: int): | 
 |             out = x | 
 |             for _ in range(y - 1): | 
 |                 out = out + x | 
 |             return out | 
 |  | 
 |         @torch.jit.script | 
 |         def foo(x): | 
 |             return x * 4 | 
 |  | 
 |         g = foo.graph | 
 |         muls = g.findAllNodes("aten::mul") | 
 |         scalar_muls = filter(lambda x: x.matches("aten::mul(Tensor self, Scalar other) -> Tensor"), muls) | 
 |         mul_constant_int = filter(lambda x: isinstance(list(x.inputs())[1].toIValue(), int), scalar_muls) | 
 |         for mul in mul_constant_int: | 
 |             with g.insert_point_guard(mul): | 
 |                 outputs = g.insertGraph(unrolled_mul.graph, list(mul.inputs())) | 
 |                 assert len(outputs) == len(list(mul.outputs())) | 
 |                 for new_out, old_out in zip(outputs, g.outputs()): | 
 |                     old_out.replaceAllUsesWith(new_out) | 
 |                 mul.destroy() | 
 |  | 
 |         FileCheck().check_not("aten::mul").check("aten::add").run(foo.graph) | 
 |         self.assertEqual(foo(torch.ones([2, 2])), torch.ones([2, 2]) * 4) | 
 |  | 
 |     @unittest.skipIf(IS_SANDCASTLE, "gtest runs these in sandcastle") | 
 |     @unittest.skipIf(RUN_CUDA, "covered by test_cpp_cuda") | 
 |     @unittest.skipIf(not torch._C._jit_has_cpp_tests(), "Tests were not built, use BUILD_TEST=1") | 
 |     def test_cpp(self): | 
 |         from cpp.jit import tests_setup | 
 |         tests_setup.setup() | 
 |         torch._C._jit_run_cpp_tests() | 
 |         tests_setup.shutdown() | 
 |  | 
 |     def test_batchnorm(self): | 
 |         x = torch.ones(2, 2, 2, 2) | 
 |         g, outputs, inputs = torch.jit._get_trace_graph(nn.BatchNorm2d(2), x, | 
 |                                                         _force_outplace=True, return_inputs=True) | 
 |         m = self.createFunctionFromGraph(g) | 
 |         self.assertEqual(outputs, m(*inputs)) | 
 |  | 
 |     def test_dropout(self): | 
 |         x = torch.ones(2, 2) | 
 |         with torch.random.fork_rng(devices=[]): | 
 |             g, outputs, inputs = torch.jit._get_trace_graph(nn.Dropout(0.6), x, return_inputs=True) | 
 |         with torch.random.fork_rng(devices=[]): | 
 |             m = self.createFunctionFromGraph(g) | 
 |             self.assertEqual(outputs, m(*inputs)) | 
 |  | 
 |     @unittest.skipIf(not RUN_CUDA, "test requires CUDA") | 
 |     @unittest.skipIf(GRAPH_EXECUTOR != ProfilingMode.PROFILING, "skip if profiling isn't enabled") | 
 |     def test_native_dropout_corner_case(self): | 
 |         with disable_autodiff_subgraph_inlining(): | 
 |             def t(x, p: float, t: bool): | 
 |                 o = torch.dropout(x, p, t) | 
 |                 return o | 
 |  | 
 |             jit_t = torch.jit.script(t) | 
 |             x = torch.randn(5).requires_grad_() | 
 |             FileCheck().check("prim::DifferentiableGraph").run(jit_t.graph_for(x, 1.0, True, profile_and_replay=True)) | 
 |  | 
 |             for train in [True, False]: | 
 |                 for p in [0.0, 1.0]: | 
 |                     for device in ["cuda", "cpu"]: | 
 |                         x = torch.randn(5).to(device=device).requires_grad_() | 
 |                         x_ref = x.detach().requires_grad_() | 
 |                         o = jit_t(x, p, train) | 
 |                         o_ref = t(x_ref, p, train) | 
 |                         o.sum().backward() | 
 |                         o_ref.sum().backward() | 
 |                         assert(o.equal(o_ref)) | 
 |                         assert(x.grad.equal(x_ref.grad)) | 
 |  | 
 |     @slowTest | 
 |     @unittest.skipIf(GRAPH_EXECUTOR != ProfilingMode.LEGACY, 'Testing differentiable graph') | 
 |     def test_dropout_module_requires_grad(self): | 
 |         with enable_profiling_mode_for_profiling_tests(): | 
 |             class MyModule(torch.nn.Module): | 
 |                 def __init__(self, M): | 
 |                     super(MyModule, self).__init__() | 
 |                     self.dropout = torch.nn.Dropout(0.5) | 
 |                     self.linear = torch.nn.Linear(M, M) | 
 |  | 
 |                 def forward(self, input): | 
 |                     input = self.dropout(input) | 
 |                     output = self.linear(input) | 
 |                     return output | 
 |  | 
 |             def profile(func, X): | 
 |                 with torch.autograd.profiler.profile() as prof: | 
 |                     func(X) | 
 |                 return [e.name for e in prof.function_events] | 
 |  | 
 |             M = 1000 | 
 |             scripted = torch.jit.script(MyModule(M)) | 
 |             # To reduce confusion about expected behaviors: | 
 |             #   requires_grad controls whether dropout is symbolically differentiated. | 
 |             #   training controls whether bernoulli_ is called inside symbolic differentiation of dropout. | 
 |             # * When requires_grad == training, the expected behaviors are obvious. | 
 |             # * When requires_grad=True and training=False, bernoulli_ might still show up in the graph. | 
 |             #   But it's in a branch that's not called. That's why we have separate checks for autograd | 
 |             #   profiler to make sure it's not run. | 
 |             # * When requires_grad=False and training=True, bernoulli_ must be run since it's the expected | 
 |             #   behavior for the dropout layer in training mode. It's independent of whether graph requires | 
 |             #   gradient. In fact bernoulli_ comes from autograd instead of autodiff in this case. | 
 |             for training in (True, False): | 
 |                 if training: | 
 |                     scripted.train() | 
 |                 else: | 
 |                     scripted.eval() | 
 |                 for requires_grad in (True, False): | 
 |                     X = torch.randn(M, M, requires_grad=requires_grad) | 
 |                     if requires_grad: | 
 |                         FileCheck().check("aten::native_dropout").run(scripted.graph_for(X, profile_and_replay=True)) | 
 |                     self.assertEqual(training, 'aten::bernoulli_' in profile(scripted, X)) | 
 |  | 
 |     @unittest.skipIf(GRAPH_EXECUTOR == ProfilingMode.SIMPLE, 'Testing differentiable graph') | 
 |     @skipIfTorchDynamo("Torchdynamo cannot correctly handle profiler.profile calls") | 
 |     def test_dropout_func_requires_grad(self): | 
 |         def dropout_training(input): | 
 |             return F.dropout(input, 0.5, training=True) | 
 |  | 
 |         def dropout_eval(input): | 
 |             return F.dropout(input, 0.5, training=False) | 
 |  | 
 |         def profile(func, X): | 
 |             with torch.autograd.profiler.profile() as prof: | 
 |                 func(X) | 
 |             return [e.name for e in prof.function_events] | 
 |  | 
 |         M = 1000 | 
 |         scripted_training = torch.jit.script(dropout_training) | 
 |         scripted_eval = torch.jit.script(dropout_eval) | 
 |         # See comments in test_dropout_module_requires_grad. | 
 |         with disable_autodiff_subgraph_inlining(): | 
 |             for requires_grad in (True, False): | 
 |                 X = torch.randn(M, M, requires_grad=requires_grad) | 
 |                 if requires_grad: | 
 |                     FileCheck().check("aten::native_dropout").run(scripted_training.graph_for(X, profile_and_replay=True)) | 
 |                 self.assertIn('aten::bernoulli_', profile(scripted_training, X)) | 
 |                 self.assertNotIn('aten::bernoulli_', profile(scripted_eval, X)) | 
 |  | 
 |     @unittest.skipIf(not RUN_CUDA, "test_dropout_cuda require CUDA") | 
 |     def test_dropout_cuda(self): | 
 |         # Dropout AD is dispatched to _fused_dropout in CUDA case, | 
 |         # which is not included in TestJitGeneratedFunctional | 
 |         def _zero_rate(t): | 
 |             return torch.true_divide((t == 0).sum(), t.numel()) | 
 |  | 
 |         x = torch.ones(1000, 1000).cuda().requires_grad_() | 
 |  | 
 |         with enable_profiling_mode_for_profiling_tests(): | 
 |             @torch.jit.script | 
 |             def func(x): | 
 |                 return torch.nn.functional.dropout(x) | 
 |  | 
 |             with freeze_rng_state(): | 
 |                 out_ref = torch.nn.functional.dropout(x) | 
 |                 grad_ref = torch.autograd.grad(out_ref.sum(), x) | 
 |  | 
 |             with freeze_rng_state(): | 
 |                 out = func(x) | 
 |                 grad = torch.autograd.grad(out.sum(), x) | 
 |  | 
 |             # TODO(#40882): previously we assert exact matches between eager and JIT result: | 
 |             #  self.assertEqual(out, out_ref) | 
 |             #  self.assertEqual(grad, grad_ref) | 
 |             # This test was disabled during legacy -> profiling executor transition. | 
 |             # Currently JIT fused results doesn't match eager result exactly due to some changes merged in between. | 
 |             # We temporarily only check statstical difference but it should be reverted once the issue is fixed. | 
 |             self.assertEqual(_zero_rate(out), _zero_rate(out_ref), rtol=1e-3, atol=1e-4) | 
 |             self.assertEqual(_zero_rate(grad[0]), _zero_rate(grad_ref[0]), rtol=1e-3, atol=1e-4) | 
 |  | 
 |     def test_torch_ops_overloaded(self): | 
 |         with self.assertRaisesRegex(RuntimeError, "failed to many any schema"): | 
 |             torch.ops.aten.add("a", 1) | 
 |         self.assertEqual("ab", torch.ops.aten.add("a", "b")) | 
 |         a, b = torch.rand(3, 4), torch.rand(3, 4) | 
 |         self.assertEqual(a + b, torch.ops.aten.add(a, b)) | 
 |         self.assertEqual(a + 1, torch.ops.aten.add(a, 1)) | 
 |  | 
 |     def test_torch_ops_kwonly(self): | 
 |         a, b = torch.rand(3, 4), torch.rand(3, 4) | 
 |         with self.assertRaisesRegex(RuntimeError, "positional argument"): | 
 |             torch.ops.aten.add(a, b, 2) | 
 |         # h/t Chillee for this ambiguous case | 
 |         self.assertEqual(a.prod(1), torch.ops.aten.prod(a, 1)) | 
 |  | 
 |     def test_torch_complex(self): | 
 |         def fn(real, img): | 
 |             return torch.complex(real, img) | 
 |  | 
 |         def fn_out(real, img, out): | 
 |             return torch.complex(real, img, out=out) | 
 |         self.checkScript(fn, (torch.rand(3, 4), torch.rand(3, 4), )) | 
 |         self.checkScript(fn, (torch.ones(5, 1, 4), torch.ones(5, 1, 4), )) | 
 |         self.checkScript(fn, (torch.zeros(1, 6), torch.ones(6, 1), )) | 
 |         self.checkScript(fn, (torch.zeros(1, 6), torch.zeros(6, 1), )) | 
 |         self.checkScript(fn, (torch.empty(3, 4), torch.empty(3, 4), )) | 
 |  | 
 |         real = torch.tensor([1, 2], dtype=torch.float32) | 
 |         img = torch.tensor([3, 4], dtype=torch.float32) | 
 |         out = torch.empty([3, 4], dtype=torch.complex64) | 
 |         self.checkScript(fn_out, (real, img, out, )) | 
 |  | 
 |         real = torch.tensor([5, 2], dtype=torch.float64) | 
 |         img = torch.tensor([3, 4], dtype=torch.float64) | 
 |         out = torch.empty([5, 2], dtype=torch.complex128) | 
 |         self.checkScript(fn_out, (real, img, out, )) | 
 |  | 
 |         real = torch.ones([1, 2]) | 
 |         img = torch.ones([1, 2]) | 
 |         out = torch.empty([1, 2], dtype=torch.complex128) | 
 |         self.checkScript(fn_out, (real, img, out, )) | 
 |  | 
 |         real = torch.ones([3, 8, 7]) | 
 |         img = torch.ones([3, 8, 7]) | 
 |         out = torch.empty([3, 8, 7], dtype=torch.complex128) | 
 |         self.checkScript(fn_out, (real, img, out, )) | 
 |  | 
 |         real = torch.empty([3, 2, 6]) | 
 |         img = torch.empty([3, 2, 6]) | 
 |         out = torch.empty([3, 2, 6], dtype=torch.complex128) | 
 |         self.checkScript(fn_out, (real, img, out, )) | 
 |  | 
 |         real = torch.zeros([1, 3]) | 
 |         img = torch.empty([3, 1]) | 
 |         out = torch.empty([3, 3], dtype=torch.complex128) | 
 |         self.checkScript(fn_out, (real, img, out, )) | 
 |  | 
 |         real = torch.ones([2, 5]) | 
 |         img = torch.empty([2, 1]) | 
 |         out = torch.empty([2, 5], dtype=torch.complex128) | 
 |         self.checkScript(fn_out, (real, img, out, )) | 
 |  | 
 |         real = torch.ones([2, 5]) | 
 |         img = torch.zeros([2, 1]) | 
 |         out = torch.empty([2, 5], dtype=torch.complex128) | 
 |         self.checkScript(fn_out, (real, img, out, )) | 
 |  | 
 |     def test_einsum(self): | 
 |         def check(fn, jitted, *args): | 
 |             self.assertGraphContains(jitted.graph, kind='aten::einsum') | 
 |             self.assertEqual(fn(*args), jitted(*args)) | 
 |  | 
 |         def equation_format(x, y): | 
 |             return torch.einsum('i,j->ij', (x, y)) | 
 |  | 
 |         def equation_format_varargs(x, y): | 
 |             return torch.einsum('i,j->ij', x, y) | 
 |  | 
 |         def sublist_format(x, y): | 
 |             return torch.einsum(x, [0], y, [1], [0, 1]) | 
 |  | 
 |         x = make_tensor((5,), dtype=torch.float32, device="cpu") | 
 |         y = make_tensor((10,), dtype=torch.float32, device="cpu") | 
 |  | 
 |         for fn in [equation_format, equation_format_varargs, sublist_format]: | 
 |             check(fn, torch.jit.script(fn), x, y) | 
 |             check(fn, torch.jit.trace(fn, (x, y)), x, y) | 
 |  | 
 |     @skipIfTorchDynamo("TorchDynamo fails with unknown reason") | 
 |     def test_python_ivalue(self): | 
 |         # Test if pure python object can be hold as IValue and conversion | 
 |         # between IValue and PyObject are correct | 
 |         # test for numpy object | 
 |         py_array = np.arange(15) | 
 |         ret_py_obj = torch._C._ivalue_debug_python_object(py_array) | 
 |         self.assertEqual(py_array, ret_py_obj) | 
 |  | 
 |         # test for function object | 
 |         ret_py_obj = torch._C._ivalue_debug_python_object(F.relu) | 
 |         self.assertEqual(F.relu, ret_py_obj) | 
 |  | 
 |         # test for memory management | 
 |         # we need to ensure IValue correctly call incref/decref to avoid | 
 |         # dangling behavior and potential memory leaks during conversions | 
 |         def test_func_scope_helper(inp): | 
 |             # create a scope and do the conversion -> ivalue -> pyobject | 
 |             # this func return a new pyobject that refcount + 1 | 
 |             inp_refcount = sys.getrefcount(inp) | 
 |             ivalue_holder = torch._C._ivalue_debug_python_object(inp) | 
 |             self.assertEqual(inp_refcount + 1, sys.getrefcount(ivalue_holder)) | 
 |             return ivalue_holder + 1 | 
 |  | 
 |         test_input = 2200 | 
 |         before_count = sys.getrefcount(test_input) | 
 |         test_func_scope_helper(test_input) | 
 |         after_count = sys.getrefcount(test_input) | 
 |  | 
 |         # after the test_func_scope_helper_call, the refcount of | 
 |         # test_input should be equal to the original refcount | 
 |         # otherwise we get either dangling pointer or memory leak! | 
 |         self.assertEqual(before_count, after_count) | 
 |  | 
 |     def test_decompose_addmm(self): | 
 |         def does_decompose(): | 
 |             @torch.jit.script | 
 |             def addmm(mat, mat1, mat2): | 
 |                 a = mat.addmm(mat1, mat2) | 
 |                 b = mat.addmm(mat1, mat2, alpha=1.0, beta=1.0) | 
 |                 return a + b | 
 |  | 
 |             mat = torch.randn(2, 2) | 
 |             mat1 = torch.randn(2, 4) | 
 |             mat2 = torch.randn(4, 2) | 
 |  | 
 |             out_ref = addmm(mat, mat1, mat2) | 
 |             self.run_pass('decompose_ops', addmm.graph) | 
 |             out_test = addmm(mat, mat1, mat2) | 
 |             self.assertEqual(out_ref, out_test) | 
 |             FileCheck().check_not("addmm").run(str(addmm.graph)) | 
 |  | 
 |         def doesnt_decompose(): | 
 |             @torch.jit.script | 
 |             def addmm(mat, mat1, mat2, alpha, beta): | 
 |                 a = mat.addmm(mat1, mat2, alpha=4.20, beta=2.0) | 
 |                 b = mat.addmm(mat1, mat2, alpha=int(alpha), beta=int(beta)) | 
 |  | 
 |                 return a + b | 
 |  | 
 |             orig = str(addmm.graph) | 
 |             self.run_pass('decompose_ops', addmm.graph) | 
 |             self.assertTrue(orig == str(addmm.graph)) | 
 |  | 
 |         does_decompose() | 
 |         doesnt_decompose() | 
 |  | 
 |     @suppress_warnings | 
 |     def test_sparse_tensors(self): | 
 |         @torch.jit.ignore | 
 |         def get_sparse(): | 
 |             return torch.sparse.FloatTensor(2, 3) | 
 |  | 
 |         @torch.jit.script | 
 |         def test_is_sparse(input): | 
 |             # type: (Tensor) -> bool | 
 |             return input.is_sparse | 
 |  | 
 |         script_out_is_sparse = test_is_sparse(get_sparse()) | 
 |         script_out_is_dense = test_is_sparse(torch.randn(2, 3)) | 
 |         self.assertEqual(script_out_is_sparse, True) | 
 |         self.assertEqual(script_out_is_dense, False) | 
 |  | 
 |         def test_basic_sparse(input): | 
 |             output = get_sparse() | 
 |             return output, input | 
 |  | 
 |         self.checkScript(test_basic_sparse, (get_sparse(),)) | 
 |         self.checkScript(test_basic_sparse, (torch.tensor([1]),)) | 
 |  | 
 |         def test_sparse_sum(input): | 
 |             return torch.sparse.sum(input) | 
 |  | 
 |         self.checkScript(test_sparse_sum, (get_sparse(),)) | 
 |  | 
 |         def test_sparse_mm(input1, input2): | 
 |             return torch.sparse.mm(input1, input2) | 
 |  | 
 |         self.checkScript(test_sparse_mm, (get_sparse(), torch.randn(3, 4))) | 
 |  | 
 |         def test_sparse_addmm(input, input1, input2): | 
 |             return torch.sparse.addmm(input, input1, input2) | 
 |  | 
 |         def test_sparse_addmm_alpha_beta(input, input1, input2): | 
 |             return torch.sparse.addmm(input, input1, input2, alpha=1.3, beta=1.5) | 
 |  | 
 |         self.checkScript(test_sparse_addmm, (torch.randn(2, 4), get_sparse(), torch.randn(3, 4))) | 
 |         self.checkScript(test_sparse_addmm_alpha_beta, (torch.randn(2, 4), get_sparse(), torch.randn(3, 4))) | 
 |  | 
 |     @suppress_warnings | 
 |     def test_sparse_csr_tensors(self): | 
 |         @torch.jit.ignore | 
 |         def get_sparse_csr(): | 
 |             return torch.randn(3, 3).to_sparse_csr() | 
 |  | 
 |         @torch.jit.script | 
 |         def test_is_sparse_csr(input): | 
 |             # type: (Tensor) -> bool | 
 |             return input.is_sparse_csr | 
 |  | 
 |         script_out_is_sparse_csr = test_is_sparse_csr(get_sparse_csr()) | 
 |         script_out_is_dense_csr = test_is_sparse_csr(torch.randn(3, 3)) | 
 |  | 
 |         self.assertEqual(script_out_is_sparse_csr, True) | 
 |         self.assertEqual(script_out_is_dense_csr, False) | 
 |  | 
 |     @unittest.skipIf(not RUN_CUDA, "requires CUDA") | 
 |     def test_device_not_equal(self): | 
 |  | 
 |         def compare_device(x: torch.device): | 
 |             return x != torch.device("cuda:0") | 
 |  | 
 |         def compare_two_device(x: torch.device, y: torch.device): | 
 |             return x != y | 
 |  | 
 |         self.checkScript(compare_device, (torch.device("cuda:0"),)) | 
 |         self.checkScript(compare_two_device, (torch.device("cuda:0"), torch.device("cuda:1"), )) | 
 |  | 
 |     def test_constant_prop_simple(self): | 
 |         @torch.jit.script | 
 |         def constant_prop(input_int): | 
 |             # type: (int) -> int | 
 |             a = 2 * 3 | 
 |             b = a + 2 | 
 |             return b - input_int | 
 |  | 
 |         out_ref = constant_prop(2) | 
 |         self.run_pass('constant_propagation', constant_prop.graph) | 
 |         out_test = constant_prop(2) | 
 |         self.assertEqual(out_ref, out_test) | 
 |         graph_str = str(constant_prop.graph) | 
 |         self.assertTrue("aten::add" not in graph_str and "aten::mul" not in graph_str) | 
 |         const = constant_prop.graph.findNode("prim::Constant").output().toIValue() | 
 |         self.assertEqual(const, 8) | 
 |  | 
 |     def test_constant_prop_nested(self): | 
 |         @torch.jit.script | 
 |         def constant_prop(a): | 
 |             b = 2 + 1 | 
 |             if bool(a < 2): | 
 |                 c = b + 2 | 
 |             else: | 
 |                 c = b - 2 | 
 |             return c | 
 |         out_ref = constant_prop(torch.tensor(2)) | 
 |         self.run_pass('constant_propagation', constant_prop.graph) | 
 |         out_test = constant_prop(torch.tensor(2)) | 
 |         self.assertEqual(out_ref, out_test) | 
 |         if_node = constant_prop.graph.findNode("prim::If") | 
 |         for block in if_node.blocks(): | 
 |             for node in block.nodes(): | 
 |                 self.assertTrue(node.kind() == "prim::Constant") | 
 |  | 
 |     def test_constant_prop_print(self): | 
 |         @torch.jit.script | 
 |         def constant_prop(input_tensor): | 
 |             a = 2 * 3 | 
 |             print(a) | 
 |             b = a + 2 | 
 |             return b + input_tensor | 
 |  | 
 |         self.run_pass('constant_propagation', constant_prop.graph) | 
 |         graph = constant_prop.graph | 
 |         print_node = graph.findNode("prim::Print") | 
 |         self.assertTrue(print_node.input().toIValue() == 6) | 
 |  | 
 |     def test_constant_prop_rand(self): | 
 |         @torch.jit.script | 
 |         def constant_prop(): | 
 |             a = torch.randn([3]) | 
 |             b = a + 2 | 
 |             return b | 
 |  | 
 |         self.run_pass('constant_propagation', constant_prop.graph) | 
 |         self.assertTrue("aten::randn" in str(constant_prop.graph)) | 
 |  | 
 |     def test_constant_prop_none(self): | 
 |         @torch.jit.script | 
 |         def typed_none(): | 
 |             # type: () -> Optional[int] | 
 |             return None | 
 |  | 
 |         @torch.jit.script | 
 |         def constant_prop(): | 
 |             a = typed_none() | 
 |             b = typed_none() | 
 |             if (a is None and b is None): | 
 |                 a = 2 | 
 |             else: | 
 |                 a = 1 | 
 |             return a | 
 |  | 
 |         self.run_pass('constant_propagation', constant_prop.graph) | 
 |         FileCheck().check("prim::Constant").run(constant_prop.graph) | 
 |  | 
 |     def test_constant_prop_if_inline(self): | 
 |         @torch.jit.script | 
 |         def constant_prop(): | 
 |             cond = True | 
 |             a = 1 | 
 |             if cond: | 
 |                 a = 1 * 2 | 
 |             else: | 
 |                 a = 1 // 0 | 
 |             return a | 
 |  | 
 |         # testing that 1 // 0 error is not thrownn | 
 |         self.run_pass('constant_propagation', constant_prop.graph) | 
 |  | 
 |     def test_constant_prop_exception(self): | 
 |         # checking y = a[4] does not error in constant propagation | 
 |         def bad_index(x): | 
 |             # type: (bool) | 
 |             y = 0 | 
 |             if x: | 
 |                 a = [1, 2, 3] | 
 |                 y = a[4] | 
 |             return y | 
 |  | 
 |         self.checkScript(bad_index, (False,)) | 
 |  | 
 |     def test_constant_prop_aliasing_type(self): | 
 |         @torch.jit.script | 
 |         def foo(): | 
 |             return len([1]), len(torch.tensor([2])) | 
 |  | 
 |         FileCheck().check_dag("aten::tensor").check_dag("aten::len").run(foo.graph) | 
 |  | 
 |         @torch.jit.script | 
 |         def fn(): | 
 |             if 1 == 1: | 
 |                 return 1 | 
 |             else: | 
 |                 return 2 | 
 |  | 
 |         FileCheck().check_not("prim::If").run(fn.graph) | 
 |  | 
 |     def test_unchecked_cast(self): | 
 |         def test(cond): | 
 |             # type: (bool) | 
 |             a = torch.tensor([10]) | 
 |             if cond: | 
 |                 b = None | 
 |             else: | 
 |                 b = a | 
 |             if b is not None: | 
 |                 b[0] = 5 | 
 |             return a.int() | 
 |  | 
 |         self.checkScript(test, (True,)) | 
 |         self.checkScript(test, (False,)) | 
 |  | 
 |     def test_constant_prop_if_constant(self): | 
 |         @torch.jit.script | 
 |         def constant_prop(a, b): | 
 |             c0 = 1 | 
 |             c1 = 1 | 
 |             c2 = 1 | 
 |             if bool(a):  # -> c0, c1 | 
 |                 if bool(b):  # -> c0 | 
 |                     if 1 == 1:  # -> c0 | 
 |                         c0 = c0 + 1 | 
 |                         if 1 == 2: | 
 |                             c1 = c1 + 1 | 
 |                             c2 = c2 + 1 | 
 |             else:  # -> c0, c1 | 
 |                 c1 = c1 + 1 | 
 |  | 
 |             if 1 == 1:  # inlined | 
 |                 c0 = c0 + 1  # dynamic | 
 |                 c2 = c2 + 4  # set to 5 | 
 |             return a + c0 + c1 + c2 | 
 |  | 
 |         graph = constant_prop.graph | 
 |         self.run_pass('constant_propagation', graph) | 
 |         ifs = graph.findAllNodes("prim::If", recurse=False) | 
 |         snd_if_inlined = len(ifs) == 1 | 
 |         self.assertTrue(snd_if_inlined) | 
 |         first_if = ifs[0] | 
 |         self.assertTrue(first_if.outputsSize() == 2) | 
 |         second_if = first_if.findNode("prim::If", recurse=False) | 
 |         self.assertTrue(second_if.outputsSize() == 1) | 
 |         self.assertTrue(second_if.findNode("prim::If") is None) | 
 |  | 
 |     def test_constant_prop_loop_constant(self): | 
 |         @torch.jit.script | 
 |         def constant_prop(cond, iter): | 
 |             # type: (bool, int) -> int | 
 |             b = 0 | 
 |             while True: | 
 |                 print("stays") | 
 |             for _ in range(2): | 
 |                 print("stays") | 
 |             for _ in range(iter): | 
 |                 print("stays") | 
 |             while cond: | 
 |                 print("stays") | 
 |             while False: | 
 |                 print("removed") | 
 |             for _i in range(0): | 
 |                 print("removed") | 
 |             for _i in range(-4): | 
 |                 print("removed") | 
 |             return b | 
 |  | 
 |         self.run_pass('constant_propagation', constant_prop.graph) | 
 |         graph = canonical(constant_prop.graph) | 
 |         self.assertTrue(graph.count("removed") == 0) | 
 |         self.assertTrue(graph.count("stays") == 1)  # constant gets pooled | 
 |         self.assertTrue(graph.count("prim::Print") == 4) | 
 |  | 
 |     def test_constant_prop_remove_output(self): | 
 |         @torch.jit.script | 
 |         def constant_prop(iter): | 
 |             # type: (int) -> None | 
 |             a = 1 | 
 |             b = 1 | 
 |             c = 1 | 
 |             for i in range(iter): | 
 |                 if 1 == 2: | 
 |                     a = 10 | 
 |                 if i == 5: | 
 |                     b = 2 | 
 |                     c = 3 | 
 |             print(a, b, c) | 
 |  | 
 |         graph = constant_prop.graph | 
 |         self.run_pass('constant_propagation', graph) | 
 |         self.assertTrue(graph.findNode("prim::Loop").outputsSize() == 2) | 
 |  | 
 |     # TODO(gmagogsfm): Refactor this test to reduce complexity. | 
 |     def test_constant_insertion(self): | 
 |         funcs_template = dedent(''' | 
 |         def func(): | 
 |             return {constant_constructor} | 
 |         ''') | 
 |  | 
 |         # constants: primitives: int, double, bool, str, lists of primitives, | 
 |         # and tuples | 
 |         def check_constant(constant_constructor): | 
 |             scope = {} | 
 |             funcs_str = funcs_template.format(constant_constructor=constant_constructor) | 
 |             execWrapper(funcs_str, globals(), scope) | 
 |             cu = torch.jit.CompilationUnit(funcs_str) | 
 |             f_script = cu.func | 
 |             self.run_pass('constant_propagation', f_script.graph) | 
 |             FileCheck().check_count("prim::Constant", 1, exactly=True).run(f_script.graph) | 
 |             self.assertEqual(scope['func'](), f_script()) | 
 |             imported = self.getExportImportCopy(f_script) | 
 |             self.assertEqual(imported(), f_script()) | 
 |  | 
 |         constants = ["None", "-.5", "0", "1", "True", "False", "''", "'a'", "'b'", "torch.tensor(1)", | 
 |                      "[True, False]", "[0., .5]", "[torch.tensor(4), torch.tensor(2)]", "[0, 1]", "['0', '1']", | 
 |                      "[True, None]", "[.5, None, .2]"] | 
 |  | 
 |         for type in ["Tensor", "str", "int", "float", "bool"]: | 
 |             constants.append("torch.jit.annotate(List[ " + type + "], [])") | 
 |  | 
 |         for constant in constants: | 
 |             check_constant(constant) | 
 |  | 
 |         for key_type in ["str", "int", "float"]: | 
 |             for value_type in ["Tensor", "bool", "str", "int", "float"]: | 
 |                 check_constant("torch.jit.annotate(Dict[ " + key_type + ", " + value_type + "], {})") | 
 |                 check_constant("torch.jit.annotate(Dict[ " + key_type + ", Optional[" + value_type + "]], {})") | 
 |  | 
 |         for i in range(len(constants)): | 
 |             for j in range(i + 1, len(constants)): | 
 |                 tup_constant = constants[i] + ", " + constants[j] | 
 |                 check_constant(tup_constant) | 
 |  | 
 |         dict_constants = [] | 
 |         for i in range(len(constants)): | 
 |             # check_constant constructs the second dict with another Tensor | 
 |             # which fails the comparison | 
 |             if not isinstance(eval(constants[i]), (str, int, float)): | 
 |                 continue | 
 |             for j in range(len(constants)): | 
 |                 dict_constant = "{ " + constants[i] + ": " + constants[j] + "}" | 
 |                 check_constant(dict_constant) | 
 |                 dict_constants.append(dict_constant) | 
 |         constants = constants + dict_constants | 
 |  | 
 |         # testing node hashing | 
 |         funcs_template = dedent(''' | 
 |         def func(): | 
 |             print({constant_constructor}) | 
 |         ''') | 
 |         single_elem_tuples = ("(" + x + ",)" for x in constants) | 
 |         input_arg = ", ".join(single_elem_tuples) | 
 |         scope = {} | 
 |         funcs_str = funcs_template.format(constant_constructor=input_arg) | 
 |         execWrapper(funcs_str, globals(), scope) | 
 |         cu = torch.jit.CompilationUnit(funcs_str) | 
 |         f_script = cu.func | 
 |         self.run_pass('constant_propagation', f_script.graph) | 
 |         # prim::None return adds one constant | 
 |         self.assertEqual(len(constants) + 1, str(f_script.graph).count("prim::Constant")) | 
 |         self.run_pass('cse', f_script.graph) | 
 |         # node hashing correctly working, no CSE occurs | 
 |         self.assertEqual(len(constants) + 1, str(f_script.graph).count("prim::Constant")) | 
 |  | 
 |         funcs_template = dedent(''' | 
 |         def func(): | 
 |             a = {constant_constructor} | 
 |             print(a) | 
 |             b = {constant_constructor} | 
 |             print(b) | 
 |         ''') | 
 |  | 
 |         # generate dicts with built-in types (excluding torch.Tensor) | 
 |         xprod = itertools.product(constants, constants) | 
 |  | 
 |         # test that equal tuples and dicts correctly work with node hashing | 
 |         for tup in ("(" + x + ",)" for x in constants): | 
 |             funcs_str = funcs_template.format(constant_constructor=tup) | 
 |             scope = {} | 
 |             execWrapper(funcs_str, globals(), scope) | 
 |             cu = torch.jit.CompilationUnit(funcs_str) | 
 |             f_script = cu.func | 
 |             self.run_pass('constant_propagation_immutable_types', f_script.graph) | 
 |             num_constants = str(f_script.graph).count("prim::Constant") | 
 |             self.run_pass('cse', f_script.graph) | 
 |             FileCheck().check_count("prim::Constant", num_constants, exactly=True).run(f_script.graph) | 
 |  | 
 |     @unittest.skipIf(not RUN_CUDA, "requires CUDA") | 
 |     def test_cuda_export_restore(self): | 
 |         class Sub(torch.jit.ScriptModule): | 
 |             def __init__(self): | 
 |                 super(Sub, self).__init__() | 
 |                 self.weight = nn.Parameter(torch.randn(3, 4)) | 
 |  | 
 |             @torch.jit.script_method | 
 |             def forward(self, thing): | 
 |                 return self.weight + thing | 
 |  | 
 |         class M(torch.jit.ScriptModule): | 
 |             def __init__(self): | 
 |                 super(M, self).__init__() | 
 |                 self.mod = Sub() | 
 |  | 
 |             @torch.jit.script_method | 
 |             def forward(self, v): | 
 |                 return self.mod(v) | 
 |         m = M() | 
 |         m.cuda() | 
 |         m2 = self.getExportImportCopy(m) | 
 |         m2.cuda() | 
 |         input = torch.rand(3, 4).cuda() | 
 |         self.assertEqual(m(input), m2(input)) | 
 |  | 
 |     @slowTest | 
 |     def test_export_batchnorm(self): | 
 |         for mode in ['eval', 'train']: | 
 |             for clazz in [ | 
 |                     torch.nn.BatchNorm1d(100), | 
 |                     torch.nn.BatchNorm1d(100, affine=False), | 
 |                     torch.nn.BatchNorm2d(100), | 
 |                     torch.nn.BatchNorm2d(100, affine=False)]: | 
 |                 getattr(clazz, mode)() | 
 |                 input = torch.randn(20, 100) if isinstance(clazz, torch.nn.BatchNorm1d) else \ | 
 |                     torch.randn(20, 100, 35, 45) | 
 |                 traced = torch.jit.trace(clazz, (input,)) | 
 |                 imported = self.getExportImportCopy(traced) | 
 |                 x = torch.randn(20, 100) if isinstance(clazz, torch.nn.BatchNorm1d) else \ | 
 |                     torch.randn(20, 100, 35, 45) | 
 |                 self.assertEqual(traced(x), imported(x)) | 
 |  | 
 |     def test_export_rnn(self): | 
 |         for clazz in [nn.RNN(10, 20, 2), nn.GRU(10, 20, 2)]: | 
 |             class RNNTest(torch.nn.Module): | 
 |                 def __init__(self): | 
 |                     super(RNNTest, self).__init__() | 
 |                     self.rnn = clazz | 
 |  | 
 |                 def forward(self, x, lengths, h0): | 
 |                     packed = torch.nn.utils.rnn.pack_padded_sequence(x, lengths) | 
 |                     out, h = self.rnn(packed, h0) | 
 |                     padded_outs, _ = torch.nn.utils.rnn.pad_packed_sequence(out) | 
 |                     return padded_outs | 
 |  | 
 |             test = RNNTest() | 
 |  | 
 |             traced = torch.jit.trace(test, (torch.randn(5, 3, 10), torch.LongTensor([3, 2, 1]), torch.randn(2, 3, 20))) | 
 |             imported = self.getExportImportCopy(traced) | 
 |             # NB: We make sure to pass in a batch with a different max sequence | 
 |             # length to ensure that the argument stashing for pad_packed works | 
 |             # properly. | 
 |             x, lengths, h0 = torch.randn(7, 4, 10), torch.LongTensor([7, 3, 2, 1]), torch.randn(2, 4, 20) | 
 |             self.assertEqual(traced(x, lengths, h0), imported(x, lengths, h0)) | 
 |  | 
 |     def test_export_lstm(self): | 
 |         class LSTMTest(torch.nn.Module): | 
 |             def __init__(self): | 
 |                 super(LSTMTest, self).__init__() | 
 |                 self.rnn = nn.LSTM(10, 20, 2) | 
 |  | 
 |             def forward(self, x, lengths, hiddens): | 
 |                 h0, c0 = hiddens | 
 |                 packed = torch.nn.utils.rnn.pack_padded_sequence(x, lengths) | 
 |                 out, (h, c) = self.rnn(packed, (h0, c0)) | 
 |                 padded_outs, _ = torch.nn.utils.rnn.pad_packed_sequence(out) | 
 |                 return padded_outs | 
 |  | 
 |         test = LSTMTest() | 
 |  | 
 |         traced = torch.jit.trace(test, (torch.randn(5, 3, 10), | 
 |                                         torch.LongTensor([3, 2, 1]), | 
 |                                         (torch.randn(2, 3, 20), torch.randn(2, 3, 20)))) | 
 |         imported = self.getExportImportCopy(traced) | 
 |         x, lengths, h0, c0 = \ | 
 |             torch.randn(7, 3, 10), torch.LongTensor([7, 5, 2]), torch.randn(2, 3, 20), torch.randn(2, 3, 20) | 
 |         self.assertEqual(traced(x, lengths, (h0, c0)), imported(x, lengths, (h0, c0))) | 
 |  | 
 |     def test_unique_state_dict(self): | 
 |         class MyModule(torch.nn.Module): | 
 |             def __init__(self): | 
 |                 super(MyModule, self).__init__() | 
 |                 shared_param = torch.nn.Parameter(torch.ones(1)) | 
 |                 self.register_parameter('w1', shared_param) | 
 |                 self.register_parameter('w2', shared_param) | 
 |  | 
 |             def forward(self, input): | 
 |                 return input + self.w1 + self.w2 | 
 |  | 
 |         model = MyModule() | 
 |         unittest.TestCase.assertEqual( | 
 |             self, len(torch.jit._unique_state_dict(model, keep_vars=False)), 1) | 
 |         unittest.TestCase.assertEqual( | 
 |             self, len(torch.jit._unique_state_dict(model, keep_vars=True)), 1) | 
 |  | 
 |     def test_export_dropout(self): | 
 |         test = torch.nn.Dropout() | 
 |         test.eval() | 
 |  | 
 |         traced = torch.jit.trace(test, (torch.rand(3, 4),), check_trace=False) | 
 |         imported = self.getExportImportCopy(traced) | 
 |         x = torch.randn(3, 4) | 
 |         self.assertEqual(traced(x), imported(x)) | 
 |  | 
 |     def test_pretty_printer(self): | 
 |         @torch.jit.script | 
 |         def if_test(a, b): | 
 |             # FIXME: use 0 instead of a. | 
 |             # c = 0 | 
 |             c = a | 
 |             if bool(a < b): | 
 |                 c = b | 
 |             else: | 
 |                 c = a | 
 |             return c | 
 |  | 
 |         @torch.jit.script | 
 |         def if_one(a, b): | 
 |             c = b | 
 |             if bool(a < b): | 
 |                 c = a | 
 |             return c | 
 |  | 
 |         @torch.jit.script | 
 |         def while_test(a, i): | 
 |             while bool(i < 3): | 
 |                 a *= a | 
 |                 i += 1 | 
 |             return a | 
 |  | 
 |         @torch.jit.script | 
 |         def while_if_test(a, b): | 
 |             c = 0 | 
 |             while bool(a < 10): | 
 |                 a = a + 1 | 
 |                 b = b + 1 | 
 |                 if bool(a > b): | 
 |                     c = 2 | 
 |                 else: | 
 |                     c = 3 | 
 |             return a + 1 + c | 
 |  | 
 |         @torch.jit.script | 
 |         def loop_use_test(y): | 
 |             x = y + 1 | 
 |             z = x + 5 | 
 |             while bool(y < 8): | 
 |                 y += 1 | 
 |                 z = x | 
 |             return x, z | 
 |  | 
 |         @torch.jit.ignore | 
 |         def python_fn(x): | 
 |             return x + 10 | 
 |  | 
 |         @torch.jit.script | 
 |         def python_op_name_test(y): | 
 |             return python_fn(y) | 
 |  | 
 |         @torch.jit.script | 
 |         def empty_int_list_test(y): | 
 |             x = torch.jit.annotate(List[int], []) | 
 |             return x[0] | 
 |  | 
 |         @torch.jit.script | 
 |         def empty_float_list_test(y): | 
 |             return [1.0, 2.0, 3.0] | 
 |  | 
 |         @torch.jit.script | 
 |         def print_weird_test(y): | 
 |             print("hi\016") | 
 |  | 
 |         self.assertExpected(if_test.code, "if_test") | 
 |         self.assertExpected(if_one.code, "if_one") | 
 |         self.assertExpected(while_test.code, "while_test") | 
 |         self.assertExpected(while_if_test.code, "while_if_test") | 
 |         self.assertExpected(loop_use_test.code, "loop_use_test") | 
 |         self.assertExpected(python_op_name_test.code, "python_op_name_test") | 
 |         self.assertExpected(empty_int_list_test.code, "empty_int_list_test") | 
 |         self.assertExpected(empty_float_list_test.code, "empty_float_list_test") | 
 |         self.assertExpected(print_weird_test.code, "print_weird_test") | 
 |  | 
 |     def test_cu_escaped_number(self): | 
 |         cu = torch.jit.CompilationUnit(''' | 
 |             def foo(a): | 
 |                 print("hi\016") | 
 |         ''') | 
 |         self.assertExpected(cu.foo.code) | 
 |  | 
 |     def test_import_method(self): | 
 |         with torch._jit_internal._disable_emit_hooks(): | 
 |             class Foo(torch.jit.ScriptModule): | 
 |                 def __init__(self): | 
 |                     super(Foo, self).__init__() | 
 |  | 
 |                 @torch.jit.script_method | 
 |                 def forward(self, x, y): | 
 |                     return 2 * x + y | 
 |  | 
 |             foo = Foo() | 
 |             buffer = io.BytesIO() | 
 |             torch.jit.save(foo, buffer) | 
 |  | 
 |             buffer.seek(0) | 
 |             foo_loaded = torch.jit.load(buffer) | 
 |             self.assertExpected(foo_loaded.forward.code) | 
 |  | 
 |     @unittest.skip("temporarily disable the test for fwd compatibility") | 
 |     def test_non_ascii_string(self): | 
 |         class Foo(torch.jit.ScriptModule): | 
 |             def __init__(self): | 
 |                 super(Foo, self).__init__() | 
 |                 self.a = "Over \u0e55\u0e57 57" | 
 |  | 
 |             @torch.jit.script_method | 
 |             def forward(self, x, y): | 
 |                 return self.a + "hi\xA1" | 
 |  | 
 |         foo = Foo() | 
 |         buffer = io.BytesIO() | 
 |         torch.jit.save(foo, buffer) | 
 |  | 
 |         buffer.seek(0) | 
 |         foo_loaded = torch.jit.load(buffer) | 
 |         self.assertExpected(foo_loaded.forward.code) | 
 |  | 
 |     def test_function_default_values(self): | 
 |         outer_var = torch.tensor(20) | 
 |         outer_var2 = torch.tensor(30) | 
 |         a = torch.tensor(0.5) | 
 |         b = torch.tensor(10) | 
 |  | 
 |         @torch.jit.script | 
 |         def simple_fn(x, a=a, b=b, c=outer_var + outer_var2): | 
 |             return x + a + b + c | 
 |  | 
 |         self.assertEqual( | 
 |             simple_fn(torch.ones(1)), | 
 |             torch.ones(1) + 0.5 + 10 + (20 + 30)) | 
 |         self.assertEqual( | 
 |             simple_fn(torch.ones(1), torch.tensor(1), torch.tensor(3), torch.tensor(4)), | 
 |             torch.ones(1) + 1 + 3 + 4) | 
 |  | 
 |         outer_c = torch.tensor(9) | 
 |         outer_flag = torch.tensor(False) | 
 |  | 
 |         @torch.jit.script | 
 |         def bool_fn(x, a=outer_c, flag=outer_flag): | 
 |             if bool(flag): | 
 |                 result = x | 
 |             else: | 
 |                 result = x + a | 
 |             return result | 
 |  | 
 |         self.assertEqual(bool_fn(torch.ones(1)), torch.ones(1) + 9) | 
 |         self.assertEqual( | 
 |             bool_fn(torch.ones(1), torch.tensor(1), torch.tensor(True)), | 
 |             torch.ones(1)) | 
 |  | 
 |         @torch.jit.script | 
 |         def none_fn(x=None): | 
 |             # type: (Optional[int]) -> Optional[int] | 
 |             return x | 
 |  | 
 |         self.assertEqual(none_fn(), None) | 
 |         self.assertEqual(none_fn(1), 1) | 
 |  | 
 |         @torch.jit.script | 
 |         def hints(x, a=0.5, b=10): | 
 |             # type: (Tensor, float, int) -> Tensor | 
 |             return x + a + b | 
 |  | 
 |         self.assertEqual(hints(torch.ones(1)), torch.ones(1) + 0.5 + 10) | 
 |  | 
 |         with self.assertRaisesRegex(RuntimeError, "Expected a default value"): | 
 |  | 
 |             @torch.jit.script | 
 |             def hints_bad_types(x, a=10, b=0.5):  # noqa: T484 | 
 |                 # type: (Tensor, float, int) -> Tensor | 
 |                 return x + a + b | 
 |         with self.assertRaisesRegex(RuntimeError, "Expected a default value"): | 
 |             @torch.jit.script | 
 |             def bad_no_optional(x=None): | 
 |                 # type: (Dict[str, int]) -> Dict[str, int] | 
 |                 return x | 
 |  | 
 |  | 
 |     def test_module_default_values(self): | 
 |         four = torch.tensor(4) | 
 |  | 
 |         class Test(torch.jit.ScriptModule): | 
 |             def __init__(self): | 
 |                 super(Test, self).__init__() | 
 |  | 
 |             @torch.jit.script_method | 
 |             def forward(self, input, other=four): | 
 |                 return input + other | 
 |  | 
 |         t = Test() | 
 |         self.assertEqual(t(torch.ones(1)), torch.ones(1) + 4) | 
 |  | 
 |     def test_mutable_default_values(self): | 
 |         with self.assertRaisesRegex(Exception, "Mutable default parameters"): | 
 |             @torch.jit.script | 
 |             def foo(x=(1, [])): | 
 |                 # type: (Tuple[int, List[Tensor]]) | 
 |                 return x | 
 |  | 
 |         class Test(torch.nn.Module): | 
 |             def forward(self, input=[]):  # noqa: B006 | 
 |                 return input | 
 |  | 
 |         with self.assertRaisesRegex(Exception, "Mutable default parameters"): | 
 |             torch.jit.script(Test()) | 
 |  | 
 |     @skipIfTorchDynamo("TorchDynamo fails with unknown reason") | 
 |     def test_warnings(self): | 
 |         import warnings | 
 |  | 
 |         def fn(x): | 
 |             if bool(x < 2): | 
 |                 warnings.warn("x is less than 2") | 
 |             return x | 
 |  | 
 |         class M(torch.nn.Module): | 
 |             def forward(self, x): | 
 |                 if bool(x < 2): | 
 |                     warnings.warn("x is less than 2") | 
 |                 return x | 
 |  | 
 |  | 
 |         scripted_mod = torch.jit.script(M()) | 
 |         scripted_fn = torch.jit.script(fn) | 
 |  | 
 |         with warnings.catch_warnings(record=True) as warns: | 
 |             fn(torch.ones(1)) | 
 |  | 
 |         with warnings.catch_warnings(record=True) as script_warns: | 
 |             scripted_fn(torch.ones(1)) | 
 |  | 
 |         with warnings.catch_warnings(record=True) as script_mod_warns: | 
 |             scripted_mod(torch.ones(1)) | 
 |  | 
 |         self.assertEqual(str(warns[0]), str(script_warns[0])) | 
 |         self.assertEqual(len(script_mod_warns), 1) | 
 |         self.assertEqual(str(warns[0].message), str(script_mod_warns[0].message)) | 
 |  | 
 |     def test_no_erroneous_warnings(self): | 
 |         import warnings | 
 |  | 
 |         def fn(x): | 
 |             if bool(x > 0): | 
 |                 warnings.warn('This should NOT be printed') | 
 |                 x += 1 | 
 |             return x | 
 |  | 
 |         with warnings.catch_warnings(record=True) as warns: | 
 |             fn_script = torch.jit.script(fn) | 
 |             fn_script(torch.tensor(0)) | 
 |         warns = [str(w.message) for w in warns] | 
 |         self.assertEqual(len(warns), 0) | 
 |  | 
 |     @unittest.skipIf(True, "TODO: re-enable with https://github.com/pytorch/pytorch/pull/29339") | 
 |     def test_torch_load_error(self): | 
 |         class J(torch.jit.ScriptModule): | 
 |             def __init__(self): | 
 |                 super(J, self).__init__() | 
 |  | 
 |             @torch.jit.script_method | 
 |             def forward(self, input): | 
 |                 return input + 100 | 
 |  | 
 |         j = J() | 
 |         with TemporaryFileName() as fname: | 
 |             j.save(fname) | 
 |             with self.assertRaisesRegex(RuntimeError, "is a zip"): | 
 |                 torch.load(fname) | 
 |  | 
 |     def test_torch_load_zipfile_check(self): | 
 |         @torch.jit.script | 
 |         def fn(x): | 
 |             return x + 10 | 
 |  | 
 |         with TemporaryFileName() as fname: | 
 |             fn.save(fname) | 
 |             with io.open(fname, 'rb') as f: | 
 |                 self.assertTrue(torch.serialization._is_zipfile(f)) | 
 |  | 
 |     def test_python_bindings(self): | 
 |         lstm_cell = torch.jit.script(LSTMCellS) | 
 |  | 
 |         def lstm(x, hx, cx, w_ih, w_hh, b_ih, b_hh): | 
 |             for i in range(x.size(0)): | 
 |                 hx, cx = lstm_cell(x[i], hx, cx, w_ih, w_hh, b_ih, b_hh) | 
 |             return hx | 
 |  | 
 |         slstm = torch.jit.script(lstm) | 
 |  | 
 |         inputs = get_lstm_inputs('cpu', training=True, seq_length=10) | 
 |         slstm(*inputs).sum().backward() | 
 |         global fw_graph | 
 |         fw_graph = slstm.graph_for(*inputs) | 
 |         nodes = list(fw_graph.nodes()) | 
 |         tested_blocks = False | 
 |         for node in nodes: | 
 |             for output in node.outputs(): | 
 |                 self.assertTrue(hasattr(output, 'type')) | 
 |                 self.assertTrue(output.type() is not None) | 
 |             for input in node.inputs(): | 
 |                 self.assertTrue(hasattr(input, 'type')) | 
 |                 self.assertTrue(input.type() is not None) | 
 |             for block in node.blocks(): | 
 |                 tested_blocks = True | 
 |                 self.assertTrue(hasattr(block, 'inputs')) | 
 |                 self.assertTrue(hasattr(block, 'outputs')) | 
 |                 for output in block.outputs(): | 
 |                     self.assertTrue(hasattr(output, 'type')) | 
 |                     self.assertTrue(output.type() is not None) | 
 |                 for input in block.inputs(): | 
 |                     self.assertTrue(hasattr(input, 'type')) | 
 |                     self.assertTrue(input.type() is not None) | 
 |                 self.assertTrue(hasattr(block, 'returnNode')) | 
 |                 self.assertTrue(type(block.returnNode()) == torch._C.Node) | 
 |                 self.assertTrue(hasattr(block, 'paramNode')) | 
 |                 self.assertTrue(type(block.paramNode()) == torch._C.Node) | 
 |         self.assertTrue(tested_blocks) | 
 |  | 
 |     def test_export_opnames(self): | 
 |         class Foo(torch.jit.ScriptModule): | 
 |             def __init__(self): | 
 |                 super(Foo, self).__init__() | 
 |  | 
 |             def one(self, x, y): | 
 |                 # type: (Tensor, Tensor) -> Tensor | 
 |                 return x + y | 
 |  | 
 |             def two(self, x): | 
 |                 # type: (Tensor) -> Tensor | 
 |                 return 2 * x | 
 |  | 
 |             @torch.jit.script_method | 
 |             def forward(self, x): | 
 |                 # type: (Tensor) -> Tensor | 
 |                 return self.one(self.two(x), x) | 
 |  | 
 |         class Bar(torch.jit.ScriptModule): | 
 |             def __init__(self): | 
 |                 super(Bar, self).__init__() | 
 |                 self.sub = Foo() | 
 |  | 
 |             @torch.jit.script_method | 
 |             def forward(self, x): | 
 |                 # type: (Tensor) -> Tensor | 
 |                 return self.sub.forward(x) | 
 |  | 
 |         bar = Bar() | 
 |         ops = torch.jit.export_opnames(bar) | 
 |         expected = ['aten::add.Tensor', 'aten::mul.Scalar'] | 
 |         self.assertTrue(set(expected).issubset(set(ops))) | 
 |  | 
 |     def test_pytorch_jit_env_off(self): | 
 |         import subprocess | 
 |         env = os.environ.copy() | 
 |         env['PYTORCH_JIT'] = '0' | 
 |         try: | 
 |             subprocess.check_output([sys.executable, '-c', 'import torch'], env=env) | 
 |         except subprocess.CalledProcessError as e: | 
 |             raise RuntimeError("Could not 'import torch' with PYTORCH_JIT=0") from e | 
 |  | 
 |     def test_print_op_module(self): | 
 |         # Issue #19351: python2 and python3 go through different paths. | 
 |         # python2 returns '<module 'torch.ops' (built-in)>' | 
 |         # python3 uses __file__ and return | 
 |         # '<module 'torch.ops' from '/scratch/ailzhang/pytorch/torch/_ops.py'>' | 
 |         s = str(torch.ops) | 
 |         self.assertRegex(s, r'ops') | 
 |  | 
 |     def test_print_classes_module(self): | 
 |         s = str(torch.classes) | 
 |         self.assertRegex(s, r'classes') | 
 |  | 
 |     def test_print_torch_ops_modules(self): | 
 |         s = str(torch._ops.ops.quantized) | 
 |         self.assertRegex(s, r'torch.ops') | 
 |         s = str(torch._ops.ops.atan) | 
 |         self.assertRegex(s, r'torch.ops') | 
 |  | 
 |     @unittest.skipIf(IS_WINDOWS, 'TODO: fix occasional windows failure') | 
 |     def test_profiler(self): | 
 |         prev_opt = torch._C._get_graph_executor_optimize() | 
 |         torch._C._set_graph_executor_optimize(False) | 
 |  | 
 |         def other_fn(x): | 
 |             return x * 2 | 
 |  | 
 |         x = torch.rand(3, 4) | 
 |         traced_other_fn = torch.jit.trace(other_fn, x) | 
 |  | 
 |         def fn(x): | 
 |             y = traced_other_fn(x) | 
 |             fut = torch.jit._fork(traced_other_fn, x) | 
 |             y = torch.jit._wait(fut) | 
 |             return y | 
 |  | 
 |         traced_fn = torch.jit.trace(fn, x) | 
 |         with torch.autograd.profiler.profile() as prof: | 
 |             traced_fn(x) | 
 |  | 
 |         # expecting to see other_fn TS function call | 
 |         # with cpu time >= mul cpu time and | 
 |         # a forked other_fn | 
 |  | 
 |         mul_events = defaultdict(int) | 
 |         other_fn_events = defaultdict(int) | 
 |         for e in prof.function_events: | 
 |             if e.name == "aten::mul": | 
 |                 self.assertTrue(e.thread not in mul_events) | 
 |                 mul_events[e.thread] = e.time_range.elapsed_us() | 
 |             elif e.name == "other_fn": | 
 |                 self.assertTrue(e.thread not in other_fn_events) | 
 |                 other_fn_events[e.thread] = e.time_range.elapsed_us() | 
 |  | 
 |         self.assertTrue(len(mul_events) == 2) | 
 |         self.assertTrue(len(other_fn_events) == 2) | 
 |  | 
 |         for thread, mul_time in mul_events.items(): | 
 |             self.assertTrue(thread in other_fn_events) | 
 |             self.assertTrue(other_fn_events[thread] >= mul_time) | 
 |  | 
 |         torch._C._set_graph_executor_optimize(prev_opt) | 
 |  | 
 |     def test_hide_source_ranges_context_manager(self): | 
 |         @torch.jit.script | 
 |         def foo(x): | 
 |             return torch.add(x, x) | 
 |  | 
 |         graph = foo.graph | 
 |         source_range_regex = "# .*\\.py" | 
 |         self.assertRegex(graph.__repr__(), source_range_regex) | 
 |         with torch.jit._hide_source_ranges(): | 
 |             self.assertNotRegex(graph.__repr__(), source_range_regex) | 
 |             self.assertRegex(graph.str(print_source_ranges=True), source_range_regex) | 
 |         self.assertRegex(graph.__repr__(), source_range_regex) | 
 |  | 
 |  | 
 | class TestFrontend(JitTestCase): | 
 |  | 
 |     def test_instancing_error(self): | 
 |         @torch.jit.ignore | 
 |         class MyScriptClass(object): | 
 |             def unscriptable(self): | 
 |                 return "a" + 200 | 
 |  | 
 |  | 
 |         class TestModule(torch.nn.Module): | 
 |             def __init__(self): | 
 |                 super(TestModule, self).__init__() | 
 |  | 
 |             def forward(self, x): | 
 |                 return MyScriptClass() | 
 |  | 
 |         with self.assertRaises(torch.jit.frontend.FrontendError) as cm: | 
 |             torch.jit.script(TestModule()) | 
 |  | 
 |         checker = FileCheck() | 
 |         checker.check("Cannot instantiate class") | 
 |         checker.check("def forward") | 
 |         checker.run(str(cm.exception)) | 
 |  | 
 |  | 
 | class TestScript(JitTestCase): | 
 |  | 
 |     # Tests that calling torch.jit.script repeated on function is allowed. | 
 |     def test_repeated_script_on_function(self): | 
 |         @torch.jit.script | 
 |         @torch.jit.script | 
 |         def fn(x): | 
 |             return x | 
 |  | 
 |         torch.jit.script(torch.jit.script(fn)) | 
 |  | 
 |     def test_pretty_print_function(self): | 
 |         @torch.jit.script | 
 |         def foo(x): | 
 |             return torch.nn.functional.interpolate(x) | 
 |  | 
 |         FileCheck().check("interpolate").run(foo.code) | 
 |  | 
 |     def test_inlined_graph(self): | 
 |         """ | 
 |         Check that the `inlined_graph` property correctly returns an inlined | 
 |         graph, both through function calls and method calls. | 
 |         """ | 
 |         @torch.jit.script | 
 |         def foo(x): | 
 |             return torch.add(x, x) | 
 |  | 
 |         class MyNestedMod(torch.nn.Module): | 
 |             def __init__(self): | 
 |                 super(MyNestedMod, self).__init__() | 
 |  | 
 |             def forward(self, x): | 
 |                 return torch.sub(x, x) | 
 |  | 
 |  | 
 |         class MyMod(torch.nn.Module): | 
 |             def __init__(self): | 
 |                 super(MyMod, self).__init__() | 
 |                 self.nested = MyNestedMod() | 
 |  | 
 |             def forward(self, x): | 
 |                 x = self.nested(x)  # sub | 
 |                 x = foo(x)  # add | 
 |                 return torch.mul(x, x) | 
 |  | 
 |         m = torch.jit.script(MyMod()) | 
 |         FileCheck().check("aten::sub") \ | 
 |             .check("aten::add") \ | 
 |             .check("aten::mul") \ | 
 |             .run(m.inlined_graph) | 
 |  | 
 |     def test_static_method_on_module(self): | 
 |         """ | 
 |         Check that the `@staticmethod` annotation on a function on a module works. | 
 |         """ | 
 |         class MyCell(torch.nn.Module): | 
 |             def __init__(self): | 
 |                 super(MyCell, self).__init__() | 
 |  | 
 |             @staticmethod | 
 |             def do_it(x, h): | 
 |                 new_h = torch.tanh(x + h) | 
 |                 return new_h, new_h | 
 |  | 
 |             def forward(self, x, h): | 
 |                 return self.do_it(x, h) | 
 |  | 
 |         my_cell = torch.jit.script(MyCell()) | 
 |         x = torch.rand(3, 4) | 
 |         h = torch.rand(3, 4) | 
 |         jitted_cell = my_cell(x, h) | 
 |         non_jitted_cell = MyCell().do_it(x, h) | 
 |  | 
 |         self.assertEqual(jitted_cell, non_jitted_cell) | 
 |  | 
 |     def test_code_with_constants(self): | 
 |         """ | 
 |         Check that the `code_with_constants` property correctly returns graph CONSTANTS in the | 
 |         CONSTANTS.cN format used in the output of the `code` property. | 
 |         """ | 
 |         @torch.jit.script | 
 |         def foo(x=torch.ones(1)): | 
 |             return x | 
 |  | 
 |         class Moddy(torch.nn.Module): | 
 |             def __init__(self): | 
 |                 super(Moddy, self).__init__() | 
 |  | 
 |             def forward(self, x): | 
 |                 return foo() | 
 |  | 
 |         m = torch.jit.script(Moddy()) | 
 |         src, CONSTANTS = m.code_with_constants | 
 |  | 
 |         self.assertEqual(CONSTANTS.c0, torch.ones(1)) | 
 |         self.assertEqual(src, m.code) | 
 |  | 
 |     def test_code_with_constants_restore(self): | 
 |         """ | 
 |         Check that the `code_with_constants` property correctly works on restoration after save() + load() | 
 |         """ | 
 |         @torch.jit.script | 
 |         def foo(x=torch.ones(1)): | 
 |             return x | 
 |  | 
 |         class Moddy(torch.nn.Module): | 
 |             def __init__(self): | 
 |                 super(Moddy, self).__init__() | 
 |  | 
 |             def forward(self, x): | 
 |                 return foo() | 
 |  | 
 |         m = torch.jit.script(Moddy()) | 
 |         src, CONSTANTS = m.code_with_constants | 
 |         eic = self.getExportImportCopy(m) | 
 |  | 
 |         src_eic, CONSTANTS_eic = eic.code_with_constants | 
 |  | 
 |         self.assertEqual(src, src_eic) | 
 |         self.assertEqual(CONSTANTS.c0, CONSTANTS_eic.c0) | 
 |  | 
 |  | 
 |     def test_oneline_func(self): | 
 |         def fn(x): return x  # noqa: E704 | 
 |  | 
 |         self.checkScript(fn, (torch.ones(2, 2), )) | 
 |  | 
 |     def test_request_bailout(self): | 
 |         with enable_profiling_mode_for_profiling_tests(): | 
 |  | 
 |             def fct_loop(x): | 
 |                 for i in range(3): | 
 |                     x = torch.cat((x, x), 0) | 
 |                 return x | 
 |  | 
 |             x = torch.ones(2, 3, 4, dtype=torch.float32) | 
 |             expected = fct_loop(x) | 
 |             jitted = torch.jit.script(fct_loop) | 
 |             # profile | 
 |             jitted(x) | 
 |             # optimize | 
 |             jitted(x) | 
 |             dstate = jitted.get_debug_state() | 
 |             eplan = get_execution_plan(dstate) | 
 |             num_bailouts = eplan.code.num_bailouts() | 
 |  | 
 |             for i in range(0, num_bailouts): | 
 |                 eplan.code.request_bailout(i) | 
 |                 self.assertEqual(jitted(x), expected) | 
 |  | 
 |     @unittest.skip("bailouts are being deprecated") | 
 |     def test_dominated_bailout(self): | 
 |         with enable_profiling_mode_for_profiling_tests(): | 
 |             # functional dominated guard | 
 |             @torch.jit.script | 
 |             def foo(x): | 
 |                 dim = x.dim() | 
 |                 if dim == 0: | 
 |                     y = int(x) | 
 |                 else: | 
 |                     y = x.size()[dim - 1] | 
 |                 return y | 
 |  | 
 |             x = torch.zeros(2) | 
 |             self.assertEqual(foo(x), 2) | 
 |             self.assertEqual(foo(x), 2) | 
 |             g = torch.jit.last_executed_optimized_graph() | 
 |             g_s = str(g) | 
 |             g_s = g_s[0:g_s.find("return")] | 
 |             FileCheck().check_count("prim::BailOut[", 1, exactly=True).run(g_s) | 
 |  | 
 |             # dominated guard of non-functional value | 
 |             @torch.jit.script | 
 |             def foo(x): | 
 |                 dim = x.dim() | 
 |                 x.add_(3) | 
 |                 if dim == 0: | 
 |                     return 0 | 
 |                 else: | 
 |                     return x.size()[dim - 1] | 
 |  | 
 |             x = torch.zeros(2) | 
 |             self.assertEqual(foo(x), 2) | 
 |             self.assertEqual(foo(x), 2) | 
 |             g = torch.jit.last_executed_optimized_graph() | 
 |             FileCheck().check("prim::BailOut[").check("aten::add_").check_next("prim::BailOut[").check("return").run(g) | 
 |  | 
 |             with torch.enable_grad(): | 
 |                 @torch.jit.ignore | 
 |                 def disable_grad(): | 
 |                     torch.set_grad_enabled(False) | 
 |  | 
 |                 @torch.jit.ignore | 
 |                 def enable_grad(): | 
 |                     torch.set_grad_enabled(True) | 
 |  | 
 |                 @torch.jit.script | 
 |                 def foo(x): | 
 |                     x = x + 1 | 
 |                     dim = x.dim() | 
 |                     disable_grad() | 
 |                     if dim == 0: | 
 |                         y = int(x) | 
 |                     else: | 
 |                         y = x.size()[dim - 1] | 
 |                     enable_grad() | 
 |                     return y | 
 |  | 
 |                 x = torch.zeros(2, requires_grad=True) | 
 |                 self.assertEqual(foo(x), 2) | 
 |                 self.assertEqual(foo(x), 2) | 
 |                 g = torch.jit.last_executed_optimized_graph() | 
 |                 # there should still be a Bailout after disable_grad call | 
 |                 FileCheck().check("disable_grad").check("BailOut[").check("BailoutTemplate").run(g) | 
 |  | 
 |     @unittest.skipIf(GRAPH_EXECUTOR != ProfilingMode.PROFILING, "skip if profiling isn't enabled") | 
 |     def test_profiling_merge(self): | 
 |         @torch.jit.script | 
 |         def test_not_const(x): | 
 |             if x.size(0) == 1: | 
 |                 return 1 | 
 |             else: | 
 |                 return 2 | 
 |  | 
 |         with enable_profiling_mode_for_profiling_tests(): | 
 |             with num_profiled_runs(2): | 
 |                 test_not_const(torch.rand([1, 2])) | 
 |                 test_not_const(torch.rand([2, 2])) | 
 |  | 
 |                 graph_str = torch.jit.last_executed_optimized_graph() | 
 |                 FileCheck().check("profiled_type=Double(*, 2, strides=[2, 1], requires_grad=0, device=cpu").run(graph_str) | 
 |                 FileCheck().check_not("profiled_type=Double(1, 2, strides=[2, 1], requires_grad=0, device=cpu").run(graph_str) | 
 |  | 
 |  | 
 |     def test_nested_bailouts(self): | 
 |         @torch.jit.script | 
 |         def fct_loop(x): | 
 |             for i in range(3): | 
 |                 x = torch.cat((x, x), 0) | 
 |             return x | 
 |  | 
 |         x = torch.ones(2, 3, 4, dtype=torch.float32) | 
 |         out = fct_loop(x) | 
 |         jit_trace = torch.jit.trace(fct_loop, x) | 
 |         out_trace = jit_trace(x) | 
 |  | 
 |     def test_no_self_arg_ignore_function(self): | 
 |         class MyModule(nn.Module): | 
 |             @torch.jit.ignore  # noqa: B902 | 
 |             def call_np():  # noqa: B902 | 
 |                 # type: () -> int | 
 |                 return np.random.choice(2, p=[.95, .05]) | 
 |  | 
 |             def forward(self): | 
 |                 return self.call_np() | 
 |  | 
 |         with self.assertRaisesRegex(Exception, "does not have a self argument"): | 
 |             torch.jit.script(MyModule()) | 
 |  | 
 |     def test_loop_liveness(self): | 
 |         with enable_profiling_mode_for_profiling_tests(): | 
 |             @torch.jit.script | 
 |             def f(i): | 
 |                 # type: (int) -> Tensor | 
 |                 l = [] | 
 |                 for n in [2, 1]: | 
 |                     l.append(torch.zeros(n, i)) | 
 |  | 
 |                 return l[0] | 
 |  | 
 |             f(2) | 
 |             f(1) | 
 |  | 
 |     def test_bailout_loop_carried_deps_name_clash(self): | 
 |         with enable_profiling_mode_for_profiling_tests(): | 
 |             NUM_ITERATIONS = 10 | 
 |  | 
 |             @torch.jit.script | 
 |             def fct_loop(z, size): | 
 |                 # type: (int, int) -> Tuple[Tensor, List[int]] | 
 |                 counters = torch.jit.annotate(List[int], []) | 
 |                 j = 0 | 
 |                 y = torch.ones(2) | 
 |                 for i in range(size): | 
 |                     counters.append(i + j) | 
 |                     y = torch.cat((y, torch.ones(z)), 0) | 
 |                     j = j + 1 | 
 |                 return y, counters | 
 |  | 
 |             inputs = [1, 2, 3, 4] | 
 |             expected = [x * 2 for x in range(NUM_ITERATIONS)] | 
 |             for inp in inputs: | 
 |                 results = fct_loop(inp, NUM_ITERATIONS) | 
 |                 self.assertEqual(results[1], expected) | 
 |  | 
 |     def test_bailout_loop_counter_transition(self): | 
 |         with enable_profiling_mode_for_profiling_tests(): | 
 |             NUM_ITERATIONS = 10 | 
 |  | 
 |             @torch.jit.script | 
 |             def fct_loop(z, size): | 
 |                 # type: (int, int) -> Tuple[Tensor, List[int]] | 
 |                 counters = torch.jit.annotate(List[int], []) | 
 |                 y = torch.ones(2) | 
 |                 for i in range(size): | 
 |                     counters.append(i) | 
 |                     y = torch.cat((y, torch.ones(z)), 0) | 
 |                 return y, counters | 
 |  | 
 |             inputs = [1, 2, 3, 4] | 
 |             expected = list(range(NUM_ITERATIONS)) | 
 |             for inp in inputs: | 
 |                 results = fct_loop(inp, NUM_ITERATIONS) | 
 |                 self.assertEqual(results[1], expected) | 
 |  | 
 |     def test_ignored_method_binding(self): | 
 |         class Bar(torch.nn.Module): | 
 |             def __init__(self): | 
 |                 super(Bar, self).__init__() | 
 |                 self.x : int = 0 | 
 |  | 
 |             @torch.jit.export | 
 |             def setx(self, x : int): | 
 |                 self.x = x | 
 |  | 
 |             @torch.jit.export | 
 |             def getx(self): | 
 |                 return self.x | 
 |  | 
 |             @torch.jit.ignore | 
 |             def ignored_getx(self): | 
 |                 return self.x | 
 |  | 
 |         b = Bar() | 
 |         b.setx(123) | 
 |         sb = torch.jit.script(b) | 
 |         self.assertEqual(sb.getx(), 123) | 
 |         self.assertEqual(sb.ignored_getx(), 123) | 
 |  | 
 |         sb.setx(456) | 
 |         self.assertEqual(sb.getx(), 456) | 
 |         self.assertEqual(sb.ignored_getx(), 456) | 
 |  | 
 |     def test_set_attribute_through_optional(self): | 
 |         class A(torch.nn.Module): | 
 |             __annotations__ = {"x": Optional[torch.Tensor]} | 
 |  | 
 |             def __init__(self): | 
 |                 super(A, self).__init__() | 
 |                 self.x = None | 
 |  | 
 |             @torch.jit.ignore | 
 |             def foo(self): | 
 |                 if self.x is None: | 
 |                     self.x = torch.tensor([3]) | 
 |                 return self.x | 
 |  | 
 |             def forward(self, x): | 
 |                 a = self.foo() | 
 |                 return x + 1 | 
 |  | 
 |         m = torch.jit.script(A()) | 
 |         self.assertEqual(m.x, None) | 
 |         m(torch.rand(1)) | 
 |         self.assertEqual(m.x, torch.tensor([3])) | 
 |  | 
 |     def test_mutate_constant(self): | 
 |         class M(torch.jit.ScriptModule): | 
 |             __constants__ = ["foo"] | 
 |  | 
 |             def __init__(self, foo): | 
 |                 super(M, self).__init__() | 
 |                 self.foo = foo | 
 |  | 
 |         m = M(5) | 
 |         # m has a constant attribute, but we can't | 
 |         # assign to it | 
 |         with self.assertRaises(RuntimeError): | 
 |             m.foo = 6 | 
 |  | 
 |     def test_class_attribute(self): | 
 |         class M(torch.jit.ScriptModule): | 
 |             FOO = 0 | 
 |  | 
 |             def __init__(self): | 
 |                 super(M, self).__init__() | 
 |                 self.foo = self.FOO | 
 |         m = M() | 
 |         self.assertEqual(m.foo, M.FOO) | 
 |  | 
 |     def test_class_attribute_in_script(self): | 
 |         class M(torch.jit.ScriptModule): | 
 |             FOO = 0 | 
 |  | 
 |             def __init__(self): | 
 |                 super(M, self).__init__() | 
 |  | 
 |             @torch.jit.script_method | 
 |             def forward(self): | 
 |                 return self.FOO | 
 |         with self.assertRaises(RuntimeError): | 
 |             M() | 
 |  | 
 |     def test_not_initialized_err(self): | 
 |         class M(torch.jit.ScriptModule): | 
 |             def __init__(self): | 
 |                 self.foo = torch.rand(2, 3) | 
 |         with self.assertRaises(RuntimeError): | 
 |             M() | 
 |  | 
 |     def test_attribute_in_init(self): | 
 |         class M(torch.jit.ScriptModule): | 
 |             def __init__(self): | 
 |                 super(M, self).__init__() | 
 |                 self.foo = torch.jit.Attribute(0.1, float) | 
 |                 # we should be able to use self.foo as a float here | 
 |                 assert 0.0 < self.foo | 
 |         M() | 
 |  | 
 |     def test_scriptable_fn_as_attr(self): | 
 |         class M(torch.nn.Module): | 
 |             def __init__(self, fn): | 
 |                 super(M, self).__init__() | 
 |                 self.fn = fn | 
 |  | 
 |             def forward(self, x): | 
 |                 return self.fn(x) | 
 |  | 
 |         m = M(torch.sigmoid) | 
 |         inp = torch.rand(2, 3) | 
 |         self.checkModule(m, (inp, )) | 
 |  | 
 |     def test_sequence_parsing(self): | 
 |         tests = [ | 
 |             ("return [x, x,]", True), | 
 |             ("return [x x]", "expected ]"), | 
 |             ("return x, x,", True), | 
 |             ("return bar(x, x,)", True), | 
 |             ("return bar()", "Argument x not provided"), | 
 |             ("for a, b, in x, x,:\n        pass", "List of iterables"), | 
 |             ("a, b, = x, x,\n    return a + b", True) | 
 |         ] | 
 |         for exp, result in tests: | 
 |             cu = torch.jit.CompilationUnit() | 
 |             full = """ | 
 | def bar(x, y): | 
 |     return x + y | 
 | def foo(x): | 
 |     {} | 
 |             """.format(exp) | 
 |             if isinstance(result, str): | 
 |                 with self.assertRaisesRegex(RuntimeError, result): | 
 |                     cu.define(full) | 
 |             else: | 
 |                 cu.define(full) | 
 |  | 
 |     def test_namedtuple_python(self): | 
 |         global MyTuple, MyMod  # see [local resolution in python] | 
 |         MyTuple = namedtuple('MyTuple', ['a']) | 
 |  | 
 |         @torch.jit.unused | 
 |         def fn(): | 
 |             # type: () -> MyTuple | 
 |             return MyTuple(1) | 
 |  | 
 |         # Only check compilation | 
 |         @torch.jit.script | 
 |         def fn2(): | 
 |             # type: () -> MyTuple | 
 |             return fn() | 
 |  | 
 |         FileCheck().check("NamedTuple").run(fn2.graph) | 
 |  | 
 |         class MyMod(torch.nn.Module): | 
 |             def __init__(self): | 
 |                 super(MyMod, self).__init__() | 
 |  | 
 |             @torch.jit.unused | 
 |             def fn(self): | 
 |                 # type: () -> MyTuple | 
 |                 return MyTuple(1) | 
 |  | 
 |             def forward(self, x): | 
 |                 if 1 == 1: | 
 |                     return MyTuple(torch.rand(2, 3)) | 
 |                 else: | 
 |                     return self.fn() | 
 |  | 
 |         # shouldn't throw a type error | 
 |         torch.jit.script(MyMod()) | 
 |  | 
 |     def test_unused_decorator(self): | 
 |         class MyMod(torch.nn.Module): | 
 |             def __init__(self): | 
 |                 super(MyMod, self).__init__() | 
 |  | 
 |             @torch.jit.unused | 
 |             @torch.no_grad() | 
 |             def fn(self, x): | 
 |                 # type: (Tensor) -> int | 
 |                 return next(x)  # invalid, but should be ignored | 
 |  | 
 |             def forward(self, x): | 
 |                 return self.fn(x) | 
 |  | 
 |         torch.jit.script(MyMod()) | 
 |  | 
 |     @_inline_everything | 
 |     def test_lazy_script(self): | 
 |         def untraceable(x): | 
 |             if x.ndim > 2: | 
 |                 print("hello") | 
 |             else: | 
 |                 print("goodbye") | 
 |             return x + 2 | 
 |  | 
 |         # Non-working example | 
 |         def fn(x): | 
 |             return untraceable(x) | 
 |  | 
 |         with self.capture_stdout(): | 
 |             traced_bad = torch.jit.trace(fn, [torch.ones(2, 2)]) | 
 |  | 
 |         FileCheck().check_not("goodbye").check_not("hello").run(traced_bad.graph) | 
 |  | 
 |         # Working example | 
 |         untraceable = torch.jit.script_if_tracing(untraceable) | 
 |  | 
 |         def fn2(x): | 
 |             return untraceable(x) | 
 |  | 
 |         with self.capture_stdout(): | 
 |             traced = torch.jit.trace(fn, [torch.ones(2, 2)]) | 
 |  | 
 |         FileCheck().check("goodbye").run(traced.graph) | 
 |  | 
 |         def foo(x: int): | 
 |             return x + 1 | 
 |  | 
 |         @torch.jit.script_if_tracing | 
 |         def fee(x: int = 2): | 
 |             return foo(1) + x | 
 |  | 
 |         # test directly compiling function | 
 |         fee_compiled = torch.jit.script(fee) | 
 |         self.assertEqual(fee_compiled(), fee()) | 
 |  | 
 |         # test compiling it within another function | 
 |         @torch.jit.script | 
 |         def hum(): | 
 |             return fee(x=3) | 
 |  | 
 |         self.assertEqual(hum(), 5) | 
 |  | 
 |     def test_big_int_literals(self): | 
 |         def ok(): | 
 |             # signed 64 bit max | 
 |             a = 9223372036854775807 | 
 |             return a | 
 |  | 
 |         def toobig(): | 
 |             a = 9223372036854775808 | 
 |             return a | 
 |  | 
 |         def waytoobig(): | 
 |             a = 99999999999999999999 | 
 |             return a | 
 |  | 
 |         self.checkScript(ok, []) | 
 |  | 
 |         with self.assertRaisesRegex(RuntimeError, "out of range"): | 
 |             torch.jit.script(toobig) | 
 |  | 
 |         with self.assertRaisesRegex(RuntimeError, "out of range"): | 
 |             torch.jit.script(waytoobig) | 
 |  | 
 |     def test_hex_literals(self): | 
 |         def test1(): | 
 |             return 0xaaaaaa | 
 |  | 
 |         def test2(): | 
 |             return 0xaaaaaa | 
 |  | 
 |         def test3(): | 
 |             return -0xaaaaaa | 
 |  | 
 |         self.checkScript(test1, []) | 
 |         self.checkScript(test2, []) | 
 |         self.checkScript(test3, []) | 
 |  | 
 |         def ok(): | 
 |             a = 0x7FFFFFFFFFFFFFFF | 
 |             return a | 
 |  | 
 |         def toobig(): | 
 |             a = 0xFFFFFFFFFFFFFFFF | 
 |             return a | 
 |  | 
 |         def waytoobig(): | 
 |             a = 0xFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFF | 
 |             return a | 
 |  | 
 |         self.checkScript(ok, []) | 
 |  | 
 |         with self.assertRaisesRegex(RuntimeError, "out of range"): | 
 |             torch.jit.script(toobig) | 
 |  | 
 |         with self.assertRaisesRegex(RuntimeError, "out of range"): | 
 |             torch.jit.script(waytoobig) | 
 |  | 
 |     def test_big_float_literals(self): | 
 |         def ok(): | 
 |             # Python interprets this as inf | 
 |             a = 1.2E400 | 
 |             return a | 
 |  | 
 |         def check(fn): | 
 |             self.assertTrue(fn() == ok()) | 
 |  | 
 |         # checkScript doesn't work since assertEqual doesn't consider | 
 |         # `inf` == `inf` | 
 |         check(torch.jit.script(ok)) | 
 |  | 
 |         cu = torch.jit.CompilationUnit() | 
 |         cu.define(dedent(inspect.getsource(ok))) | 
 |         check(cu.ok) | 
 |  | 
 |     def _test_device_type(self, dest): | 
 |         def fn(x): | 
 |             # type: (Device) -> Tuple[str, Optional[int]] | 
 |             return x.type, x.index | 
 |  | 
 |         device = torch.ones(2).to(dest).device | 
 |         self.checkScript(fn, [device]) | 
 |  | 
 |     def test_device_type(self): | 
 |         self._test_device_type('cpu') | 
 |  | 
 |     @unittest.skipIf(not RUN_CUDA, "Requires CUDA") | 
 |     def test_device_type_cuda(self): | 
 |         self._test_device_type('cuda') | 
 |  | 
 |     def test_string_device_implicit_conversion(self): | 
 |         @torch.jit.script | 
 |         def fn(x: torch.device): | 
 |             return x | 
 |  | 
 |         self.assertEqual(fn("cpu"), torch.device("cpu")) | 
 |  | 
 |         with self.assertRaisesRegex(RuntimeError, "Expected one of"): | 
 |             fn("invalid_device") | 
 |  | 
 |     def test_eval_python(self): | 
 |         def _test(m): | 
 |             self.assertTrue(m(torch.ones(2, 2))) | 
 |             self.assertTrue(m.training) | 
 |             self.assertTrue(m._c.getattr('training')) | 
 |  | 
 |             m.eval() | 
 |  | 
 |             self.assertFalse(m.training) | 
 |             self.assertFalse(m._c.getattr('training')) | 
 |             self.assertFalse(m(torch.ones(2, 2))) | 
 |  | 
 |             buffer = io.BytesIO() | 
 |             torch.jit.save(m, buffer) | 
 |             buffer.seek(0) | 
 |  | 
 |             loaded = torch.jit.load(buffer) | 
 |  | 
 |             self.assertFalse(loaded.training) | 
 |             self.assertFalse(loaded._c.getattr('training')) | 
 |  | 
 |         class M(nn.Module): | 
 |             def __init__(self): | 
 |                 super(M, self).__init__() | 
 |  | 
 |             def forward(self, x): | 
 |                 return self.training | 
 |  | 
 |         class OldM(torch.jit.ScriptModule): | 
 |             def __init__(self): | 
 |                 super(OldM, self).__init__() | 
 |  | 
 |             @torch.jit.script_method | 
 |             def forward(self, x): | 
 |                 return self.training | 
 |  | 
 |         _test(torch.jit.script(M())) | 
 |         _test(OldM()) | 
 |  | 
 |     def test_inherit_method(self): | 
 |         class A(torch.jit.ScriptModule): | 
 |             def __init__(self): | 
 |                 super(A, self).__init__() | 
 |  | 
 |             @torch.jit.script_method | 
 |             def forward(self, x): | 
 |                 return x + self.bar(x) | 
 |  | 
 |         class B(A): | 
 |             def __init__(self): | 
 |                 super(B, self).__init__() | 
 |  | 
 |             @torch.jit.script_method | 
 |             def bar(self, x): | 
 |                 return x * x | 
 |  | 
 |         with self.assertRaisesRegex(RuntimeError, 'attribute'): | 
 |             A()  # cannot use because bar is not defined | 
 |  | 
 |         v = torch.rand(3, 4) | 
 |         b = B() | 
 |         self.assertEqual(b(v), v + v * v) | 
 |  | 
 |         class C(torch.jit.ScriptModule): | 
 |             def __init__(self): | 
 |                 super(C, self).__init__() | 
 |  | 
 |             @torch.jit.script_method | 
 |             def bar(self, x): | 
 |                 return x | 
 |  | 
 |         class D(C, B): | 
 |             def __init__(self): | 
 |                 super(D, self).__init__() | 
 |  | 
 |         self.assertEqual(D()(v), v + v) | 
 |  | 
 |     def test_tensor_subclasses(self): | 
 |         def check_subclass(x, tensor): | 
 |             template = dedent(""" | 
 |                 def func(input: {}) -> {}: | 
 |                     return torch.zeros((input.shape[0], 1), dtype=input.dtype) | 
 |                 """) | 
 |  | 
 |             self._check_code(template.format(x, x), "func", [tensor]) | 
 |  | 
 |         check_subclass("torch.LongTensor", torch.LongTensor([[1, 2], [3, 4]])) | 
 |         check_subclass("torch.DoubleTensor", torch.DoubleTensor([[1.2, 2.3], [3.4, 4.5]])) | 
 |         check_subclass("torch.IntTensor", torch.IntTensor([[1, 2], [3, 4]])) | 
 |         check_subclass("torch.BoolTensor", torch.BoolTensor([[False, True], [True, False]])) | 
 |  | 
 |         def check_subclass_warn(input: torch.LongTensor) -> torch.LongTensor: | 
 |             return torch.zeros((input.shape[0], 1), dtype=input.dtype) | 
 |  | 
 |         with warnings.catch_warnings(record=True) as warns: | 
 |             scripted = torch.jit.script(check_subclass_warn) | 
 |         FileCheck().check("TorchScript will treat type annotations of Tensor").run(str(warns[0])) | 
 |  | 
 |     def test_first_class_module(self): | 
 |         class Foo(torch.jit.ScriptModule): | 
 |             def __init__(self): | 
 |                 super(Foo, self).__init__() | 
 |                 self.foo = nn.Parameter(torch.rand(3, 4)) | 
 |  | 
 |             @torch.jit.script_method | 
 |             def forward(self, input): | 
 |                 self.foo = input | 
 |                 return self.foo | 
 |         foo = Foo() | 
 |         input = torch.rand(3, 4) | 
 |         foo.forward(input) | 
 |         self.assertEqual(input, foo.foo) | 
 |  | 
 |     @_tmp_donotuse_dont_inline_everything | 
 |     def test_first_class_calls(self): | 
 |         @torch.jit.script | 
 |         class Foo(object): | 
 |             def __init__(self, x): | 
 |                 self.bar = x | 
 |  | 
 |             def stuff(self, x): | 
 |                 return self.bar + x | 
 |  | 
 |         @torch.jit.script | 
 |         def foo(x): | 
 |             return x * x + Foo(x).stuff(2 * x) | 
 |  | 
 |         @torch.jit.script | 
 |         def bar(x): | 
 |             return foo(x) * foo(x) | 
 |  | 
 |         x = torch.rand(3, 4) | 
 |         self.assertEqual(bar(x), (x * x + 3 * x) * (x * x + 3 * x)) | 
 |  | 
 |     def test_static_methods(self): | 
 |         class M(nn.Module): | 
 |             def __init__(self): | 
 |                 super(M, self).__init__() | 
 |  | 
 |             @staticmethod | 
 |             def my_method(x): | 
 |                 return x + 100 | 
 |  | 
 |             def forward(self, x): | 
 |                 return x + M.my_method(x) | 
 |  | 
 |         class N(nn.Module): | 
 |             def __init__(self): | 
 |                 super(N, self).__init__() | 
 |  | 
 |             @staticmethod | 
 |             def my_method(x): | 
 |                 return x * 100 | 
 |  | 
 |             def forward(self, x): | 
 |                 return x - M.my_method(x) + N.my_method(x) | 
 |  | 
 |         self.checkModule(M(), (torch.ones(2, 2),)) | 
 |  | 
 |         self.checkModule(N(), (torch.ones(2, 2),)) | 
 |  | 
 |     def test_invalid_prefix_annotation(self): | 
 |         with self.assertRaisesRegex(RuntimeError, "annotation prefix in line"): | 
 |             with self.capture_stdout() as captured: | 
 |                 @torch.jit.script | 
 |                 def invalid_prefix_annotation1(a): | 
 |                     #type: (Int) -> Int # noqa: E265 | 
 |                     return a + 2 | 
 |  | 
 |         with self.assertRaisesRegex(RuntimeError, "annotation prefix in line"): | 
 |             with self.capture_stdout() as captured: | 
 |                 @torch.jit.script | 
 |                 def invalid_prefix_annotation2(a): | 
 |                     #type   : (Int) -> Int # noqa: E265 | 
 |                     return a + 2 | 
 |  | 
 |         with self.assertRaisesRegex(RuntimeError, "annotation prefix in line"): | 
 |             with self.capture_stdout() as captured: | 
 |                 @torch.jit.script | 
 |                 def invalid_prefix_annotation3(a): | 
 |                     #     type: (Int) -> Int | 
 |                     return a + 2 | 
 |  | 
 |     def test_builtin_function_attributes(self): | 
 |         class Add(nn.Module): | 
 |             def __init__(self): | 
 |                 super(Add, self).__init__() | 
 |                 self.add = torch.add | 
 |  | 
 |             def forward(self, input): | 
 |                 return self.add(input, input) | 
 |  | 
 |         self.checkModule(Add(), [torch.randn(2, 2)]) | 
 |  | 
 |     def test_pybind_type_comparisons(self): | 
 |         @torch.jit.script | 
 |         def f(): | 
 |             return None | 
 |  | 
 |         node = list(f.graph.nodes())[0] | 
 |         t = node.outputsAt(0).type() | 
 |         self.assertIsNotNone(t) | 
 |  | 
 |     @unittest.skipIf(IS_WINDOWS and sys.version_info >= (3, 8), 'TODO: need to fix the test case') | 
 |     def test_unmatched_type_annotation(self): | 
 |         message1 = re.escape("Number of type annotations (2) did not match the number of function parameters (1):") | 
 |         message2 = 'def invalid2\\(a\\):\n\\s*~+\\.*\\s+<--- HERE\n\\s+# type: \\(Int, Int\\) -> Int\n\\s+return a \\+ 2' | 
 |         message3 = 'def invalid4\\(a\\):\n\\s*~+\\.*\\s+<--- HERE\n\\s+# type: \\(Int, Int\\) -> Int\n\\s+return a \\+ 2' | 
 |         with self.assertRaisesRegex(RuntimeError, message1): | 
 |             @torch.jit.script | 
 |             def invalid1(a): | 
 |                 # type: (Int, Int) -> Int | 
 |                 return a + 2 | 
 |  | 
 |         with self.assertRaisesRegex(RuntimeError, message2): | 
 |             @torch.jit.script | 
 |             def invalid2(a): | 
 |                 # type: (Int, Int) -> Int | 
 |                 return a + 2 | 
 |  | 
 |         with self.assertRaisesRegex(RuntimeError, message1): | 
 |             def invalid3(a): | 
 |                 # type: (Int, Int) -> Int | 
 |                 return a + 2 | 
 |             torch.jit.script(invalid3) | 
 |  | 
 |         with self.assertRaisesRegex(RuntimeError, message3): | 
 |             def invalid4(a): | 
 |                 # type: (Int, Int) -> Int | 
 |                 return a + 2 | 
 |             torch.jit.script(invalid4) | 
 |  | 
 |     def test_is_optional(self): | 
 |         ann = Union[List[int], List[float]] | 
 |         torch._jit_internal.is_optional(ann) | 
 |  | 
 |     def test_interpreter_fuzz(self): | 
 |         import builtins | 
 |         # This test generates random tree-like programs to fuzz test | 
 |         # that the interpreter does not have a bug in its stack manipulation | 
 |         # code. An assert in that code ensures individual operators are | 
 |         # not reordered. | 
 |         templates = [ | 
 |             "torch.rand(3, 4)", | 
 |             "({} + {})", | 
 |             "-{}", | 
 |             "({} * {})", | 
 |             "torch.tanh({})", | 
 |             "VAR {}", | 
 |         ] | 
 |  | 
 |         def gen_code(): | 
 |             src_lines = ['def f():'] | 
 |             exprs = [] | 
 |             n_variables = 0 | 
 |  | 
 |             def get_expr(idx): | 
 |                 elem = exprs[idx] | 
 |                 exprs[idx] = exprs[-1] | 
 |                 exprs.pop() | 
 |                 return elem | 
 |  | 
 |             def select_expr_or_var(): | 
 |                 idx = random.randrange(0, len(exprs) + n_variables) | 
 |                 if idx < len(exprs): | 
 |                     return get_expr(idx) | 
 |                 else: | 
 |                     return 'v{}'.format(idx - len(exprs)) | 
 |  | 
 |             for i in range(50): | 
 |                 n = None | 
 |                 while n is None or n > len(exprs) + n_variables: | 
 |                     template = random.choice(templates) | 
 |                     n = template.count('{}') | 
 |  | 
 |                 if 'VAR' in template: | 
 |                     src_lines.append('  v{} = {}'.format(n_variables, select_expr_or_var())) | 
 |                     n_variables += 1 | 
 |                 else: | 
 |                     exprs.append(template.format(*(select_expr_or_var() for _ in range(n)))) | 
 |  | 
 |             src_lines.append('  return ({})\n'.format(''.join('v{},'.format(i) for i in range(n_variables)))) | 
 |             return '\n'.join(src_lines) | 
 |  | 
 |         for i in range(100): | 
 |             g = {'torch': torch} | 
 |             code = gen_code() | 
 |             builtins.exec(code, g, None) | 
 |             cu = torch.jit.CompilationUnit(code) | 
 |             with freeze_rng_state(): | 
 |                 o1 = g['f']() | 
 |             with freeze_rng_state(): | 
 |                 o2 = cu.f() | 
 |             self.assertEqual(o1, o2) | 
 |  | 
 |     @skipIfTorchDynamo("TorchDynamo fails with unknown reason") | 
 |     def test_cpp_module_iterator(self): | 
 |         a = nn.Module() | 
 |         a.name = 'a' | 
 |         a.p = nn.Parameter(torch.rand(3, 4)) | 
 |         a.foo = nn.Module() | 
 |         a.foo.name = 'foo' | 
 |         a.foo.register_buffer('b', torch.rand(1, 1)) | 
 |         a.foo.bar = nn.Module() | 
 |         a.foo.bar.name = 'bar' | 
 |         a.foo.bar.an_int = 4 | 
 |         a.another = nn.Module() | 
 |         a.another.name = 'another' | 
 |         sa = torch.jit.script(a) | 
 |         result = torch._C._jit_debug_module_iterators(sa._c) | 
 |  | 
 |         def replace(e): | 
 |             if e is a.p: | 
 |                 return 'P' | 
 |             elif e is a.foo.b: | 
 |                 return 'B' | 
 |             elif isinstance(e, torch._C.ScriptModule): | 
 |                 return e.getattr('name') | 
 |  | 
 |             return e | 
 |         for k, v in result.items(): | 
 |             for i in range(len(v)): | 
 |                 if isinstance(v[i], tuple): | 
 |                     n, v2 = v[i] | 
 |                     v[i] = (n, replace(v2)) | 
 |                 else: | 
 |                     v[i] = replace(v[i]) | 
 |             # module type creation is not deterministic, so we have to sort | 
 |             # the result | 
 |             v.sort() | 
 |         expected = {'buffers': [], | 
 |                     'buffers_r': ['B'], | 
 |                     'children': ['another', 'foo'], | 
 |                     'modules': ['a', 'another', 'bar', 'foo'], | 
 |                     'named_attributes': [('_is_full_backward_hook', None), | 
 |                                          ('another', 'another'), | 
 |                                          ('foo', 'foo'), | 
 |                                          ('name', 'a'), | 
 |                                          ('p', 'P'), | 
 |                                          ('training', True)], | 
 |                     'named_attributes_r': [('_is_full_backward_hook', None), | 
 |                                            ('another', 'another'), | 
 |                                            ('another._is_full_backward_hook', None), | 
 |                                            ('another.name', 'another'), | 
 |                                            ('another.training', True), | 
 |                                            ('foo', 'foo'), | 
 |                                            ('foo._is_full_backward_hook', None), | 
 |                                            ('foo.b', 'B'), | 
 |                                            ('foo.bar', 'bar'), | 
 |                                            ('foo.bar._is_full_backward_hook', None), | 
 |                                            ('foo.bar.an_int', 4), | 
 |                                            ('foo.bar.name', 'bar'), | 
 |                                            ('foo.bar.training', True), | 
 |                                            ('foo.name', 'foo'), | 
 |                                            ('foo.training', True), | 
 |                                            ('name', 'a'), | 
 |                                            ('p', 'P'), | 
 |                                            ('training', True)], | 
 |                     'named_buffers': [], | 
 |                     'named_buffers_r': [('foo.b', 'B')], | 
 |                     'named_children': [('another', 'another'), ('foo', 'foo')], | 
 |                     'named_modules': [('', 'a'), | 
 |                                       ('another', 'another'), | 
 |                                       ('foo', 'foo'), | 
 |                                       ('foo.bar', 'bar')], | 
 |                     'named_parameters': [('p', 'P')], | 
 |                     'named_parameters_r': [('p', 'P')], | 
 |                     'parameters': ['P'], | 
 |                     'parameters_r': ['P']} | 
 |         self.assertEqual(expected, result) | 
 |  | 
 |     def test_parameter_order(self): | 
 |         m = nn.Module() | 
 |         for i, name in enumerate(string.ascii_letters): | 
 |             setattr(m, name, nn.Parameter(torch.tensor([float(i)]))) | 
 |         ms = torch.jit.script(m) | 
 |         print(torch.cat(list(m.parameters()))) | 
 |         print(torch.cat(list(ms.parameters()))) | 
 |         self.assertEqual(list(m.parameters()), list(ms.parameters())) | 
 |  | 
 |     def test_python_op_builtins(self): | 
 |         @torch.jit.unused | 
 |         def fn(x): | 
 |             # type: (List[int]) -> int | 
 |             return sum(x) | 
 |  | 
 |         @torch.jit.script | 
 |         def script_fn(x): | 
 |             # type: (List[int]) -> int | 
 |             return fn(x) | 
 |  | 
 |     def test_submodule_twice(self): | 
 |         @torch.jit.script | 
 |         def foo(x): | 
 |             return x * x | 
 |  | 
 |         class What(torch.jit.ScriptModule): | 
 |             def __init__(self, x): | 
 |                 super(What, self).__init__() | 
 |                 self.foo = x | 
 |         a = What(foo) | 
 |         c = What(foo) | 
 |  | 
 |     def test_training_param(self): | 
 |         class What(torch.jit.ScriptModule): | 
 |             def __init__(self): | 
 |                 super(What, self).__init__() | 
 |  | 
 |             @torch.jit.script_method | 
 |             def forward(self, x): | 
 |                 # type: (int) -> int | 
 |                 if self.training: | 
 |                     r = x | 
 |                 else: | 
 |                     r = x + 4 | 
 |                 # check double use of training | 
 |                 if self.training: | 
 |                     r = r + 1 | 
 |                 return r | 
 |  | 
 |         w = What() | 
 |         self.assertEqual(4, w(3)) | 
 |         w.train(False) | 
 |         self.assertEqual(7, w(3)) | 
 |         self.assertFalse("training" in w.state_dict()) | 
 |  | 
 |     def test_class_as_attribute(self): | 
 |         @torch.jit.script | 
 |         class Foo321(object): | 
 |             def __init__(self): | 
 |                 self.x = 3 | 
 |  | 
 |         class FooBar1234(torch.nn.Module): | 
 |             def __init__(self): | 
 |                 super(FooBar1234, self).__init__() | 
 |                 self.f = Foo321() | 
 |  | 
 |             def forward(self, x): | 
 |                 return x + self.f.x | 
 |  | 
 |         scripted = torch.jit.script(FooBar1234()) | 
 |         eic = self.getExportImportCopy(scripted) | 
 |         x = torch.rand(3, 4) | 
 |         self.assertEqual(scripted(x), eic(x)) | 
 |  | 
 |     def test_module_str(self): | 
 |         class Foo(torch.nn.Module): | 
 |             def forward(self, x): | 
 |                 return torch.relu(x) | 
 |  | 
 |         f = torch.jit.script(Foo()) | 
 |         self.assertEqual('ScriptObject', str(f._c)) | 
 |  | 
 |     def test_jitter_bug(self): | 
 |         @torch.jit.script | 
 |         def fn2(input, kernel_size): | 
 |             # type: (Tensor, List[int]) -> Tensor | 
 |             if kernel_size[0] > 1: | 
 |                 _stride = [2] | 
 |             else: | 
 |                 _stride = kernel_size | 
 |             print(_stride, kernel_size) | 
 |             return input | 
 |  | 
 |         @torch.jit.script | 
 |         def fn(input): | 
 |             # type: (Tensor) -> Tensor | 
 |             return fn2(input, [1]) | 
 |  | 
 |     def test_parser_kwargonly(self): | 
 |         cu = torch.jit.CompilationUnit(''' | 
 |             def foo(x, *, y) -> Tuple[Tensor, Tensor]: | 
 |                 return x, x | 
 |             def bar(x): | 
 |                 return foo(x, y=x) | 
 |         ''') | 
 |         self.assertTrue('*' in str(cu.foo.schema)) | 
 |         with self.assertRaisesRegex(RuntimeError, "not provided"): | 
 |             torch.jit.CompilationUnit(''' | 
 |                 def foo(x, *, y) -> Tuple[Tensor, Tensor]: | 
 |                     return x, x | 
 |                 def bar(x): | 
 |                     return foo(x, x) | 
 |             ''') | 
 |  | 
 |     def test_annoying_doubles(self): | 
 |         mod = types.ModuleType("temp") | 
 |         mod.inf = float("inf") | 
 |         mod.ninf = float("-inf") | 
 |         mod.nan = float("nan") | 
 |  | 
 |         with torch._jit_internal._disable_emit_hooks(): | 
 |             class Foo(torch.jit.ScriptModule): | 
 |                 def __init__(self): | 
 |                     super(Foo, self).__init__() | 
 |  | 
 |                 @torch.jit.script_method | 
 |                 def forward(self): | 
 |                     return math.pi, 0.1, mod.inf, mod.ninf, 2.225073858507201e-308, mod.nan | 
 |  | 
 |             foo = Foo() | 
 |             buffer = io.BytesIO() | 
 |             torch.jit.save(foo, buffer) | 
 |  | 
 |             buffer.seek(0) | 
 |             foo_loaded = torch.jit.load(buffer) | 
 |  | 
 |             r = foo() | 
 |             r2 = foo_loaded() | 
 |             # use precise assert, we are checking floating point details | 
 |             self.assertTrue(r[:-1] == r2[:-1]) | 
 |             self.assertTrue(math.isnan(r[-1]) and math.isnan(r2[-1])) | 
 |  | 
 |     def test_type_annotate(self): | 
 |  | 
 |         def foo(a): | 
 |             return torch.jit.annotate(torch.Tensor, a) | 
 |  | 
 |         self.checkScript(foo, (torch.rand(3),)) | 
 |  | 
 |         def bar(): | 
 |             a = torch.jit.annotate(List[int], []) | 
 |             for _ in range(10): | 
 |                 a.append(4) | 
 |             return a | 
 |  | 
 |         self.checkScript(bar, ()) | 
 |  | 
 |         def baz(a): | 
 |             return torch.jit.annotate(float, a) | 
 |         self.checkScript(baz, (torch.rand(()),)) | 
 |  | 
 |         # test annotate none types | 
 |         def annotate_none(): | 
 |             return torch.jit.annotate(Optional[torch.Tensor], None) | 
 |  | 
 |         self.checkScript(annotate_none, ()) | 
 |  | 
 |  | 
 |     def test_robust_op_resolution(self): | 
 |         neg = torch.add  # misleading name to make sure we resolve by function | 
 |  | 
 |         def stuff(x): | 
 |             return neg(x, x) | 
 |  | 
 |         a = (torch.rand(3),) | 
 |         self.checkScript(stuff, a) | 
 |  | 
 |     def test_nested_aug_assign(self): | 
 |         @torch.jit.script | 
 |         class SomeClass(object): | 
 |             def __init__(self): | 
 |                 self.num = 99 | 
 |  | 
 |             def __iadd__(self, x): | 
 |                 # type: (int) | 
 |                 self.num += x | 
 |                 return self | 
 |  | 
 |             def __eq__(self, other): | 
 |                 # type: (SomeClass) -> bool | 
 |                 return self.num == other.num | 
 |  | 
 |         @torch.jit.script | 
 |         class SomeOutOfPlaceClass(object): | 
 |             def __init__(self): | 
 |                 self.num = 99 | 
 |  | 
 |             def __add__(self, x): | 
 |                 # type: (int) | 
 |                 self.num = x | 
 |                 return self | 
 |  | 
 |             def __eq__(self, other): | 
 |                 # type: (SomeClass) -> bool | 
 |                 return self.num == other.num | 
 |  | 
 |         class Child(nn.Module): | 
 |             def __init__(self): | 
 |                 super().__init__() | 
 |                 self.x = 2 | 
 |                 self.o = SomeClass() | 
 |                 self.oop = SomeOutOfPlaceClass() | 
 |                 self.list = [1, 2, 3] | 
 |  | 
 |         class A(nn.Module): | 
 |             def __init__(self): | 
 |                 super().__init__() | 
 |                 self.child = Child() | 
 |  | 
 |             def forward(self): | 
 |                 self.child.x += 1 | 
 |                 self.child.o += 5 | 
 |                 self.child.oop += 5 | 
 |                 some_list = [1, 2] | 
 |                 self.child.list += some_list | 
 |                 self.child.list *= 2 | 
 |                 return self.child.x, self.child.o, self.child.list, self.child.oop | 
 |  | 
 |         a = A() | 
 |         sa = torch.jit.script(A()) | 
 |         eager_result = a() | 
 |         script_result = sa() | 
 |         self.assertEqual(eager_result, script_result) | 
 |         self.assertEqual(a.child.x, sa.child.x) | 
 |         self.assertEqual(a.child.o, sa.child.o) | 
 |         self.assertEqual(a.child.list, sa.child.list) | 
 |  | 
 |         @torch.jit.script | 
 |         class SomeNonAddableClass(object): | 
 |             def __init__(self): | 
 |                 self.num = 99 | 
 |  | 
 |             def __eq__(self, other): | 
 |                 # type: (SomeClass) -> bool | 
 |                 return self.num == other.num | 
 |  | 
 |         # with self.assertRaisesRegex(RuntimeError, "") | 
 |         class A(nn.Module): | 
 |             def __init__(self): | 
 |                 super().__init__() | 
 |                 self.x = SomeNonAddableClass() | 
 |  | 
 |             def forward(self): | 
 |                 self.x += SomeNonAddableClass() | 
 |                 return self.x | 
 |  | 
 |         with self.assertRaisesRegex(RuntimeError, "Cannot emit inplace op"): | 
 |             torch.jit.script(A()) | 
 |  | 
 |     def test_var_aug_assign(self): | 
 |         @torch.jit.script | 
 |         class SomeNonAddableClass(object): | 
 |             def __init__(self): | 
 |                 self.num = 99 | 
 |  | 
 |             def __eq__(self, other): | 
 |                 # type: (SomeNonAddableClass) -> bool | 
 |                 return self.num == other.num | 
 |  | 
 |         with self.assertRaisesRegex(RuntimeError, "Cannot emit inplace op"): | 
 |             @torch.jit.script | 
 |             def fn(): | 
 |                 a = SomeNonAddableClass() | 
 |                 a += SomeNonAddableClass() | 
 |                 return a | 
 |  | 
 |         @torch.jit.script | 
 |         class SomeClass(object): | 
 |             def __init__(self): | 
 |                 self.num = 99 | 
 |  | 
 |             def __iadd__(self, x): | 
 |                 # type: (int) | 
 |                 self.num += x | 
 |                 return self | 
 |  | 
 |             def __eq__(self, other): | 
 |                 # type: (SomeClass) -> bool | 
 |                 return self.num == other.num | 
 |  | 
 |         @torch.jit.script | 
 |         class SomeOutOfPlaceClass(object): | 
 |             def __init__(self): | 
 |                 self.num = 99 | 
 |  | 
 |             def __add__(self, x): | 
 |                 # type: (int) | 
 |                 self.num = x | 
 |                 return self | 
 |  | 
 |             def __eq__(self, other): | 
 |                 # type: (SomeClass) -> bool | 
 |                 return self.num == other.num | 
 |  | 
 |         def fn2(): | 
 |             a = SomeClass() | 
 |             a_copy = a | 
 |             a += 20 | 
 |             assert a is a_copy | 
 |             b = SomeOutOfPlaceClass() | 
 |             b_copy = b | 
 |             b += 99 | 
 |             assert b is b_copy | 
 |             c = [1, 2, 3] | 
 |             c_copy = c | 
 |             c *= 2 | 
 |             assert c is c_copy | 
 |             c += [4, 5, 6] | 
 |             d = torch.ones(2, 2) | 
 |             d_copy = d | 
 |             d += torch.ones(2, 2) | 
 |             assert d is d_copy | 
 |             return a, b, c, d | 
 |  | 
 |         self.checkScript(fn2, []) | 
 |  | 
 |     def test_nested_list_construct(self): | 
 |         def foo(): | 
 |             return [[4]] + [[4, 5]] | 
 |         self.checkScript(foo, ()) | 
 |  | 
 |     def test_file_line_error(self): | 
 |         def foobar(xyz): | 
 |             return torch.blargh(xyz) | 
 |  | 
 |         _, lineno = inspect.getsourcelines(foobar) | 
 |         with self.assertRaisesRegex(RuntimeError, "test_jit.py\", line {}".format(lineno + 1)): | 
 |             scripted = torch.jit.script(foobar) | 
 |  | 
 |     def test_file_line_error_class_defn(self): | 
 |         class FooBar(object): | 
 |             def baz(self, xyz): | 
 |                 return torch.blargh(xyz) | 
 |  | 
 |         _, lineno = inspect.getsourcelines(FooBar) | 
 |         with self.assertRaisesRegex(RuntimeError, "test_jit.py\", line {}".format(lineno + 2)): | 
 |             torch.jit.script(FooBar) | 
 |  | 
 |     def test_file_line_graph(self): | 
 |         def foobar(xyz): | 
 |             return torch.neg(xyz) | 
 |  | 
 |         scripted = torch.jit.script(foobar) | 
 |  | 
 |         _, lineno = inspect.getsourcelines(foobar) | 
 |         fc = FileCheck().check('test_jit.py:{}:19'.format(lineno + 1)) | 
 |         fc.run(scripted.graph) | 
 |         fc.run(str(scripted.graph)) | 
 |  | 
 |     def test_file_line_save_load(self): | 
 |         class Scripted(torch.jit.ScriptModule): | 
 |             @torch.jit.script_method | 
 |             def forward(self, xyz): | 
 |                 return torch.neg(xyz) | 
 |  | 
 |         scripted = Scripted() | 
 |  | 
 |         # NB: not using getExportImportCopy because that takes a different | 
 |         # code path that calls CompilationUnit._import rather than | 
 |         # going through the full save/load pathway | 
 |         buffer = scripted.save_to_buffer() | 
 |         bytesio = io.BytesIO(buffer) | 
 |         scripted = torch.jit.load(bytesio) | 
 |  | 
 |         _, lineno = inspect.getsourcelines(Scripted) | 
 |         fc = FileCheck().check(':{}'.format(lineno + 3)) | 
 |         fc.run(scripted.graph) | 
 |         fc.run(str(scripted.graph)) | 
 |  | 
 |     def test_file_line_string(self): | 
 |         scripted = torch.jit.CompilationUnit(''' | 
 | def foo(xyz): | 
 |     return torch.neg(xyz) | 
 |         ''') | 
 |  | 
 |         fc = FileCheck().check('<string>:3:11') | 
 |         fc.run(scripted.foo.graph) | 
 |         fc.run(str(scripted.foo.graph)) | 
 |  | 
 |     @skipIfCrossRef | 
 |     def test_file_line_trace(self): | 
 |         def foobar(xyz): | 
 |             return torch.neg(xyz) | 
 |  | 
 |         scripted = torch.jit.trace(foobar, (torch.rand(3, 4))) | 
 |  | 
 |         _, lineno = inspect.getsourcelines(foobar) | 
 |         fc = FileCheck().check('test_jit.py:{}:0'.format(lineno + 1)) | 
 |         fc.run(scripted.graph) | 
 |         fc.run(str(scripted.graph)) | 
 |  | 
 |     def test_serialized_source_ranges(self): | 
 |  | 
 |         class FooTest(torch.jit.ScriptModule): | 
 |             @torch.jit.script_method | 
 |             def forward(self, x, w): | 
 |                 return torch.mm(x, w.t()) | 
 |  | 
 |         ft = FooTest() | 
 |         loaded = self.getExportImportCopy(ft) | 
 |         _, lineno = inspect.getsourcelines(FooTest) | 
 |  | 
 |         with self.assertRaisesRegex(RuntimeError, 'test_jit.py\", line {}'.format(lineno + 3)): | 
 |             loaded(torch.rand(3, 4), torch.rand(30, 40)) | 
 |  | 
 |     def test_serialized_source_ranges_graph(self): | 
 |  | 
 |         class FooTest3(torch.jit.ScriptModule): | 
 |             @torch.jit.script_method | 
 |             def forward(self, x, w): | 
 |                 return torch.mm(x, w.t()) | 
 |  | 
 |         ft = FooTest3() | 
 |         loaded = self.getExportImportCopy(ft) | 
 |         _, lineno = inspect.getsourcelines(FooTest3) | 
 |  | 
 |         fc = FileCheck().check('test_jit.py:{}'.format(lineno + 3)) | 
 |         fc.run(loaded.graph) | 
 |  | 
 |     def test_serialized_source_ranges2(self): | 
 |  | 
 |         class FooTest2(torch.jit.ScriptModule): | 
 |             @torch.jit.script_method | 
 |             def forward(self): | 
 |                 raise RuntimeError('foo') | 
 |  | 
 |         _, lineno = inspect.getsourcelines(FooTest2) | 
 |  | 
 |         with self.assertRaisesRegex(torch.jit.Error, 'test_jit.py\", line {}'.format(lineno + 3)): | 
 |             ft = FooTest2() | 
 |             loaded = self.getExportImportCopy(ft) | 
 |             loaded() | 
 |  | 
 |     def test_serialized_source_ranges_dont_jitter(self): | 
 |         class FooTest3(torch.jit.ScriptModule): | 
 |             @torch.jit.script_method | 
 |             def forward(self, lim): | 
 |                 first = 1 | 
 |                 second = 1 | 
 |                 i = 1 | 
 |                 somenum = 5 | 
 |                 dontmutateme = 3 | 
 |                 third = 0 | 
 |                 while bool(i < lim): | 
 |                     third = first + second | 
 |                     first = second | 
 |                     second = third | 
 |                     j = 0 | 
 |                     while j < 10: | 
 |                         somenum = somenum * 2 | 
 |                         j = j + 1 | 
 |                     i = i + j | 
 |                     i = i + dontmutateme | 
 |  | 
 |                 st = second + third | 
 |                 fs = first + second | 
 |                 return third, st, fs | 
 |  | 
 |         ft3 = FooTest3() | 
 |  | 
 |         def debug_records_from_mod(self, mod): | 
 |             buffer = io.BytesIO() | 
 |             torch.jit.save(ft3, buffer) | 
 |             buffer.seek(0) | 
 |             archive = zipfile.ZipFile(buffer) | 
 |             files = filter(lambda x: x.startswith('archive/code/'), archive.namelist()) | 
 |             debug_files = list(filter(lambda f: f.endswith('.debug_pkl'), files)) | 
 |             self.assertEqual(len(debug_files), 1) | 
 |             debug_file = archive.open(debug_files[0]) | 
 |             return pickle.load(debug_file), buffer | 
 |  | 
 |         records1, buffer = debug_records_from_mod(self, ft3) | 
 |  | 
 |         buffer.seek(0) | 
 |         loaded = torch.jit.load(buffer) | 
 |         records2, buffer = debug_records_from_mod(self, loaded) | 
 |  | 
 |         buffer.seek(0) | 
 |         loaded2 = torch.jit.load(buffer) | 
 |         records3, _ = debug_records_from_mod(self, loaded2) | 
 |  | 
 |         self.assertEqual(records1, records2) | 
 |         self.assertEqual(records2, records3) | 
 |  | 
 |     def test_serialized_source_ranges_no_dups(self): | 
 |         class FooTest3(torch.jit.ScriptModule): | 
 |             @torch.jit.script_method | 
 |             def forward(self, lim): | 
 |                 first = 1 | 
 |                 second = 1 | 
 |                 i = 1 | 
 |                 somenum = 5 | 
 |                 dontmutateme = 3 | 
 |                 third = 0 | 
 |                 while bool(i < lim): | 
 |                     third = first + second | 
 |                     first = second | 
 |                     second = third | 
 |                     j = 0 | 
 |                     while j < 10: | 
 |                         somenum = somenum * 2 | 
 |                         j = j + 1 | 
 |                     i = i + j | 
 |                     i = i + dontmutateme | 
 |  | 
 |                 st = second + third | 
 |                 fs = first + second | 
 |                 return third, st, fs | 
 |  | 
 |         ft3 = FooTest3() | 
 |  | 
 |         def debug_records_from_mod(mod): | 
 |             buffer = io.BytesIO() | 
 |             torch.jit.save(ft3, buffer) | 
 |             buffer.seek(0) | 
 |             archive = zipfile.ZipFile(buffer) | 
 |             files = list(filter(lambda x: x.startswith('archive/code/'), archive.namelist())) | 
 |             debug_files = filter(lambda f: f.endswith('.debug_pkl'), files) | 
 |             debug_files = (archive.open(f) for f in debug_files) | 
 |             debug_files = (pickle.load(f) for f in debug_files) | 
 |             debug_files = (f[2] for f in debug_files) | 
 |             return list(debug_files) | 
 |  | 
 |         debug_files = debug_records_from_mod(ft3) | 
 |         for debug_file in debug_files: | 
 |             for i in range(len(debug_file) - 1): | 
 |                 offset, source_range_tag, source_range = debug_file[i] | 
 |                 offset2, source_range_tag2, source_range2 = debug_file[i + 1] | 
 |                 self.assertNotEqual(source_range, source_range2) | 
 |  | 
 |     def test_circular_dependency(self): | 
 |         """ | 
 |         https://github.com/pytorch/pytorch/issues/25871 | 
 |         """ | 
 |         class A(torch.jit.ScriptModule): | 
 |             def __init__(self): | 
 |                 super(A, self).__init__() | 
 |  | 
 |             @torch.jit.script_method | 
 |             def forward(self, x): | 
 |                 return x | 
 |  | 
 |         class B(torch.jit.ScriptModule): | 
 |             def __init__(self): | 
 |                 super(B, self).__init__() | 
 |                 self.foo = torch.nn.ModuleList([A()]) | 
 |  | 
 |             @torch.jit.script_method | 
 |             def forward(self, x): | 
 |                 for f in self.foo: | 
 |                     x = f(x) | 
 |                 return x | 
 |  | 
 |         class C(torch.jit.ScriptModule): | 
 |             def __init__(self): | 
 |                 super(C, self).__init__() | 
 |                 self.foo = torch.nn.Sequential(B()) | 
 |  | 
 |             @torch.jit.script_method | 
 |             def forward(self, x): | 
 |                 for f in self.foo: | 
 |                     x = f(x) | 
 |                 return x | 
 |         self.getExportImportCopy(C()) | 
 |  | 
 |     def test_serialize_long_lines(self): | 
 |         class OrderModuleLong(torch.nn.Module): | 
 |             def forward(self, long_arg_name: List[torch.Tensor]): | 
 |                 return [(long_arg_name[1],), (long_arg_name[0].argmax(),)] | 
 |         src = str(torch.jit.script(OrderModuleLong()).code) | 
 |         # make long_arg_name[1] does not get reordered after the argmax | 
 |         FileCheck().check("long_arg_name[1]").check("argmax").run(src) | 
 |  | 
 |     def test_tensor_shape(self): | 
 |         x = torch.empty(34, 56, 78) | 
 |  | 
 |         def f(x): | 
 |             return x.shape | 
 |  | 
 |         self.checkScript(f, (x,)) | 
 |  | 
 |  | 
 |     def test_block_input_grad_in_loop(self): | 
 |  | 
 |         x = torch.randn(3, 3, requires_grad=False) | 
 |         y = torch.randn(3, 3, requires_grad=True) | 
 |  | 
 |         def grad_in_loop(x, y): | 
 |             for i in range(100): | 
 |                 x = y @ x | 
 |             return x | 
 |  | 
 |         scripted = torch.jit.script(grad_in_loop) | 
 |         outer = scripted.graph_for(x, y) | 
 |         loop = outer.findNode("prim::Loop") | 
 |         loop_block = next(loop.blocks()) | 
 |         param_node = loop_block.paramNode() | 
 |         x_value = list(param_node.outputs())[1] | 
 |         self.assertTrue(x_value.requires_grad()) | 
 |  | 
 |     def test_tensor_grad(self): | 
 |         x = torch.randn(3, 4, requires_grad=True) | 
 |         y = torch.randn(3, 4, requires_grad=False) | 
 |  | 
 |         def f_requires_grad(x): | 
 |             return x.requires_grad | 
 |  | 
 |         self.checkScript(f_requires_grad, (x,)) | 
 |         self.checkScript(f_requires_grad, (y,)) | 
 |  | 
 |         def f_grad(x): | 
 |             return x.grad | 
 |  | 
 |         x.sum().backward() | 
 |         self.checkScript(f_grad, (x,)) | 
 |         self.checkScript(f_grad, (y,)) | 
 |  | 
 |     @unittest.skipIf(GRAPH_EXECUTOR != ProfilingMode.LEGACY, "shape analysis is only enabled in Legacy") | 
 |     def test_prim_grad_undefined(self): | 
 |  | 
 |         x = torch.ones(2) | 
 |  | 
 |         def f_grad(x): | 
 |             return x.grad | 
 |  | 
 |         scripted = self.checkScript(f_grad, (x,)) | 
 |         g = scripted.graph_for(x) | 
 |  | 
 |         prim_grad_node = g.findNode("prim::grad") | 
 |         self.assertTrue(next(prim_grad_node.outputs()).type().undefined() is None) | 
 |  | 
 |     def test_tensor_data(self): | 
 |         x = torch.randn(3, 4, requires_grad=True) | 
 |         y = torch.randn(4, 5) | 
 |  | 
 |         def f_data(x): | 
 |             return x.data | 
 |  | 
 |         scripted_f_data = torch.jit.script(f_data) | 
 |  | 
 |         scripted_x = scripted_f_data(x) | 
 |         self.assertEqual(scripted_x, f_data(x)) | 
 |         self.assertEqual(scripted_x.requires_grad, False) | 
 |  | 
 |         scripted_y = scripted_f_data(y) | 
 |         self.assertEqual(scripted_y, f_data(y)) | 
 |         self.assertEqual(scripted_x.requires_grad, False) | 
 |  | 
 |     def test_tensor_dtype(self): | 
 |         x_byte = torch.empty(34, 56, 78, dtype=torch.uint8) | 
 |         x_long = torch.empty(34, 56, 78, dtype=torch.long) | 
 |         x_float32 = torch.empty(34, 56, 78, dtype=torch.float32) | 
 |  | 
 |         @torch.jit.script | 
 |         def byte(x): | 
 |             return x.dtype == torch.uint8 | 
 |  | 
 |         @torch.jit.script | 
 |         def long(x): | 
 |             return x.dtype == torch.long | 
 |  | 
 |         @torch.jit.script | 
 |         def float32(x): | 
 |             return x.dtype == torch.float32 | 
 |  | 
 |         self.assertTrue(byte(x_byte)) | 
 |         self.assertFalse(byte(x_long)) | 
 |         self.assertFalse(byte(x_float32)) | 
 |         self.assertFalse(long(x_byte)) | 
 |         self.assertTrue(long(x_long)) | 
 |         self.assertFalse(long(x_float32)) | 
 |         self.assertFalse(float32(x_byte)) | 
 |         self.assertFalse(float32(x_long)) | 
 |         self.assertTrue(float32(x_float32)) | 
 |  | 
 |     @unittest.skipIf(not RUN_CUDA, "device tests require CUDA") | 
 |     def test_tensor_device(self): | 
 |         cpu = torch.empty(34, 56, 78, device='cpu') | 
 |         gpu = torch.empty(34, 56, 78, device='cuda') | 
 |  | 
 |         @torch.jit.script | 
 |         def same_device(x, y): | 
 |             return x.device == y.device | 
 |  | 
 |         self.assertTrue(same_device(cpu, cpu)) | 
 |         self.assertTrue(same_device(gpu, gpu)) | 
 |         self.assertFalse(same_device(cpu, gpu)) | 
 |  | 
 |     @unittest.skipIf(not RUN_CUDA, "device tests require CUDA") | 
 |     def test_tensor_to_device(self): | 
 |         def to_device(x): | 
 |             return x.to(device="cuda").to(device=torch.device("cpu")) | 
 |  | 
 |         self.checkScript(to_device, (torch.ones(3, 4),)) | 
 |  | 
 |     def test_tensor_to_cpu(self): | 
 |         def to_cpu(x): | 
 |             return x.cpu() | 
 |  | 
 |         x = torch.ones(3, 4) | 
 |         script_fn = torch.jit.script(to_cpu) | 
 |         self.assertEqual(to_cpu(x).device, script_fn(x).device) | 
 |         self.checkScript(to_cpu, (x,)) | 
 |  | 
 |     @unittest.skipIf(not RUN_CUDA, "device tests require CUDA") | 
 |     def test_tensor_to_cuda(self): | 
 |         def to_cuda(x): | 
 |             return x.cuda() | 
 |  | 
 |         x = torch.ones(3, 4) | 
 |         script_fn = torch.jit.script(to_cuda) | 
 |         self.assertEqual(to_cuda(x).device, script_fn(x).device) | 
 |         self.checkScript(to_cuda, (x,)) | 
 |  | 
 |     def test_generic_list_errors(self): | 
 |         with self.assertRaisesRegex(RuntimeError, "previously matched to type"): | 
 |             @torch.jit.script | 
 |             def foo(x): | 
 |                 return [[x]] + [[1]] | 
 |  | 
 |     def test_script_cu(self): | 
 |         cu = torch.jit.CompilationUnit(''' | 
 |             def foo(a): | 
 |                 b = a | 
 |                 return b | 
 |         ''') | 
 |         a = Variable(torch.rand(1)) | 
 |         self.assertEqual(a, cu.foo(a)) | 
 |  | 
 |     # because the compilation unit ingests python strings | 
 |     # to use an escape sequence escape the backslash (\\n = \n) | 
 |     def test_string_cu(self): | 
 |         cu = torch.jit.CompilationUnit(''' | 
 |             def foo(a): | 
 |                 print(a, """a\\n\tb\\n""", 2, "a\ | 
 | a") | 
 |                 return a | 
 |         ''') | 
 |         FileCheck().check("aa").check("a\\n\\tb\\n").run(str(cu.foo.graph)) | 
 |  | 
 |     def test_function_compilation_caching(self): | 
 |         def fun(): | 
 |             return 1 + 2 | 
 |  | 
 |         fun_compiled = torch.jit.script(fun) | 
 |         # python wrapper around the script function is a different pointer, | 
 |         # but the underlying script function graph is the same | 
 |         self.assertIs(fun_compiled.graph, torch.jit.script(fun).graph) | 
 |  | 
 |         def fun(): | 
 |             return 3 + 4 | 
 |  | 
 |         num_ref_counts = sys.getrefcount(fun) | 
 |  | 
 |         # caching doesn't get tripped up by same qualname | 
 |         fun_compiled_2 = torch.jit.script(fun) | 
 |         self.assertIsNot(fun_compiled, fun_compiled_2) | 
 |         self.assertEqual(fun_compiled_2(), 7) | 
 |  | 
 |         # caching doesnt increase refcounts to function (holds weak reference) | 
 |         self.assertTrue(sys.getrefcount(fun), num_ref_counts) | 
 |  | 
 |     def test_string_ops(self): | 
 |         def foo(): | 
 |             a = "a" + "b" | 
 |             return a + a, "ab" == "b", "ab" != "b", "ab" == "ab", "ab" != "ab" | 
 |  | 
 |         self.checkScript(foo, ()) | 
 |  | 
 |     def test_string_sorted(self): | 
 |         def foo(strs: List[str]): | 
 |             return sorted(strs) | 
 |  | 
 |         FileCheck() \ | 
 |             .check("graph") \ | 
 |             .check_next("str[] = aten::sorted") \ | 
 |             .check_next("return") \ | 
 |             .run(str(torch.jit.script(foo).graph)) | 
 |  | 
 |         inputs = ["str3", "str2", "str1"] | 
 |         self.checkScript(foo, (inputs,)) | 
 |  | 
 |     def test_string_sort(self): | 
 |         def foo(strs: List[str]): | 
 |             strs.sort() | 
 |             return strs | 
 |  | 
 |         inputs = ["str3", "str2", "str1"] | 
 |         self.checkScript(foo, (inputs,)) | 
 |  | 
 |     def test_tuple_sorted(self): | 
 |         def foo(tups: List[Tuple[int, int]]): | 
 |             return sorted(tups) | 
 |  | 
 |         inputs = [(1, 2), (0, 2), (1, 3)] | 
 |         self.checkScript(foo, (inputs,)) | 
 |  | 
 |     def test_tuple_sort(self): | 
 |         def foo(tups: List[Tuple[int, int]]): | 
 |             tups.sort() | 
 |             return tups | 
 |  | 
 |         inputs = [(1, 2), (0, 2), (1, 3)] | 
 |         self.checkScript(foo, (inputs,)) | 
 |  | 
 |     def test_tuple_sort_reverse(self): | 
 |         def foo(tups: List[Tuple[int, int]]): | 
 |             tups.sort(reverse=True) | 
 |             return tups | 
 |  | 
 |         inputs = [(1, 2), (0, 2), (1, 3)] | 
 |         self.checkScript(foo, (inputs,)) | 
 |  | 
 |     def test_tuple_unsortable_element_type(self): | 
 |         @torch.jit.script | 
 |         def foo(): | 
 |             tups = [({1: 2}, {2: 3})] | 
 |             tups.sort() | 
 |             return tups | 
 |  | 
 |         with self.assertRaisesRegexWithHighlight(RuntimeError, "are not sortable", "tups.sort"): | 
 |             foo() | 
 |  | 
 |     def test_tuple_unsortable_diff_type(self): | 
 |         @torch.jit.script | 
 |         def foo(inputs: List[Any]): | 
 |             inputs.sort() | 
 |             return inputs | 
 |  | 
 |         inputs = [(1, 2), ("foo", "bar")] | 
 |         with self.assertRaisesRegexWithHighlight(RuntimeError, "Only values of same type can be compared", "inputs.sort"): | 
 |             foo(inputs) | 
 |  | 
 |     def test_tuple_nested_sort(self): | 
 |         def foo(inputs: List[Tuple[int, Tuple[int, str]]]): | 
 |             inputs.sort() | 
 |             return inputs | 
 |  | 
 |         inputs = [(1, (2, "foo")), (1, (2, "bar")), (1, (0, "bar"))] | 
 |         self.checkScript(foo, (inputs,)) | 
 |  | 
 |     def test_tuple_unsortable_nested_diff_type(self): | 
 |         @torch.jit.script | 
 |         def foo(inputs: List[Any]): | 
 |             inputs.sort() | 
 |             return inputs | 
 |  | 
 |         inputs = [(1, (2, 3)), (2, ("foo", "bar"))] | 
 |         with self.assertRaisesRegexWithHighlight(RuntimeError, "Only values of same type can be compared", "inputs.sort"): | 
 |             foo(inputs) | 
 |  | 
 |     def test_string_new_line(self): | 
 |         with self.assertRaisesRegex(RuntimeError, "expected a valid token*"): | 
 |             torch.jit.CompilationUnit(''' | 
 |             def test_while(a): | 
 |                 print(" | 
 |                     a") | 
 |                 return a | 
 |             ''') | 
 |  | 
 |     def test_string_single_escape(self): | 
 |         with self.assertRaisesRegex(RuntimeError, "expected a valid token*"): | 
 |             torch.jit.CompilationUnit(''' | 
 |             def test_while(a): | 
 |                 print("\\") | 
 |                 return a | 
 |             ''') | 
 |  | 
 |     def test_script_annotation(self): | 
 |         @torch.jit.script | 
 |         def foo(a): | 
 |             return a + a + a | 
 |         s = Variable(torch.rand(2)) | 
 |         self.assertEqual(s + s + s, foo(s)) | 
 |  | 
 |     def test_torch_pow(self): | 
 |         def func(a, b): | 
 |             return pow(a, b) | 
 |  | 
 |         def func2(a, b, c, d): | 
 |             return pow(pow(c + a, b), d) | 
 |  | 
 |         def func3(a : int, b : float): | 
 |             # type: (int, float) -> float | 
 |             return pow(a, b) | 
 |  | 
 |         def func4(): | 
 |             # type: () -> float | 
 |             return pow(2, -2) | 
 |  | 
 |         def func5(x, y): | 
 |             return pow(x.item(), y.item()) | 
 |  | 
 |         def func6(a : int, b : int): | 
 |             # type: (int, int) -> float | 
 |             return pow(a, b) | 
 |  | 
 |         a = torch.rand(1) | 
 |         b = torch.rand(1) | 
 |         c = torch.rand(1) | 
 |         d = torch.rand(1) | 
 |         self.checkScript(func, (a, b)) | 
 |         self.checkScript(func2, (a, b, c, d)) | 
 |         self.checkScript(func3, (4, -0.5)) | 
 |         self.checkScript(func4, ()) | 
 |         self.checkScript(func6, (2, 4)) | 
 |  | 
 |         inputs = [torch.tensor(2), torch.tensor(-2), torch.tensor(.5), torch.tensor(.2)] | 
 |         for x in inputs: | 
 |             for y in inputs: | 
 |                 if x < 0: | 
 |                     continue | 
 |                 else: | 
 |                     self.checkScript(func5, (x, y)) | 
 |  | 
 |     @unittest.skipIf(not RUN_CUDA, "device tests require CUDA") | 
 |     def test_pow_scalar_backward_cuda(self): | 
 |         # see that scalar exponent works with cuda base (#19253) | 
 |         with enable_profiling_mode_for_profiling_tests(): | 
 |             for dtype in [torch.float, torch.double]: | 
 |                 @torch.jit.script | 
 |                 def func(a, b): | 
 |                     # type: (Tensor, float) -> Tensor | 
 |                     return (a * 2) ** b | 
 |  | 
 |                 a = torch.rand(1, requires_grad=True, device='cuda', dtype=dtype) | 
 |                 func(a, 1, profile_and_replay=True).backward() | 
 |  | 
 |                 @torch.jit.script | 
 |                 def func(a, b): | 
 |                     # type: (float, Tensor) -> Tensor | 
 |                     return a ** (b * 2 + 1) | 
 |  | 
 |                 a = torch.rand(1, requires_grad=True, device='cuda', dtype=dtype) | 
 |                 func(2, a, profile_and_replay=True).backward() | 
 |  | 
 |     def _check_code(self, code_str, fn_name, inputs): | 
 |         scope = {} | 
 |         exec(code_str, globals(), scope) | 
 |         cu = torch.jit.CompilationUnit(code_str) | 
 |         self.assertEqual(cu.func(*inputs), scope[fn_name](*inputs)) | 
 |  | 
 |     @unittest.skipIf(not RUN_CUDA, 'no CUDA') | 
 |     def test_scriptmodule_releases_tensors_cuda(self): | 
 |         with enable_profiling_mode_for_profiling_tests(): | 
 |             @torch.jit.script | 
 |             def fn(x, y): | 
 |                 return x.sigmoid() * y.tanh() | 
 |  | 
 |             def test(backward=False): | 
 |                 x = torch.randn(3, 3, dtype=torch.double, device='cuda', requires_grad=True) | 
 |                 y = torch.randn(3, 3, dtype=torch.double, device='cuda', requires_grad=True) | 
 |                 out = fn(x, y, profile_and_replay=True) | 
 |                 if backward: | 
 |                     out.sum().backward() | 
 |  | 
 |             with self.assertLeaksNoCudaTensors(): | 
 |                 test() | 
 |                 test() | 
 |                 test() | 
 |  | 
 |             if GRAPH_EXECUTOR != ProfilingMode.SIMPLE: | 
 |                 with self.assertLeaksNoCudaTensors(): | 
 |                     test(backward=True) | 
 |                     test(backward=True) | 
 |                     test(backward=True) | 
 |  | 
 |     def test_index(self): | 
 |         def consec(size, start=0): | 
 |             numel = torch.tensor(size).prod().item() | 
 |             return torch.arange(numel).view(size) | 
 |  | 
 |         def consec_list(size): | 
 |             return list(range(size)) | 
 |  | 
 |         def random_string(size): | 
 |             letters = string.ascii_lowercase | 
 |             return "".join(random.choice(letters) for i in range(size)) | 
 |  | 
 |         def check_indexing(indexing, tensor): | 
 |             template = dedent(""" | 
 |             def func(x): | 
 |                 return x{} | 
 |             """) | 
 |  | 
 |             self._check_code(template.format(indexing), "func", [tensor]) | 
 |  | 
 |         def check_dynamic_indexing(indexing, tensor, value1, value2): | 
 |             value1 = torch.tensor(value1) | 
 |             value2 = torch.tensor(value2) | 
 |  | 
 |             template = dedent(""" | 
 |             def func(x, value1, value2): | 
 |                 i = int(value1) | 
 |                 j = int(value2) | 
 |                 return x{} | 
 |             """) | 
 |  | 
 |             self._check_code(template.format(indexing), "func", [tensor, value1, value2]) | 
 |  | 
 |         # Torchscript assumes type Tensor by default, so we need this explicit | 
 |         # declaration. | 
 |         def check_indexing_list_int(indexing, list): | 
 |             template = dedent(""" | 
 |             def func(x): | 
 |                 # type: (List[int]) -> Any | 
 |                 return x{} | 
 |             """) | 
 |  | 
 |             self._check_code(template.format(indexing), "func", [list]) | 
 |  | 
 |         def check_indexing_str(indexing, str): | 
 |             template = dedent(""" | 
 |             def func(x): | 
 |                 # type: (str) -> Any | 
 |                 return x{} | 
 |             """) | 
 |  | 
 |             self._check_code(template.format(indexing), "func", [str]) | 
 |  | 
 |         # basic slices | 
 |         check_indexing('[0]', consec((3, 3))) | 
 |         check_indexing('[1]', consec((3, 3), 10)) | 
 |         check_indexing('[2]', consec((3, 3), 19)) | 
 |         check_indexing('[2]', consec((3,))) | 
 |         check_indexing('[-1]', consec((3, 3), 19)) | 
 |         check_indexing('[0:2]', consec((3, 3, 3))) | 
 |         check_indexing('[1:-1]', consec((3, 3, 3))) | 
 |         check_indexing('[-3:-1]', consec((6, 3))) | 
 |         check_indexing('[1:]', consec((3, 3))) | 
 |         check_indexing('[:1]', consec((3, 3))) | 
 |         check_indexing('[:]', consec((3, 2))) | 
 |  | 
 |         # multi-dim: indexes | 
 |         check_indexing('[0, 1]', consec((3, 3))) | 
 |         check_indexing('[0, 1]', consec((3, 3, 2))) | 
 |         check_indexing('[1, 0, 2]', consec((3, 3, 3))) | 
 |         check_indexing('[2, -1]', consec((3, 3))) | 
 |  | 
 |         # multi-dim: mixed slicing and indexing | 
 |         check_indexing('[0, 1:2]', consec((3, 3))) | 
 |         check_indexing('[0, :1]', consec((3, 3, 2))) | 
 |         check_indexing('[1, 2:]', consec((3, 3, 3))) | 
 |         check_indexing('[-1, 1:, 0]', consec((3, 3, 3, 3))) | 
 |         check_indexing('[1:, -1, 0]', consec((3, 3, 3, 3))) | 
 |         check_indexing('[-1, 2:, 1:2]', consec((3, 3, 3, 3))) | 
 |         check_indexing('[-1, 1:, 0]', consec((3, 3, 3, 3))) | 
 |         check_indexing('[-1, :, 0, 2]', consec((3, 3, 3, 3))) | 
 |  | 
 |         # zero-sized slices | 
 |         check_indexing('[0:0]', consec((2, 2))) | 
 |         check_indexing('[0:0, 1]', consec((3, 3))) | 
 |  | 
 |         # trivial expression usage | 
 |         check_indexing('[1+1]', consec((3, 3))) | 
 |         check_indexing('[1:(0 + 2)]', consec((3, 3, 3))) | 
 |  | 
 |         # None for new dimensions | 
 |         check_indexing('[None, 0]', consec((3, 3))) | 
 |         check_indexing('[1, None]', consec((3, 3), 10)) | 
 |         check_indexing('[None, None, 2]', consec((3, 3), 19)) | 
 |         check_indexing('[None, 2, None]', consec((3,))) | 
 |         check_indexing('[0:2, None]', consec((3, 3, 3))) | 
 |         check_indexing('[None, 1:-1]', consec((3, 3, 3))) | 
 |         check_indexing('[None, -3:-1, None]', consec((6, 3))) | 
 |         check_indexing('[-1, None, 2:, None, 1:2]', consec((3, 3, 3, 3))) | 
 |         check_indexing('[None, -1, None, 2:, None, 1:2, None]', consec((3, 3, 3, 3))) | 
 |  | 
 |         # dynamic expression usage | 
 |         check_dynamic_indexing("[i + j]", consec((3, 3)), 0, 1) | 
 |         check_dynamic_indexing("[i:j, i]", consec((3, 3, 2)), 0, 2) | 
 |  | 
 |         # positive striding | 
 |         check_indexing_list_int('[0]', consec_list(6)) | 
 |         check_indexing_list_int('[1]', consec_list(7)) | 
 |         check_indexing_list_int('[2]', consec_list(8)) | 
 |         check_indexing_list_int('[2]', consec_list(9)) | 
 |         check_indexing_list_int('[-1]', consec_list(10)) | 
 |         check_indexing_list_int('[0:2]', consec_list(11)) | 
 |         check_indexing_list_int('[1:-1]', consec_list(12)) | 
 |         check_indexing_list_int('[-3:-1]', consec_list(13)) | 
 |         check_indexing_list_int('[1:]', consec_list(15)) | 
 |         check_indexing_list_int('[:1]', consec_list(16)) | 
 |         check_indexing_list_int('[:]', consec_list(17)) | 
 |         check_indexing_list_int('[::]', consec_list(0)) | 
 |         check_indexing_list_int('[1000::]', consec_list(0)) | 
 |         check_indexing_list_int('[:1000:]', consec_list(0)) | 
 |  | 
 |         # negative striding | 
 |         check_indexing_list_int('[::-1]', consec_list(7)) | 
 |         check_indexing_list_int('[:3:-1]', consec_list(7)) | 
 |         check_indexing_list_int('[3::-1]', consec_list(7)) | 
 |         check_indexing_list_int('[1000::-1]', consec_list(7)) | 
 |         check_indexing_list_int('[3:0:-1]', consec_list(7)) | 
 |         check_indexing_list_int('[3:-1000:-1]', consec_list(7)) | 
 |         check_indexing_list_int('[0:0:-1]', consec_list(7)) | 
 |         check_indexing_list_int('[0:-1000:-1]', consec_list(7)) | 
 |  | 
 |         # only step is specified | 
 |         check_indexing_list_int('[::-1]', consec_list(0)) | 
 |         check_indexing_list_int('[::-1]', consec_list(7)) | 
 |         check_indexing_list_int('[::-2]', consec_list(7)) | 
 |         check_indexing_list_int('[::2]', consec_list(7)) | 
 |         check_indexing_list_int('[::42]', consec_list(7)) | 
 |         check_indexing_list_int('[::-42]', consec_list(7)) | 
 |         check_indexing_list_int('[::42]', consec_list(0)) | 
 |         check_indexing_list_int('[::-42]', consec_list(0)) | 
 |         check_indexing_list_int('[::9223372036854775807]', consec_list(42)) | 
 |         check_indexing_list_int('[::-9223372036854775807]', consec_list(42)) | 
 |         with self.assertRaisesRegex(RuntimeError, "out of bounds"): | 
 |             check_indexing_list_int('[::-9223372036854775808]', consec_list(42)) | 
 |         with self.assertRaisesRegex(RuntimeError, "should have non-zero step"): | 
 |             check_indexing_list_int('[::0]', consec_list(42)) | 
 |  | 
 |         # striding strings | 
 |         check_indexing_str('[0]', random_string(6)) | 
 |         check_indexing_str('[1]', random_string(7)) | 
 |         check_indexing_str('[2]', random_string(8)) | 
 |         check_indexing_str('[2]', random_string(9)) | 
 |         check_indexing_str('[-1]', random_string(10)) | 
 |         check_indexing_str('[0:2]', random_string(11)) | 
 |         check_indexing_str('[1:-1]', random_string(12)) | 
 |         check_indexing_str('[-3:-1]', random_string(13)) | 
 |         check_indexing_str('[1:]', random_string(15)) | 
 |         check_indexing_str('[:1]', random_string(16)) | 
 |         check_indexing_str('[:]', random_string(17)) | 
 |         check_indexing_str('[::]', random_string(0)) | 
 |         check_indexing_str('[1000::]', random_string(0)) | 
 |         check_indexing_str('[:1000:]', random_string(0)) | 
 |  | 
 |         check_indexing_str('[::-1]', random_string(7)) | 
 |         check_indexing_str('[:3:-1]', random_string(7)) | 
 |         check_indexing_str('[3::-1]', random_string(7)) | 
 |         check_indexing_str('[1000::-1]', random_string(7)) | 
 |         check_indexing_str('[3:0:-1]', random_string(7)) | 
 |         check_indexing_str('[3:-1000:-1]', random_string(7)) | 
 |         check_indexing_str('[0:0:-1]', random_string(7)) | 
 |         check_indexing_str('[0:-1000:-1]', random_string(7)) | 
 |  | 
 |         check_indexing_str('[::-1]', random_string(0)) | 
 |         check_indexing_str('[::-1]', random_string(7)) | 
 |         check_indexing_str('[::-2]', random_string(7)) | 
 |         check_indexing_str('[::2]', random_string(7)) | 
 |         check_indexing_str('[::42]', random_string(7)) | 
 |         check_indexing_str('[::-42]', random_string(7)) | 
 |         check_indexing_str('[::42]', random_string(0)) | 
 |         check_indexing_str('[::-42]', random_string(0)) | 
 |         check_indexing_str('[::9223372036854775807]', random_string(42)) | 
 |         check_indexing_str('[::-9223372036854775807]', random_string(42)) | 
 |         with self.assertRaisesRegex(RuntimeError, "out of bounds"): | 
 |             check_indexing_str('[::-9223372036854775808]', random_string(42)) | 
 |         with self.assertRaisesRegex(RuntimeError, "should have non-zero step"): | 
 |             check_indexing_str('[::0]', random_string(42)) | 
 |  | 
 |     def test_module_copy_with_attributes(self): | 
 |         class Vocabulary(torch.jit.ScriptModule): | 
 |             def __init__(self, vocab_list): | 
 |                 super(Vocabulary, self).__init__() | 
 |                 self._vocab = torch.jit.Attribute(vocab_list, List[str]) | 
 |                 self.some_idx = torch.jit.Attribute(2, int) | 
 |                 self.idx = torch.jit.Attribute( | 
 |                     {word: i for i, word in enumerate(vocab_list)}, Dict[str, int] | 
 |                 ) | 
 |  | 
 |             @torch.jit.script_method | 
 |             def lookup_indices_1d(self, values): | 
 |                 # type: (List[str]) -> List[int] | 
 |                 result = torch.jit.annotate(List[int], []) | 
 |                 # Direct list iteration not supported | 
 |                 for i in range(len(values)): | 
 |                     value = values[i] | 
 |                     result.append(self.idx.get(value, self.some_idx)) | 
 |                 return result | 
 |  | 
 |             @torch.jit.script_method | 
 |             def forward(self, values): | 
 |                 # type: (List[List[str]]) -> List[List[int]] | 
 |                 result = torch.jit.annotate(List[List[int]], []) | 
 |                 # Direct list iteration not supported | 
 |                 for i in range(len(values)): | 
 |                     result.append(self.lookup_indices_1d(values[i])) | 
 |                 return result | 
 |  | 
 |         v = Vocabulary(list('uabcdefg')) | 
 |         v.__copy__() | 
 |  | 
 |     def test_tuple_to_opt_list(self): | 
 |         @torch.jit.script | 
 |         def foo(x): | 
 |             # type: (Optional[List[int]]) -> int | 
 |             return 1 | 
 |  | 
 |         @torch.jit.script | 
 |         def tuple_call(): | 
 |             return foo((1, 2)) | 
 |  | 
 |     def test_keyword(self): | 
 |         @torch.jit.script | 
 |         def func(x): | 
 |             return torch.sum(x, dim=0) | 
 |  | 
 |         x = torch.rand(10, dtype=torch.float, requires_grad=True) | 
 |         y = func(x) | 
 |         y2 = torch.sum(x, dim=0) | 
 |         self.assertEqual(y, y2) | 
 |  | 
 |     def test_constant_pooling_none(self): | 
 |         @torch.jit.script | 
 |         def typed_nones(a=None, b=None, c=None): | 
 |             # type: (Optional[int], Optional[bool], Optional[Tensor]) -> Tuple[Optional[int], Optional[bool], Optional[Tensor]] | 
 |             return a, b, c | 
 |  | 
 |         @torch.jit.script | 
 |         def test(a): | 
 |             # type: (bool) -> None | 
 |             if a: | 
 |                 print(typed_nones()) | 
 |             else: | 
 |                 print(typed_nones()) | 
 |  | 
 |         graph_str = str(test.graph) | 
 |         self.assertTrue(graph_str.count("NoneType = prim::Constant") == 1) | 
 |  | 
 |     def test_constant_pooling_same_identity(self): | 
 |         def foo(): | 
 |             a = torch.tensor([4]) | 
 |             b = (a,) | 
 |             index = len(a) - 1 | 
 |             c = b[index] | 
 |             d = b[index] | 
 |             return c, d | 
 |  | 
 |         foo_script = torch.jit.script(foo) | 
 |         self.run_pass('constant_propagation', foo_script.graph) | 
 |         self.run_pass('constant_pooling', foo_script.graph) | 
 |         # even though the c & d escape scope, we are still able | 
 |         # pool them into one constant because they are the same object | 
 |         FileCheck().check_count("prim::Constant", 1, exactly=True).run(foo_script.graph) | 
 |         self.assertEqual(foo(), foo_script()) | 
 |  | 
 |     def test_constant_pooling_introduce_aliasing(self): | 
 |         @torch.jit.script | 
 |         def foo(): | 
 |             a = torch.tensor(1) | 
 |             b = torch.tensor(1) | 
 |             return a, b | 
 |  | 
 |         self.run_pass('constant_propagation', foo.graph) | 
 |         self.run_pass('constant_pooling', foo.graph) | 
 |         # dont pool constants bc it would introduce observable alias relationship changing | 
 |         a, b = foo() | 
 |         self.assertIsNot(a, b) | 
 |  | 
 |     def test_literal(self): | 
 |         def func1(a, b): | 
 |             c = a, b | 
 |             d, e = c | 
 |             return d + e | 
 |  | 
 |         def func2(a, b): | 
 |             c = a, (a, b) | 
 |             d, e = c | 
 |             f, g = e | 
 |             return d + f + g | 
 |  | 
 |         def func3(a, b): | 
 |             # type: (float, float) -> float | 
 |             c = 0., (0., 0.) | 
 |             x = True | 
 |             while x: | 
 |                 x = False | 
 |                 c = a, (a, b) | 
 |             d, e = c | 
 |             f, g = e | 
 |             return d + f + g | 
 |  | 
 |         a = torch.rand(1, requires_grad=True) | 
 |         b = torch.rand(1, requires_grad=True) | 
 |         self.checkScript(func1, (a, b), optimize=True) | 
 |         self.checkScript(func2, (a, b), optimize=True) | 
 |         self.checkScript(func3, (a.item(), b.item()), optimize=True) | 
 |  | 
 |     def test_expand(self): | 
 |         @torch.jit.script | 
 |         def func(x, y): | 
 |             return x + y | 
 |  | 
 |         x = torch.rand(2, 3, dtype=torch.float, requires_grad=True) | 
 |         y = torch.rand(3, dtype=torch.float, requires_grad=True) | 
 |         out = func(x, y) | 
 |         self.assertEqual(func(x, y), x + y) | 
 |  | 
 |         grad = torch.randn(2, 3, dtype=torch.float) | 
 |         out.backward(grad) | 
 |         self.assertEqual(x.grad, grad) | 
 |         self.assertEqual(y.grad, grad.sum(dim=0)) | 
 |  | 
 |     def test_sum(self): | 
 |         @torch.jit.script | 
 |         def func(x): | 
 |             return x.sum(dim=[4]) | 
 |  | 
 |         @torch.jit.script | 
 |         def func2(x): | 
 |             return x.sum(dim=4) | 
 |  | 
 |         # test that shape analysis is written correctly for sum with OptionalIntArrayRef[1] dim argument | 
 |         self.run_pass('constant_propagation', func.graph) | 
 |         self.run_pass('constant_propagation', func2.graph) | 
 |         g = _propagate_shapes(func.graph, (torch.zeros(1, 1, 1, 1, 4),), False) | 
 |         g2 = _propagate_shapes(func2.graph, (torch.zeros(1, 1, 1, 1, 4),), False) | 
 |  | 
 |     def test_cat(self): | 
 |         with enable_profiling_mode_for_profiling_tests(): | 
 |             @torch.jit.script | 
 |             def func(x): | 
 |                 return torch.cat((x, x), dim=0) | 
 |  | 
 |             x = torch.rand(10, dtype=torch.float, requires_grad=True) | 
 |             self.assertEqual(func(x, profile_and_replay=True), torch.cat((x, x), dim=0)) | 
 |  | 
 |             @torch.jit.script | 
 |             def func2(x, y): | 
 |                 return torch.cat((x, x), y) | 
 |  | 
 |             with disable_autodiff_subgraph_inlining(): | 
 |                 for sizes in ((2, 2), (0, 2)): | 
 |                     x = torch.rand(sizes).requires_grad_() | 
 |                     y = torch.tensor(1) | 
 |  | 
 |                     output = func2(x, y, profile_and_replay=True) | 
 |                     output_ref = torch.cat((x, x), y) | 
 |                     self.assertEqual(output, output_ref) | 
 |  | 
 |                     if GRAPH_EXECUTOR != ProfilingMode.SIMPLE: | 
 |                         self.assertAutodiffNode(func2.graph_for(x, y), True, ['aten::cat'], []) | 
 |  | 
 |                         grad = torch.autograd.grad(output.sum(), x) | 
 |                         grad_ref = torch.autograd.grad(output_ref.sum(), x) | 
 |                         self.assertEqual(grad, grad_ref) | 
 |  | 
 |     def test_cat_lifts(self): | 
 |         @torch.jit.script | 
 |         def foo(x): | 
 |             return torch.cat([x, x], dim=1) | 
 |  | 
 |         @torch.jit.script | 
 |         def foo2(x): | 
 |             return torch.cat([], dim=1) | 
 |  | 
 |         @torch.jit.script | 
 |         def foo3(x): | 
 |             return torch.cat([x], dim=1) | 
 |  | 
 |         for g in [foo.graph, foo2.graph, foo3.graph]: | 
 |             FileCheck().check("int =").check("ListConstruct").check("aten::cat").run(str(g)) | 
 |  | 
 |     def test_stack(self): | 
 |         with enable_profiling_mode_for_profiling_tests(): | 
 |             @torch.jit.script | 
 |             def func(x): | 
 |                 return torch.stack((x, x), dim=1) | 
 |             x = torch.rand(10, 10) | 
 |             self.assertEqual(func(x, profile_and_replay=True), torch.stack((x, x), dim=1)) | 
 |  | 
 |             @torch.jit.script | 
 |             def func2(x, y): | 
 |                 return torch.stack((x, y), dim=0) | 
 |  | 
 |             with disable_autodiff_subgraph_inlining(): | 
 |                 x = torch.randn([2, 2]).requires_grad_() | 
 |                 y = torch.randn([2, 2]).requires_grad_() | 
 |  | 
 |                 output = func2(x, y, profile_and_replay=True) | 
 |                 output_ref = torch.stack((x, y), 0) | 
 |                 self.assertEqual(output, output_ref) | 
 |                 if GRAPH_EXECUTOR != ProfilingMode.SIMPLE: | 
 |                     self.assertAutodiffNode(func2.graph_for(x, y), True, ['aten::stack'], []) | 
 |  | 
 |                     grads = torch.autograd.grad(output.sum(), (x, y)) | 
 |                     grads_ref = torch.autograd.grad(output_ref.sum(), (x, y)) | 
 |                     self.assertEqual(grads, grads_ref) | 
 |  | 
 |     @unittest.skipIf(GRAPH_EXECUTOR != ProfilingMode.LEGACY, | 
 |                      "Profiling executor will be using different heuristics for constructing differentiable graphs") | 
 |     def test_unbind(self): | 
 |         with enable_profiling_mode_for_profiling_tests(): | 
 |             @torch.jit.script | 
 |             def func(x, y): | 
 |                 # type: (Tensor, int) -> List[Tensor] | 
 |                 return torch.unbind(x, y) | 
 |  | 
 |             with disable_autodiff_subgraph_inlining(): | 
 |                 x = torch.rand([2, 2]).requires_grad_() | 
 |                 y = 0 | 
 |                 outputs = func(x, y, profile_and_replay=True) | 
 |                 outputs_ref = torch.unbind(x, dim=y) | 
 |                 self.assertEqual(outputs, outputs_ref) | 
 |                 self.assertAutodiffNode(func.graph_for(x, y), True, [], []) | 
 |  | 
 |                 grad = torch.autograd.grad(_sum_of_list(outputs), x) | 
 |                 grad_ref = torch.autograd.grad(_sum_of_list(outputs_ref), x) | 
 |                 self.assertEqual(grad, grad_ref) | 
 |  | 
 |  | 
 |     @unittest.skipIf(GRAPH_EXECUTOR == ProfilingMode.PROFILING, | 
 |                      "Profiling executor fails to recognize that tensors in a list require gradients") | 
 |     def test_meshgrid(self): | 
 |         with enable_profiling_mode_for_profiling_tests(): | 
 |             @torch.jit.script | 
 |             def func(a): | 
 |                 # type: (List[Tensor]) -> List[Tensor] | 
 |                 return torch.meshgrid(a) | 
 |             with disable_autodiff_subgraph_inlining(): | 
 |                 a = torch.tensor([1.0, 2, 3]).requires_grad_() | 
 |                 b = torch.tensor([1.0, 2, 3, 4]).requires_grad_() | 
 |                 inputs = [a, b] | 
 |  | 
 |                 outputs_ref = torch.meshgrid(inputs) | 
 |                 outputs = func(inputs, profile_and_replay=True) | 
 |                 self.assertEqual(outputs, outputs_ref) | 
 |  | 
 |                 if GRAPH_EXECUTOR != ProfilingMode.SIMPLE: | 
 |                     self.assertAutodiffNode(func.graph_for(inputs), True, [], []) | 
 |  | 
 |                     grads = torch.autograd.grad(_sum_of_list(outputs), inputs) | 
 |                     grads_ref = torch.autograd.grad(_sum_of_list(outputs_ref), inputs) | 
 |                     self.assertEqual(grads, grads_ref) | 
 |  | 
 |     def test_tensor_len(self): | 
 |         def func(x): | 
 |             return len(x) | 
 |  | 
 |         self.checkScript(func, [torch.ones(4, 5, 6)]) | 
 |  | 
 |     def test_func_call(self): | 
 |         def add(a, b): | 
 |             return a + b | 
 |  | 
 |         def mul(a, x): | 
 |             return a * x | 
 |  | 
 |         def func(alpha, beta, x, y): | 
 |             return add(mul(alpha, x), mul(beta, y)) | 
 |  | 
 |         alpha = torch.rand(1, dtype=torch.float, requires_grad=True) | 
 |         beta = torch.rand(1, dtype=torch.float, requires_grad=True) | 
 |         x = torch.rand(3, dtype=torch.float, requires_grad=True) | 
 |         y = torch.rand(3, dtype=torch.float, requires_grad=True) | 
 |  | 
 |         # NOTE: cannot optimize yet because broadcasts are not inserted before the fuser runs | 
 |         self.checkScript(func, [alpha, beta, x, y], optimize=False) | 
 |  | 
 |     @unittest.skip("bailouts are being deprecated") | 
 |     def test_profiling_graph_executor(self): | 
 |         @torch.jit.script | 
 |         def def_in_one_branch(x, z): | 
 |             # type: (Tensor, bool) -> float | 
 |             y = x | 
 |             if z is False: | 
 |                 y = x + 1 | 
 |  | 
 |             return y.sum() | 
 |  | 
 |         a = torch.rand(2, 3) | 
 |  | 
 |         with enable_profiling_mode_for_profiling_tests(): | 
 |             # check prim::profile are inserted | 
 |             profiled_graph_str = str(def_in_one_branch.graph_for(a, True)) | 
 |             FileCheck().check_count("prim::profile", 4).run(profiled_graph_str) | 
 |             # this call is optimized for | 
 |             # the given shape of (2, 3) | 
 |             def_in_one_branch(a, False) | 
 |             # change shape to (3) | 
 |             # so we go down a bailout path | 
 |             a = torch.ones(3) | 
 |             # check prim::BailOuts are inserted | 
 |             bailout_graph_str = str(def_in_one_branch.graph_for(a, True)) | 
 |             FileCheck().check_count("prim::BailOut", 3).run(bailout_graph_str) | 
 |             # this triggers all 3 bailouts | 
 |             self.assertEqual(def_in_one_branch(a, False), 6.0) | 
 |             # this triggers 2 bailouts | 
 |             self.assertEqual(def_in_one_branch(a, True), 3.0) | 
 |  | 
 |     @unittest.skip("bailouts are being deprecated") | 
 |     def test_maxpool_guard_elimination(self): | 
 |         @torch.jit.script | 
 |         def my_maxpool(x): | 
 |             return F.max_pool1d(x, kernel_size=[1]) + torch.ones([32, 32, 32]) | 
 |  | 
 |         a = torch.rand(32, 32, 32) | 
 |  | 
 |         with enable_profiling_mode_for_profiling_tests(): | 
 |             my_maxpool(a) | 
 |             bailout_graph_str = str(my_maxpool.graph_for(a)) | 
 |             FileCheck().check_count("prim::BailOut", 1).run(bailout_graph_str) | 
 |  | 
 |     @unittest.skip("bailouts are being deprecated") | 
 |     def test_slice_guard_elimination(self): | 
 |         @torch.jit.script | 
 |         def my_slice(x): | 
 |             return x[0:16:2] + x[0:16:2] | 
 |  | 
 |         a = torch.rand(32, 4) | 
 |  | 
 |         with enable_profiling_mode_for_profiling_tests(): | 
 |             my_slice(a) | 
 |             bailout_graph_str = str(my_slice.graph_for(a)) | 
 |             FileCheck().check_count("prim::BailOut", 1).run(bailout_graph_str) | 
 |  | 
 |     @unittest.skip("bailouts are being deprecated") | 
 |     def test_unsqueeze_guard_elimination(self): | 
 |         @torch.jit.script | 
 |         def my_unsqueeze(x): | 
 |             return torch.unsqueeze(x, 0) + torch.unsqueeze(x, 0) | 
 |  | 
 |         a = torch.rand(32, 4) | 
 |  | 
 |         with enable_profiling_mode_for_profiling_tests(): | 
 |             my_unsqueeze(a) | 
 |             bailout_graph_str = str(my_unsqueeze.graph_for(a)) | 
 |             FileCheck().check_count("prim::BailOut", 2).run(bailout_graph_str) | 
 |  | 
 |     def test_resize_input_ops(self): | 
 |         # resize_ and resize_as resize the input tensor. because our shape analysis | 
 |         # is flow invariant, we set any Tensor that can alias a resized Tensor | 
 |         # to the base Tensor Type, without size information. | 
 |  | 
 |         # testing that value which is an input of a graph gets handled | 
 |         def out_op_graph_input(): | 
 |             @torch.jit.script | 
 |             def test(x, y, z): | 
 |                 torch.mul(x, y, out=z) | 
 |                 return z | 
 |  | 
 |             graph = _propagate_shapes(test.graph, | 
 |                                       (torch.zeros(2, 1), torch.zeros(1, 2), torch.zeros(1, 1, 1)), False) | 
 |             self.assertTrue(next(graph.outputs()).type() == TensorType.get()) | 
 |         out_op_graph_input() | 
 |  | 
 |         def test_resize(): | 
 |             @torch.jit.script | 
 |             def test(x): | 
 |                 after_resize_alias = torch.zeros([2]) | 
 |                 for _i in range(5): | 
 |                     b = x + 1 | 
 |                     f = [1] | 
 |                     before_resize_alias = b.sub_(1) | 
 |                     # for i in range(10): | 
 |                     f.append(1) | 
 |                     b.resize_(f) | 
 |                     after_resize_alias = b.add_(1) | 
 |                 return after_resize_alias | 
 |  | 
 |             self.run_pass('constant_propagation', test.graph) | 
 |             g = _propagate_shapes(test.graph, (torch.zeros(1, 1),), False) | 
 |             resize_node = g.findNode("aten::resize_") | 
 |             # first input and output of b.resize_ is b | 
 |             self.assertTrue(next(resize_node.inputs()).type() == TensorType.get()) | 
 |             self.assertTrue(next(resize_node.outputs()).type() == TensorType.get()) | 
 |  | 
 |             # correctly propagates to b alias set | 
 |             before_resize = g.findNode("aten::sub_") | 
 |             self.assertTrue(next(before_resize.outputs()).type() == TensorType.get()) | 
 |  | 
 |             after_resize = g.findNode("aten::add_") | 
 |             self.assertTrue(next(after_resize.outputs()).type() == TensorType.get()) | 
 |  | 
 |         test_resize() | 
 |  | 
 |         def test_resize_as(): | 
 |             @torch.jit.script | 
 |             def test(x): | 
 |                 b = torch.zeros([2, 2]) | 
 |                 b.resize_as_(x) | 
 |                 return b | 
 |  | 
 |             g = test.graph | 
 |             self.run_pass('constant_propagation', g) | 
 |             g = _propagate_shapes(test.graph, (torch.zeros(1, 1),), False) | 
 |  | 
 |             # x doesn't alias a resized op so it shouldn't be set to base Tensor type | 
 |             self.assertTrue(next(g.inputs()).type() != TensorType.get()) | 
 |             # return is resized | 
 |             self.assertTrue(next(g.outputs()).type() == TensorType.get()) | 
 |  | 
 |         test_resize_as() | 
 |  | 
 |     def test_uninitialized(self): | 
 |         graph_str = """graph(): | 
 |           %1 : int = prim::Uninitialized() | 
 |           %2 : int = prim::Constant[value=1]() | 
 |           %3 : int = aten::add(%1, %2) | 
 |           return (%3) | 
 |         """ | 
 |         g = parse_ir(graph_str) | 
 |         m = self.createFunctionFromGraph(g) | 
 |         self.getExportImportCopy(m) | 
 |         with self.assertRaisesRegex(RuntimeError, "isInt"): | 
 |             m() | 
 |  | 
 |  | 
 |     @unittest.skipIf(GRAPH_EXECUTOR == ProfilingMode.SIMPLE, "Simple Executor doesn't use requires_grad information") | 
 |     @unittest.skipIf(GRAPH_EXECUTOR == ProfilingMode.PROFILING, "Peeling is now disabled") | 
 |     def test_requires_grad_loop(self): | 
 |         @torch.jit.script | 
 |         def test(x, y, z): | 
 |             # type: (Tensor, Tensor, int) -> Tensor | 
 |             for _ in range(z): | 
 |                 x = y | 
 |             return x | 
 |  | 
 |         # x requires grad, y does not | 
 |         # testing that requires grad analysis correctly exits, with its input | 
 |         # to the loop (x) requiring grad and its output to the loop not requiring grad | 
 |         # and the output of the node conservatively setting grad to true | 
 |  | 
 |         inps = (torch.tensor(1.0, requires_grad=True), torch.tensor(1), 10) | 
 |         test(*inps, profile_and_replay=True) | 
 |  | 
 |         graph = test.graph_for(*inps) | 
 |         loop = graph.findNode("prim::Loop") | 
 |         loop_body = next(loop.blocks()) | 
 |         loop_inputs = list(loop_body.inputs()) | 
 |         loop_outputs = list(loop_body.outputs()) | 
 |  | 
 |         if GRAPH_EXECUTOR == ProfilingMode.PROFILING: | 
 |             # TODO: simplify this test as it's very sensitive | 
 |             # the optimized graph will have 3 loops | 
 |             # the original loop is peeled | 
 |             # peeled loop also gets unrolled | 
 |             index_of_x_in_peeled_unrolled_loop = -2 | 
 |             self.assertTrue(loop_inputs[index_of_x_in_peeled_unrolled_loop].requires_grad()) | 
 |             bailouts_in_outer_block = graph.findAllNodes("prim::BailOut", False) | 
 |             last_bailout_index_on_loops_output = -1 | 
 |             self.assertFalse(bailouts_in_outer_block[last_bailout_index_on_loops_output].output().requires_grad()) | 
 |         else: | 
 |             self.assertTrue(loop_inputs[1].requires_grad()) | 
 |             self.assertTrue(loop.output().requires_grad()) | 
 |             self.assertFalse(loop_outputs[1].requires_grad()) | 
 |  | 
 |     def test_view_shape_prop(self): | 
 |         cu = torch.jit.CompilationUnit(''' | 
 |         def test_view_shape_prop(a): | 
 |             return a.view(size=[-1]) | 
 |         ''') | 
 |         inputs = [torch.zeros(10, 10)] | 
 |         outputs = torch.zeros(100) | 
 |  | 
 |         real_outs = cu.test_view_shape_prop(*inputs) | 
 |         self.assertEqual(real_outs, outputs) | 
 |  | 
 |     def test_view_listconstruct_shape_prop(self): | 
 |         def fn(x): | 
 |             B = x.size(0) | 
 |             C = x.size(1) | 
 |             T = x.size(2) | 
 |             return x.view(T, B, C) | 
 |  | 
 |         x = torch.randn(3, 1, 5, requires_grad=True) | 
 |         fn = torch.jit.script(fn) | 
 |         graph = _propagate_shapes(fn.graph, (x,), False) | 
 |         self.assertTrue(next(graph.outputs()).type().scalarType() == 'Double') | 
 |  | 
 |     def test_shape_prop_promotion(self): | 
 |         @torch.jit.script | 
 |         def fn(x, y): | 
 |             return x + y | 
 |  | 
 |         x, y = torch.rand(3, 4, dtype=torch.float), torch.rand(3, 4, dtype=torch.double) | 
 |         graph = _propagate_shapes(fn.graph, (x, y), False) | 
 |         FileCheck().check('Double(*, *, device=cpu) = aten::add').run(graph) | 
 |  | 
 |     def test_shape_prop_promote_scalar_arg(self): | 
 |         @torch.jit.script | 
 |         def fn(x): | 
 |             return math.pi + x | 
 |  | 
 |         x = torch.zeros(3, 4, dtype=torch.long) | 
 |         graph = _propagate_shapes(fn.graph, (x,), False) | 
 |         default = torch.get_default_dtype() | 
 |         if(default == torch.float): | 
 |             FileCheck().check('Float(*, *, requires_grad=0, device=cpu) = aten::add').run(graph) | 
 |         else: | 
 |             FileCheck().check('Double(*, *, requires_grad=0, device=cpu) = aten::add').run(graph) | 
 |  | 
 |     def test_integral_shape_inference(self): | 
 |         cu = torch.jit.CompilationUnit(''' | 
 |         def test_integral_shape_inference(a): | 
 |             return a * a | 
 |         ''') | 
 |         inputs = [torch.ones(10, 10, dtype=torch.long)] | 
 |         outputs = torch.ones(10, 10, dtype=torch.long) | 
 |  | 
 |         self.assertEqual(cu.test_integral_shape_inference(*inputs), outputs) | 
 |  | 
 |     @unittest.skipIf(RUN_CUDA, 'This tests the CPU fuser') | 
 |     @unittest.skipIf(IS_SANDCASTLE, "NYI: fuser support for Sandcastle") | 
 |     @enable_cpu_fuser | 
 |     def test_batchnorm_fuser_cpu(self): | 
 |         code = ''' | 
 |             graph(%3 : Tensor, | 
 |                   %7 : Tensor, | 
 |                   %12 : Float(*, *), | 
 |                   %13 : Tensor, | 
 |                   %25 : Tensor): | 
 |                 %23 : int = prim::Constant[value=1]() | 
 |                 %22 : float = prim::Constant[value=1e-05]() | 
 |                 %26 : Tensor = aten::sqrt(%25) | 
 |                 %24 : Tensor = aten::add(%26, %22, %23) | 
 |                 %20 : Tensor = aten::reciprocal(%24) | 
 |                 %norm_invstd : Tensor = aten::mul(%20, %23) | 
 |                 %15 : Tensor = aten::sub(%12, %13, %23) | 
 |                 %11 : Tensor = aten::mul(%15, %norm_invstd) | 
 |                 %8 : Tensor = aten::mul(%11, %7) | 
 |                 %5 : Tensor = aten::add(%8, %3, %23) | 
 |                 %1 : Float(*, *) = aten::relu(%5) | 
 |                 return (%1) | 
 |         ''' | 
 |  | 
 |         graph = parse_ir(code) | 
 |         inputs = 5 * [torch.rand(26, 2048, dtype=torch.float)] | 
 |         code = torch._C._jit_fuser_get_fused_kernel_code(graph, inputs) | 
 |         FileCheck().check('sqrtf').run(code) | 
 |  | 
 |     @slowTest | 
 |     @unittest.skipIf(RUN_CUDA, 'This tests the CPU fuser') | 
 |     @unittest.skipIf(IS_SANDCASTLE, "NYI: fuser support for Sandcastle") | 
 |     @enable_cpu_fuser | 
 |     def test_fuser_double_float_codegen(self): | 
 |         fns = ['log', 'log10', 'log1p', 'log2', 'lgamma', 'exp', 'expm1', 'erf', | 
 |                'erfc', 'cos', 'acos', 'cosh', 'sin', 'asin', 'sinh', 'tan', | 
 |                'atan', 'tanh', 'sqrt', 'ceil', 'floor', 'round', 'trunc', | 
 |                'frac'] | 
 |  | 
 |         def lookup_c_equivalent_fn(aten_fn): | 
 |             return aten_fn | 
 |  | 
 |         def test_dispatch(op, expects, dtype, binary=False): | 
 |             if dtype == torch.double: | 
 |                 dtype_str = 'Double' | 
 |             elif dtype == torch.float: | 
 |                 dtype_str = 'Float' | 
 |             else: | 
 |                 raise RuntimeError('Unknown dtype') | 
 |  | 
 |             if binary: | 
 |                 code = ''' | 
 |                     graph(%3 : Tensor, %4 : Tensor): | 
 |                         %2 : {dtype}(*, *) = aten::{op}(%3, %4) | 
 |                         %1 : {dtype}(*, *) = aten::relu(%2) | 
 |                         return (%1) | 
 |                 '''.format(op=op, dtype=dtype_str) | 
 |             else: | 
 |                 code = ''' | 
 |                     graph(%3 : Tensor): | 
 |                         %2 : {dtype}(*, *) = aten::{op}(%3) | 
 |                         %1 : {dtype}(*, *) = aten::relu(%2) | 
 |                         return (%1) | 
 |                 '''.format(op=op, dtype=dtype_str) | 
 |  | 
 |             graph = parse_ir(code) | 
 |             inputs = (2 if binary else 1) * [torch.rand(26, 2048, dtype=dtype)] | 
 |             code = torch._C._jit_fuser_get_fused_kernel_code(graph, inputs) | 
 |             FileCheck().check(expects).run(code) | 
 |  | 
 |         for fn in fns: | 
 |             test_dispatch(fn, lookup_c_equivalent_fn(fn) + '(', torch.double) | 
 |             test_dispatch(fn, lookup_c_equivalent_fn(fn) + 'f(', torch.float) | 
 |  | 
 |         # 'min', 'max' were previously tested but are now replaced with ternary expressions | 
 |         # instead of fmin() and fmax() | 
 |         binary_fns = ['pow'] | 
 |         for fn in binary_fns: | 
 |             test_dispatch(fn, lookup_c_equivalent_fn(fn) + '(', torch.double, binary=True) | 
 |             test_dispatch(fn, lookup_c_equivalent_fn(fn) + 'f(', torch.float, binary=True) | 
 |  | 
 |     @unittest.skipIf(RUN_CUDA, 'This tests the CPU fuser') | 
 |     @unittest.skipIf(IS_SANDCASTLE, "NYI: fuser support for Sandcastle") | 
 |     @enable_cpu_fuser | 
 |     def test_fuser_double_literal_precision(self): | 
 |         code = ''' | 
 |         graph(%2 : Float(*, *)): | 
 |             %4 : int = prim::Constant[value=1]() | 
 |             %3 : float = prim::Constant[value=1.282549830161864]() | 
 |             %5 : Float(*, *) = aten::add(%2, %3, %4) | 
 |             %1 : Float(*, *) = aten::relu(%5) | 
 |             return (%1) | 
 |         ''' | 
 |  | 
 |         graph = parse_ir(code) | 
 |         code = torch._C._jit_fuser_get_fused_kernel_code(graph, [torch.rand(3, 4)]) | 
 |         FileCheck().check('1.282549830161864').run(code) | 
 |  | 
 |     def test_fuser_multiple_blocks(self): | 
 |         cu = torch.jit.CompilationUnit(''' | 
 |         def test_fuser_multiple_blocks(this, that, theother, meme): | 
 |             i = 0 | 
 |             while i < 20: | 
 |                 this = torch.cat([this, meme], dim=0) | 
 |                 that = torch.cat([that, meme], dim=0) | 
 |                 theother = torch.cat([theother, meme], dim=0) | 
 |                 i = i + 1 | 
 |             return this, that, theother | 
 |         ''') | 
 |  | 
 |         inputs = [torch.ones(0, 10, 10)] * 3 | 
 |         inputs += [torch.ones(1, 10, 10)] | 
 |         outputs = [torch.ones(20, 10, 10)] * 3 | 
 |  | 
 |         self.assertEqual(cu.test_fuser_multiple_blocks(*inputs), outputs) | 
 |  | 
 |     def test_dropout_script(self): | 
 |  | 
 |         eg = torch.zeros(1, 2, 3, requires_grad=True) | 
 |  | 
 |         @_trace(eg) | 
 |         def foo(x): | 
 |             x = torch.neg(x) | 
 |             return F.dropout(x) | 
 |  | 
 |         class MyDrop(nn.Module): | 
 |             def forward(self, x): | 
 |                 return foo(x) | 
 |  | 
 |         f = io.BytesIO() | 
 |         with warnings.catch_warnings(record=True): | 
 |             torch.onnx.export(MyDrop(), (eg,), f, verbose=False) | 
 |  | 
 |     @unittest.skip("RuntimeError: VariableType::ID() not implemented") | 
 |     def test_cast(self): | 
 |         script = ''' | 
 |         def to_int(x): | 
 |             return int(x) | 
 |         ''' | 
 |         x = Variable(torch.FloatTensor([1.1, 2.3]), requires_grad=True) | 
 |         out = Variable(torch.IntTensor([1, 2]), requires_grad=True) | 
 |         self.checkScript(script, [x], optimize=True, outputs=[out], func='to_int') | 
 |  | 
 |     def test_str_cast(self): | 
 |         @torch.jit.script | 
 |         def to_str(x): | 
 |             # type: (int) -> str | 
 |             return str((x, x)) | 
 |  | 
 |         self.assertEqual("(1, 1)", to_str(1)) | 
 |  | 
 |     def test_int_cast(self): | 
 |         @torch.jit.script | 
 |         def to_int(x): | 
 |             # type: (str) -> int | 
 |             return int(x) | 
 |  | 
 |         self.assertEqual(5, to_int('5')) | 
 |         self.assertEqual(-5, to_int('-5')) | 
 |         self.assertEqual(2147483647, to_int('2147483647')) | 
 |         self.assertEqual(-2147483648, to_int('-2147483648')) | 
 |  | 
 |         with self.assertRaisesRegex(RuntimeError, "invalid literal for int()"): | 
 |             to_int('0x20') | 
 |  | 
 |         with self.assertRaisesRegex(RuntimeError, "invalid literal for int()"): | 
 |             to_int('0b0001') | 
 |  | 
 |     def test_python_frontend(self): | 
 |         def fn(x, y, z): | 
 |             q = None | 
 |             q = x + y - z.sigmoid() | 
 |             print(q) | 
 |             w = -z | 
 |             if not x and not y and z: | 
 |                 m = x if not z else y | 
 |             while x < y > z: | 
 |                 q = x | 
 |             assert 1 == 1, "hello" | 
 |             return x | 
 |  | 
 |         ast = torch.jit.frontend.get_jit_def(fn, fn.__name__) | 
 |         self.assertExpected(str(ast)) | 
 |  | 
 |     def test_python_frontend_source_range(self): | 
 |         def fn(): | 
 |             raise Exception("hello") | 
 |         ast = torch.jit.frontend.get_jit_def(fn, fn.__name__) | 
 |         FileCheck().check("SourceRange at:") \ | 
 |                    .check("def fn():") \ | 
 |                    .check("~~~~~~~~~") \ | 
 |                    .check('raise Exception("hello")') \ | 
 |                    .check('~~~~~~~~~~~~~~~~~ <--- HERE') \ | 
 |                    .run(str(ast.range())) | 
 |  | 
 |     def test_python_frontend_py3(self): | 
 |         def fn(): | 
 |             raise Exception("hello") | 
 |         ast = torch.jit.frontend.get_jit_def(fn, fn.__name__) | 
 |         self.assertExpected(str(ast)) | 
 |  | 
 |     def _make_scalar_vars(self, arr, dtype): | 
 |         return [torch.tensor(val, dtype=dtype) for val in arr] | 
 |  | 
 |  | 
 |     def test_string_print(self): | 
 |         def func(a): | 
 |             print(a, "a" 'b' '''c''' """d""", 2, 1.5) | 
 |             return a | 
 |  | 
 |         inputs = self._make_scalar_vars([1], torch.int64) | 
 |         self.checkScript(func, inputs, capture_output=True) | 
 |  | 
 |     def test_while(self): | 
 |         def func(a, b, max): | 
 |             while bool(a < max): | 
 |                 a = a + 1 | 
 |                 b = b + 1 | 
 |             c = a + b | 
 |             return c | 
 |  | 
 |         inputs = self._make_scalar_vars([1, 1, 10], torch.int64) | 
 |         self.checkScript(func, inputs, optimize=True) | 
 |  | 
 |     def test_fibb(self): | 
 |         def func(lim): | 
 |             first = 1 | 
 |             second = 1 | 
 |             i = 1 | 
 |             somenum = 5 | 
 |             dontmutateme = 3 | 
 |             third = 0 | 
 |             while bool(i < lim): | 
 |                 third = first + second | 
 |                 first = second | 
 |                 second = third | 
 |                 j = 0 | 
 |                 while j < 10: | 
 |                     somenum = somenum * 2 | 
 |                     j = j + 1 | 
 |                 i = i + j | 
 |                 i = i + dontmutateme | 
 |  | 
 |             st = second + third | 
 |             fs = first + second | 
 |             return third, st, fs | 
 |  | 
 |         inputs = self._make_scalar_vars([10], torch.int64) | 
 |         self.checkScript(func, inputs, optimize=True) | 
 |  | 
 |     def test_fibb_totally_better(self): | 
 |         def fib(x): | 
 |             # type: (int) -> int | 
 |             prev = 1 | 
 |             v = 1 | 
 |             for i in range(0, x): | 
 |                 save = v | 
 |                 v = v + prev | 
 |                 prev = save | 
 |             return v | 
 |  | 
 |         self.checkScript(fib, (10,)) | 
 |  | 
 |     def test_if(self): | 
 |         def func(a, b): | 
 |             # type: (int, int) -> int | 
 |             d = 3 | 
 |             if bool(a > 10): | 
 |                 a = 3 + d | 
 |             else: | 
 |                 b = 3 + d | 
 |                 d = 4 | 
 |             c = a + b | 
 |             return c | 
 |  | 
 |         inputs = self._make_scalar_vars([1, -1], torch.int64) | 
 |         self.checkScript(func, inputs, optimize=True) | 
 |  | 
 |     def test_if_for_in_range(self): | 
 |         def func(a, b): | 
 |             # type: (int, int) -> int | 
 |             d = 3 | 
 |             for _ in range(20): | 
 |                 if bool(a > 10): | 
 |                     a = 3 + d | 
 |                 else: | 
 |                     b = 3 + d | 
 |                     d = 4 | 
 |                 c = a + b | 
 |             return d | 
 |         inputs = self._make_scalar_vars([1, -1], torch.int64) | 
 |         self.checkScript(func, inputs, optimize=True) | 
 |  | 
 |     def test_if_noelse(self): | 
 |         def func(a, b): | 
 |             if bool(a > 10): | 
 |                 a = 3 + b | 
 |             c = a + b | 
 |             return c | 
 |  | 
 |         inputs = self._make_scalar_vars([-1, 1], torch.int64) | 
 |         self.checkScript(func, inputs, optimize=True) | 
 |  | 
 |     def test_if_is_none_dispatch(self): | 
 |  | 
 |         @torch.jit.script | 
 |         def test_lhs_none_rhs_none(): | 
 |             # LHS, RHS both alwaysNone, dispatch always_none_branch | 
 |             # only emit one prim::Constant | 
 |             if None is None: | 
 |                 return 1 | 
 |             elif None is not None: | 
 |                 return 2 | 
 |             else: | 
 |                 return 3 | 
 |  | 
 |         self.assertTrue(str(test_lhs_none_rhs_none.graph).count(': int = prim::Constant') == 1) | 
 |  | 
 |         @torch.jit.script | 
 |         def test_lhs_opt_rhs_none(lhs=None): | 
 |             # type: (Optional[Tensor]) -> int | 
 |             # LHS maybeNone: emit normal if stmt that contains 3 constants | 
 |             if lhs is not None: | 
 |                 return 2 | 
 |             elif lhs is None: | 
 |                 return 1 | 
 |             else: | 
 |                 return 3 | 
 |  | 
 |         self.assertTrue(str(test_lhs_opt_rhs_none.graph).count(': int = prim::Constant') == 3) | 
 |  | 
 |         @torch.jit.script | 
 |         def test_lhs_none_rhs_opt(rhs=None): | 
 |             # type: (Optional[Tensor]) -> int | 
 |             # RHS maybeNone, emit normal if stmt that contains 3 constants | 
 |             if None is rhs: | 
 |                 return 1 | 
 |             elif None is not rhs: | 
 |                 return 2 | 
 |             else: | 
 |                 return 3 | 
 |  | 
 |         self.assertTrue(str(test_lhs_opt_rhs_none.graph).count(': int = prim::Constant') == 3) | 
 |  | 
 |         @torch.jit.script | 
 |         def test_lhs_never_rhs_none(lhs): | 
 |             # LHS neverNone, RHS alwaysNone dispatch never_none_branch | 
 |             # only emit one prim::Constant | 
 |             if lhs is None: | 
 |                 return 1 | 
 |             elif lhs is not None: | 
 |                 return 2 | 
 |             else: | 
 |                 return 3 | 
 |  | 
 |         self.assertTrue(str(test_lhs_never_rhs_none.graph).count(': int = prim::Constant') == 1) | 
 |  | 
 |         @torch.jit.script | 
 |         def test_lhs_none_rhs_never(rhs): | 
 |             # LHS alwaysNone, RHS neverNone dispatch never_none_branch | 
 |             # only emit one prim::Constant | 
 |             if None is rhs: | 
 |                 return 1 | 
 |             elif None is not rhs: | 
 |                 return 2 | 
 |             else: | 
 |                 return 3 | 
 |  | 
 |         self.assertTrue(str(test_lhs_none_rhs_never.graph).count(': int = prim::Constant') == 1) | 
 |  | 
 |         @torch.jit.script | 
 |         def test_bool_arith_and(lhs): | 
 |             if lhs is None and lhs is not None: | 
 |                 return 1 | 
 |             else: | 
 |                 return 2 | 
 |         self.assertEqual(test_bool_arith_and(torch.zeros(3)), 2) | 
 |         self.assertTrue(str(test_bool_arith_and.graph).count('if') == 0) | 
 |  | 
 |         @torch.jit.script | 
 |         def test_bool_arith_or(lhs): | 
 |             if lhs is None or lhs is not None: | 
 |                 return 1 | 
 |             else: | 
 |                 return 2 | 
 |         self.assertEqual(test_bool_arith_or(torch.zeros(3)), 1) | 
 |         self.assertTrue(str(test_bool_arith_or.graph).count('if') == 0) | 
 |  | 
 |  | 
 |         @torch.jit.script | 
 |         def test_bool_arith_not(lhs): | 
 |             if not (lhs is None): | 
 |                 return 1 | 
 |             else: | 
 |                 return 2 | 
 |         self.assertEqual(test_bool_arith_not(torch.zeros(3)), 1) | 
 |         self.assertTrue(str(test_bool_arith_not.graph).count('if') == 0) | 
 |  | 
 |     def test_conditional_casting(self): | 
 |         def test_bool_cast_tensor(x): | 
 |             if x: | 
 |                 return 1 | 
 |             else: | 
 |                 return 0 | 
 |  | 
 |         for make_one_dim in [True, False]: | 
 |             for inp_val in [0.1, 0.0, -0.0, -0.1, -1, 0, 1]: | 
 |                 inp_val = [inp_val] if make_one_dim else inp_val | 
 |                 self.checkScript(test_bool_cast_tensor, (torch.tensor(inp_val),)) | 
 |  | 
 |         self.checkScriptRaisesRegex(test_bool_cast_tensor, (torch.tensor([1, 1]),), Exception, | 
 |                                     "Boolean value of Tensor with more than one value") | 
 |  | 
 |         def test_not_cast(x): | 
 |             if not x: | 
 |                 return 1 | 
 |             else: | 
 |                 return 0 | 
 |  | 
 |         self.checkScript(test_not_cast, (torch.tensor(1),)) | 
 |         self.checkScript(test_not_cast, (torch.tensor(0),)) | 
 |  | 
 |         with self.assertRaisesRegex(RuntimeError, r"Could not cast value of type Tuple\[Tensor, Tensor\]"):  # noqa: W605 | 
 |             @torch.jit.script | 
 |             def test_mult(x, y): | 
 |                 return not(x, y) | 
 |  | 
 |         def test_cast_int(x): | 
 |             # type: (int) -> int | 
 |             if x: | 
 |                 return 1 | 
 |             else: | 
 |                 return 0 | 
 |         self.checkScript(test_cast_int, (1,)) | 
 |         self.checkScript(test_cast_int, (0,)) | 
 |         self.checkScript(test_cast_int, (-1,)) | 
 |  | 
 |         def test_cast_float(x): | 
 |             # type: (float) -> int | 
 |             if x: | 
 |                 return 1 | 
 |             else: | 
 |                 return 0 | 
 |         self.checkScript(test_cast_float, (1.,)) | 
 |         self.checkScript(test_cast_float, (0.,)) | 
 |         self.checkScript(test_cast_float, (-1.,)) | 
 |  | 
 |         with self.assertRaisesRegex(RuntimeError, r"Could not cast value of type Tuple\[int, int\] to bool"):  # noqa: W605 | 
 |  | 
 |             @torch.jit.script | 
 |             def test_bad_conditional(x): | 
 |                 if (1, 2):  # noqa: F634 | 
 |                     return | 
 |                 else: | 
 |                     return 0 | 
 |  | 
 |     def test_while_nonexistent_value(self): | 
 |         with self.assertRaisesRegex(RuntimeError, "undefined value x"): | 
 |             torch.jit.CompilationUnit(''' | 
 |             def test_while(a, b): | 
 |                 while bool(a < 10): | 
 |                     a = a + x | 
 |                     b = b + 1 | 
 |                 return a + b | 
 |             ''') | 
 |  | 
 |     def test_while_nonexistent_cond_value(self): | 
 |         with self.assertRaisesRegex(RuntimeError, "undefined value x"): | 
 |             torch.jit.CompilationUnit(''' | 
 |             def test_while(a, b): | 
 |                 while a < x: | 
 |                     a = a + 1 | 
 |                     b = b + 1 | 
 |                 return a + b | 
 |             ''') | 
 |  | 
 |         @torch.jit.script | 
 |         def test_ternary(x): | 
 |             # type: (Optional[int]) -> int | 
 |             x = x if x is not None else 2 | 
 |             return x | 
 |  | 
 |         @torch.jit.script | 
 |         def test_not_none(x): | 
 |             # type: (Optional[int]) -> None | 
 |             if x is not None: | 
 |                 print(x + 1) | 
 |  | 
 |         @torch.jit.script | 
 |         def test_and(x, y): | 
 |             # type: (Optional[int], Optional[int]) -> None | 
 |             if x is not None and y is not None: | 
 |                 print(x + y) | 
 |  | 
 |         @torch.jit.script | 
 |         def test_not(x, y): | 
 |             # type: (Optional[int], Optional[int]) -> None | 
 |             if not (x is not None and y is not None): | 
 |                 pass | 
 |             else: | 
 |                 print(x + y) | 
 |  | 
 |         @torch.jit.script | 
 |         def test_bool_expression(x): | 
 |             # type: (Optional[int]) -> None | 
 |             if x is not None and x < 2: | 
 |                 print(x + 1) | 
 |  | 
 |         @torch.jit.script | 
 |         def test_nested_bool_expression(x, y): | 
 |             # type: (Optional[int], Optional[int]) -> int | 
 |             if x is not None and x < 2 and y is not None: | 
 |                 x = x + y | 
 |             else: | 
 |                 x = 5 | 
 |             return x + 2 | 
 |  | 
 |         @torch.jit.script | 
 |         def test_or(x, y): | 
 |             # type: (Optional[int], Optional[int]) -> None | 
 |             if y is None or x is None: | 
 |                 pass | 
 |             else: | 
 |                 print(x + y) | 
 |  | 
 |         # backwards compatibility | 
 |         @torch.jit.script | 
 |         def test_manual_unwrap_opt(x): | 
 |             # type: (Optional[int]) -> int | 
 |             if x is None: | 
 |                 x = 1 | 
 |             else: | 
 |                 x = torch.jit._unwrap_optional(x) | 
 |             return x  # noqa: T484 | 
 |  | 
 |         with self.assertRaisesRegex(RuntimeError, "Arguments for call are not valid"): | 
 |             @torch.jit.script | 
 |             def or_error(x, y): | 
 |                 # type: (Optional[int], Optional[int]) -> None | 
 |                 if x is None or y is None: | 
 |                     print(x + y)  # noqa: T484 | 
 |  | 
 |         with self.assertRaisesRegex(RuntimeError, "Arguments for call are not valid"): | 
 |             @torch.jit.script | 
 |             def and_error(x, y): | 
 |                 # type: (Optional[int], Optional[int]) -> None | 
 |                 if x is None and y is None: | 
 |                     pass | 
 |                 else: | 
 |                     print(x + y)  # noqa: T484 | 
 |  | 
 |         with self.assertRaisesRegex(RuntimeError, "Arguments for call are not valid"): | 
 |             @torch.jit.script | 
 |             def named_var(x): | 
 |                 # type: (Optional[int]) -> None | 
 |                 x_none = x is not None | 
 |                 if x_none: | 
 |                     print(x + 1)  # noqa: T484 | 
 |  | 
 |         with self.assertRaisesRegex(RuntimeError, "Arguments for call are not valid"): | 
 |             @torch.jit.script | 
 |             def named_var_and(x, y): | 
 |                 # type: (Optional[int], Optional[int]) -> None | 
 |                 x_none = x is not None | 
 |                 if y is not None and x_none: | 
 |                     print(x + y)  # noqa: T484 | 
 |  | 
 |     def test_assertion_optional_refinement(self): | 
 |         @torch.jit.script | 
 |         def test(x, y): | 
 |             # type: (Optional[int], Optional[int]) -> int | 
 |             assert x is not None and y is not None | 
 |             return x + y | 
 |  | 
 |         self.assertEqual(test(2, 2), 4) | 
 |         with self.assertRaisesRegex(Exception, ""): | 
 |             test(1, None) | 
 |  | 
 |     @unittest.skipIf(GRAPH_EXECUTOR != ProfilingMode.LEGACY, "the current version of Profiler doesn't profile/specialize Optionals") | 
 |     def test_optional_tensor(self): | 
 |         @torch.jit.script | 
 |         def fn(x, y): | 
 |             # type: (Optional[Tensor], int) -> int | 
 |             if x is None: | 
 |                 return y | 
 |             else: | 
 |                 return 0 | 
 |  | 
 |         res = fn(None, 1) | 
 |         self.assertEqual(res, 1) | 
 |         g = torch.jit.last_executed_optimized_graph() | 
 |         first_input = next(g.inputs()) | 
 |         # check if input is disconnected | 
 |         self.assertEqual(first_input.type().kind(), 'OptionalType') | 
 |         self.assertEqual(first_input.uses(), []) | 
 |         t = torch.ones(1) | 
 |         res = fn(t, 1) | 
 |         self.assertEqual(res, 0) | 
 |         g = torch.jit.last_executed_optimized_graph() | 
 |         self.assertEqual(next(g.inputs()).type().kind(), 'TensorType') | 
 |  | 
 |         @torch.jit.script | 
 |         def fn(x, y, b): | 
 |             # type: (Optional[Tensor], Tensor, bool) -> Tensor | 
 |             if b: | 
 |                 res = y | 
 |             else: | 
 |                 res = torch.jit._unwrap_optional(x) | 
 |             return res | 
 |  | 
 |         t2 = torch.zeros(1) | 
 |         res = fn(t, t2, True) | 
 |         self.assertEqual(res, t2) | 
 |         with self.assertRaisesRegex(RuntimeError, "Unwrapping null optional"): | 
 |             res = fn(None, t2, False) | 
 |         res = fn(None, t2, True) | 
 |         g = torch.jit.last_executed_optimized_graph() | 
 |         self.assertIn(next(g.outputs()).type().str(), ("Tensor", "Tensor(requires_grad=1)")) | 
 |  | 
 |     @unittest.skipIf(GRAPH_EXECUTOR != ProfilingMode.LEGACY, "the current version of Profiler doesn't profile/specialize Optionals") | 
 |     def test_optional_list(self): | 
 |         @torch.jit.script | 
 |         def fn(x, y): | 
 |             # type: (Optional[List[int]], int) -> int | 
 |             if x is None: | 
 |                 return y | 
 |             else: | 
 |                 res = 0 | 
 |                 for d in x: | 
 |                     res += d | 
 |                 return res | 
 |  | 
 |         res = fn(None, 1) | 
 |         self.assertEqual(res, 1) | 
 |         g = torch.jit.last_executed_optimized_graph() | 
 |         first_input = next(g.inputs()) | 
 |         # check if input is disconnected | 
 |         self.assertEqual(first_input.type().kind(), 'OptionalType') | 
 |         self.assertEqual(first_input.uses(), []) | 
 |         l = [2, 3] | 
 |         res = fn(l, 1) | 
 |         self.assertEqual(res, 5) | 
 |         g = torch.jit.last_executed_optimized_graph() | 
 |         self.assertEqual(next(g.inputs()).type().kind(), 'ListType') | 
 |  | 
 |         @torch.jit.script | 
 |         def fn(x, y, b): | 
 |             # type: (Optional[List[int]], List[int], bool) -> List[int] | 
 |             if b: | 
 |                 l = torch.jit._unwrap_optional(x) | 
 |             else: | 
 |                 l = y | 
 |             return l | 
 |  | 
 |         l2 = [0, 1] | 
 |         res = fn(l, l2, True) | 
 |         self.assertEqual(res, l) | 
 |         with self.assertRaisesRegex(RuntimeError, "Unwrapping null optional"): | 
 |             res = fn(None, l2, True) | 
 |         res = fn(None, l2, False) | 
 |         g = torch.jit.last_executed_optimized_graph() | 
 |         self.assertEqual(next(g.outputs()).type().str(), "int[]") | 
 |  | 
 |     def test_alias_covariant_type_containers(self): | 
 |         @torch.jit.script | 
 |         def foo(x): | 
 |             # type: (bool) | 
 |             if x: | 
 |                 a = (None,) | 
 |             else: | 
 |                 a = ([],) | 
 |             return a | 
 |  | 
 |         @torch.jit.script | 
 |         def foo2(x, li): | 
 |             # type: (bool, Tuple[Optional[List[Tensor]]]) | 
 |             if x: | 
 |                 li = (None,) | 
 |             return li | 
 |  | 
 |     def test_while_write_outer_then_read(self): | 
 |         def func(a, b): | 
 |             while bool(a < 10): | 
 |                 a = a + 1 | 
 |                 b = a + 1 | 
 |             return a + b | 
 |  | 
 |         inputs = self._make_scalar_vars([42, 1337], torch.int64) | 
 |         self.checkScript(func, inputs, optimize=True) | 
 |  | 
 |     @skipIfTorchDynamo("TorchDynamo fails with unknown reason") | 
 |     def test_while_nest_if(self): | 
 |         def func(a, b): | 
 |             # type: (int, int) -> int | 
 |             c = 0 | 
 |             while a < 10: | 
 |                 a = a + 1 | 
 |                 b = b + 1 | 
 |                 if a > b: | 
 |                     c = -a | 
 |                 else: | 
 |                     c = -b | 
 |             return c + 1 | 
 |  | 
 |         inputs = self._make_scalar_vars([-1234, 4321], torch.int64) | 
 |         self.checkScript(func, inputs, optimize=True) | 
 |  | 
 |     def test_divmod(self): | 
 |         def func_int(a, b): | 
 |             # type: (int, int) -> Tuple[int, int] | 
 |             return divmod(a, b) | 
 |  | 
 |         def func_float(a, b): | 
 |             # type: (float, float) -> Tuple[float, float] | 
 |             return divmod(a, b) | 
 |  | 
 |         def func_int_float(a, b): | 
 |             # type: (int, float) -> Tuple[float, float] | 
 |             return divmod(a, b) | 
 |  | 
 |         def func_float_int(a, b): | 
 |             # type: (float, int) -> Tuple[float, float] | 
 |             return divmod(a, b) | 
 |  | 
 |         def divmod_test_iterator(func, num, den): | 
 |             for i in num: | 
 |                 for j in den: | 
 |                     self.checkScript(func, (i, j), frames_up=2) | 
 |  | 
 |         num_int = [1024, -1024] | 
 |         den_int = [10, -10] | 
 |         num_float = [5.3, -5.3] | 
 |         den_float = [2.0, -2.0] | 
 |         divmod_test_iterator(func_int, num_int, den_int) | 
 |         divmod_test_iterator(func_float, num_float, den_float) | 
 |         divmod_test_iterator(func_int_float, num_int, den_float) | 
 |         divmod_test_iterator(func_float_int, num_float, den_int) | 
 |  | 
 |         with self.assertRaisesRegex(RuntimeError, "ZeroDivisionError: integer division or modulo by zero"): | 
 |             cu = torch.jit.CompilationUnit(dedent(inspect.getsource(func_int))) | 
 |             cu.func_int(1024, 0) | 
 |         with self.assertRaisesRegex(RuntimeError, "ZeroDivisionError: float divmod()"): | 
 |             cu = torch.jit.CompilationUnit(dedent(inspect.getsource(func_float))) | 
 |             cu.func_float(5.3, 0.0) | 
 |         with self.assertRaisesRegex(RuntimeError, "ZeroDivisionError: float divmod()"): | 
 |             cu = torch.jit.CompilationUnit(dedent(inspect.getsource(func_int_float))) | 
 |             cu.func_int_float(1024, 0.0) | 
 |         with self.assertRaisesRegex(RuntimeError, "ZeroDivisionError: float divmod()"): | 
 |             cu = torch.jit.CompilationUnit(dedent(inspect.getsource(func_float_int))) | 
 |             cu.func_float_int(5.3, 0) | 
 |  | 
 |     def test_math_ops(self): | 
 |         def checkMathWrap(func_name, num_args=1, is_float=True, **args): | 
 |             if is_float: | 
 |                 checkMath(func_name, num_args, True, **args) | 
 |                 checkMath(func_name, num_args, False, **args) | 
 |             else: | 
 |                 checkMath(func_name, num_args, is_float, **args) | 
 |  | 
 |         inf = float("inf") | 
 |         NaN = float("nan") | 
 |         mx_int = 2**31 - 1 | 
 |         mn_int = -2**31 | 
 |         float_vals = ([inf, NaN, 0.0, 1.0, 2.2, -1.0, -0.0, -2.2, -inf, 1, 0, 2] + | 
 |                       [10.0 ** i for i in range(5)] + [-(10.0 ** i) for i in range(5)]) | 
 |         int_vals = list(range(-5, 5, 1)) + [mx_int + 5, mx_int * 2, mn_int - 5, mn_int * 2] | 
 |  | 
 |         def checkMath(func_name, num_args, is_float=True, ret_type="float", debug=False, vals=None, args_type=None): | 
 |             funcs_template = dedent(''' | 
 |             def func(a, b): | 
 |                 # type: {args_type} -> {ret_type} | 
 |                 return math.{func}({args}) | 
 |             ''') | 
 |             if num_args == 1: | 
 |                 args = "a" | 
 |             elif num_args == 2: | 
 |                 args = "a, b" | 
 |             else: | 
 |                 raise RuntimeError("Test doesn't support more than 2 arguments") | 
 |             if args_type is None: | 
 |                 args_type = "(float, float)" if is_float else "(int, int)" | 
 |             funcs_str = funcs_template.format(func=func_name, args=args, args_type=args_type, ret_type=ret_type) | 
 |             scope = {} | 
 |             execWrapper(funcs_str, globals(), scope) | 
 |             cu = torch.jit.CompilationUnit(funcs_str) | 
 |             f_script = cu.func | 
 |             f = scope['func'] | 
 |  | 
 |             if vals is None: | 
 |                 vals = float_vals if is_float else int_vals | 
 |                 vals = [(i, j) for i in vals for j in vals] | 
 |  | 
 |             for a, b in vals: | 
 |                 res_python = None | 
 |                 res_script = None | 
 |                 try: | 
 |                     res_python = f(a, b) | 
 |                 except Exception as e: | 
 |                     res_python = e | 
 |                 try: | 
 |                     res_script = f_script(a, b) | 
 |                 except Exception as e: | 
 |                     res_script = e | 
 |                 if debug: | 
 |                     print("in: ", a, b) | 
 |                     print("out: ", res_python, res_script) | 
 |                 # We can't use assertEqual because of a couple of differences: | 
 |                 # 1. nan == nan should return true | 
 |                 # 2. When python functions throw an exception, we usually want to silently ignore them. | 
 |                 # (ie: We want to return `nan` for math.sqrt(-5)) | 
 |                 if res_python != res_script: | 
 |                     if isinstance(res_python, Exception): | 
 |                         continue | 
 |  | 
 |                     if type(res_python) == type(res_script): | 
 |                         if isinstance(res_python, tuple) and (math.isnan(res_python[0]) == math.isnan(res_script[0])): | 
 |                             continue | 
 |                         if isinstance(res_python, float) and math.isnan(res_python) and math.isnan(res_script): | 
 |                             continue | 
 |                     msg = ("Failed on {func_name} with inputs {a} {b}. Python: {res_python}, Script: {res_script}" | 
 |                            .format(func_name=func_name, a=a, b=b, res_python=res_python, res_script=res_script)) | 
 |                     self.assertEqual(res_python, res_script, msg=msg, atol=(1e-4) * max(abs(res_python), res_script), rtol=0) | 
 |  | 
 |         unary_float_ops = ["log", "log1p", "log10", "exp", "sqrt", "gamma", "lgamma", "erf", | 
 |                            "erfc", "expm1", "fabs", "acos", "asin", "atan", "cos", "sin", "tan", | 
 |                            "asinh", "atanh", "acosh", "sinh", "cosh", "tanh", "degrees", "radians"] | 
 |         binary_float_ops = ["atan2", "fmod", "copysign"] | 
 |         for op in unary_float_ops: | 
 |             checkMathWrap(op, 1) | 
 |         for op in binary_float_ops: | 
 |             checkMathWrap(op, 2) | 
 |  | 
 |         checkMath("modf", 1, ret_type="Tuple[float, float]") | 
 |         checkMath("frexp", 1, ret_type="Tuple[float, int]") | 
 |         checkMath("isnan", 1, ret_type="bool") | 
 |         checkMath("isinf", 1, ret_type="bool") | 
 |         checkMath("ldexp", 2, is_float=False, ret_type="float", args_type="(float, int)", | 
 |                   vals=[(i, j) for i in float_vals for j in range(-10, 10)]) | 
 |         checkMath("pow", 2, is_float=False, ret_type="float") | 
 |         checkMath("pow", 2, is_float=True, ret_type="float") | 
 |         checkMathWrap("floor", ret_type="int") | 
 |         checkMathWrap("ceil", ret_type="int") | 
 |         checkMathWrap("gcd", 2, is_float=False, ret_type="int") | 
 |         checkMath("isfinite", 1, ret_type="bool") | 
 |         checkMathWrap("remainder", 2) | 
 |         checkMathWrap("factorial", 1, is_float=False, ret_type="int", vals=[(i, 0) for i in range(-2, 10)]) | 
 |  | 
 |     @skipIfTorchDynamo("TorchDynamo fails with unknown reason") | 
 |     def test_if_nest_while(self): | 
 |         def func(a, b): | 
 |             # type: (int, int) -> int | 
 |             c = 0 | 
 |             if a > b: | 
 |                 while a > b: | 
 |                     b = b + 1 | 
 |                     c = -b | 
 |             return c | 
 |  | 
 |         inputs = self._make_scalar_vars([4321, 1234], torch.int64) | 
 |         self.checkScript(func, inputs) | 
 |  | 
 |     def test_script_optional_none(self): | 
 |         def none_stmt(x): | 
 |             output = None | 
 |             output = x | 
 |             return output | 
 |  | 
 |         def none_args(x): | 
 |             # type: (Optional[Tensor]) -> Optional[Tensor] | 
 |             return None | 
 |  | 
 |         self.checkScript(none_stmt, [torch.arange(0, 2)], optimize=True) | 
 |         self.checkScript(none_args, [None], optimize=True) | 
 |  | 
 |         # test undefined tensor None as default param | 
 |         def test_script_optional_tensor_none(x=None): | 
 |             # type: (Optional[Tensor]) -> Tensor | 
 |             res = torch.zeros(1, dtype=torch.int8) | 
 |             if x is None: | 
 |                 res = res + 1 | 
 |             else: | 
 |                 res = x | 
 |             return res | 
 |  | 
 |         fn = test_script_optional_tensor_none | 
 |         scripted_fn = torch.jit.script(fn) | 
 |         self.assertEqual(fn(), scripted_fn()) | 
 |         self.assertEqual(fn(torch.zeros(1)), scripted_fn(torch.zeros(1))) | 
 |  | 
 |         # test typical None as default param | 
 |         def test_script_optional_other_none(x=None): | 
 |             # type: (Optional[float]) -> float | 
 |             res = 2.0 | 
 |             if x is None: | 
 |                 res = res + 1.0 | 
 |             else: | 
 |                 res = x | 
 |             return res | 
 |  | 
 |         fn = test_script_optional_other_none | 
 |         scripted_fn = torch.jit.script(fn) | 
 |         self.assertEqual(fn(), scripted_fn()) | 
 |         self.assertEqual(fn(1.0), scripted_fn(1.0)) | 
 |  | 
 |     def test_script_clamp_none(self): | 
 |         def test_script_clamp_max_none(x): | 
 |             return torch.clamp(x, min=2, max=None) | 
 |  | 
 |         def test_script_clamp_max(x): | 
 |             return torch.clamp(x, max=2) | 
 |  | 
 |         def test_script_clamp_min_none(x): | 
 |             return torch.clamp(x, min=None, max=2) | 
 |  | 
 |         def test_script_clamp_min(x): | 
 |             return torch.clamp(x, min=2) | 
 |  | 
 |         input = [torch.arange(0, 3)] | 
 |         self.checkScript(test_script_clamp_max_none, input, optimize=True) | 
 |         self.checkScript(test_script_clamp_max, input, optimize=True) | 
 |         self.checkScript(test_script_clamp_min_none, input, optimize=True) | 
 |         self.checkScript(test_script_clamp_min, input, optimize=True) | 
 |  | 
 |     def test_script_bool_constant(self): | 
 |         def test_script_bool_constant(): | 
 |             a = True | 
 |             return a | 
 |         self.checkScript(test_script_bool_constant, []) | 
 |  | 
 |     def test_ternary(self): | 
 |         def func(a, b): | 
 |             c = 3 | 
 |             c = a + b if bool(a > 3) else b | 
 |             return c | 
 |  | 
 |         inputs_true = self._make_scalar_vars([5, 2], torch.int64) | 
 |         inputs_false = self._make_scalar_vars([1, 0], torch.int64) | 
 |         self.checkScript(func, inputs_true, optimize=True) | 
 |         self.checkScript(func, inputs_false, optimize=True) | 
 |  | 
 |     def test_ternary_module_type_hint(self): | 
 |         class M1(torch.nn.Module): | 
 |             def forward(self) -> Any: | 
 |                 return 'out' if self.training else {} | 
 |  | 
 |         class M2(torch.nn.Module): | 
 |             def forward(self) -> Any: | 
 |                 out: Any = 'out' if self.training else {} | 
 |                 return out | 
 |  | 
 |         class M3(torch.nn.Module): | 
 |             def forward(self) -> Optional[int]: | 
 |                 return None if self.training else 1 | 
 |  | 
 |         for module in [M1, M2, M3]: | 
 |             self.checkModule(module().train(), ()) | 
 |             self.checkModule(module().eval(), ()) | 
 |  | 
 |     def test_ternary_static_if(self): | 
 |         # Test for True branch when condition variable | 
 |         # is annotated as Final | 
 |         class M1(torch.nn.Module): | 
 |             flag: torch.jit.Final[bool] | 
 |  | 
 |             def __init__(self): | 
 |                 super().__init__() | 
 |                 self.flag = True | 
 |  | 
 |             def forward(self) -> torch.Tensor: | 
 |                 return torch.ones(3) if self.flag else {} | 
 |  | 
 |         # Test for True branch when condition variable | 
 |         # is annotated as Final | 
 |         class M2(torch.nn.Module): | 
 |             flag: torch.jit.Final[bool] | 
 |  | 
 |             def __init__(self): | 
 |                 super().__init__() | 
 |                 self.flag = False | 
 |  | 
 |             def forward(self) -> torch.Tensor: | 
 |                 return {} if self.flag else torch.ones(3) | 
 |  | 
 |         model1 = M1() | 
 |         model2 = M2() | 
 |         script_model_1 = torch.jit.script(model1) | 
 |         script_model_2 = torch.jit.script(model2) | 
 |         self.assertEqual(model1.forward(), script_model_1.forward()) | 
 |         self.assertEqual(model2.forward(), script_model_2.forward()) | 
 |  | 
 |     def test_ternary_right_associative(self): | 
 |         def plus_123(x: int): | 
 |             return x + 1 if x == 1 else x + 2 if x == 2 else x + 3 | 
 |         self.checkScript(plus_123, (1,)) | 
 |         self.checkScript(plus_123, (2,)) | 
 |         self.checkScript(plus_123, (3,)) | 
 |  | 
 |     def test_print(self): | 
 |         def func(x, y): | 
 |             q = (x + y).sigmoid() | 
 |             print(q, 1, 2, [1, 2], [1.0, 2.0]) | 
 |             w = -q | 
 |             return w * w | 
 |  | 
 |         x = torch.arange(4., requires_grad=True) | 
 |         y = torch.arange(0., 8, 2, requires_grad=True) | 
 |         self.checkScript(func, [x, y], optimize=True, capture_output=True) | 
 |  | 
 |     def test_format(self): | 
 |         def func(x): | 
 |             print("{}, I'm a {}".format("Hello", "test")) | 
 |             print("format blank".format()) | 
 |             print("stuff before {}".format("hi")) | 
 |             print("{} stuff after".format("hi")) | 
 |             return x + 1 | 
 |  | 
 |         x = torch.arange(4., requires_grad=True) | 
 |         self.checkScript(func, [x], optimize=True, capture_output=True) | 
 |  | 
 |     def test_logical_short_circuit(self): | 
 |         @torch.jit.script | 
 |         def testNoThrows(t): | 
 |             c1 = 1 | 
 |             if (False and bool(t[1])) or (True or bool(t[1])): | 
 |                 c1 = 0 | 
 |             return c1 | 
 |  | 
 |         FileCheck().check_not("prim::If").run(testNoThrows.graph) | 
 |         self.assertEqual(0, testNoThrows(torch.randn(0))) | 
 |         self.assertEqual(0, testNoThrows(torch.randn([2, 3]))) | 
 |  | 
 |         @torch.jit.script | 
 |         def throwsOr(t): | 
 |             c0 = False or bool(t[1]) | 
 |             print(c0) | 
 |  | 
 |         @torch.jit.script | 
 |         def throwsAnd(t): | 
 |             c0 = True and bool(t[1]) | 
 |             print(c0) | 
 |  | 
 |         t = torch.randn(0) | 
 |         with self.assertRaisesRegex(RuntimeError, "index 1 out of range for tensor of size"): | 
 |             throwsOr(t) | 
 |         with self.assertRaisesRegex(RuntimeError, "index 1 out of range for tensor of size"): | 
 |             throwsAnd(t) | 
 |  | 
 |     def test_type_cast(self): | 
 |         template = dedent(''' | 
 |         def func(v): | 
 |             # type: ({from_type}) -> {to_type} | 
 |             return {to_type}(v) | 
 |         ''') | 
 |  | 
 |         def check_cast(from_type, to_type, value, raises=False): | 
 |             code = template.format(from_type=from_type, to_type=to_type) | 
 |             self.checkScript(code, (value,)) | 
 |  | 
 |         check_cast('int', 'float', 1) | 
 |         check_cast('int', 'bool', 1) | 
 |         check_cast('int', 'bool', 0) | 
 |  | 
 |         check_cast('float', 'int', 1.) | 
 |         check_cast('float', 'bool', 1.) | 
 |         check_cast('float', 'bool', 0.) | 
 |  | 
 |         check_cast('bool', 'int', True) | 
 |         check_cast('bool', 'float', True) | 
 |  | 
 |     def test_multiple_assignment(self): | 
 |         def outer_func(x): | 
 |             return x * 2, x + 2 | 
 |  | 
 |         @torch.jit.script | 
 |         def func(x): | 
 |             y, z = outer_func(x) | 
 |             return y + z | 
 |  | 
 |         x = torch.arange(4) | 
 |         self.assertEqual(func(x), x * 2 + x + 2) | 
 |  | 
 |     def test_literals(self): | 
 |         def func(a): | 
 |             return a.view(size=[1, 2, 3]) | 
 |  | 
 |         a = torch.randn(6) | 
 |         self.checkScript(func, [a], optimize=True) | 
 |  | 
 |     def test_return(self): | 
 |         def no_return(a): | 
 |             a + 1 | 
 |  | 
 |         def void_return(a): | 
 |             return | 
 |  | 
 |         def one_return(a): | 
 |             return a + 1. | 
 |  | 
 |         def multiple_returns(a): | 
 |             return a * 1., a * 2., a * 3. | 
 |  | 
 |         a = torch.randn(1, dtype=torch.float) | 
 |         self.checkScript(no_return, [a], optimize=True) | 
 |         self.checkScript(void_return, [a], optimize=True) | 
 |         self.checkScript(one_return, [a], optimize=True) | 
 |         self.checkScript(multiple_returns, [a], optimize=True) | 
 |  | 
 |         with self.assertRaisesRegex(RuntimeError, "does not return along all paths"): | 
 |             torch.jit.CompilationUnit(''' | 
 |             def no_return_bad_annotation(a): | 
 |                 # type: (Tensor) -> Tensor | 
 |                 a + 1 | 
 |             ''') | 
 |  | 
 |     def test_error(self): | 
 |         @torch.jit.script | 
 |         def foo(a): | 
 |             return a.t() | 
 |         s = Variable(torch.rand(5, 5, 5)) | 
 |         # XXX: this should stay quiet in stay propagation and only fail in the interpreter | 
 |         with self.assertRaisesRegex(RuntimeError, "failed in the TorchScript interpreter"): | 
 |             foo(s) | 
 |  | 
 |         @torch.jit.script | 
 |         def bar(c, b): | 
 |             return c + b | 
 |  | 
 |         with self.assertRaisesRegex(RuntimeError, "failed in the TorchScript interpreter"): | 
 |             bar(Variable(torch.rand(10), requires_grad=True), Variable(torch.rand(9), requires_grad=True)) | 
 |  | 
 |     def test_error_stacktrace(self): | 
 |         @torch.jit.script | 
 |         def baz(c, b): | 
 |             return c + b | 
 |  | 
 |         @torch.jit.script | 
 |         def foo(c, b): | 
 |             return baz(c, b) | 
 |  | 
 |         @torch.jit.script | 
 |         def bar(c, b): | 
 |             return foo(c, b) | 
 |  | 
 |         with self.assertRaises(RuntimeError) as cm: | 
 |             bar(torch.rand(10), torch.rand(9)) | 
 |         FileCheck().check("The following operation failed in the TorchScript interpreter") \ | 
 |                    .check("Traceback") \ | 
 |                    .check("in foo").check("in baz").run(str(cm.exception)) | 
 |  | 
 |     def test_error_stacktrace_interface(self): | 
 |         @torch.jit.script | 
 |         def baz(c, b): | 
 |             return c + b | 
 |  | 
 |         @torch.jit.script | 
 |         def foo(c, b): | 
 |             return baz(c, b) | 
 |  | 
 |         @torch.jit.script | 
 |         def bar(c, b): | 
 |             return foo(c, b) | 
 |  | 
 |         @torch.jit.script | 
 |         class Bar(object): | 
 |             def one(self, x, y): | 
 |                 return bar(x, y) | 
 |  | 
 |         @torch.jit.interface | 
 |         class IFace(object): | 
 |             def one(self, x, y): | 
 |                 # type: (Tensor, Tensor) -> Tensor | 
 |                 pass | 
 |  | 
 |         make_global(IFace) | 
 |  | 
 |         @torch.jit.script | 
 |         def as_interface(x): | 
 |             # type: (IFace) -> IFace | 
 |             return x | 
 |  | 
 |         f = as_interface(Bar()) | 
 |  | 
 |         with self.assertRaises(RuntimeError) as cm: | 
 |             x = f.one(torch.rand(10), torch.rand(9)) | 
 |             bar(torch.rand(10), torch.rand(9)) | 
 |         FileCheck().check("The following operation failed in the TorchScript interpreter") \ | 
 |                    .check("Traceback") \ | 
 |                    .check("in foo").check("in baz").run(str(cm.exception)) | 
 |  | 
 |     def test_operator_precedence(self): | 
 |         def double(x): | 
 |             # type: (int) -> int | 
 |             return 2 * x | 
 |  | 
 |         def complicated_arithmetic_operation(): | 
 |             # TODO we need to test exponent operator '**' and bitwise not | 
 |             # operator '~' once they are properly supported. | 
 |             list = [0, 1, 2, 3] | 
 |             result = list[1:3][0] + double(4) + (-3 + 8) * 6 // 2 % 4 << 2 + 1 >> 1 | 23 & 16 + 3 ^ 4 | 
 |             return result | 
 |  | 
 |         self.checkScript(complicated_arithmetic_operation, ()) | 
 |  | 
 |     def test_in_operator_with_two_strings(self): | 
 |         def fn() -> bool: | 
 |             return "a" in "abcd" | 
 |         self.checkScript(fn, ()) | 
 |  | 
 |     def test_bitwise_ops(self): | 
 |  | 
 |         def int_test(): | 
 |             return 2 & 3, 2 ^ 3, 2 | 3, 2 << 3, 2 >> 3 | 
 |  | 
 |         self.checkScript(int_test, ()) | 
 |  | 
 |         def bool_test(x, y): | 
 |             # type: (bool, bool) -> Tuple[bool, bool, bool] | 
 |             return x & y, x ^ y, x | y | 
 |  | 
 |         self.checkScript(bool_test, (True, False)) | 
 |         self.checkScript(bool_test, (True, True)) | 
 |  | 
 |         def tensor_test(x, y): | 
 |             return x & y, x ^ y, x | y | 
 |  | 
 |         def tensor_with_int_test(x, y): | 
 |             # type: (Tensor, int) -> Tuple[Tensor, Tensor] | 
 |             return x << y, x >> y | 
 |  | 
 |         x = torch.tensor(2) | 
 |         y = torch.tensor(3) | 
 |  | 
 |         self.checkScript(tensor_test, (x, y)) | 
 |         self.checkScript(tensor_with_int_test, (x, 2)) | 
 |  | 
 |         def not_test(x): | 
 |             return ~x | 
 |  | 
 |         self.checkScript(not_test, (torch.tensor([2, 4]), )) | 
 |  | 
 |     def test_all(self): | 
 |         @torch.jit.script | 
 |         def test_all_tensor(x): | 
 |             return all(x) | 
 |         self.assertFalse(test_all_tensor(torch.tensor([1, 0, 3], dtype=torch.uint8))) | 
 |         self.assertTrue(test_all_tensor(torch.tensor([3.14, 3, 99], dtype=torch.uint8))) | 
 |         self.assertTrue(test_all_tensor(torch.tensor([True, True], dtype=torch.uint8))) | 
 |         self.assertFalse(test_all_tensor(torch.tensor([True, False], dtype=torch.uint8))) | 
 |  | 
 |         @torch.jit.script | 
 |         def test_all_bool_list(x): | 
 |             # type: (List[bool]) -> bool | 
 |             return all(x) | 
 |         self.assertTrue(test_all_bool_list([True, True])) | 
 |         self.assertTrue(test_all_bool_list([True, 1])) | 
 |         self.assertFalse(test_all_bool_list([True, False])) | 
 |         self.assertFalse(test_all_bool_list([True, 0])) | 
 |         self.assertFalse(test_all_bool_list([False, 0])) | 
 |         self.assertTrue(test_all_bool_list([])) | 
 |  | 
 |         @torch.jit.script | 
 |         def test_all_int_list(x): | 
 |             # type: (List[int]) -> bool | 
 |             return all(x) | 
 |         self.assertTrue(test_all_int_list([3, 6])) | 
 |         self.assertFalse(test_all_int_list([2, 0])) | 
 |  | 
 |         @torch.jit.script | 
 |         def test_all_float_list(x): | 
 |             # type: (List[float]) -> bool | 
 |             return all(x) | 
 |         self.assertTrue(test_all_float_list([3.14, 8.1])) | 
 |         self.assertFalse(test_all_float_list([3.14, 0, 8.9])) | 
 |  | 
 |  | 
 |     def test_number_math(self): | 
 |         ops_template = dedent(''' | 
 |         def func(): | 
 |             return {scalar1} {op} {scalar2} | 
 |         ''') | 
 |         ops = ['+', '-', '*', '%', '<', '<=', '>', '>=', '==', '!=', '//'] | 
 |         funcs_template = dedent(''' | 
 |         def func(): | 
 |             return {func}({scalar1}, {scalar2}) | 
 |         ''') | 
 |         funcs = ['min', 'max'] | 
 |         scalars = ['7', '2', '3', '-3', '3.14', '0.125', '-0.5', '2.0', '-2.0'] | 
 |         scalar_pairs = [(scalar1, scalar2) for scalar1 in scalars for scalar2 in scalars] | 
 |  | 
 |         def run_test(code): | 
 |             scope = {} | 
 |             execWrapper(code, globals(), scope) | 
 |             cu = torch.jit.CompilationUnit(code) | 
 |  | 
 |             self.assertEqual(cu.func(), scope['func']()) | 
 |  | 
 |         for scalar1, scalar2 in scalar_pairs: | 
 |             for op in ops: | 
 |                 code = ops_template.format(op=op, scalar1=scalar1, scalar2=scalar2) | 
 |                 run_test(code) | 
 |             for func in funcs: | 
 |                 code = funcs_template.format(func=func, scalar1=scalar1, scalar2=scalar2) | 
 |                 run_test(code) | 
 |  | 
 |         # test Scalar overloads | 
 |         for scalar1, scalar2 in scalar_pairs: | 
 |             item1 = 'torch.tensor(' + scalar1 + ').item()' | 
 |             item2 = 'torch.tensor(' + scalar2 + ').item()' | 
 |             for op in ops: | 
 |                 code = ops_template.format(op=op, scalar1=item1, scalar2=scalar2) | 
 |                 run_test(code) | 
 |                 code = ops_template.format(op=op, scalar1=scalar1, scalar2=item2) | 
 |                 run_test(code) | 
 |                 code = ops_template.format(op=op, scalar1=item1, scalar2=item2) | 
 |                 run_test(code) | 
 |             for func in funcs: | 
 |                 code = funcs_template.format(func=func, scalar1=item1, scalar2=scalar2) | 
 |                 run_test(code) | 
 |                 code = funcs_template.format(func=func, scalar1=scalar1, scalar2=item2) | 
 |                 run_test(code) | 
 |                 code = funcs_template.format(func=func, scalar1=item1, scalar2=item2) | 
 |                 run_test(code) | 
 |  | 
 |     def test_number_abs(self): | 
 |         def func1(x): | 
 |             # type: (float) -> float | 
 |             return abs(x) | 
 |  | 
 |         def func2(x): | 
 |             # type: (int) -> int | 
 |             return abs(x) | 
 |  | 
 |         def func3(x): | 
 |             return abs(x) | 
 |  | 
 |         self.checkScript(func1, (-3.14,)) | 
 |         self.checkScript(func1, (3.14,)) | 
 |         self.checkScript(func2, (-10,)) | 
 |         self.checkScript(func2, (10,)) | 
 |         self.checkScript(func3, (torch.tensor([-5, -10, -20]),)) | 
 |         self.checkScript(func3, (torch.tensor([5, 10, 20]),)) | 
 |         self.checkScript(func3, (torch.tensor([-5, 10, -20]),)) | 
 |  | 
 |     def test_number_div(self): | 
 |         self.assertEqual(div_int_future(), torch.jit.script(div_int_future)()) | 
 |         self.checkScript(div_float_future, ()) | 
 |  | 
 |         self.checkScript(div_int_nofuture, ()) | 
 |         self.checkScript(div_float_nofuture, ()) | 
 |  | 
 |     # Testing bitwise shorthand aug assignment | 
 |     def test_bool_augassign_bitwise_or(self): | 
 |         def func(a: bool, b: bool) -> bool: | 
 |             a |= b | 
 |             return a | 
 |  | 
 |         self.checkScript(func, (True, False), optimize=True) | 
 |         self.checkScript(func, (True, True), optimize=True) | 
 |         self.checkScript(func, (False, False), optimize=True) | 
 |         self.checkScript(func, (False, True), optimize=True) | 
 |  | 
 |     def test_bool_augassign_bitwise_and(self): | 
 |         def func(a: bool, b: bool) -> bool: | 
 |             a &= b | 
 |             return a | 
 |  | 
 |         self.checkScript(func, (True, False), optimize=True) | 
 |         self.checkScript(func, (True, True), optimize=True) | 
 |         self.checkScript(func, (False, False), optimize=True) | 
 |         self.checkScript(func, (False, True), optimize=True) | 
 |  | 
 |     def test_bool_augassign_bitwise_xor(self): | 
 |         def func(a: bool, b: bool) -> bool: | 
 |             a ^= b | 
 |             return a | 
 |  | 
 |         self.checkScript(func, (True, False), optimize=True) | 
 |         self.checkScript(func, (True, True), optimize=True) | 
 |         self.checkScript(func, (False, False), optimize=True) | 
 |         self.checkScript(func, (False, True), optimize=True) | 
 |  | 
 |     def test_number_augassign_bitwise_lshift(self): | 
 |         def func() -> int: | 
 |             z = 8 | 
 |             z <<= 2 | 
 |             return z | 
 |  | 
 |         self.checkScript(func, (), optimize=True) | 
 |  | 
 |     def test_number_augassign_bitwise_rshift(self): | 
 |         def func() -> int: | 
 |             z = 8 | 
 |             z >>= 2 | 
 |             return z | 
 |  | 
 |         self.checkScript(func, (), optimize=True) | 
 |  | 
 |     def test_number_augassign_bitwise_pow(self): | 
 |         def func() -> float: | 
 |             z = 8 | 
 |             z **= 2 | 
 |             return z | 
 |  | 
 |         self.checkScript(func, (), optimize=True) | 
 |  | 
 |     def test_number_augassign(self): | 
 |         def func(): | 
 |             z = 1 | 
 |             z += 2 | 
 |             return z | 
 |  | 
 |         self.checkScript(func, (), optimize=True) | 
 |  | 
 |     def test_nested_select_assign(self): | 
 |         class SubSubModule(torch.nn.Module): | 
 |             def __init__(self): | 
 |                 super(SubSubModule, self).__init__() | 
 |                 self.abc = 11 | 
 |  | 
 |             def forward(self, x): | 
 |                 return self.abc | 
 |  | 
 |         class SubModule(torch.nn.Module): | 
 |             def __init__(self): | 
 |                 super(SubModule, self).__init__() | 
 |                 self.a = 11 | 
 |                 self.nested = SubSubModule() | 
 |  | 
 |             def forward(self, x): | 
 |                 return self.a | 
 |  | 
 |         class TestModule(torch.nn.Module): | 
 |             def __init__(self): | 
 |                 super(TestModule, self).__init__() | 
 |                 self.sub = SubModule() | 
 |                 self.hi = 1 | 
 |  | 
 |             def forward(self): | 
 |                 self.hi = 5 | 
 |                 self.sub.a = 1 | 
 |                 self.sub.nested.abc = 5 | 
 |                 return self.sub.a * 20 + self.sub.nested.abc * 3 + self.hi | 
 |  | 
 |         self.checkModule(TestModule(), ()) | 
 |  | 
 |     def test_number_neg(self): | 
 |         # int -> int | 
 |         def func1(): | 
 |             return -8 | 
 |  | 
 |         # float -> float | 
 |         def func2(): | 
 |             return -3.14 | 
 |  | 
 |         self.checkScript(func1, (), optimize=True) | 
 |         self.checkScript(func2, (), optimize=True) | 
 |  | 
 |     def test_compare_two_bool_inputs(self): | 
 |         def compare_eq(a: bool, b: bool): | 
 |             return a == b | 
 |  | 
 |         def compare_ne(a: bool, b: bool): | 
 |             return a != b | 
 |  | 
 |         scripted_fn_eq = torch.jit.script(compare_eq) | 
 |         scripted_fn_ne = torch.jit.script(compare_ne) | 
 |         self.assertEqual(scripted_fn_eq(True, False), compare_eq(True, False)) | 
 |         self.assertEqual(scripted_fn_eq(False, True), compare_eq(False, True)) | 
 |         self.assertEqual(scripted_fn_eq(True, True), compare_eq(True, True)) | 
 |         self.assertEqual(scripted_fn_eq(False, False), compare_eq(False, False)) | 
 |  | 
 |         self.assertEqual(scripted_fn_ne(True, False), compare_ne(True, False)) | 
 |         self.assertEqual(scripted_fn_ne(False, True), compare_ne(False, True)) | 
 |         self.assertEqual(scripted_fn_ne(True, True), compare_ne(True, True)) | 
 |         self.assertEqual(scripted_fn_ne(False, False), compare_ne(False, False)) | 
 |  | 
 |  | 
 |     def _test_tensor_number_math(self, device='cpu'): | 
 |         template = dedent(''' | 
 |         def func(t): | 
 |             return {lhs} {op} {rhs} | 
 |         ''') | 
 |  | 
 |         def test(op, tensor, const, swap_args, template=template): | 
 |             args = ('t', const) | 
 |             if swap_args: | 
 |                 args = (const, 't') | 
 |  | 
 |             code = template.format(lhs=args[0], rhs=args[1], op=op) | 
 |             scope = {} | 
 |             execWrapper(code, globals(), scope) | 
 |             cu = torch.jit.CompilationUnit(code) | 
 |             message = 'with code `{} {} {}` and t={}'.format(args[0], op, args[1], tensor) | 
 |             res1 = cu.func(tensor) | 
 |             res2 = scope['func'](tensor) | 
 |             self.assertEqual(res1, res2, msg=message + "\nres1=" + str(res1) + "\nres2=" + str(res2)) | 
 |             self.assertEqual(res1.dtype, res2.dtype, msg=message + "\nres1=" + str(res1) + "\nres2=" + str(res2)) | 
 |  | 
 |         var_int = [2, -2] | 
 |         var_float = [1.4321, -1.2] | 
 |  | 
 |         ops = ['+', '-', '*', '%', '<', '<=', '>', '>=', '==', '!=', '/'] | 
 |  | 
 |         float_tensor = torch.randn(5, 5, device=device) | 
 |         double_tensor = torch.randn(5, 5, dtype=torch.double, device=device) | 
 |         long_tensor = torch.randint(-5, 5, (5, 5), dtype=torch.long, device=device) | 
 |         long_tensor[long_tensor == 0] = 2 | 
 |  | 
 |         tensors = [float_tensor, double_tensor, long_tensor] | 
 |         consts = var_int + var_float | 
 |  | 
 |         for op, tensor, const, swap_args in product(ops, tensors, consts, [True, False]): | 
 |             # FIXME: things like 2 / long_tensor are not implemented correctly | 
 |             # Look in torch/_tensor.py to see how pytorch implements it. | 
 |             if op == '/' and tensor.data_ptr() == long_tensor.data_ptr(): | 
 |                 continue | 
 |  | 
 |             # % operator does not take: const % tensor | 
 |             if op == '%' and swap_args is True: | 
 |                 continue | 
 |  | 
 |             test(op, tensor, const, swap_args) | 
 |  | 
 |     def test_tensor_number_math(self): | 
 |         self._test_tensor_number_math() | 
 |  | 
 |     def test_torch_tensor_bad_input(self): | 
 |         with self.assertRaisesRegex(RuntimeError, "must be of ints, floats, " | 
 |                                     "or bools, got None"): | 
 |             @torch.jit.script | 
 |             def test(): | 
 |                 return torch.tensor([None]) | 
 |             test() | 
 |  | 
 |         with self.assertRaisesRegex(RuntimeError, r"Empty lists default to List\[Tensor\]"): | 
 |             @torch.jit.script | 
 |             def tmp(): | 
 |                 return torch.tensor([]) | 
 |             tmp() | 
 |  | 
 |         @torch.jit.script | 
 |         def foo(): | 
 |             return torch.tensor([[2, 2], [1]]) | 
 |         with self.assertRaisesRegex(RuntimeError, "Expected sequence of length"): | 
 |             foo() | 
 |  | 
 |     @suppress_warnings | 
 |     def test_torch_tensor_as_tensor_empty_list(self): | 
 |         tensor_template = dedent(''' | 
 |         def func(): | 
 |             empty_list = torch.jit.annotate(List[int], []) | 
 |             ten1 = torch.{tensor_op}({input}) | 
 |             return ten1 | 
 |         ''') | 
 |         ops = ['tensor', 'as_tensor'] | 
 |         inputs = ['empty_list', '[empty_list, empty_list]', '[[[empty_list]]]'] | 
 |  | 
 |         for op in ops: | 
 |             for inp in inputs: | 
 |                 code = tensor_template.format(tensor_op=op, input=inp) | 
 |                 scope = {} | 
 |                 exec(code, globals(), scope) | 
 |                 cu = torch.jit.CompilationUnit(code) | 
 |                 t1 = cu.func() | 
 |                 t2 = scope['func']() | 
 |                 if inp == 'empty_list': | 
 |                     # torchscript returns int tensor, python returns float tensor | 
 |                     self.assertNotEqual(t1.dtype, t2.dtype) | 
 |                 self.assertEqual(t1, t2, exact_dtype=False) | 
 |                 self.assertEqual(t1.device, t2.device) | 
 |  | 
 |     @unittest.skipIf(GRAPH_EXECUTOR != ProfilingMode.LEGACY, "Simple Executor doesn't have any shapes to propagate") | 
 |     def test_tensor_as_tensor_shape_prop(self): | 
 |         tensor_template = dedent(''' | 
 |         def func(): | 
 |             return torch.{tensor_op}({input}) | 
 |         ''') | 
 |         ops = ['tensor', 'as_tensor'] | 
 |         inputs = ['[1]', '[False]', '[2.5]', '0.5', '1', 'False', '[[1]]', 'torch.jit.annotate(List[List[int]], [])'] | 
 |         expected_shape = ["Long(*, device=cpu)", "Bool(*, device=cpu)", | 
 |                           "Double(*, device=cpu)", "Double(device=cpu)", | 
 |                           "Long(device=cpu)", "Bool(device=cpu)", "Long(*, *, device=cpu)"] | 
 |  | 
 |         for op in ops: | 
 |             for inp, expect in zip(inputs, expected_shape): | 
 |                 code = tensor_template.format(tensor_op=op, input=inp) | 
 |                 scope = {} | 
 |                 exec(code, globals(), scope) | 
 |                 cu = torch.jit.CompilationUnit(code) | 
 |                 torch._C._jit_pass_complete_shape_analysis(cu.func.graph, (), False) | 
 |                 FileCheck().check(expect).check("aten::{tensor_op}".format(tensor_op=op)).run(cu.func.graph) | 
 |  | 
 |         @torch.jit.script | 
 |         def test_dtype(inp_dtype: torch.dtype): | 
 |             a = torch.tensor(1.0, dtype=torch.float, requires_grad=True) | 
 |             return a, torch.tensor(1.0, dtype=inp_dtype) | 
 |  | 
 |         if GRAPH_EXECUTOR == ProfilingMode.PROFILING: | 
 |             g = test_dtype.graph_for(5, profile_and_replay=True) | 
 |             # both should have completed shapes | 
 |             FileCheck().check("Tensor = aten::tensor").check("Float(device=cpu) = prim::BailOut") \ | 
 |                        .check("Tensor = aten::tensor").check("Half(device=cpu) = prim::BailOut").run(g) | 
 |         else: | 
 |             g = test_dtype.graph_for(5) | 
 |             # first should have type set second should not | 
 |             FileCheck().check("Float(requires_grad=1, device=cpu) = aten::tensor") \ | 
 |                        .check("Tensor(requires_grad=0) = aten::tensor").run(g) | 
 |  | 
 |         @torch.jit.script | 
 |         def test_as_tensor_tensor_input(input): | 
 |             a = torch.as_tensor(input, dtype=input.dtype) | 
 |             return a, torch.as_tensor(input, dtype=torch.float) | 
 |  | 
 |         if GRAPH_EXECUTOR == ProfilingMode.PROFILING: | 
 |             g = test_as_tensor_tensor_input.graph_for(torch.ones(3, 4), profile_and_replay=True) | 
 |             FileCheck().check("Tensor = aten::as_tensor").check("Float(3, 4) = prim::BailOut") \ | 
 |                        .check("Tensor = aten::as_tensor").check("Float(3, 4) = prim::BailOut").run(g) | 
 |         else: | 
 |             g = test_as_tensor_tensor_input.graph_for(torch.ones(3, 4)) | 
 |             FileCheck().check("Tensor = aten::as_tensor").check("Float(*, *, requires_grad=0, device=cpu) = aten::as_tensor").run(g) | 
 |  | 
 |     @unittest.skipIf(GRAPH_EXECUTOR != ProfilingMode.LEGACY, "testing legacy behavior") | 
 |     def test_tensor_requires_grad(self): | 
 |         @torch.jit.script | 
 |         def test(b): | 
 |             # type: (bool) -> Tuple[Tensor, Tensor, Tensor] | 
 |             a = torch.tensor(1., requires_grad=b) | 
 |             b = torch.tensor(1., requires_grad=True) | 
 |             c = torch.tensor(1., requires_grad=False) | 
 |             return a, b, c | 
 |  | 
 |         g = test.graph_for(True) | 
 |         out = next(g.outputs()) | 
 |         out_inp = list(out.node().inputs()) | 
 |  | 
 |         self.assertTrue(out_inp[0].requires_grad()) | 
 |         self.assertTrue(out_inp[1].requires_grad()) | 
 |         self.assertFalse(out_inp[2].requires_grad()) | 
 |  | 
 |     def test_grad_from_script(self): | 
 |         def test(): | 
 |             a = torch.tensor(2.5, requires_grad=True) | 
 |             b = a * 2 | 
 |             return a, b | 
 |  | 
 |         a, b = test() | 
 |         b.backward() | 
 |  | 
 |         a_script, b_script = torch.jit.script(test)() | 
 |         b_script.backward() | 
 |         self.assertEqual(a.grad, a_script.grad) | 
 |  | 
 |     def test_torch_tensor_as_tensor(self): | 
 |         tensor_template = dedent(''' | 
 |         def func(): | 
 |             li = {list_create} | 
 |             ten1 = torch.{tensor_op}(li {options}) | 
 |             return ten1 | 
 |         ''') | 
 |  | 
 |         lists = ["2.5", "4", "True", "False", "[2]", "[-.5]", "[False, True, False]", "[2, 2]", "(1, 1)", | 
 |                  "torch.jit.annotate(List[List[int]], [])", | 
 |                  "torch.jit.annotate(List[int], [])", "[2.5, 2.5]", "[[2], [2]]", "[[-.5], [2.2]]", "[[False], [True]]"] | 
 |  | 
 |         dtypes = ["", ", dtype=torch.float", ", dtype=torch.double", ", dtype=torch.half", | 
 |                   ", dtype=torch.uint8", ", dtype=torch.int8", ", dtype=torch.short", | 
 |                   ", dtype=torch.int", ", dtype=torch.long", ", dtype=torch.cfloat", | 
 |                   ", dtype=torch.cdouble"] | 
 |  | 
 |         ops = ['tensor', 'as_tensor'] | 
 |         devices = ['', ", device='cpu'"] | 
 |         if RUN_CUDA: | 
 |             devices.append(", device='cuda'") | 
 |  | 
 |         option_pairs = [dtype + device for dtype in dtypes for device in devices] | 
 |         for op in ops: | 
 |             for li in lists: | 
 |                 for option in option_pairs: | 
 |                     # tensor from empty list is type float in python and annotated type in torchscript | 
 |                     if "annotate" in li and "dtype" not in option: | 
 |                         continue | 
 |                     # Skip unsigned tensor initializaton for signed values on 3.10 | 
 |                     if sys.version_info[:2] >= (3, 10) and "torch.uint8" in option and "-" in li: | 
 |                         continue | 
 |                     code = tensor_template.format(list_create=li, tensor_op=op, options=option) | 
 |                     scope = {} | 
 |                     exec(code, globals(), scope) | 
 |                     cu = torch.jit.CompilationUnit(code) | 
 |                     t1 = cu.func() | 
 |                     t2 = scope['func']() | 
 |                     if t1.dtype == torch.float16:  # equality NYI for half tensor | 
 |                         self.assertTrue(str(t1) == str(t2)) | 
 |                     else: | 
 |                         self.assertEqual(t1, t2) | 
 |                     self.assertEqual(t1.dtype, t2.dtype) | 
 |                     self.assertEqual(t1.device, t2.device) | 
 |  | 
 |         def test_as_tensor_tensor_input(input): | 
 |             # type: (Tensor) -> Tuple[Tensor, Tensor, Tensor] | 
 |             return torch.as_tensor(input, dtype=torch.cfloat), torch.as_tensor(input, dtype=torch.float), \ | 
 |                 torch.as_tensor(input, dtype=torch.int32) | 
 |  | 
 |         inp = torch.randn(3, 4, dtype=torch.cfloat) | 
 |         self.checkScript(test_as_tensor_tensor_input, (inp,)) | 
 |  | 
 |     def test_torch_tensor_dtype(self): | 
 |         def foo(s: float): | 
 |             return torch.tensor(s), torch.tensor([s, s]) | 
 |  | 
 |         # need to clear function cache so we re run shape analysis | 
 |         with set_default_dtype(torch.double): | 
 |             self.assertEqual(torch.jit.script(foo)(1.), foo(1.), exact_dtype=True) | 
 |             if GRAPH_EXECUTOR == ProfilingMode.LEGACY: | 
 |                 FileCheck().check("Double").check_same("aten::tensor").run(torch.jit.last_executed_optimized_graph()) | 
 |         with set_default_dtype(torch.float): | 
 |             del torch.jit._state._jit_caching_layer[foo] | 
 |             self.assertEqual(torch.jit.script(foo)(1.), foo(1.), exact_dtype=True) | 
 |             if GRAPH_EXECUTOR == ProfilingMode.LEGACY: | 
 |                 FileCheck().check("Float").check_same("aten::tensor").run(torch.jit.last_executed_optimized_graph()) | 
 |         with set_default_dtype(torch.half): | 
 |             del torch.jit._state._jit_caching_layer[foo] | 
 |             self.assertEqual(torch.jit.script(foo)(1.), foo(1.), exact_dtype=True) | 
 |             if GRAPH_EXECUTOR == ProfilingMode.LEGACY: | 
 |                 FileCheck().check("Half").check_same("aten::tensor").run(torch.jit.last_executed_optimized_graph()) | 
 |  | 
 |     def test_shape_analysis_grad_property(self): | 
 |         @torch.jit.script | 
 |         def foo(x): | 
 |             return torch.sub(x, torch.tanh(x)) | 
 |  | 
 |         torch._C._jit_pass_complete_shape_analysis(foo.graph, (torch.tensor([0.39]),), False) | 
 |  | 
 |         # requires_grad property shouldn't be accidentally set by shape analysis | 
 |         self.assertTrue(foo.graph.findNode("aten::sub").output().requiresGrad() is None) | 
 |  | 
 |     def test_empty_like_memory_format_bc(self): | 
 |         def f(x): | 
 |             # type: (Tensor) -> Tensor | 
 |             return torch.zeros_like(x, memory_format=None) | 
 |  | 
 |         scripted_f = torch.jit.script(f) | 
 |         x = torch.rand(3, 4) | 
 |         self.assertEqual(scripted_f(x), f(x)) | 
 |  | 
 |     def test_multiline_string_dedents(self): | 
 |         def foo() -> None: | 
 |             multiline_string_dedent_1 = """ | 
 | This is a string dedent """ | 
 |             multiline_string_dedent_2 = """ This is a | 
 |   string dedent """ | 
 |             multiline_string_dedent_3 = """ | 
 |             This is a string | 
 | dedent """ | 
 |             multiline_string_dedent_4 = """ This is a string dedent """ | 
 |  | 
 |         scripted_foo = torch.jit.script(foo) | 
 |         self.assertEqual(scripted_foo(), foo()) | 
 |  | 
 |     def test_class_with_comment_at_lower_indentation(self): | 
 |         class Foo(torch.nn.Module): | 
 |             def forward(self, x): | 
 |                 x = torch.neg(x) | 
 |         # This comment is at the wrong indent | 
 |                 return x | 
 |  | 
 |         torch.jit.script(Foo()) | 
 |  | 
 |     # adapted from test in test_torch | 
 |     def test_tensor_to(self): | 
 |         template = dedent(''' | 
 |         def func(t): | 
 |             cuda = "{cuda}" | 
 |             device = "{device}" | 
 |             non_blocking = {non_blocking} | 
 |             return {to_str} | 
 |         ''') | 
 |  | 
 |         def s(t, to_str, non_blocking=None, device=None, cuda=None): | 
 |             device = device if device is not None else str(t.device) | 
 |             non_blocking = non_blocking if non_blocking is not None else False | 
 |             cuda = "cuda" if cuda is None else cuda | 
 |             code = template.format(to_str=to_str, device=device, non_blocking=non_blocking, cuda=cuda) | 
 |             scope = {} | 
 |             cu = torch.jit.CompilationUnit(code) | 
 |             return cu.func(t, profile_and_replay=True) | 
 |  | 
 |         def test_copy_behavior(t, non_blocking=False): | 
 |             self.assertIs(t, s(t, 't.to(t, non_blocking=non_blocking)', non_blocking)) | 
 |             self.assertIs(t, s(t, 't.to(t.dtype, non_blocking=non_blocking)', non_blocking)) | 
 |             self.assertIs(t, s(t, 't.to(torch.empty_like(t), non_blocking=non_blocking)', non_blocking)) | 
 |             self.assertIsNot(t, s(t, 't.to(t, non_blocking=non_blocking, copy=True)', non_blocking)) | 
 |             self.assertIsNot(t, s(t, 't.to(t.dtype, non_blocking=non_blocking, copy=True)', non_blocking)) | 
 |             self.assertIsNot(t, s(t, 't.to(torch.empty_like(t), non_blocking=non_blocking, copy=True)', non_blocking)) | 
 |  | 
 |             devices = [t.device] | 
 |             if t.device.type == 'cuda': | 
 |                 if t.device.index == -1: | 
 |                     devices.append('cuda:{}'.format(torch.cuda.current_device())) | 
 |                 elif t.device.index == torch.cuda.current_device(): | 
 |                     devices.append('cuda') | 
 |             for device in devices: | 
 |                 self.assertIs(t, s(t, 't.to(device, non_blocking=non_blocking)', non_blocking, device)) | 
 |                 self.assertIs(t, s(t, 't.to(device, t.dtype, non_blocking=non_blocking)', non_blocking, device)) | 
 |                 self.assertIsNot(t, s(t, 't.to(device, non_blocking=non_blocking, copy=True)', non_blocking, device)) | 
 |                 self.assertIsNot(t, s(t, 't.to(device, t.dtype, non_blocking=non_blocking, copy=True)', | 
 |                                       non_blocking, device)) | 
 |  | 
 |         t = torch.tensor(5) | 
 |         test_copy_behavior(t) | 
 |  | 
 |         self.assertEqual(t.device, s(t, "t.to('cpu')").device) | 
 |         self.assertEqual(t.device, s(t, "t.to('cpu', dtype=torch.float32)").device) | 
 |         self.assertIs(torch.float32, s(t, "t.to('cpu', dtype=torch.float32)").dtype) | 
 |         self.assertEqual(t.device, s(t, "t.to(torch.float32)").device) | 
 |         self.assertIs(torch.float32, s(t, "t.to(dtype=torch.float32)").dtype) | 
 |         self.assertEqual(t.data_ptr(), s(t, "t.to('cpu')").data_ptr()) | 
 |         self.assertEqual(t.data_ptr(), s(t, "t.to(dtype=t.dtype, device=t.device, copy=False)").data_ptr()) | 
 |         self.assertEqual(t.data_ptr(), s(t, "t.to('cpu', copy=False)").data_ptr()) | 
 |         self.assertNotEqual(t.data_ptr(), s(t, "t.to('cpu', copy=True)").data_ptr()) | 
 |  | 
 |         a = torch.tensor(5) | 
 |         if torch.cuda.is_available(): | 
 |             for non_blocking in [True, False]: | 
 |                 for cuda in ['cuda', 'cuda:0' if torch.cuda.device_count() == 1 else 'cuda:1']: | 
 |                     b = torch.tensor(5., device=cuda) | 
 |                     test_copy_behavior(b, non_blocking) | 
 |                     self.assertEqual(b.device, s(b, "t.to(cuda, non_blocking=non_blocking).device", cuda=cuda)) | 
 |                     self.assertEqual(a.device, s(b, "t.to('cpu', non_blocking=non_blocking).device")) | 
 |                     self.assertEqual(b.device, s(b, "t.to(cuda, non_blocking=non_blocking).device", cuda=cuda)) | 
 |                     self.assertIs(torch.int32, s(b, "t.to('cpu', dtype=torch.int32, non_blocking=non_blocking)").dtype) | 
 |                     self.assertEqual(a.device, s(b, "t.to('cpu', dtype=torch.int32, non_blocking=non_blocking)").device) | 
 |                     self.assertIs(torch.int32, s(b, "t.to(dtype=torch.int32)").dtype) | 
 |                     self.assertEqual(b.device, s(b, "t.to(dtype=torch.int32)").device) | 
 |  | 
 |         # Test AD: aten::to(Tensor self, int dtype, bool non_blocking, bool copy) -> Tensor | 
 |         t = torch.tensor(5).float().requires_grad_() | 
 |         out_ref = t.to(torch.float32) | 
 |         out = s(t, "t.to(torch.float32)") | 
 |         self.assertEqual(out_ref, out) | 
 |  | 
 |         grad_ref = torch.autograd.grad(out_ref.sum(), t) | 
 |         grad = torch.autograd.grad(out.sum(), t) | 
 |         self.assertEqual(grad_ref, grad) | 
 |  | 
 |         # Test AD: aten::to(Tensor self, Device? device, int? dtype, bool non_blocking, bool copy) -> Tensor | 
 |         out_ref = t.to('cpu') | 
 |         out = s(t, "t.to('cpu')") | 
 |         self.assertEqual(out_ref, out) | 
 |  | 
 |         grad_ref = torch.autograd.grad(out_ref.sum(), t) | 
 |         grad = torch.autograd.grad(out.sum(), t) | 
 |         self.assertEqual(grad_ref, grad) | 
 |  | 
 |         # Test AD: aten::to(Tensor self, Tensor other, bool non_blocking, bool copy) -> Tensor | 
 |         @torch.jit.script | 
 |         def func2(t, t_ref): | 
 |             return t.to(t_ref) | 
 |  | 
 |         with disable_autodiff_subgraph_inlining(): | 
 |             t_ref = torch.tensor(4).double() | 
 |             out_ref = t.to(t_ref) | 
 |             out = func2(t, t_ref) | 
 |             grad_ref = torch.autograd.grad(out_ref.sum(), t) | 
 |             grad = torch.autograd.grad(out.sum(), t) | 
 |             self.assertEqual(grad_ref, grad) | 
 |  | 
 |     @unittest.skipIf(not RUN_CUDA, "No CUDA") | 
 |     def test_tensor_number_math_cuda(self): | 
 |         self._test_tensor_number_math(device='cuda') | 
 |  | 
 |     def test_not(self): | 
 |         # test not operator in python | 
 |         # TODO: add more tests when bool conversions ready | 
 |         def test_not_op(a): | 
 |             return not bool(a > 1) | 
 |  | 
 |         self.checkScript(test_not_op, (torch.tensor(2), ), optimize=True) | 
 |  | 
 |     def test_is_isnot(self): | 
 |         # test is and is not operator in python | 
 |         template = dedent(''' | 
 |         def func(): | 
 |             # type: () -> bool | 
 |             return {lhs} {op} {rhs} | 
 |         ''') | 
 |  | 
 |         def test(op, args): | 
 |             code = template.format(lhs=args[0], rhs=args[1], op=op) | 
 |             scope = {} | 
 |             execWrapper(code, globals(), scope) | 
 |             cu = torch.jit.CompilationUnit(code) | 
 |             self.assertEqual( | 
 |                 cu.func(), | 
 |                 scope['func'](), | 
 |                 msg="Failed with op: {}, lhs: {}, rhs: {}" | 
 |                 .format(op, args[0], args[1]) | 
 |             ) | 
 |  | 
 |         ops = ['is', 'is not'] | 
 |         type_literals = [True, False, None, [1, 1], 1, 2, .5, 1.5] | 
 |  | 
 |         # do literals product to try any types combinations | 
 |         for op, lhs, rhs in product(ops, type_literals, type_literals): | 
 |             test(op, [lhs, rhs]) | 
 |  | 
 |     def test_isinstance_refinement(self): | 
 |         @torch.jit.script | 
 |         def foo(a): | 
 |             # type: (Optional[int]) -> int | 
 |             if isinstance(a, int): | 
 |                 return a + 3 | 
 |             else: | 
 |                 return 4 | 
 |         self.assertEqual(foo(4), 7) | 
 |         self.assertEqual(foo(None), 4) | 
 |  | 
 |         @torch.jit.script | 
 |         def foo2(a, b): | 
 |             # type: (Optional[int], Optional[int]) -> int | 
 |             if not isinstance(a, int) or not isinstance(b, int): | 
 |                 return 0 | 
 |             else: | 
 |                 return a + b | 
 |         self.assertEqual(foo2(3, 4), 7) | 
 |         self.assertEqual(foo2(None, 4), 0) | 
 |         self.assertEqual(foo2(4, None), 0) | 
 |  | 
 |         @torch.jit.script | 
 |         def any_refinement(a, b): | 
 |             # type: (Any, Any) -> int | 
 |             if isinstance(a, int) and isinstance(b, int): | 
 |                 return a + b | 
 |             return 0 | 
 |  | 
 |         self.assertEqual(any_refinement(3, 4), 7) | 
 |         self.assertEqual(any_refinement(3, "hi"), 0) | 
 |  | 
 |         @torch.jit.script | 
 |         def any_refinement2(a): | 
 |             # type: (Any) -> Tensor | 
 |             if isinstance(a, Tensor): | 
 |                 return a | 
 |             return torch.tensor(3) | 
 |  | 
 |         self.assertEqual(any_refinement2(3), torch.tensor(3)) | 
 |         self.assertEqual(any_refinement2(torch.tensor(5)), torch.tensor(5)) | 
 |  | 
 |     @unittest.skipIf(GRAPH_EXECUTOR == ProfilingMode.LEGACY, "bug persists in deprecated executor") | 
 |     def test_unspecialized_any_binding(self): | 
 |         # any binding will infer the type, if it infers | 
 |         # a specialized tensor type `x` Dict type will fail isinstance check | 
 |  | 
 |         @torch.jit.script | 
 |         def foo(x: Any): | 
 |             assert isinstance(x, Dict[str, torch.Tensor]) | 
 |  | 
 |         foo({"1": torch.tensor(3)}) | 
 |         with self.assertRaises(Exception): | 
 |             foo(2) | 
 |  | 
 |     def test_isinstance(self): | 
 |         # test isinstance operator for static type checking | 
 |         template = dedent(''' | 
 |         def func(x): | 
 |             # type: ({type_hint}) -> bool | 
 |             return isinstance(x, {typ}) | 
 |         ''') | 
 |  | 
 |         def test(inp, typ, type_hint): | 
 |             code = template.format(typ=typ, type_hint=type_hint) | 
 |             scope = {} | 
 |             execWrapper(code, globals(), scope) | 
 |             cu = torch.jit.CompilationUnit(code) | 
 |             self.assertEqual( | 
 |                 cu.func(inp), | 
 |                 scope['func'](inp), | 
 |                 msg="Failed with typ: {}" | 
 |                 .format(typ) | 
 |             ) | 
 |  | 
 |         inputs = [True, 1, 1.0, torch.tensor(1), [1, 2], (1.0,), [1, 2], 1] | 
 |         type_literals = ['bool', 'int', 'float', 'torch.Tensor', 'list', 'tuple', | 
 |                          '(list, tuple)', '(int, float, bool)'] | 
 |         type_annotations = ['bool', 'int', 'float', 'Tensor', 'List[int]', 'Tuple[float]', | 
 |                             'List[int]', 'int'] | 
 |  | 
 |         # do zipping to try different types | 
 |         for inp, typ, type_hint in zip(inputs, type_literals, type_annotations): | 
 |             test(inp, typ, type_hint) | 
 |  | 
 |         # test optional isinstance check | 
 |         @torch.jit.script | 
 |         def opt_func(x): | 
 |             # type: (Optional[int]) -> bool | 
 |             return isinstance(x, int) | 
 |         self.assertTrue(opt_func(3)) | 
 |         self.assertFalse(opt_func(None)) | 
 |  | 
 |     def test_dropout_eval(self): | 
 |         class ScriptedConv2d(torch.jit.ScriptModule): | 
 |             def __init__(self, in_channels, out_channels, **kwargs): | 
 |                 super(ScriptedConv2d, self).__init__() | 
 |                 self.conv = nn.Conv2d(in_channels, out_channels, bias=False, **kwargs) | 
 |                 self.bn = nn.BatchNorm2d(out_channels, eps=0.001) | 
 |  | 
 |             @torch.jit.script_method | 
 |             def forward(self, x): | 
 |                 x = self.conv(x) | 
 |                 x = self.bn(x) | 
 |                 return F.relu(x, inplace=True) | 
 |  | 
 |         class ScriptMod(torch.jit.ScriptModule): | 
 |             def __init__(self): | 
 |                 super(ScriptMod, self).__init__() | 
 |                 self.Conv2d_1a_3x3 = ScriptedConv2d(3, 32, kernel_size=3, stride=2) | 
 |  | 
 |             @torch.jit.script_method | 
 |             def forward(self, x): | 
 |                 x = self.Conv2d_1a_3x3(x) | 
 |                 return F.dropout(x, training=self.training) | 
 |  | 
 |         class EagerConv2d(torch.nn.Module): | 
 |             def __init__(self, in_channels, out_channels, **kwargs): | 
 |                 super(EagerConv2d, self).__init__() | 
 |                 self.conv = nn.Conv2d(in_channels, out_channels, bias=False, **kwargs) | 
 |                 self.bn = nn.BatchNorm2d(out_channels, eps=0.001) | 
 |  | 
 |             def forward(self, x): | 
 |                 x = self.conv(x) | 
 |                 x = self.bn(x) | 
 |                 return F.relu(x, inplace=True) | 
 |  | 
 |         class EagerMod(torch.nn.Module): | 
 |             def __init__(self): | 
 |                 super(EagerMod, self).__init__() | 
 |                 self.Conv2d_1a_3x3 = EagerConv2d(3, 32, kernel_size=3, stride=2) | 
 |  | 
 |             def forward(self, x): | 
 |                 x = self.Conv2d_1a_3x3(x) | 
 |                 return F.dropout(x, training=self.training) | 
 |  | 
 |         script_input = torch.rand(4, 3, 299, 299) | 
 |         eager_input = script_input.clone() | 
 |  | 
 |         with freeze_rng_state(): | 
 |             script_mod = ScriptMod() | 
 |             script_mod.eval() | 
 |             script_output = script_mod(script_input) | 
 |  | 
 |         with freeze_rng_state(): | 
 |             eager_mod = EagerMod() | 
 |             eager_mod.eval() | 
 |             eager_output = eager_mod(eager_input) | 
 |  | 
 |         self.assertEqual(script_output, eager_output) | 
 |  | 
 |         with freeze_rng_state(): | 
 |             script_mod = ScriptMod() | 
 |             script_mod.train() | 
 |             script_output = script_mod(script_input) | 
 |  | 
 |         with freeze_rng_state(): | 
 |             eager_mod = EagerMod() | 
 |             eager_mod.train() | 
 |             eager_output = eager_mod(eager_input) | 
 |  | 
 |         self.assertEqual(script_output, eager_output) | 
 |  | 
 |     def test_nested_breaks(self): | 
 |         def no_bool_loop_outputs(g): | 
 |             # testing that the "did exit" transform values are not loop block | 
 |             # outputs (and thus not affecting one loop from another) | 
 |             loops = g.findAllNodes("prim::Loop") | 
 |             for loop in loops: | 
 |                 for out in loop.outputs(): | 
 |                     self.assertTrue(out.type() != BoolType.get()) | 
 |  | 
 |         def test(y): | 
 |             # type: (int) | 
 |             ret = 0 | 
 |             tensor = torch.tensor(0) | 
 |             while int(tensor.add_(1)) < 4: | 
 |                 if y == 1: | 
 |                     continue | 
 |                 for i in range(y): | 
 |                     continue | 
 |                     ret += 1 | 
 |                 ret += 1 | 
 |             return ret, int(tensor) | 
 |  | 
 |         self.assertEqual(torch.jit.script(test)(1), test(1)) | 
 |         self.assertEqual(torch.jit.script(test)(2), test(2)) | 
 |         no_bool_loop_outputs(torch.jit.script(test).graph) | 
 |  | 
 |         def foo(): | 
 |             y = torch.tensor(0) | 
 |             z = 0 | 
 |             while int(y.add_(1)) < 20: | 
 |                 if int(y) < 10: | 
 |                     for i in range(6): | 
 |                         if i == 3: | 
 |                             continue | 
 |                         else: | 
 |                             if i > 3: | 
 |                                 break | 
 |                         z += 2 | 
 |                 if int(y) == 18: | 
 |                     break | 
 |                 if int(y) == 15: | 
 |                     continue | 
 |                 z += 1 | 
 |             return int(y), z | 
 |  | 
 |         no_bool_loop_outputs(torch.jit.script(foo).graph) | 
 |         self.checkScript(foo, ()) | 
 |  | 
 |         def test_nested_two(): | 
 |             i = 0 | 
 |             k = 0 | 
 |             while i < 5: | 
 |                 for j in range(5): | 
 |                     k += 1 | 
 |                     if j == 3: | 
 |                         continue | 
 |                 i += 1 | 
 |                 k += 1 | 
 |                 if i == 4: | 
 |                     break | 
 |             return i, k | 
 |  | 
 |         self.checkScript(test_nested_two, ()) | 
 |         no_bool_loop_outputs(torch.jit.script(test_nested_two).graph) | 
 |  | 
 |     def test_breaks_continues(self): | 
 |         def foo_continue(cond): | 
 |             # type: (int) | 
 |             j = 1 | 
 |             for i in range(5): | 
 |                 if i == cond: | 
 |                     continue | 
 |                 j += 1 | 
 |             return j | 
 |  | 
 |         def foo_break(cond): | 
 |             # type: (int) | 
 |             j = 1 | 
 |             for i in range(5): | 
 |                 if i == cond: | 
 |                     break | 
 |                 j += 1 | 
 |             return j | 
 |  | 
 |         for i in range(1, 4): | 
 |             self.checkScript(foo_continue, (i,)) | 
 |             self.checkScript(foo_break, (i,)) | 
 |  | 
 |         def test_refine_outside_loop(): | 
 |             if 1 == 1: | 
 |                 x = None | 
 |             else: | 
 |                 x = 1 | 
 |             i = 0 | 
 |             j = 0 | 
 |             while (x is None or torch.jit._unwrap_optional(x) > 3): | 
 |                 if i < 3: | 
 |                     if i < 3: | 
 |                         x = torch.jit.annotate(Optional[int], None) | 
 |                         i += 1 | 
 |                         continue | 
 |                     x = 1 | 
 |                 else: | 
 |                     x = 1 if x is None else x | 
 |                 x = x + 1 | 
 |                 j = x + x | 
 |  | 
 |             return x, j | 
 |  | 
 |         self.checkScript(test_refine_outside_loop, ()) | 
 |  | 
 |         def assign_after_break(y): | 
 |             # type: (int) | 
 |             x = 0 | 
 |             for i in range(y): | 
 |                 x = y * 2 + i | 
 |                 break | 
 |                 x = 4 | 
 |             return x | 
 |  | 
 |         self.checkScript(assign_after_break, (1,)) | 
 |         self.checkScript(assign_after_break, (2,)) | 
 |         self.checkScript(assign_after_break, (3,)) | 
 |  | 
 |         def assign_after_break_nested(y): | 
 |             # type: (int) | 
 |             x = 0 | 
 |             for i in range(y): | 
 |                 if y == 1: | 
 |                     x = 5 | 
 |                     break | 
 |                     assert 1 == 2 | 
 |                 else: | 
 |                     x = x + 1 | 
 |                     break | 
 |                     assert 1 == 2 | 
 |                 x = -30 | 
 |                 assert 1 == 2 | 
 |             return x | 
 |  | 
 |         self.checkScript(assign_after_break_nested, (1,)) | 
 |         self.checkScript(assign_after_break_nested, (2,)) | 
 |         self.checkScript(assign_after_break_nested, (3,)) | 
 |  | 
 |         def may_break(y): | 
 |             # type: (int) | 
 |             x = 0 | 
 |             for i in range(y): | 
 |                 if y == 1: | 
 |                     x = 5 | 
 |                 else: | 
 |                     x = x + 1 | 
 |                     break | 
 |                 x = -30 | 
 |             return x | 
 |  | 
 |         self.checkScript(may_break, (1,)) | 
 |         self.checkScript(may_break, (2,)) | 
 |         self.checkScript(may_break, (3,)) | 
 |  | 
 |         def test(x, y): | 
 |             # type: (int, int) | 
 |             a = 1 | 
 |             while (x > 0): | 
 |                 if y == 3: | 
 |                     for i in range(y): | 
 |                         a += (1 % (i + 1)) | 
 |                         x -= 1 | 
 |                 if x == 3: | 
 |                     a = x * 3 | 
 |                     break | 
 |                 if x < 3: | 
 |                     if x == 1: | 
 |                         a -= 2 | 
 |                         x -= 1 | 
 |                         break | 
 |                 a -= 1 | 
 |                 x -= 3 | 
 |             return a, x | 
 |  | 
 |         self.checkScript(test, (10, 3)) | 
 |         self.checkScript(test, (10, 2)) | 
 |         self.checkScript(test, (3, 2)) | 
 |         self.checkScript(test, (5, 3)) | 
 |         self.checkScript(test, (2, 3)) | 
 |  | 
 |         def test_delete_after_break(x): | 
 |             # type: (int) | 
 |             a = 1 | 
 |             b = 1 | 
 |             for i in range(x): | 
 |                 a = i * 3 | 
 |                 break | 
 |                 b = i * 5 | 
 |             return a, b | 
 |  | 
 |         self.checkScript(test_delete_after_break, (0,)) | 
 |         self.checkScript(test_delete_after_break, (1,)) | 
 |  | 
 |         def test_will_break_after_guard(x): | 
 |             # type: (int) | 
 |             a = 1 | 
 |             for i in range(x): | 
 |                 if i == 4: | 
 |                     a = 3 | 
 |                     break | 
 |                 a -= 1 | 
 |                 break | 
 |                 assert 1 == 2 | 
 |                 a -= -100 | 
 |             return a | 
 |  | 
 |         self.checkScript(test_will_break_after_guard, (0,)) | 
 |         self.checkScript(test_will_break_after_guard, (2,)) | 
 |         self.checkScript(test_will_break_after_guard, (4,)) | 
 |  | 
 |         def test_varexit(cond): | 
 |             # type: (int) | 
 |             m = 0 | 
 |             for i in range(3): | 
 |                 if cond == 2: | 
 |                     if cond == 2: | 
 |                         m = 2 | 
 |                         break | 
 |                     k = 1 | 
 |                 else: | 
 |                     k = 2 | 
 |                 m += k | 
 |             return m | 
 |  | 
 |         # use of k tests the pathway where we have to insert unitialized | 
 |         self.checkScript(test_varexit, (3,)) | 
 |         self.checkScript(test_varexit, (2,)) | 
 |  | 
 |         def test_break_true(): | 
 |             i = 0 | 
 |             while True: | 
 |                 i += 1 | 
 |                 if i == 3: | 
 |                     break | 
 |             while False: | 
 |                 i += 1 | 
 |             return i | 
 |  | 
 |         self.checkScript(test_break_true, ()) | 
 |  | 
 |     def test_break_continue_error(self): | 
 |         with self.assertRaisesRegex(RuntimeError, "Syntax"): | 
 |             cu = torch.jit.CompilationUnit(''' | 
 |             def other_func(a): | 
 |                 break | 
 |                 ''') | 
 |  | 
 |         with self.assertRaisesRegex(RuntimeError, "Syntax"): | 
 |             cu = torch.jit.CompilationUnit(''' | 
 |             def other_func(a): | 
 |                 for i in range(5): | 
 |                     def foo(): | 
 |                         break | 
 |                 ''') | 
 |  | 
 |         with self.assertRaisesRegex(RuntimeError, "do not support break or continue inside"): | 
 |             @torch.jit.script | 
 |             def foo(x): | 
 |                 i = 0 | 
 |                 for a in (1, "2", 1.5): | 
 |                     b = a | 
 |                     if x: | 
 |                         break | 
 |                 return b | 
 |  | 
 |     def test_python_call(self): | 
 |         def pyfunc(a): | 
 |             return a * 3.0 | 
 |  | 
 |         cu = torch.jit.CompilationUnit(''' | 
 |         def other_func(a): | 
 |             return a + a | 
 |  | 
 |         def test_call_python(a): | 
 |             b = pyfunc(a) | 
 |             b = other_func(b) | 
 |             i = 0 | 
 |             step = 1 | 
 |             while i < 10: | 
 |                 b = pyfunc(b) | 
 |                 if bool(b > 3.0): | 
 |                     b = pyfunc(b) | 
 |                 i = 11 | 
 |             return b | 
 |         ''') | 
 |         inputs = self._make_scalar_vars([1], torch.float) | 
 |         outputs = self._make_scalar_vars([54], torch.float) | 
 |  | 
 |         self.assertEqual(cu.test_call_python(*inputs), outputs[0]) | 
 |  | 
 |     def test_python_call_failure(self): | 
 |         with self.assertRaisesRegex(RuntimeError, "undefined value pyfunc2"): | 
 |             def pyfunc(a): | 
 |                 return a * 3.0 | 
 |  | 
 |             cu = torch.jit.CompilationUnit(''' | 
 |             def other_func(a): | 
 |                 return a + a | 
 |  | 
 |             def test_call_python(a): | 
 |                 b = pyfunc(a) | 
 |                 b = other_func(b) | 
 |                 i = 0 | 
 |                 step = 1 | 
 |                 while i < 10: | 
 |                     b = pyfunc2(b) | 
 |                     if b > 3.0: | 
 |                         b = pyfunc(b) | 
 |                     i = 11 | 
 |                 return b | 
 |             ''') | 
 |             inputs = self._make_scalar_vars([1], torch.float) | 
 |             outputs = self._make_scalar_vars([54], torch.float) | 
 |  | 
 |             self.assertEqual(cu.test_call_python(*inputs), outputs) | 
 |  | 
 |     def test_type_call_in_script(self): | 
 |         @torch.jit.script | 
 |         def fn(x): | 
 |             return type(x) | 
 |  | 
 |         with self.assertRaisesRegex(RuntimeError, "value of type _TensorMeta"): | 
 |             fn(torch.tensor(.5)) | 
 |  | 
 |     def test_python_call_annotation(self): | 
 |         def pyfunc(a): | 
 |             return a * 3.0 | 
 |  | 
 |         @torch.jit.script | 
 |         def foo(a): | 
 |             return pyfunc(a) + pyfunc(a) | 
 |  | 
 |         inputs = self._make_scalar_vars([1], torch.float) | 
 |         outputs = self._make_scalar_vars([6], torch.float) | 
 |         self.assertEqual(foo(*inputs), outputs[0]) | 
 |  | 
 |     def test_python_call_annoytation_failure(self): | 
 |         with self.assertRaisesRegex(RuntimeError, "undefined value pyfunc2"): | 
 |             def pyfunc(a): | 
 |                 return a * 3.0 | 
 |  | 
 |             @torch.jit.script | 
 |             def foo(a): | 
 |                 return pyfunc2(a) + pyfunc(a) | 
 |  | 
 |             inputs = self._make_scalar_vars([1], torch.float) | 
 |             outputs = self._make_scalar_vars([6], torch.float) | 
 |  | 
 |             self.assertEqual(foo(*inputs), outputs[0]) | 
 |  | 
 |     def test_desugar_module(self): | 
 |         import torch.nn.functional as F | 
 |  | 
 |         def fn(x, slope): | 
 |             a = torch.abs(x) | 
 |             b = torch.nn.functional.prelu(x, slope) | 
 |             c = F.prelu(x, slope) | 
 |             return a, b, c | 
 |  | 
 |         x = torch.arange(-3., 4) | 
 |         slope = torch.tensor([0.5]) | 
 |         self.checkScript(fn, [x, slope], optimize=True) | 
 |  | 
 |     def test_script_docstring(self): | 
 |         @torch.jit.script | 
 |         def with_docstring(x): | 
 |             """test str""" | 
 |             y = x | 
 |             """y is the same as x""" | 
 |             return y | 
 |         self.assertEqual(with_docstring.__doc__, 'test str') | 
 |  | 
 |     def test_script_method_docstring(self): | 
 |         class A(torch.jit.ScriptModule): | 
 |             @torch.jit.script_method | 
 |             def with_docstring(self, x): | 
 |                 """test str""" | 
 |                 y = x | 
 |                 """y is the same as x""" | 
 |                 return y | 
 |         a = A() | 
 |         self.assertEqual(a.with_docstring.__doc__, 'test str') | 
 |  | 
 |     def test_script_module(self): | 
 |         class M1(torch.jit.ScriptModule): | 
 |             def __init__(self): | 
 |                 super(M1, self).__init__() | 
 |                 self.weight = nn.Parameter(torch.randn(2)) | 
 |  | 
 |             @torch.jit.script_method | 
 |             def forward(self, thing): | 
 |                 return self.weight + thing | 
 |  | 
 |         class PModule(nn.Module): | 
 |             def __init__(self): | 
 |                 super(PModule, self).__init__() | 
 |                 self.a = nn.Parameter(torch.randn(2, 3)) | 
 |  | 
 |             def forward(self, a): | 
 |                 return self.a.mm(a) | 
 |  | 
 |         class M2(torch.jit.ScriptModule): | 
 |             def __init__(self): | 
 |                 super(M2, self).__init__() | 
 |                 # test submodule | 
 |                 self.sub = M1() | 
 |                 self.sub2 = PModule() | 
 |                 # test parameters | 
 |                 self.weight = nn.Parameter(torch.randn(2, 3)) | 
 |                 self.bias = nn.Parameter(torch.randn(2)) | 
 |                 # test defining a method from a string | 
 |                 self.define(""" | 
 |                     def hi(self, a): | 
 |                         return self.weight.mm(a) | 
 |                 """) | 
 |             # test script methods | 
 |  | 
 |             @torch.jit.script_method | 
 |             def doit(self, input): | 
 |                 # test use of parameter | 
 |                 return self.weight.mm(input) | 
 |  | 
 |             @torch.jit.script_method | 
 |             def doit2(self, input): | 
 |                 return self.weight.mm(input) | 
 |  | 
 |             @torch.jit.script_method | 
 |             def forward(self, input): | 
 |                 a = self.doit(input) | 
 |                 b = self.doit2(input) | 
 |                 c = self.hi(input) | 
 |                 d = self.sub2(input) | 
 |                 return a + b + self.bias + self.sub(a) + c + d | 
 |         with torch.jit.optimized_execution(False): | 
 |             m2 = M2() | 
 |             input = torch.randn(3, 2) | 
 |             a = m2.weight.mm(input) | 
 |             b = m2.weight.mm(input) | 
 |             c = m2.weight.mm(input) | 
 |             d = m2.sub2.a.mm(input) | 
 |             ref = a + b + m2.bias + m2.sub.weight + a + c + d | 
 |             self.assertEqual(ref, m2.forward(input)) | 
 |             m2.weight = nn.Parameter(torch.zeros_like(m2.weight)) | 
 |             m2.bias = nn.Parameter(torch.zeros_like(m2.bias)) | 
 |             m2.sub.weight = nn.Parameter(torch.zeros_like(m2.sub.weight)) | 
 |             m2.sub2.a.data.zero_() | 
 |             self.assertEqual(torch.zeros(2, 2), m2.forward(torch.randn(3, 2))) | 
 |  | 
 |     def test_irparser(self): | 
 |         graph_str = """graph(%0 : Double(5, 5)): | 
 |           # CHECK: aten::relu | 
 |           %1 : Double(5, 5) = aten::relu(%0) | 
 |           return (%1) | 
 |         """ | 
 |         FileCheck().run(graph_str, parse_ir(graph_str)) | 
 |  | 
 |     def test_parse_tensor_constants(self): | 
 |         def foo(): | 
 |             return torch.zeros([4, 4]) | 
 |  | 
 |         foo_s = torch.jit.script(foo) | 
 |         torch._C._jit_pass_constant_propagation(foo_s.graph) | 
 |  | 
 |         g = str(foo_s.graph) | 
 |         g_parsed = parse_ir(g, parse_tensor_constants=True) | 
 |         self.assertEqual(str(canonical(g_parsed)), str(canonical(foo_s.graph))) | 
 |         func = torch._C._create_function_from_graph("forward", g_parsed) | 
 |  | 
 |         out_parsed = func() | 
 |         out_func = foo() | 
 |         # not checking data, just dtype, size etc | 
 |         out_parsed[:] = 0 | 
 |         out_func[:] = 0 | 
 |         self.assertEqual(out_func, out_parsed) | 
 |  | 
 |         with self.assertRaises(RuntimeError): | 
 |             parse_ir(g, parse_tensor_constants=False) | 
 |  | 
 |     def test_parse_nested_names(self): | 
 |         g_str = """ | 
 |     graph(%x.1 : Tensor): | 
 |         %3 : int = prim::Constant[value=1]() | 
 |         %2 : int = prim::Constant[value=2]() | 
 |         %hi.submod.value.5 : Tensor = aten::add(%x.1, %2, %3) | 
 |         return (%hi.submod.value.5) | 
 |         """ | 
 |         g = parse_ir(g_str) | 
 |         round_trip_g = parse_ir(str(g)) | 
 |         self.assertEqual(canonical(g), canonical(round_trip_g)) | 
 |  | 
 |         func1 = torch._C._create_function_from_graph("forward", g) | 
 |         func2 = torch._C._create_function_from_graph("forward", round_trip_g) | 
 |         self.assertEqual(func1(torch.ones([2])), func2(torch.ones([2]))) | 
 |  | 
 |     def test_is_after_use(self): | 
 |         def sorted_input_use(g): | 
 |             uses = list(next(g.inputs()).uses()) | 
 |             return sorted(uses, key=functools.cmp_to_key(type(uses[0]).isAfter)) | 
 |  | 
 |         @torch.jit.script | 
 |         def foo(x): | 
 |             a = x + 1 | 
 |             return (x, x, a) | 
 |  | 
 |         uses_sorted = sorted_input_use(foo.graph) | 
 |         # sorts last use to the end | 
 |         self.assertFalse(uses_sorted[0].isAfter(uses_sorted[1])) | 
 |         self.assertTrue(uses_sorted[0].user.kind() == "aten::add") | 
 |         self.assertEqual(uses_sorted[1].offset, 0) | 
 |  | 
 |         @torch.jit.script | 
 |         def foo(x, cond: bool): | 
 |             if cond: | 
 |                 return x + 3 | 
 |             else: | 
 |                 return x - 3 | 
 |  | 
 |         uses_sorted = sorted_input_use(foo.graph) | 
 |         self.assertTrue(uses_sorted[0].user.kind() == "aten::add") | 
 |         self.assertTrue(uses_sorted[1].user.kind() == "aten::sub") | 
 |  | 
 |         @torch.jit.script | 
 |         def foo(x, cond: bool, cond2: bool): | 
 |             if cond: | 
 |                 return x + 3 | 
 |             elif cond2 : | 
 |                 return x - 3 | 
 |  | 
 |             return x / 3 | 
 |  | 
 |         graph1 = foo.graph | 
 |  | 
 |         @torch.jit.script | 
 |         def foo(x, cond: bool, cond2: bool): | 
 |             if cond: | 
 |                 return x + 3 | 
 |             else: | 
 |                 if cond2 : | 
 |                     return x - 3 | 
 |                 return x / 3 | 
 |  | 
 |         graph2 = foo.graph | 
 |  | 
 |         for graph in [graph1, graph2]: | 
 |             uses_sorted = sorted_input_use(graph) | 
 |             self.assertTrue(uses_sorted[0].user.kind() == "aten::add") | 
 |             self.assertTrue(uses_sorted[1].user.kind() == "aten::sub") | 
 |             self.assertTrue(uses_sorted[2].user.kind() == "aten::div") | 
 |  | 
 |     def test_canonicalize_control_outputs(self): | 
 |         def test_all_outputs(g): | 
 |             ifs = g.findAllNodes("prim::If") | 
 |             loops = g.findAllNodes("prim::Loop") | 
 |  | 
 |             def contained_blocks(node): | 
 |                 return len(node.findAllNodes("prim::If")) * 2 + len(node.findAllNodes("prim::Loop")) | 
 |             for node in ifs + loops: | 
 |                 outs = list(node.outputs()) | 
 |                 out_name = [x.debugName() for x in outs] | 
 |                 if len(out_name) == 0: | 
 |                     continue | 
 |                 fc = FileCheck() | 
 |                 # find the last output, then all subsequent uses | 
 |                 fc.check(out_name[-1] + " : ") | 
 |                 # skip past node body | 
 |                 for i in range(contained_blocks(node)): | 
 |                     fc.check("->") | 
 |                 if (node.kind() == "prim::If"): | 
 |                     fc.check("->").check("->").check("\n") | 
 |                 else: | 
 |                     fc.check("->").check("\n") | 
 |                 # the canonical order is the same order as the first use | 
 |                 # appears in text | 
 |                 for name in out_name: | 
 |                     fc.check(name) | 
 |                 fc.run(g) | 
 |  | 
 |         @torch.jit.script | 
 |         def test(x): | 
 |             # type: (bool) -> Tuple[int, int] | 
 |             b = 2 | 
 |             a = 1 | 
 |             if x: | 
 |                 a = 1 | 
 |                 b = 2 | 
 |                 x = False | 
 |             if x: | 
 |                 b = a | 
 |             else: | 
 |                 a = b | 
 |  | 
 |             return a, b | 
 |         test_all_outputs(test.graph) | 
 |  | 
 |         @torch.jit.script | 
 |         def test2(x): | 
 |             # type: (bool) -> Tuple[int, int] | 
 |             b = 2 | 
 |             a = 1 | 
 |             if x: | 
 |                 a = 1 | 
 |                 b = 2 | 
 |                 x = False | 
 |             if x: | 
 |                 print(a) | 
 |             else: | 
 |                 if x: | 
 |                     print(b) | 
 |  | 
 |             return a, b | 
 |         test_all_outputs(test2.graph) | 
 |  | 
 |         @torch.jit.script | 
 |         def test_loop(x, iter): | 
 |             # type: (bool, int) -> (None) | 
 |             a = 1 | 
 |             b = 2 | 
 |             c = 3 | 
 |             for i in range(iter): | 
 |                 a = 4 | 
 |                 b = 5 | 
 |                 c = 6 | 
 |                 x = True | 
 |             print(c) | 
 |             if x: | 
 |                 print(a, b) | 
 |         test_all_outputs(test_loop.graph) | 
 |  | 
 |         @torch.jit.script | 
 |         def loop_unused(iter): | 
 |             # type: (int) -> (None) | 
 |             a = 1 | 
 |             b = 2 | 
 |             c = 3 | 
 |             for i in range(iter): | 
 |                 c = c + 1 | 
 |                 b = b + 1 | 
 |                 a = a + 1 | 
 |                 print(a, b) | 
 |             print(c) | 
 |  | 
 |         # c is used, then unused should be ordered by alphabetical | 
 |         FileCheck().check(r"%c : int, %a : int, %b : int").run(loop_unused.graph) | 
 |  | 
 |     def test_filecheck(self): | 
 |         def test_check(): | 
 |             file = "232" | 
 |             FileCheck().check("2").check("3").check("2").run(file) | 
 |             FileCheck().check("232").run(file) | 
 |  | 
 |             with self.assertRaisesRegex(RuntimeError, 'Expected to find "22"'): | 
 |                 FileCheck().check("22").run(file) | 
 |             with self.assertRaisesRegex(RuntimeError, "CHECK: 3"): | 
 |                 FileCheck().check("3").check("3").run(file) | 
 |  | 
 |         test_check() | 
 |  | 
 |         def test_check_count(): | 
 |             file = "22222" | 
 |             FileCheck().check_count("2", 5).run(file) | 
 |             FileCheck().check_count("22", 2).run(file) | 
 |             FileCheck().check_count("222", 1).run(file) | 
 |  | 
 |             with self.assertRaisesRegex(RuntimeError, 'Expected to not find'): | 
 |                 FileCheck().check_count("2", 4, exactly=True).run(file) | 
 |  | 
 |             with self.assertRaisesRegex(RuntimeError, 'Expected to find "22"'): | 
 |                 FileCheck().check_count("22", 3).run(file) | 
 |  | 
 |             with self.assertRaisesRegex(RuntimeError, "CHECK-COUNT-6: 2"): | 
 |                 FileCheck().check_count("2", 6).run(file) | 
 |  | 
 |         test_check_count() | 
 |  | 
 |         def test_check_same(): | 
 |             file = "22\n33" | 
 |             FileCheck().check_same("22").run(file) | 
 |  | 
 |             with self.assertRaisesRegex(RuntimeError, "Expected to not find"): | 
 |                 FileCheck().check_same("33").run(file) | 
 |  | 
 |             file = "22  1  3" | 
 |  | 
 |             FileCheck().check("2").check_same("3").run(file) | 
 |             FileCheck().check_count("2", 2).check_same("3").run(file) | 
 |  | 
 |         test_check_same() | 
 |  | 
 |         def test_check_next(): | 
 |             file = "\n1\n2\n3" | 
 |             FileCheck().check("1").check_next("2").check_next("3").run(file) | 
 |             FileCheck().check_next("1").check_next("2").check_next("3").run(file) | 
 |  | 
 |             with self.assertRaisesRegex(RuntimeError, "Expected to find"): | 
 |                 FileCheck().check("1").check_next("2").run("12") | 
 |  | 
 |             with self.assertRaisesRegex(RuntimeError, "Expected to not find"): | 
 |                 FileCheck().check("1").check_next("2").run("1\n\n2") | 
 |  | 
 |         test_check_next() | 
 |  | 
 |         def test_check_dag(): | 
 |             fc = FileCheck().check_dag("1").check_dag("2").check_not("2") | 
 |             fc.run("12") | 
 |             fc.run("21") | 
 |  | 
 |             fc = FileCheck() | 
 |             fc.check_not("3").check_dag("1").check_dag("2").check_not("3") | 
 |             fc.run("1 3 2") | 
 |             fc.run("2 3 1") | 
 |  | 
 |             fc = FileCheck().check_dag("1").check_dag("2").check("3") | 
 |             with self.assertRaisesRegex(RuntimeError, 'Expected to find "3" but did not find it'): | 
 |                 fc.run("1 3 2") | 
 |  | 
 |         test_check_dag() | 
 |  | 
 |         def test_check_not(): | 
 |             FileCheck().check_not("2").check("1").run("12") | 
 |             FileCheck().check("2").check_not("2").run("12") | 
 |  | 
 |             with self.assertRaisesRegex(RuntimeError, 'Expected to not find "2"'): | 
 |                 FileCheck().check_not("2").check("1").run("21") | 
 |  | 
 |             with self.assertRaisesRegex(RuntimeError, 'Expected to not find "1"'): | 
 |                 FileCheck().check("2").check_not("1").run("21") | 
 |  | 
 |             # checks with distinct range matchings | 
 |             fb = FileCheck().check_count("2", 2).check_count("2", 2).check_not("2") | 
 |             with self.assertRaisesRegex(RuntimeError, 'Expected to not find "2"'): | 
 |                 fb.run("22 2 22") | 
 |  | 
 |             fb = FileCheck().check_count("2", 2).check_not("1").check_count("2", 2) | 
 |             with self.assertRaisesRegex(RuntimeError, 'Expected to not find "1"'): | 
 |                 fb.run("22 1 22") | 
 |  | 
 |     def _dtype_to_jit_name(self, dtype): | 
 |         if(dtype == torch.float32): | 
 |             return "Float" | 
 |         if(dtype == torch.float64): | 
 |             return "Double" | 
 |         if(dtype == torch.int64): | 
 |             return "Long" | 
 |         if(dtype == torch.int32): | 
 |             return "Int" | 
 |         if(dtype == torch.bool): | 
 |             return "Bool" | 
 |         raise RuntimeError('dtype not handled') | 
 |  | 
 |     def _dtype_to_expect(self, dtype, dim=0): | 
 |         param = ', '.join(['*'] * dim + ['device=cpu']) | 
 |         param = '(' + param + ')' | 
 |         jit_type = self._dtype_to_jit_name(dtype) | 
 |         if dim >= 0: | 
 |             return jit_type + param | 
 |         # special case representing wrapped number | 
 |         else: | 
 |             return jit_type.lower() | 
 |  | 
 |  | 
 |     def _test_dtype_op_shape(self, ops, args, input_dims=1): | 
 |         if input_dims < 1: | 
 |             raise RuntimeError("input dims must be at least 1") | 
 |         dtypes = [torch.float32, torch.float64, torch.int64, torch.int32] | 
 |         str_args = ', '.join([str(arg) for arg in args]) + (', ' if len(args) else '') | 
 |         tensor_data = ('[' * input_dims) + '1, 2, 3' + (input_dims * ']') | 
 |         template = dedent(''' | 
 |         def func(): | 
 |             return {return_line} | 
 |         ''') | 
 |  | 
 |         for op in ops: | 
 |             for dtype in (dtypes + [None]): | 
 |                 for tensor_type in dtypes: | 
 |                     # a couple of ops aren't implemented for non-floating types | 
 |                     if(not tensor_type.is_floating_point or (dtype is not None and not dtype.is_floating_point)): | 
 |                         if op in ['mean', 'softmax', 'log_softmax']: | 
 |                             continue | 
 |                     return_line = "torch.tensor({}, dtype={}).{}({}dtype={})".format(tensor_data, tensor_type, op, str_args, dtype) | 
 |                     # uncomment for debugging a failed test: | 
 |                     # print("testing {}".format(return_line)) | 
 |                     code = template.format(return_line=return_line) | 
 |                     scope = {} | 
 |                     exec(code, globals(), scope) | 
 |                     cu = torch.jit.CompilationUnit(code) | 
 |                     graph = cu.func.graph | 
 |                     torch._C._jit_pass_complete_shape_analysis(graph, (), False) | 
 |                     input_array = [1, 2, 3] | 
 |                     for _ in range(1, input_dims): | 
 |                         input_array = [input_array] | 
 |                     t = torch.tensor(input_array, dtype=tensor_type) | 
 |                     attr = getattr(t, op) | 
 |                     kwargs = {'dtype': dtype} | 
 |                     result = attr(*args, **kwargs) | 
 |                     expect = self._dtype_to_expect(result.dtype, result.dim()) | 
 |                     FileCheck().check("aten::tensor").check(expect).run(graph) | 
 |  | 
 |     def test_dtype_op_shape(self): | 
 |         ops = ['prod'] | 
 |         self._test_dtype_op_shape(ops, args=[]) | 
 |         self._test_dtype_op_shape(ops, args=[0, False]) | 
 |         self._test_dtype_op_shape(ops, args=[0, False]) | 
 |         self._test_dtype_op_shape(ops, args=[0, True]) | 
 |  | 
 |     def test_dtype_op_shape2(self): | 
 |         ops = ['cumprod', 'cumsum', 'softmax', 'log_softmax'] | 
 |         self._test_dtype_op_shape(ops, args=[0]) | 
 |  | 
 |         self._test_dtype_op_shape(ops, args=[1], input_dims=4) | 
 |  | 
 |  | 
 |     def _test_binary_op_shape(self, ops, input_dims=1): | 
 |  | 
 |         dtypes = [torch.float32, torch.float64, torch.int64, torch.int32, torch.bool] | 
 |  | 
 |         if input_dims == 0: | 
 |             shape = '1' | 
 |         else: | 
 |             shape = '[' + ('1,' * 4) + ']' | 
 |             for _ in range(1, input_dims): | 
 |                 shape = '[' + ",".join([shape] * 4) + ']' | 
 |  | 
 |         template = dedent(''' | 
 |         def func(): | 
 |             arg1 = {} | 
 |             arg2 = {} | 
 |             return torch.{}(arg1, arg2) | 
 |         ''') | 
 |  | 
 |         args = [] | 
 |         for dtype in dtypes: | 
 |             args = args + ["torch.tensor({}, dtype={})".format(shape, dtype)] | 
 |         args = args + [1, 1.5] | 
 |  | 
 |         def isBool(arg): | 
 |             return type(arg) == bool or (type(arg) == str and "torch.bool" in arg) | 
 |  | 
 |         for op in ops: | 
 |             for first_arg in args: | 
 |                 for second_arg in args: | 
 |                     # subtract not supported for bool | 
 |                     if (op == 'sub' or op == 'div') and (isBool(first_arg) or isBool(second_arg)): | 
 |                         continue | 
 |                     # div is not implemented correctly for mixed-type or int params | 
 |                     if (op == 'div' and (type(first_arg) != type(second_arg) or | 
 |                        isinstance(first_arg, int) or | 
 |                        (isinstance(first_arg, str) and 'int' in first_arg))): | 
 |                         continue | 
 |                     return_line = "torch.{}({}, {})".format(op, first_arg, second_arg) | 
 |                     # uncomment for debugging a failed test: | 
 |                     # print("testing {}".format(return_line)) | 
 |                     code = template.format(first_arg, second_arg, op) | 
 |                     scope = {} | 
 |                     exec(code, globals(), scope) | 
 |                     non_jit_result = scope['func']() | 
 |  | 
 |                     cu = torch.jit.CompilationUnit(code) | 
 |                     graph = cu.func.graph | 
 |                     torch._C._jit_pass_complete_shape_analysis(graph, (), False) | 
 |                     # use dim=-1 to represent a python/jit scalar. | 
 |                     dim = -1 if type(first_arg) != str and type(second_arg) != str else non_jit_result.dim() | 
 |                     dtype = non_jit_result.dtype | 
 |                     # jit only supports int/float scalars. | 
 |                     if dim < 0: | 
 |                         if dtype == torch.int64: | 
 |                             dtype = torch.int32 | 
 |                         if dtype == torch.float64: | 
 |                             dtype = torch.float32 | 
 |                     expect = self._dtype_to_expect(dtype, dim) | 
 |                     jit_output = next(graph.outputs()) | 
 |  | 
 |                     check = FileCheck() | 
 |                     check.check(expect).run(str(jit_output)) | 
 |  | 
 |     def test_binary_op_shape(self): | 
 |         self._test_binary_op_shape(['mul', 'div', 'add', 'sub'], 0) | 
 |         self._test_binary_op_shape(['mul', 'div', 'add', 'sub'], 3) | 
 |  | 
 |     def test_no_dtype_shape(self): | 
 |  | 
 |         @torch.jit.script | 
 |         def foo(x): | 
 |             scalar_number = x.item() | 
 |             return x.add(scalar_number) | 
 |  | 
 |         @torch.jit.script | 
 |         def foo2(x): | 
 |             scalar_number = x.item() | 
 |             return torch.tensor(1).add(scalar_number) | 
 |  | 
 |         t = torch.tensor(5) | 
 |         g = foo.graph_for(t) | 
 |         type = next(g.outputs()) | 
 |         self.assertTrue(type.type() == torch._C.TensorType.get()) | 
 |         g2 = foo2.graph_for(t) | 
 |         type = next(g.outputs()) | 
 |         self.assertTrue(type.type() == torch._C.TensorType.get()) | 
 |  | 
 |  | 
 |     def test_filecheck_parse(self): | 
 |         def test_check(): | 
 |             file = """ | 
 |                 # CHECK: 2 | 
 |                 # CHECK: 3 | 
 |                 # CHECK: 2 | 
 |                 232 | 
 |                 """ | 
 |             FileCheck().run(checks_file=file, test_file=file) | 
 |             file = """ | 
 |                 # CHECK: 232 | 
 |                 232 | 
 |                 """ | 
 |             FileCheck().run(file, "232") | 
 |             with self.assertRaisesRegex(RuntimeError, 'Expected to find "232"'): | 
 |                 FileCheck().run(file, "22") | 
 |             with self.assertRaisesRegex(RuntimeError, 'Expected to find "22"'): | 
 |                 FileCheck().run("# CHECK: 22", "23") | 
 |         test_check() | 
 |  | 
 |         def test_check_count(): | 
 |             file = "22222" | 
 |             FileCheck().run("# CHECK-COUNT-5: 2", file) | 
 |             FileCheck().run("# CHECK-COUNT-EXACTLY-5: 2", file) | 
 |             FileCheck().run("# CHECK-COUNT-2: 22", file) | 
 |             FileCheck().run("# CHECK-COUNT-1: 222", file) | 
 |  | 
 |             with self.assertRaisesRegex(RuntimeError, 'Expected to not find'): | 
 |                 FileCheck().run("# CHECK-COUNT-EXACTLY-2: 2", file) | 
 |         test_check_count() | 
 |  | 
 |         def test_check_same(): | 
 |             file = "22\n33" | 
 |             FileCheck().run("# CHECK-SAME: 22", file) | 
 |  | 
 |             with self.assertRaisesRegex(RuntimeError, "Expected to not find"): | 
 |                 FileCheck().run("# CHECK-SAME: 33", file) | 
 |  | 
 |             file = "22  1  3" | 
 |  | 
 |             FileCheck().run("# CHECK: 2\n # CHECK-SAME: 3", file) | 
 |             FileCheck().run("# CHECK-COUNT-2: 2\n # CHECK-SAME: 3", file) | 
 |         test_check_same() | 
 |  | 
 |         def test_bad_input(): | 
 |             with self.assertRaisesRegex(RuntimeError, "Check for bad input"): | 
 |                 FileCheck().run("", "1") | 
 |  | 
 |             with self.assertRaisesRegex(RuntimeError, "Could not parse check"): | 
 |                 FileCheck().run("# CHECK1", "") | 
 |  | 
 |         test_bad_input() | 
 |  | 
 |     def test_script_module_call_noscript(self): | 
 |         class M(torch.jit.ScriptModule): | 
 |             def __init__(self): | 
 |                 super(M, self).__init__() | 
 |                 self.value = 1 | 
 |  | 
 |             @torch.jit.ignore | 
 |             def foo(self): | 
 |                 return torch.ones(2, 2) + self.value | 
 |  | 
 |             @torch.jit.script_method | 
 |             def forward(self, input): | 
 |                 return input + self.foo() | 
 |  | 
 |         with torch.jit.optimized_execution(False): | 
 |             m = M() | 
 |             input = torch.randn(2, 2) | 
 |             o = m(input) | 
 |             self.assertEqual(o, input + torch.ones(2, 2) + 1) | 
 |             # check that we can change python attributes | 
 |             # and that those changes are picked up in script methods | 
 |             m.value = 2 | 
 |             o = m(input) | 
 |             self.assertEqual(o, input + torch.ones(2, 2) + 2) | 
 |  | 
 |     def test_script_module_nochange_submodule(self): | 
 |         class M(torch.jit.ScriptModule): | 
 |             def __init__(self): | 
 |                 super(M, self).__init__() | 
 |                 self.sub = nn.Linear(5, 5) | 
 |  | 
 |             @torch.jit.script_method | 
 |             def forward(self, input): | 
 |                 return self.sub(input) | 
 |         with torch.jit.optimized_execution(False): | 
 |             m = M() | 
 |             input = torch.randn(1, 5, 5) | 
 |             o = m(input) | 
 |             self.assertEqual(o, m.sub(input)) | 
 |             with self.assertRaisesRegex(RuntimeError, "Cannot re-assign"): | 
 |                 m.sub = nn.Linear(5, 5) | 
 |  | 
 |     def test_module_apis(self): | 
 |         class Sub(torch.nn.Module): | 
 |             def __init__(self): | 
 |                 super(Sub, self).__init__() | 
 |  | 
 |             def forward(self, thing): | 
 |                 return thing - 2 | 
 |  | 
 |         class Double(torch.nn.Module): | 
 |             def __init__(self): | 
 |                 super(Double, self).__init__() | 
 |  | 
 |             def forward(self, thing): | 
 |                 return thing * 2 | 
 |  | 
 |         class MyMod(torch.nn.Module): | 
 |             def __init__(self): | 
 |                 super(MyMod, self).__init__() | 
 |                 self.mod = (Sub()) | 
 |                 self.mod2 = (Sub()) | 
 |                 self.mod3 = nn.Sequential(nn.Sequential(Sub())) | 
 |                 self.mod4 = nn.Sequential(Sub(), Double()) | 
 |  | 
 |             @torch.jit.export | 
 |             def method(self, x, x1, y, y1): | 
 |                 mod_names = "" | 
 |                 for name, mod in self.named_modules(): | 
 |                     mod_names = mod_names + " " + name | 
 |                     x = mod(x) | 
 |  | 
 |                 children_names = "" | 
 |                 for name, mod in self.named_children(): | 
 |                     children_names = children_names + " " + name | 
 |                     x1 = mod(x1) | 
 |  | 
 |                 for mod in self.modules(): | 
 |                     y = mod(y) | 
 |  | 
 |                 for mod in self.children(): | 
 |                     y1 = mod(y1) | 
 |  | 
 |                 return mod_names, children_names, x, x1, y, y1 | 
 |  | 
 |             def forward(self, x): | 
 |                 return x + 2 | 
 |  | 
 |         mod = torch.jit.script(MyMod()) | 
 |         inps = tuple([torch.tensor(i) for i in range(1, 5)]) | 
 |         self.assertEqual(mod.method(*inps), MyMod().method(*inps)) | 
 |  | 
 |     def test_script_module_const(self): | 
 |         class M(torch.jit.ScriptModule): | 
 |  | 
 |             __constants__ = ['b', 'i', 'c', 's'] | 
 |  | 
 |             def __init__(self): | 
 |                 super(M, self).__init__() | 
 |                 self.b = False | 
 |                 self.i = 1 | 
 |                 self.c = 3.5 | 
 |                 self.s = ["hello"] | 
 |  | 
 |             @torch.jit.script_method | 
 |             def forward(self): | 
 |                 return self.b, self.i, self.c | 
 |  | 
 |         with torch.jit.optimized_execution(False): | 
 |             m = M() | 
 |             o0, o1, o2 = m() | 
 |         self.assertEqual(o0, 0) | 
 |         self.assertEqual(o1, 1) | 
 |         self.assertEqual(o2, 3.5) | 
 |  | 
 |     def test_script_module_fail_exist(self): | 
 |         class M(torch.jit.ScriptModule): | 
 |             def __init__(self): | 
 |                 super(M, self).__init__() | 
 |  | 
 |             @torch.jit.script_method | 
 |             def forward(self, x): | 
 |                 return x + self.whatisgoingon | 
 |         with self.assertRaisesRegex(RuntimeError, "Module 'M' has no attribute"): | 
 |             M() | 
 |  | 
 |     @unittest.skip("[module dedupe] currently NoneType refinement on optional attributes doesn't work.") | 
 |     def test_script_module_none_exist_fail(self): | 
 |         class M(torch.jit.ScriptModule): | 
 |             def __init__(self, my_optional): | 
 |                 super(M, self).__init__() | 
 |                 self.my_optional = my_optional | 
 |  | 
 |             @torch.jit.script_method | 
 |             def forward(self, x): | 
 |                 if self.my_optional is not None: | 
 |                     return torch.neg(x) + self.my_optional | 
 |                 return torch.neg(x) | 
 |         with self.assertRaisesRegex(RuntimeError, "has no attribute 'my_optional'"): | 
 |             x = torch.rand(3, 4) | 
 |             fb = M(None) | 
 |             fb(x) | 
 |  | 
 |     def test_script_module_invalid_consts(self): | 
 |         class Foo(torch.jit.ScriptModule): | 
 |             __constants__ = ['invalid'] | 
 |  | 
 |             def __init__(self): | 
 |                 super(Foo, self).__init__() | 
 |                 self.invalid = [nn.Linear(3, 4)] | 
 |  | 
 |         with self.assertRaisesRegex( | 
 |                 TypeError, | 
 |                 "Linear' object in attribute 'Foo.invalid' is not a valid constant"): | 
 |             Foo() | 
 |  | 
 |         class Foo2(torch.jit.ScriptModule): | 
 |             __constants__ = ['invalid'] | 
 |  | 
 |             def __init__(self): | 
 |                 super(Foo2, self).__init__() | 
 |                 self.invalid = type(1) | 
 |  | 
 |         with self.assertRaisesRegex(TypeError, "not a valid constant"): | 
 |             Foo2() | 
 |  | 
 |         class Foo3(torch.jit.ScriptModule): | 
 |             __constants__ = ['invalid'] | 
 |  | 
 |             def __init__(self): | 
 |                 super(Foo3, self).__init__() | 
 |                 self.invalid = (3, 4, {}) | 
 |  | 
 |         with self.assertRaisesRegex(TypeError, "not a valid constant"): | 
 |             Foo3() | 
 |  | 
 |         class Foo4(torch.jit.ScriptModule): | 
 |             __constants__ = ['invalid'] | 
 |  | 
 |             def __init__(self): | 
 |                 super(Foo4, self).__init__() | 
 |                 self.invalid = np.int64(5) | 
 |  | 
 |         # verify that we capture human understandable class name | 
 |         with self.assertRaisesRegex(TypeError, "numpy.int64"): | 
 |             Foo4() | 
 |  | 
 |     def test_script_module_param_buffer_mutation(self): | 
 |         # TODO: add param mutation test case after JIT support it | 
 |         class ModuleBufferMutate(torch.jit.ScriptModule): | 
 |             def __init__(self): | 
 |                 super(ModuleBufferMutate, self).__init__() | 
 |                 self.register_buffer('running_var', torch.tensor(0, dtype=torch.long)) | 
 |  | 
 |             @torch.jit.script_method | 
 |             def forward(self): | 
 |                 if self.training: | 
 |                     self.running_var += 1 | 
 |                 return self.running_var | 
 |  | 
 |         with torch.jit.optimized_execution(False): | 
 |             m = ModuleBufferMutate() | 
 |             self.assertEqual(m(), 1) | 
 |             m.eval() | 
 |             self.assertEqual(m(), 1) | 
 |  | 
 |     def test_script_module_for(self): | 
 |         class M(torch.jit.ScriptModule): | 
 |             __constants__ = ['b'] | 
 |  | 
 |             def __init__(self): | 
 |                 super(M, self).__init__() | 
 |                 self.b = [1, 2, 3, 4] | 
 |  | 
 |             @torch.jit.script_method | 
 |             def forward(self): | 
 |                 sum = 0 | 
 |                 for i in self.b: | 
 |                     sum += i | 
 |                 return sum | 
 |  | 
 |         with torch.jit.optimized_execution(False): | 
 |             m = M() | 
 |             self.assertEqual(m(), 10) | 
 |  | 
 |     def test_override_magic(self): | 
 |         class OverrideMagic(nn.Module): | 
 |             def __init__(self): | 
 |                 super(OverrideMagic, self).__init__() | 
 |  | 
 |             @torch.jit.export | 
 |             def __len__(self): | 
 |                 return 10 | 
 |  | 
 |         mod = OverrideMagic() | 
 |         self.assertEqual(len(mod), len(torch.jit.script(mod))) | 
 |  | 
 |         class OverrideMagicSeq(nn.Sequential): | 
 |             def __init__(self): | 
 |                 super(OverrideMagicSeq, self).__init__() | 
 |  | 
 |             @torch.jit.export | 
 |             def __len__(self): | 
 |                 return 10 | 
 |  | 
 |         mod = OverrideMagicSeq() | 
 |         self.assertEqual(len(mod), len(torch.jit.script(mod))) | 
 |         self.assertTrue(torch.jit.script(mod)) | 
 |  | 
 |     def test_script_module_for2(self): | 
 |         class Sub(torch.jit.ScriptModule): | 
 |             def __init__(self): | 
 |                 super(Sub, self).__init__() | 
 |                 self.weight = nn.Parameter(torch.randn(2)) | 
 |  | 
 |             @torch.jit.script_method | 
 |             def forward(self, thing): | 
 |                 return self.weight + thing | 
 |  | 
 |         class M(torch.jit.ScriptModule): | 
 |             def __init__(self): | 
 |                 super(M, self).__init__() | 
 |                 self.mods = nn.ModuleList([Sub() for i in range(10)]) | 
 |  | 
 |             @torch.jit.script_method | 
 |             def forward(self, v): | 
 |                 for m in self.mods: | 
 |                     v = m(v) | 
 |                 return v | 
 |  | 
 |         with torch.jit.optimized_execution(False): | 
 |             i = torch.empty(2) | 
 |             m = M() | 
 |             o = m(i) | 
 |             v = i | 
 |             for sub in m.mods: | 
 |                 v = sub(v) | 
 |             self.assertEqual(o, v) | 
 |             with self.assertRaisesRegex(Exception, "object is not iterable"): | 
 |                 print(list(m)) | 
 |  | 
 |     def test_attr_qscheme_script(self): | 
 |         class Foo(torch.nn.Module): | 
 |             def __init__(self): | 
 |                 super(Foo, self).__init__() | 
 |                 self.qscheme = torch.per_tensor_affine | 
 |  | 
 |             def forward(self): | 
 |                 if self.qscheme == torch.per_tensor_symmetric: | 
 |                     return 3 | 
 |                 else: | 
 |                     return 4 | 
 |  | 
 |         f = Foo() | 
 |         scripted = torch.jit.script(f) | 
 |         self.assertEqual(f(), scripted()) | 
 |  | 
 |     def test_script_module_const_submodule_fail(self): | 
 |         class Sub(torch.jit.ScriptModule): | 
 |             def __init__(self): | 
 |                 super(Sub, self).__init__() | 
 |                 self.weight = nn.Parameter(torch.randn(2)) | 
 |  | 
 |             @torch.jit.script_method | 
 |             def forward(self, thing): | 
 |                 return self.weight + thing | 
 |  | 
 |         class M(torch.jit.ScriptModule): | 
 |             def __init__(self): | 
 |                 super(M, self).__init__() | 
 |                 self.mods = [Sub() for _ in range(10)] | 
 |  | 
 |             @torch.jit.script_method | 
 |             def forward(self): | 
 |                 for _ in self.mods: | 
 |                     print(1) | 
 |                 return 4 | 
 |  | 
 |         with self.assertRaisesRegex(RuntimeError, "has no attribute 'mods'"): | 
 |             M() | 
 |  | 
 |     class DerivedStateModule(torch.jit.ScriptModule): | 
 |         def __init__(self): | 
 |             super(TestScript.DerivedStateModule, self).__init__() | 
 |             self.param = torch.nn.Parameter(torch.ones(3, 4, dtype=torch.float)) | 
 |             self.register_buffer('derived', torch.neg(self.param).detach().clone()) | 
 |  | 
 |             # This is a flag so we can test that the pack method was called | 
 |             self.register_buffer('pack_called', torch.zeros(1, dtype=torch.long)) | 
 |             # This is a flag so we can test that the unpack method was called | 
 |             self.register_buffer('unpack_called', torch.zeros(1, dtype=torch.long)) | 
 |  | 
 |         @torch.jit.script_method | 
 |         def _pack(self): | 
 |             self.pack_called.set_(torch.ones(1, dtype=torch.long)) | 
 |             self.derived.set_(torch.rand(1, dtype=torch.float).detach()) | 
 |  | 
 |         @torch.jit.script_method | 
 |         def _unpack(self): | 
 |             self.unpack_called.set_(torch.ones(1, dtype=torch.long)) | 
 |             self.derived.set_(torch.neg(self.param).detach()) | 
 |  | 
 |         @torch.jit.script_method | 
 |         def forward(self, x): | 
 |             return x + self.derived | 
 |  | 
 |     def test_pack_unpack_state(self): | 
 |         sm = TestScript.DerivedStateModule() | 
 |         x = torch.rand(3, 4, dtype=torch.float) | 
 |         torch.testing.assert_close(sm(x), x + torch.neg(torch.ones(3, 4, dtype=torch.float))) | 
 |  | 
 |         # Test save path | 
 |         self.assertFalse(sm.pack_called.item()) | 
 |         self.assertFalse(sm.unpack_called.item()) | 
 |         imported = self.getExportImportCopyWithPacking(sm) | 
 |         # ensure pack was called before serialization | 
 |         self.assertTrue(sm.pack_called.item()) | 
 |         # ensure unpack was called after serialization so as to leave the module in an initialized state | 
 |         self.assertTrue(sm.unpack_called.item()) | 
 |  | 
 |         torch.testing.assert_close(sm.derived, torch.neg(sm.param)) | 
 |  | 
 |         # Test load paths | 
 |         self.assertTrue(imported.unpack_called.item()) | 
 |         torch.testing.assert_close(imported(x), x + torch.neg(torch.ones(3, 4, dtype=torch.float))) | 
 |  | 
 |     @unittest.skipIf(not TEST_MKL, "PyTorch is built without MKL support") | 
 |     @unittest.skipIf(True, "Skipping while landing PR stack") | 
 |     def test_torch_functional(self): | 
 |         def stft(input, n_fft): | 
 |             # type: (Tensor, int) -> Tensor | 
 |             return torch.stft(input, n_fft, return_complex=True) | 
 |  | 
 |         inps = (torch.randn(10), 7) | 
 |         self.assertEqual(stft(*inps), torch.jit.script(stft)(*inps)) | 
 |  | 
 |         def istft(input, n_fft): | 
 |             # type: (Tensor, int) -> Tensor | 
 |             return torch.istft(input, n_fft) | 
 |  | 
 |         inps2 = (stft(*inps), inps[1]) | 
 |         self.assertEqual(istft(*inps2), torch.jit.script(istft)(*inps2)) | 
 |  | 
 |         def lu_unpack(x): | 
 |             A_LU, pivots = torch.linalg.lu_factor(x) | 
 |             return torch.lu_unpack(A_LU, pivots) | 
 |  | 
 |         for shape in ((3, 3), (5, 3, 3), (7, 3, 5, 5), (7, 5, 3, 3, 3)): | 
 |             a = torch.randn(*shape) | 
 |             self.checkScript(lu_unpack, (a,)) | 
 |  | 
 |         def cdist_fn(): | 
 |             a = torch.tensor([[0.9041, 0.0196], [-0.3108, -2.4423], [-0.4821, 1.059]]) | 
 |             b = torch.tensor([[-2.1763, -0.4713], [-0.6986, 1.3702]]) | 
 |             return torch.cdist(a, b, compute_mode="use_mm_for_euclid_dist") | 
 |  | 
 |         self.checkScript(cdist_fn, ()) | 
 |  | 
 |         def norm(): | 
 |             c = torch.tensor([[1, 2, 3], [-1, 1, 4]], dtype=torch.float) | 
 |             return torch.norm(c, p="fro"), torch.norm(c, p="nuc"), torch.norm(c), torch.norm(c, p=.5) | 
 |  | 
 |         self.checkScript(norm, ()) | 
 |  | 
 |         def torch_unique(dim: Optional[int]): | 
 |             ten = torch.unique(torch.tensor([[1, 3], [2, 3]], dtype=torch.long)) | 
 |             a = torch.unique(ten, dim=dim) | 
 |             b = torch.unique(ten, return_counts=True, dim=dim) | 
 |             c = torch.unique(ten, return_inverse=True, dim=dim) | 
 |             d = torch.unique(ten, return_counts=True, return_inverse=True, dim=dim) | 
 |             return a, b, c, d | 
 |  | 
 |         self.checkScript(torch_unique, (None,)) | 
 |         self.checkScript(torch_unique, (0,)) | 
 |  | 
 |         def torch_unique_consecutive(dim: Optional[int]): | 
 |             ten = torch.unique(torch.tensor([[1, 3], [3, 2], [3, 2], [2, 3]], dtype=torch.long)) | 
 |             a = torch.unique_consecutive(ten, dim=dim) | 
 |             b = torch.unique_consecutive(ten, return_counts=True, dim=dim) | 
 |             c = torch.unique_consecutive(ten, return_inverse=True, dim=dim) | 
 |             d = torch.unique_consecutive(ten, return_counts=True, return_inverse=True, dim=dim) | 
 |             return a, b, c, d | 
 |  | 
 |         self.checkScript(torch_unique_consecutive, (None,)) | 
 |         self.checkScript(torch_unique_consecutive, (0,)) | 
 |  | 
 |     def test_torch_functional_tensordot_int(self): | 
 |         def tensordot_dims_int(a: torch.Tensor, b: torch.Tensor, dims: int): | 
 |             return torch.tensordot(a, b, dims=dims) | 
 |  | 
 |         a = torch.arange(120.).reshape(2, 3, 4, 5) | 
 |         b = torch.arange(840.).reshape(4, 5, 6, 7) | 
 |         dims = 2 | 
 |         self.checkScript(tensordot_dims_int, (a, b, dims)) | 
 |  | 
 |     def test_torch_functional_tensordot_tensor(self): | 
 |         def tensordot_dims_tensor(a: torch.Tensor, b: torch.Tensor, dims: torch.Tensor): | 
 |             return torch.tensordot(a, b, dims=dims) | 
 |  | 
 |         a = torch.arange(120.).reshape(2, 3, 4, 5) | 
 |         b = torch.arange(840.).reshape(4, 5, 6, 7) | 
 |         dims = torch.tensor([2]) | 
 |         self.checkScript(tensordot_dims_tensor, (a, b, dims)) | 
 |  | 
 |         a = torch.arange(60.).reshape(3, 4, 5) | 
 |         b = torch.arange(24.).reshape(4, 3, 2) | 
 |         dims = torch.tensor([[1, 0], [0, 1]], dtype=torch.long) | 
 |         self.checkScript(tensordot_dims_tensor, (a, b, dims)) | 
 |  | 
 |     def test_torch_functional_tensordot_list(self): | 
 |         def tensordot_dims_list(a: torch.Tensor, b: torch.Tensor, dims: List[List[int]]): | 
 |             return torch.tensordot(a, b, dims=dims) | 
 |  | 
 |         a = torch.arange(60.).reshape(3, 4, 5) | 
 |         b = torch.arange(24.).reshape(4, 3, 2) | 
 |         dims = [[1, 0], [0, 1]] | 
 |         self.checkScript(tensordot_dims_list, (a, b, dims)) | 
 |  | 
 |     def test_torch_functional_tensordot_tuple(self): | 
 |         def tensordot_dims_tuple(a: torch.Tensor, b: torch.Tensor, dims: Tuple[List[int], List[int]]): | 
 |             return torch.tensordot(a, b, dims=dims) | 
 |  | 
 |         a = torch.arange(60.).reshape(3, 4, 5) | 
 |         b = torch.arange(24.).reshape(4, 3, 2) | 
 |         dims = ([1, 0], [0, 1]) | 
 |         self.checkScript(tensordot_dims_tuple, (a, b, dims)) | 
 |  | 
 |     def test_missing_getstate(self): | 
 |         class Foo(torch.nn.Module): | 
 |             def __init__(self): | 
 |                 super(Foo, self).__init__() | 
 |                 self.x = 1 | 
 |  | 
 |             def forward(self, x): | 
 |                 return x * self.x | 
 |  | 
 |             @torch.jit.export | 
 |             def __setstate__(self, state): | 
 |                 self.x = state[0] | 
 |                 self.training = state[1] | 
 |  | 
 |         with self.assertRaisesRegex(RuntimeError, "getstate"): | 
 |             scripted = torch.jit.script(Foo()) | 
 |  | 
 |     def test_inlining_cleanup(self): | 
 |         def foo(x): | 
 |             return F.linear(x, x) | 
 |  | 
 |         @torch.jit.script | 
 |         def fee(x): | 
 |             return foo(x) | 
 |  | 
 |         # inlining optimizations should have cleaned up linear if statement | 
 |         self.run_pass("inline", fee.graph) | 
 |         FileCheck().check_not("prim::If").run(fee.graph) | 
 |  | 
 |     def test_pack_unpack_nested(self): | 
 |         class SubSubMod(torch.jit.ScriptModule): | 
 |             def __init__(self): | 
 |                 super(SubSubMod, self).__init__() | 
 |                 self.register_buffer('buf', torch.ones(3, 4) * 3) | 
 |  | 
 |             @torch.jit.script_method | 
 |             def _pack(self): | 
 |                 self.buf.set_(torch.zeros(1, dtype=torch.double)) | 
 |  | 
 |             @torch.jit.script_method | 
 |             def _unpack(self): | 
 |                 self.buf.set_(torch.ones(3, 4, dtype=torch.double) * 3) | 
 |  | 
 |             @torch.jit.script_method | 
 |             def forward(self, x): | 
 |                 return x + self.buf | 
 |  | 
 |         class SubMod(torch.jit.ScriptModule): | 
 |             def __init__(self): | 
 |                 super(SubMod, self).__init__() | 
 |                 self.register_buffer('buf', torch.ones(3, 4) * 2) | 
 |                 self.ssm = SubSubMod() | 
 |  | 
 |             @torch.jit.script_method | 
 |             def _pack(self): | 
 |                 self.buf.set_(torch.zeros(1, dtype=torch.double)) | 
 |  | 
 |             @torch.jit.script_method | 
 |             def _unpack(self): | 
 |                 self.buf.set_(torch.ones(3, 4, dtype=torch.double) * 2) | 
 |  | 
 |             @torch.jit.script_method | 
 |             def forward(self, x): | 
 |                 return self.ssm(x + self.buf) | 
 |  | 
 |         class Mod(torch.jit.ScriptModule): | 
 |             def __init__(self): | 
 |                 super(Mod, self).__init__() | 
 |                 self.submod = SubMod() | 
 |                 self.register_buffer('buf', torch.ones(3, 4) * 1) | 
 |  | 
 |             @torch.jit.script_method | 
 |             def _pack(self): | 
 |                 self.buf.set_(torch.zeros(1, dtype=torch.double)) | 
 |  | 
 |             @torch.jit.script_method | 
 |             def _unpack(self): | 
 |                 self.buf.set_(torch.ones(3, 4, dtype=torch.double)) | 
 |  | 
 |             @torch.jit.script_method | 
 |             def forward(self, x): | 
 |                 return self.submod(x + self.buf) | 
 |  | 
 |         m = Mod() | 
 |         torch.testing.assert_close(m(torch.zeros(3, 4)), torch.ones(3, 4) * 6) | 
 |         m.apply(lambda s: s._pack()) | 
 |         torch.testing.assert_close(m(torch.zeros(3, 4)), torch.zeros(3, 4)) | 
 |         m.apply(lambda s: s._unpack()) | 
 |         torch.testing.assert_close(m(torch.zeros(3, 4)), torch.ones(3, 4) * 6) | 
 |  | 
 |     def test_torch_any(self): | 
 |         def fn(x): | 
 |             return torch.any(x) | 
 |  | 
 |         def fn1(x, dim: int): | 
 |             return torch.any(x, dim) | 
 |  | 
 |         self.checkScript(fn, (torch.randn(3, 4), )) | 
 |         self.checkScript(fn, (torch.empty(3), )) | 
 |         self.checkScript(fn, (torch.empty(1), )) | 
 |         self.checkScript(fn, (torch.ones(3, 4),)) | 
 |         self.checkScript(fn, (torch.zeros(5, 7, 1),)) | 
 |         self.checkScript(fn1, (torch.empty(3, 4), -2)) | 
 |         self.checkScript(fn1, (torch.randn(3, 8), 1)) | 
 |         self.checkScript(fn1, (torch.zeros(3, 6, 9), -3)) | 
 |         self.checkScript(fn1, (torch.empty(5), 0)) | 
 |  | 
 |     def test_any(self): | 
 |         def fn(x: List[int]): | 
 |             return any(x) | 
 |  | 
 |         def fn1(x: List[float]): | 
 |             return any(x) | 
 |  | 
 |         def fn2(x: List[bool]): | 
 |             return any(x) | 
 |  | 
 |         def fn3(x: List[str]): | 
 |             return any(x) | 
 |  | 
 |         self.checkScript(fn, ([0, 0, 0, 0], )) | 
 |         self.checkScript(fn, ([0, 3, 0], )) | 
 |         self.checkScript(fn, ([], )) | 
 |         self.checkScript(fn1, ([1.0, 2.0, 3.0], )) | 
 |         self.checkScript(fn1, ([0.0, 0.0, 0.0], )) | 
 |         self.checkScript(fn1, ([0, 0, 0], )) | 
 |         self.checkScript(fn1, ([], )) | 
 |         self.checkScript(fn2, ([True, False, False], )) | 
 |         self.checkScript(fn2, ([False, False, False], )) | 
 |         self.checkScript(fn2, ([True, True, True, True], )) | 
 |         self.checkScript(fn2, ([], )) | 
 |         self.checkScript(fn3, (["", "", ""], )) | 
 |         self.checkScript(fn3, (["", "", "", "-1"], )) | 
 |         self.checkScript(fn3, ([], )) | 
 |  | 
 |     def test_script_module_not_tuple(self): | 
 |         class M(torch.jit.ScriptModule): | 
 |             __constants__ = ['mods'] | 
 |  | 
 |             def __init__(self): | 
 |                 super(M, self).__init__() | 
 |                 self.mods = 1 | 
 |  | 
 |             @torch.jit.script_method | 
 |             def forward(self, v): | 
 |                 for m in self.mods: | 
 |                     print(m) | 
 |                 return v | 
 |         with self.assertRaisesRegex(RuntimeError, "'int' object is not iterable"): | 
 |             M() | 
 |  | 
 |     def test_attr_module_constants(self): | 
 |         class M2(torch.jit.ScriptModule): | 
 |             def __init__(self, mod_list): | 
 |                 super(M2, self).__init__() | 
 |                 self.mods = mod_list | 
 |  | 
 |             @torch.jit.script_method | 
 |             def forward(self, x): | 
 |                 return self.mods.forward(x) | 
 |  | 
 |         with torch.jit.optimized_execution(False): | 
 |             m = M2(nn.Sequential(nn.ReLU())) | 
 |             self.assertExportImportModule(m, (torch.randn(2, 2),)) | 
 |  | 
 |     def test_script_sequential_for(self): | 
 |         class Sub(torch.jit.ScriptModule): | 
 |             def __init__(self): | 
 |                 super(Sub, self).__init__() | 
 |                 self.weight = nn.Parameter(torch.randn(2)) | 
 |  | 
 |             @torch.jit.script_method | 
 |             def forward(self, thing): | 
 |                 return self.weight + thing | 
 |  | 
 |         class M(torch.jit.ScriptModule): | 
 |             def __init__(self): | 
 |                 super(M, self).__init__() | 
 |                 self.mods = nn.Sequential(Sub(), Sub(), Sub()) | 
 |  | 
 |             @torch.jit.script_method | 
 |             def forward(self, v): | 
 |                 for m in self.mods: | 
 |                     v = m(v) | 
 |                 return v | 
 |  | 
 |             @torch.jit.script_method | 
 |             def forward2(self, v): | 
 |                 return self.mods(v) | 
 |  | 
 |         with torch.jit.optimized_execution(False): | 
 |             i = torch.empty(2) | 
 |             m = M() | 
 |             o = m(i) | 
 |             v = i | 
 |             for sub in m.mods._modules.values(): | 
 |                 v = sub(v) | 
 |             self.assertEqual(o, v) | 
 |  | 
 |             o2 = m.forward2(i) | 
 |             self.assertEqual(o2, v) | 
 |  | 
 |     def test_script_sequential_sliced_iteration(self): | 
 |         class seq_mod(nn.Module): | 
 |             def __init__(self): | 
 |                 super(seq_mod, self).__init__() | 
 |                 self.layers = [nn.ReLU(), nn.ReLU(), nn.ReLU()] | 
 |                 self.layers = nn.Sequential(*self.layers) | 
 |  | 
 |             def forward(self, input): | 
 |                 x = self.layers[0].forward(input) | 
 |                 for layer in self.layers[1:3]: | 
 |                     x = layer.forward(x) | 
 |                 for layer in self.layers[2:]: | 
 |                     x = layer.forward(x) | 
 |                 return x | 
 |  | 
 |         seq = seq_mod() | 
 |         self.checkModule(seq, [torch.tensor([-2, 1, -1, 2])]) | 
 |  | 
 |     def test_script_sequential_orderdict(self): | 
 |         class M(torch.jit.ScriptModule): | 
 |             def __init__(self): | 
 |                 super(M, self).__init__() | 
 |                 self.mods = nn.Sequential(OrderedDict([ | 
 |                     ("conv", nn.Conv2d(1, 20, 5)), | 
 |                     ("relu", nn.ReLU()) | 
 |                 ])) | 
 |  | 
 |             @torch.jit.script_method | 
 |             def forward(self, input): | 
 |                 return self.mods(input) | 
 |  | 
 |         m = M() | 
 |         self.assertTrue('mods.conv.weight' in m.state_dict().keys()) | 
 |  | 
 |     def test_script_sequential_multi_output_fail(self): | 
 |         class Sub(torch.jit.ScriptModule): | 
 |             def __init__(self): | 
 |                 super(Sub, self).__init__() | 
 |                 self.weight = nn.Parameter(torch.randn(2)) | 
 |  | 
 |             @torch.jit.script_method | 
 |             def forward(self, thing): | 
 |                 return self.weight + thing | 
 |  | 
 |         class ReturnMulti(torch.jit.ScriptModule): | 
 |             def __init__(self): | 
 |                 super(ReturnMulti, self).__init__() | 
 |  | 
 |             @torch.jit.script_method | 
 |             def forward(self, x): | 
 |                 return x, x, x | 
 |  | 
 |         class HaveSequential(torch.jit.ScriptModule): | 
 |             def __init__(self): | 
 |                 super(HaveSequential, self).__init__() | 
 |                 self.someseq = nn.Sequential( | 
 |                     Sub(), | 
 |                     ReturnMulti(), | 
 |                     Sub() | 
 |                 ) | 
 |  | 
 |             @torch.jit.script_method | 
 |             def forward(self, x): | 
 |                 return self.someseq(x) | 
 |  | 
 |         with self.assertRaisesRegex(RuntimeError, "(Tensor, Tensor, Tensor)"): | 
 |             with torch.jit.optimized_execution(False): | 
 |                 hs = HaveSequential() | 
 |                 i = torch.empty(2) | 
 |                 hs(i) | 
 |  | 
 |     @_tmp_donotuse_dont_inline_everything | 
 |     def test_script_sequential_in_mod_list(self): | 
 |         class Sub(torch.jit.ScriptModule): | 
 |             def __init__(self): | 
 |                 super(Sub, self).__init__() | 
 |                 self.weight = nn.Parameter(torch.randn(2)) | 
 |  | 
 |             @torch.jit.script_method | 
 |             def forward(self, thing): | 
 |                 return self.weight + thing | 
 |  | 
 |         class M(torch.jit.ScriptModule): | 
 |             def __init__(self): | 
 |                 super(M, self).__init__() | 
 |                 self.mods = nn.ModuleList([Sub(), nn.Sequential(Sub(), nn.Sequential(Sub(), Sub()), Sub())]) | 
 |  | 
 |             @torch.jit.script_method | 
 |             def forward(self, v): | 
 |                 for mod in self.mods: | 
 |                     v = mod(v) | 
 |                 return v | 
 |  | 
 |         m = M() | 
 |         graph = str(m.graph) | 
 |         self.assertTrue(graph.count("prim::CallMethod") == 2) | 
 |         self.assertTrue("python" not in graph) | 
 |  | 
 |     @_tmp_donotuse_dont_inline_everything | 
 |     def test_script_nested_mod_list(self): | 
 |         class Sub(torch.jit.ScriptModule): | 
 |             def __init__(self): | 
 |                 super(Sub, self).__init__() | 
 |                 self.weight = nn.Parameter(torch.randn(2)) | 
 |  | 
 |             @torch.jit.script_method | 
 |             def forward(self, thing): | 
 |                 return self.weight + thing | 
 |  | 
 |         class M(torch.jit.ScriptModule): | 
 |             def __init__(self): | 
 |                 super(M, self).__init__() | 
 |                 self.mods = nn.ModuleList([nn.ModuleList([Sub()]), nn.Sequential(Sub()), nn.ModuleList([Sub(), Sub()])]) | 
 |  | 
 |             @torch.jit.script_method | 
 |             def forward(self, v): | 
 |                 for mod in self.mods: | 
 |                     for m in mod: | 
 |                         v = m(v) | 
 |                 return v | 
 |  | 
 |         m = M() | 
 |         graph = str(m.graph) | 
 |         self.assertTrue(graph.count("prim::CallMethod") == 4) | 
 |         self.assertTrue("python" not in graph) | 
 |  | 
 |     def test_constant_as_attr(self): | 
 |         class M(torch.jit.ScriptModule): | 
 |             __constants__ = ['dim'] | 
 |  | 
 |             def __init__(self): | 
 |                 super(M, self).__init__() | 
 |                 self.dim = 1 | 
 |  | 
 |             @torch.jit.script_method | 
 |             def forward(self, v): | 
 |                 return torch.cat([v, v, v], dim=self.dim) | 
 |         v = torch.zeros(1, 1) | 
 |         with torch.jit.optimized_execution(False): | 
 |             self.assertEqual(torch.cat([v, v, v], dim=1), M()(v)) | 
 |  | 
 |     class StarTestSumStarred(torch.nn.Module): | 
 |         def __init__(self): | 
 |             super(TestScript.StarTestSumStarred, self).__init__() | 
 |  | 
 |         def forward(self, *inputs): | 
 |             output = inputs[0] | 
 |             for i in range(1, len(inputs)): | 
 |                 output += inputs[i] | 
 |             return output | 
 |  | 
 |     class StarTestReturnThree(torch.nn.Module): | 
 |         def __init__(self): | 
 |             super(TestScript.StarTestReturnThree, self).__init__() | 
 |  | 
 |         def forward(self, rep): | 
 |             return rep, rep, rep | 
 |  | 
 |     def test_script_star_expr(self): | 
 |  | 
 |         class M2(torch.jit.ScriptModule): | 
 |             def __init__(self): | 
 |                 super(M2, self).__init__() | 
 |                 self.m = torch.jit.trace(TestScript.StarTestSumStarred(), | 
 |                                          (torch.ones(4, 3), torch.ones(4, 3), torch.ones(4, 3))) | 
 |                 self.g = torch.jit.trace(TestScript.StarTestReturnThree(), torch.ones(4, 3)) | 
 |  | 
 |             @torch.jit.script_method | 
 |             def forward(self, rep): | 
 |                 tup = self.g(rep) | 
 |                 return self.m(*tup) | 
 |  | 
 |         m = M2() | 
 |         self.assertEqual(m(torch.zeros(4, 3)), 3 * torch.zeros(4, 3)) | 
 |  | 
 |     def test_script_star_expr_string(self): | 
 |         class M2(torch.jit.ScriptModule): | 
 |             def __init__(self): | 
 |                 super(M2, self).__init__() | 
 |                 self.m = torch.jit.trace(TestScript.StarTestSumStarred(), | 
 |                                          (torch.ones(4, 3), torch.ones(4, 3), torch.ones(4, 3))) | 
 |                 self.g = torch.jit.trace(TestScript.StarTestReturnThree(), torch.ones(4, 3)) | 
 |  | 
 |                 self.define(''' | 
 |             def forward(self, rep): | 
 |                 tup = self.g(rep) | 
 |                 return self.m(*tup) | 
 |                 ''') | 
 |  | 
 |         m = M2() | 
 |         self.assertEqual(m(torch.zeros(4, 3)), 3 * torch.zeros(4, 3)) | 
 |  | 
 |     class StarTestSumAndReturnThree(torch.nn.Module): | 
 |         def __init__(self): | 
 |             super(TestScript.StarTestSumAndReturnThree, self).__init__() | 
 |  | 
 |         def forward(self, *inputs): | 
 |             output = inputs[0] | 
 |             for i in range(1, len(inputs)): | 
 |                 output += inputs[i] | 
 |             return output, output, output | 
 |  | 
 |     def test_script_star_assign(self): | 
 |         class M2(torch.jit.ScriptModule): | 
 |             def __init__(self): | 
 |                 super(M2, self).__init__() | 
 |                 self.g = torch.jit.trace(TestScript.StarTestSumAndReturnThree(), torch.ones(4, 3)) | 
 |                 self.define(''' | 
 |             def forward(self, rep): | 
 |                 head, *tail = self.g(rep) | 
 |                 return head | 
 |                 ''') | 
 |  | 
 |         m = M2() | 
 |         self.assertEqual(m(torch.zeros(4, 3)), 3 * torch.zeros(4, 3)) | 
 |  | 
 |     def test_script_module_star_assign2(self): | 
 |         class M2(torch.jit.ScriptModule): | 
 |             def __init__(self): | 
 |                 super(M2, self).__init__() | 
 |                 self.g = torch.jit.trace( | 
 |                     TestScript.StarTestSumAndReturnThree(), | 
 |                     (torch.ones(4, 3), torch.ones(4, 3), torch.ones(4, 3)), | 
 |                     _force_outplace=True) | 
 |                 self.define(''' | 
 |             def forward(self, rep): | 
 |                 *head, tail = self.g(rep, rep, rep) | 
 |                 return tail | 
 |                 ''') | 
 |  | 
 |         m = M2() | 
 |         self.assertEqual(m(torch.ones(4, 3)), 3 * torch.ones(4, 3)) | 
 |  | 
 |     def test_script_module_star_assign2_inplace(self): | 
 |         class M2(torch.jit.ScriptModule): | 
 |             def __init__(self): | 
 |                 super(M2, self).__init__() | 
 |                 self.g = torch.jit.trace( | 
 |                     TestScript.StarTestSumAndReturnThree(), | 
 |                     (torch.ones(4, 3), torch.ones(4, 3), torch.ones(4, 3)), | 
 |                     _force_outplace=False) | 
 |                 self.define(''' | 
 |             def forward(self, rep): | 
 |                 *head, tail = self.g(rep, rep, rep) | 
 |                 return tail | 
 |                 ''') | 
 |  | 
 |         m = M2() | 
 |         # since forward() makes three aliases to the input `rep` before passing | 
 |         # it to StarTestSumAndReturnThree(), in-place behavior will be different | 
 |         # than the above out of place. | 
 |         self.assertEqual(m(torch.ones(4, 3)), 4 * torch.ones(4, 3)) | 
 |  | 
 |     def test_script_module_star_assign_fail_pythonop(self): | 
 |  | 
 |         with self.assertRaisesRegex(RuntimeError, "cannot be used as a tuple"): | 
 |             class M2(torch.jit.ScriptModule): | 
 |                 def __init__(self): | 
 |                     super(M2, self).__init__() | 
 |  | 
 |                     @torch.jit.ignore | 
 |                     def myfunc(): | 
 |                         return torch.zeros(1, 2, 3), torch.zeros(1, 2, 3) | 
 |  | 
 |                     self.define(''' | 
 |                 def forward(self, rep): | 
 |                     a, *b = myfunc() | 
 |                     return a | 
 |                     ''') | 
 |  | 
 |             m = M2() | 
 |             m(torch.zeros(4, 3)) | 
 |  | 
 |     def test_script_module_star_assign_fail_builtin(self): | 
 |         with self.assertRaisesRegex(RuntimeError, "cannot be used as a tuple"): | 
 |             class M2(torch.jit.ScriptModule): | 
 |                 def __init__(self): | 
 |                     super(M2, self).__init__() | 
 |  | 
 |                     self.define(''' | 
 |                 def forward(self, rep): | 
 |                     a, *b = torch.neg(rep) | 
 |                     return a | 
 |                     ''') | 
 |  | 
 |             m = M2() | 
 |             m(torch.zeros(4, 3)) | 
 |  | 
 |     @skipIfCompiledWithoutNumpy | 
 |     def test_pack_padded_pad_packed_trace(self): | 
 |         from torch.nn.utils.rnn import pack_padded_sequence, pad_packed_sequence | 
 |         T, B, C = 3, 5, 7 | 
 |  | 
 |         class PadPackedWrapper(torch.nn.Module): | 
 |             def __init__(self): | 
 |                 super(PadPackedWrapper, self).__init__() | 
 |  | 
 |             def forward(self, x, seq_lens): | 
 |                 x = pack_padded_sequence(x, seq_lens) | 
 |                 x, _ = pad_packed_sequence(x) | 
 |                 return x | 
 |  | 
 |         x = np.ones((T, B, C)) | 
 |         seq_lens = np.array([3, 3, 2, 2, 1], dtype=np.int32) | 
 |         # set padding value so we can test equivalence | 
 |         for b in range(B): | 
 |             if seq_lens[b] < T: | 
 |                 x[seq_lens[b]:, b, :] = 0 | 
 |         seq_lens = torch.from_numpy(seq_lens) | 
 |         x = torch.autograd.Variable(torch.from_numpy(x), requires_grad=True) | 
 |  | 
 |         m = PadPackedWrapper() | 
 |         m_traced = torch.jit.trace(m, (x, seq_lens,)) | 
 |  | 
 |         y = m(x, seq_lens) | 
 |         loss = torch.sum(y) | 
 |         loss.backward() | 
 |         grad = x.grad.clone() | 
 |         x.grad.zero_() | 
 |  | 
 |         y_traced = m_traced(x, seq_lens) | 
 |         loss_traced = torch.sum(y_traced) | 
 |         loss_traced.backward() | 
 |         grad_traced = x.grad.clone() | 
 |  | 
 |         self.assertEqual(y_traced, x) | 
 |         self.assertEqual(y_traced, y) | 
 |         self.assertEqual(grad, grad_traced) | 
 |  | 
 |         f = io.BytesIO() | 
 |         torch.onnx._export(m, (x, seq_lens), f, verbose=False) | 
 |  | 
 |     def test_script_pack_padded_sequence(self): | 
 |         from torch.nn.utils.rnn import pack_padded_sequence, pad_packed_sequence | 
 |  | 
 |         def pack_padded_pad_packed_script(x, seq_lens): | 
 |             x = pack_padded_sequence(x, seq_lens) | 
 |             x, lengths = pad_packed_sequence(x) | 
 |             return x, lengths | 
 |  | 
 |         T, B, C = 3, 5, 7 | 
 |         x = torch.ones((T, B, C)) | 
 |         seq_lens = torch.tensor([3, 3, 2, 2, 1]) | 
 |         # set padding value so we can test equivalence | 
 |         for b in range(B): | 
 |             if seq_lens[b] < T: | 
 |                 x[seq_lens[b]:, b, :] = 0 | 
 |  | 
 |         eager_seq, eager_lengths = pack_padded_pad_packed_script(x, seq_lens) | 
 |         with torch._jit_internal._disable_emit_hooks(): | 
 |             scripted_pack_padded_seq = torch.jit.script(pack_padded_pad_packed_script) | 
 |         script_seq, script_lengths = scripted_pack_padded_seq(x, seq_lens) | 
 |         self.assertEqual(eager_seq, script_seq) | 
 |         self.assertEqual(eager_lengths, script_lengths) | 
 |  | 
 |         class ExperimentalLSTM(torch.nn.Module): | 
 |             def __init__(self, input_dim, hidden_dim): | 
 |                 super().__init__() | 
 |  | 
 |             def forward(self, input): | 
 |                 # type: (Tensor) | 
 |                 packed = pack_padded_sequence( | 
 |                     input=input, lengths=torch.tensor([1, 2]), enforce_sorted=False | 
 |                 ) | 
 |                 output, lengths = pad_packed_sequence( | 
 |                     sequence=packed, total_length=2 | 
 |                 ) | 
 |                 # lengths is flipped, so is output | 
 |                 return output[0] | 
 |  | 
 |         lstm = ExperimentalLSTM(input_dim=2, hidden_dim=2) | 
 |  | 
 |         with torch._jit_internal._disable_emit_hooks(): | 
 |             self.checkModule(lstm, [torch.ones(2, 2)]) | 
 |  | 
 |     def test_script_pad_sequence_pack_sequence(self): | 
 |         from torch.nn.utils.rnn import pad_sequence, pack_sequence, pad_packed_sequence | 
 |  | 
 |         def pad_sequence_func(tensor_list, batch_first=False, padding_value=0.0): | 
 |             # type: (List[Tensor], bool, float) -> Tensor | 
 |             return pad_sequence(tensor_list, batch_first, padding_value) | 
 |  | 
 |         def pack_sequence_func(tensor_list, enforce_sorted=True): | 
 |             # type: (List[Tensor], bool) -> Tensor | 
 |             return pad_packed_sequence(pack_sequence(tensor_list, enforce_sorted))[0] | 
 |  | 
 |         ones3 = torch.ones(3, 5) | 
 |         ones4 = torch.ones(4, 5) | 
 |         ones5 = torch.ones(5, 5) | 
 |         tensor1 = torch.tensor([1, 2, 3]) | 
 |         tensor2 = torch.tensor([4, 5]) | 
 |         tensor3 = torch.tensor([6]) | 
 |         with torch._jit_internal._disable_emit_hooks(): | 
 |             self.checkScript(pad_sequence_func, | 
 |                              ([ones3, ones4, ones5],)) | 
 |             self.checkScript(pad_sequence_func, | 
 |                              ([ones3, ones4, ones5], True)) | 
 |             self.checkScript(pad_sequence_func, | 
 |                              ([ones3, ones4, ones5], True, 2.5)) | 
 |             self.checkScript(pack_sequence_func, | 
 |                              ([tensor1, tensor2, tensor3],)) | 
 |             self.checkScript(pack_sequence_func, | 
 |                              ([tensor1, tensor2, tensor3], False)) | 
 |  | 
 |     def test_script_get_tracing_state(self): | 
 |         def test_if_tracing(x): | 
 |             if torch._C._get_tracing_state(): | 
 |                 return x + 1 | 
 |             else: | 
 |                 return x - 1 | 
 |  | 
 |         inp = torch.randn(3, 3) | 
 |         self.checkScript(test_if_tracing, (inp,)) | 
 |  | 
 |     def test_script_is_tracing(self): | 
 |         def test_is_tracing(x): | 
 |             if torch.jit.is_tracing(): | 
 |                 return x + 1 | 
 |             else: | 
 |                 return x - 1 | 
 |  | 
 |         inp = torch.randn(3, 3) | 
 |         self.checkScript(test_is_tracing, (inp,)) | 
 |  | 
 |     def test_is_scripting(self): | 
 |         def foo(): | 
 |             return torch.jit.is_scripting() | 
 |  | 
 |         self.assertFalse(foo()) | 
 |         scripted = torch.jit.script(foo) | 
 |         self.assertTrue(scripted()) | 
 |  | 
 |     def test_comment_ignore_indent(self): | 
 |         class Model(torch.nn.Module): | 
 |             def __init__(self): | 
 |     # useless comment that is not indented correctly  # noqa: E115 | 
 |                 super().__init__() | 
 |  | 
 |             def forward(self): | 
 |                 return 5 | 
 |  | 
 |         # should compile without an error | 
 |         self.checkModule(Model(), ()) | 
 |  | 
 |     def test_script_outputs(self): | 
 |         with self.assertRaisesRegex(RuntimeError, "cannot be used as a tuple"): | 
 |             @torch.jit.script | 
 |             def foo(a): | 
 |                 c, d = a + a | 
 |                 return c + d | 
 |  | 
 |         @torch.jit.script | 
 |         def return3(): | 
 |             return 1, 2, 3 | 
 |  | 
 |         with self.assertRaisesRegex(RuntimeError, "too many values to unpack"): | 
 |             @torch.jit.script | 
 |             def bind2(): | 
 |                 a, b = return3() | 
 |                 print(a) | 
 |                 print(b) | 
 |  | 
 |     @unittest.skipIf(not RUN_CUDA, "requires CUDA") | 
 |     def test_script_get_device_cuda(self): | 
 |         @torch.jit.script | 
 |         def foo(a): | 
 |             return a.get_device() | 
 |  | 
 |         v = torch.randn(1, device='cuda') | 
 |         self.assertEqual(foo(v), 0) | 
 |  | 
 |     def test_script_chunk(self): | 
 |         @torch.jit.script | 
 |         def foo(a): | 
 |             b, c = torch.chunk(a, dim=0, chunks=2) | 
 |             return b | 
 |         v = torch.rand(10, 3) | 
 |         self.assertEqual(torch.chunk(v, dim=0, chunks=2)[0], foo(v)) | 
 |  | 
 |     def test_script_copy(self): | 
 |         class M(torch.nn.Module): | 
 |             __annotations__ = { | 
 |                 "val": Optional[torch.Tensor] | 
 |             } | 
 |  | 
 |             def __init__(self): | 
 |                 super(M, self).__init__() | 
 |                 self.val = None | 
 |  | 
 |             def some_method(self): | 
 |                 return 3 | 
 |  | 
 |             def forward(self, x): | 
 |                 # type: (Tensor) -> Tensor | 
 |                 self.val = x + self.some_method() | 
 |                 return x | 
 |  | 
 |         m = torch.jit.script(M()) | 
 |         # test copy | 
 |         copy.copy(m) | 
 |         copy.deepcopy(m) | 
 |  | 
 |     def test_script_forward_method_replacement(self): | 
 |         # We want to support the use case of attaching a different `forward` method | 
 |         class LowLevelModule(torch.nn.Module): | 
 |             def __init__(self): | 
 |                 super(LowLevelModule, self).__init__() | 
 |  | 
 |             def forward(self, input: torch.Tensor): | 
 |                 # Generic forward dispatch | 
 |                 return self.forward_pytorch(input) * 2 | 
 |  | 
 |         class TestModule(LowLevelModule): | 
 |             def __init__(self): | 
 |                 super(TestModule, self).__init__() | 
 |                 # Replace the forward method | 
 |                 self.forward = types.MethodType(LowLevelModule.forward, self) | 
 |  | 
 |             def forward_pytorch(self, input: torch.Tensor): | 
 |                 return torch.tensor(123) | 
 |  | 
 |             def forward(self, input: torch.Tensor): | 
 |                 # Should not use this forward method | 
 |                 raise AssertionError("This method should not be used") | 
 |                 return self.forward_pytorch(input) | 
 |  | 
 |         m = TestModule() | 
 |         self.assertEqual(m(torch.tensor(1)), torch.tensor(246)) | 
 |  | 
 |         m_scripted = torch.jit.script(m) | 
 |         self.assertEqual(m_scripted(torch.tensor(1)), torch.tensor(246)) | 
 |  | 
 |     # Suppression: ONNX warns when exporting RNNs because of potential batch size mismatch. | 
 |     @suppress_warnings | 
 |     @skipIfCompiledWithoutNumpy | 
 |     def test_rnn_trace_override(self): | 
 |         from torch.nn.utils.rnn import pack_padded_sequence, pad_packed_sequence | 
 |         num_layers = 3 | 
 |         T, B, C = 11, 5, 7 | 
 |  | 
 |         class RNNTraceWrapper(torch.nn.Module): | 
 |             def __init__(self, cell_type): | 
 |                 super(RNNTraceWrapper, self).__init__() | 
 |                 if cell_type == 'RNN': | 
 |                     self.rnn = torch.nn.RNN(input_size=C, hidden_size=C, num_layers=num_layers) | 
 |                 elif cell_type == 'LSTM': | 
 |                     self.rnn = torch.nn.LSTM(input_size=C, hidden_size=C, num_layers=num_layers) | 
 |                 elif cell_type == 'GRU': | 
 |                     self.rnn = torch.nn.GRU(input_size=C, hidden_size=C, num_layers=num_layers) | 
 |  | 
 |             def forward(self, x, seq_lens): | 
 |                 x = pack_padded_sequence(x, seq_lens) | 
 |                 x, _ = self.rnn(x) | 
 |                 x, _ = pad_packed_sequence(x) | 
 |                 return x | 
 |  | 
 |         for cell_type in ['RNN', 'LSTM', 'GRU']: | 
 |             x = torch.ones(T, B, C, requires_grad=True) | 
 |             seq_lens = torch.from_numpy(np.array([11, 3, 2, 2, 1], dtype=np.int32)) | 
 |  | 
 |             m = RNNTraceWrapper(cell_type) | 
 |             m_traced = torch.jit.trace(m, (x, seq_lens,)) | 
 |  | 
 |             y = m(x, seq_lens) | 
 |             loss = torch.sum(y) | 
 |             loss.backward() | 
 |             grad = x.grad.clone() | 
 |             x.grad.zero_() | 
 |  | 
 |             y_traced = m_traced(x, seq_lens) | 
 |             loss_traced = torch.sum(y_traced) | 
 |             loss_traced.backward() | 
 |             grad_traced = x.grad.clone() | 
 |  | 
 |             self.assertEqual(y_traced, y) | 
 |             self.assertEqual(grad, grad_traced) | 
 |  | 
 |             f = io.BytesIO() | 
 |             torch.onnx._export(m, (x, seq_lens), f, verbose=False) | 
 |  | 
 |     def test_python_call_non_tensor(self): | 
 |         def foo(a, b, c): | 
 |             # type: (Tensor, int, Tuple[Tensor, int]) -> Tuple[int, Tensor] | 
 |             d, e = c | 
 |             return b + e, a + d | 
 |  | 
 |         @torch.jit.script | 
 |         def bar(): | 
 |             x = torch.ones(3, 4) | 
 |             a, b = foo(x, 3, (x, 3)) | 
 |             return a, b | 
 |  | 
 |         self.assertEqual((6, torch.ones(3, 4) + 1), bar()) | 
 |  | 
 |     def test_python_call_non_tensor_wrong(self): | 
 |         with self.assertRaisesRegex(RuntimeError, r"but instead got value of type tuple"): | 
 |             @torch.jit.ignore | 
 |             def foo(): | 
 |                 # type: () -> Tensor | 
 |                 return ((3, 4),)  # noqa: T484 | 
 |  | 
 |             @torch.jit.script | 
 |             def bar(): | 
 |                 return foo() | 
 |  | 
 |             bar() | 
 |  | 
 |     def test_if_different_type(self): | 
 |         with self.assertRaisesRegex(RuntimeError, "c0 is set to type " | 
 |                                     "int in the true branch and type " | 
 |                                     "float in the false branch"): | 
 |             @torch.jit.script | 
 |             def diff_type_used(): | 
 |                 if 1 == 2: | 
 |                     c0 = 1 | 
 |                 else: | 
 |                     c0 = 1.0 | 
 |                 return c0 | 
 |  | 
 |         with self.assertRaisesRegex(RuntimeError, "Variable 'c0' previously had type float"): | 
 |             @torch.jit.script | 
 |             def diff_existing_type(x): | 
 |                 c0 = 1.0 | 
 |                 if 1 == 2: | 
 |                     c0 = 1 | 
 |                     print(x) | 
 |                 return x | 
 |  | 
 |         @torch.jit.script | 
 |         def diff_type_unused(): | 
 |             if 1 == 1: | 
 |                 c0 = 1 | 
 |                 print(c0) | 
 |             else: | 
 |                 c0 = 1.0 | 
 |                 print(c0) | 
 |             return 1 | 
 |  | 
 |     def test_if_not_defined_error(self): | 
 |         with self.assertRaisesRegex(RuntimeError, "c0 is not defined in the false branch"): | 
 |             @torch.jit.script | 
 |             def test(): | 
 |                 if 1 == 1: | 
 |                     c0 = 1 | 
 |                 return c0 | 
 |         with self.assertRaisesRegex(RuntimeError, "c0 is not defined in the true branch"): | 
 |             @torch.jit.script | 
 |             def test2(): | 
 |                 if 1 == 1: | 
 |                     pass | 
 |                 else: | 
 |                     c0 = 1 | 
 |                 return c0 | 
 |  | 
 |     def test_if_list_cat(self): | 
 |         # testing that different length lists don't throw error on cat in shape prop | 
 |         @torch.jit.script | 
 |         def test_list(x): | 
 |             if bool(x.sum() < 1): | 
 |                 c = [x, x] | 
 |             else: | 
 |                 c = [x, x, x] | 
 |             return torch.cat(c) | 
 |  | 
 |         b = torch.zeros(2, 4) | 
 |         _propagate_shapes(test_list.graph, (b,), False) | 
 |  | 
 |     def test_if_supertype(self): | 
 |         @torch.jit.script | 
 |         def tensor_unifying(x, y, z): | 
 |             # testing dynamic is appropriately set for y and z | 
 |             if bool(x): | 
 |                 x, y, z = x + 1, y, z | 
 |             else: | 
 |                 x, y, z = x + 1, x, y | 
 |  | 
 |             return x, y, z | 
 |  | 
 |         a = torch.zeros(2, 2, dtype=torch.float) | 
 |         b = torch.zeros(2, 4, dtype=torch.long) | 
 |         c = torch.zeros(2, 4, dtype=torch.float) | 
 |  | 
 |         graph = _propagate_shapes(tensor_unifying.graph, (a, b, c), False) | 
 |         if_outputs = list(graph.findNode("prim::If").outputs()) | 
 |         self.assertTrue(if_outputs[0].type().str() == "Float(*, *, requires_grad=0, device=cpu)") | 
 |         self.assertTrue(if_outputs[1].type().str() == "Tensor(*, *, requires_grad=0, device=cpu)") | 
 |         self.assertTrue(if_outputs[2].type().str() == "Tensor(*, *, requires_grad=0, device=cpu)") | 
 |  | 
 |     def test_list_unify(self): | 
 |         # allowing a unififed int?[] would cause a runtime error b/c | 
 |         # the index operation expects int?[] to be a generic list, | 
 |         # but in the true branch the IValue will be a int list | 
 |         with self.assertRaisesRegex(RuntimeError, "int[] in the true branch and type None[]"): | 
 |             @torch.jit.script | 
 |             def list_optional_fails(x): | 
 |                 # type: (bool) -> Optional[int] | 
 |                 if x: | 
 |                     y = [1] | 
 |                 else: | 
 |                     y = [None]  # noqa: T484 | 
 |                 return y[0] | 
 |  | 
 |         @torch.jit.script | 
 |         def list_tensors(x): | 
 |             # type: (bool) -> Tuple[Tensor, List[Tensor]] | 
 |             if x: | 
 |                 a = torch.zeros([1, 1]) | 
 |                 y = [a] | 
 |             else: | 
 |                 a = torch.zeros([1, 2]) | 
 |                 y = [a] | 
 |             return a, y | 
 |  | 
 |         self.run_pass('constant_propagation', list_tensors.graph) | 
 |         m = self.createFunctionFromGraph(list_tensors.graph) | 
 |         # testing that tensor type of lists is unified | 
 |         self.getExportImportCopy(m) | 
 |  | 
 |     @_inline_everything | 
 |     def test_import_constants_not_specialized(self): | 
 |         class Mod(torch.nn.Module): | 
 |             def forward(self, x): | 
 |                 return torch.cat(2 * [x], dim=0) | 
 |  | 
 |         class ScriptMod(torch.jit.ScriptModule): | 
 |             def __init__(self, mod): | 
 |                 super(ScriptMod, self).__init__() | 
 |                 x = torch.zeros(1, 3) | 
 |                 mod_fn = lambda : mod(x)  # noqa: E731 | 
 |                 self.mod = torch.jit.trace(mod_fn, tuple()) | 
 |  | 
 |             @torch.jit.script_method | 
 |             def forward(self): | 
 |                 return self.mod() | 
 |  | 
 |         cm = ScriptMod(Mod()) | 
 |         # specialized tensor in graph | 
 |         FileCheck().check("Double(1, 3, strides=[3, 1], requires_grad=0, device=cpu)").run(cm.forward.graph) | 
 |         buffer = io.BytesIO() | 
 |         torch.jit.save(cm, buffer) | 
 |         buffer.seek(0) | 
 |         # when tensor is loaded as constant it isnt specialized | 
 |         cm_load = torch.jit.load(buffer) | 
 |         FileCheck().check_not("Double(1, 3)").run(cm_load.forward.graph) | 
 |  | 
 |     @skipIfTorchDynamo("TorchDynamo fails with unknown reason") | 
 |     def test_type_annotations_repeated_list(self): | 
 |         @torch.jit.script | 
 |         def float_fn(x, y): | 
 |             # type: (float, BroadcastingList3[float]) -> List[float] | 
 |             return y | 
 |         self.assertEqual(float_fn(2.0, 1.0), float_fn(2.0, [1.0, 1.0, 1.0])) | 
 |         self.assertEqual(float_fn(2.0, 1.0), float_fn(2.0, (1.0, 1.0, 1.0))) | 
 |  | 
 |         @torch.jit.script | 
 |         def float_fn_call(): | 
 |             print(float_fn(1.0, 1.0)) | 
 |             print(float_fn(1.0, (1.0, 1.0, 1.0))) | 
 |  | 
 |         @torch.jit.script | 
 |         def int_fn(x): | 
 |             # type: (BroadcastingList3[int]) -> List[int] | 
 |             return x | 
 |         self.assertEqual(int_fn(1), int_fn([1, 1, 1])) | 
 |         self.assertEqual(int_fn(1), int_fn((1, 1, 1))) | 
 |  | 
 |         @torch.jit.script | 
 |         def int_fn_call(): | 
 |             print(int_fn(1)) | 
 |             print(int_fn((1, 1, 1))) | 
 |  | 
 |         with self.assertRaisesRegex(RuntimeError, "must be a positive integer:"): | 
 |             @torch.jit.script  # noqa: T484 | 
 |             def fn(x): | 
 |                 # type: (BroadcastingListx[int]) -> List[int]  # noqa: T484 | 
 |                 return x | 
 |  | 
 |         # using CU so that flake8 error on int[2] is not raised (noqa not working) | 
 |         with self.assertRaisesRegex(RuntimeError, "Unknown type constructor"): | 
 |             cu = torch.jit.CompilationUnit(''' | 
 |                 def nested(x, y): | 
 |                     # type: (int, Tuple[int, int[2]]) -> List[int] | 
 |                     return x  # noqa: T484 | 
 |             ''') | 
 |  | 
 |         @torch.jit.script | 
 |         def f(x: BroadcastingList2[int]): | 
 |             return x | 
 |  | 
 |         out = f(1) | 
 |         self.assertTrue(isinstance(out[0], int)) | 
 |         self.assertEqual(out, [1, 1]) | 
 |  | 
 |     def test_ntuple_builtins(self): | 
 |         from torch.nn.modules.utils import _single, _pair, _triple, _quadruple | 
 |  | 
 |         def test_ints(): | 
 |             return _single(1), _pair(2), _triple(3), _quadruple(4) | 
 |  | 
 |         def test_floats(): | 
 |             return _single(1), _pair(2.1), _triple(3.1), _quadruple(4.1) | 
 |  | 
 |         self.checkScript(test_ints, ()) | 
 |         self.checkScript(test_floats, ()) | 
 |  | 
 |     def test_embedding_renorm_grad_error(self): | 
 |         # Testing that the builtin call to embedding_renorm_ correctly throws | 
 |         # Error when .backward() is called on its input | 
 |  | 
 |         def embedding_norm(input, embedding_matrix, max_norm): | 
 |             F.embedding(input, embedding_matrix, max_norm=0.01) | 
 |  | 
 |         @torch.jit.script | 
 |         def embedding_norm_script(input, embedding_matrix, max_norm): | 
 |             # type: (Tensor, Tensor, float) -> None | 
 |             F.embedding(input, embedding_matrix, max_norm=0.01) | 
 |  | 
 |         for _ in [embedding_norm, embedding_norm_script]: | 
 |             input = torch.tensor([[1, 2, 4, 5], [4, 3, 2, 9]]) | 
 |             embedding_matrix = torch.randn(10, 3) | 
 |  | 
 |             var1 = torch.randn(10, 3, requires_grad=True) | 
 |             var2 = var1.detach().requires_grad_() | 
 |             output1 = var1 * embedding_matrix | 
 |             output2 = var2 * embedding_matrix | 
 |  | 
 |             output1.sum().backward() | 
 |  | 
 |             ignore = F.embedding(input, embedding_matrix, max_norm=0.01) | 
 |             with self.assertRaisesRegex(RuntimeError, "modified"): | 
 |                 output2.sum().backward() | 
 |  | 
 |     def test_type_annotations(self): | 
 |         def fn(x, y): | 
 |             # type: (Tensor, Tensor) -> Tuple[Tensor, Tensor, Tensor] | 
 |             return x, x * 2, x * 3 | 
 |  | 
 |         with self.assertRaisesRegex(RuntimeError, r"need 4 values .* found only 3"): | 
 |             @torch.jit.script | 
 |             def script_fn(x): | 
 |                 x, y, z, w = fn(x, x) | 
 |  | 
 |         with self.assertRaisesRegex(RuntimeError, r"too many values .* need 2 but found 3"): | 
 |             @torch.jit.script | 
 |             def script_fn2(x): | 
 |                 x, y = fn(x, x) | 
 |  | 
 |         def fn_unpack(x): | 
 |             y, z, w = fn(x, x) | 
 |             return y | 
 |  | 
 |         def fn_index(x): | 
 |             q = fn(x, x) | 
 |             return x | 
 |  | 
 |         def fn_string(str, strpair): | 
 |             # type: (str, Tuple[str, str]) -> Tuple[str, int, str, str] | 
 |             str1, str2 = strpair | 
 |             return str, 2, str1, str2 | 
 |  | 
 |         x = torch.ones(2, 2) | 
 |         self.checkScript(fn_unpack, (x,), optimize=True) | 
 |         self.checkScript(fn_index, (x,), optimize=True) | 
 |         self.checkScript(fn_string, ("1", ("3", "4")), optimize=True) | 
 |  | 
 |     def test_type_annotations_varargs(self): | 
 |         @torch.jit.ignore | 
 |         def fn_varargs(x, *args): | 
 |             return args[0] if args else x | 
 |  | 
 |         def fn1(x, y, z): | 
 |             return fn_varargs(x) | 
 |  | 
 |         def fn2(x, y, z): | 
 |             return fn_varargs(x, y) | 
 |  | 
 |         def fn3(x, y, z): | 
 |             return fn_varargs(x, y, z) | 
 |  | 
 |         x, y, z = [torch.randn(2, 2) for _ in range(3)] | 
 |         self.checkScript(fn1, (x, y, z), optimize=True) | 
 |         self.checkScript(fn2, (x, y, z), optimize=True) | 
 |         self.checkScript(fn3, (x, y, z), optimize=True) | 
 |  | 
 |     def test_type_annotation_py3(self): | 
 |         code = dedent(""" | 
 |         import torch | 
 |         from torch import Tensor | 
 |         from typing import Tuple | 
 |  | 
 |         def fn(x : torch.Tensor, y : Tensor, z) -> Tuple[Tensor, Tensor, Tensor]: | 
 |             return (x, y + z, z) | 
 |         """) | 
 |  | 
 |         with tempfile.TemporaryDirectory() as tmp_dir: | 
 |             script_path = os.path.join(tmp_dir, 'script.py') | 
 |             with open(script_path, 'w') as f: | 
 |                 f.write(code) | 
 |             fn = get_fn('test_type_annotation_py3', script_path) | 
 |             fn = torch.jit.ignore(fn) | 
 |  | 
 |             with self.assertRaisesRegex(RuntimeError, r"Expected a value of type 'Tensor' for argument" | 
 |                                                       r" 'x' but instead found type 'Tuple\[Tensor,"): | 
 |                 @torch.jit.script | 
 |                 def bad_fn(x): | 
 |                     x, y = fn((x, x), x, x) | 
 |                     return y | 
 |  | 
 |             with self.assertRaisesRegex(RuntimeError, r"too many values .* need 2 but found 3"): | 
 |                 @torch.jit.script | 
 |                 def bad_fn2(x): | 
 |                     x, y = fn(x, x, x) | 
 |                     return y | 
 |  | 
 |             with self.assertRaisesRegex(RuntimeError, r"need 4 values .* found only 3"): | 
 |                 @torch.jit.script | 
 |                 def bad_fn3(x): | 
 |                     x, y, z, w = fn(x, x, x) | 
 |                     return y | 
 |  | 
 |             def good_fn(x): | 
 |                 y, z, w = fn(x, x, x) | 
 |                 return y, z, w | 
 |  | 
 |             self.checkScript(good_fn, (torch.ones(2, 2),), optimize=True) | 
 |  | 
 |     def test_type_annotation_module(self): | 
 |         class BaseModule(torch.jit.ScriptModule): | 
 |             @torch.jit.ignore | 
 |             def foo(self, x): | 
 |                 # type: (Tensor) -> Tensor | 
 |                 return x + 1 | 
 |  | 
 |             @torch.jit.ignore | 
 |             def bar(self, x, y): | 
 |                 # type: (Tensor, Tensor) -> Tuple[Tensor, Tensor] | 
 |                 return x + y, y | 
 |  | 
 |             @torch.jit.ignore | 
 |             def baz(self, x, y): | 
 |                 return x | 
 |  | 
 |         class ModuleTooMany(BaseModule): | 
 |             @torch.jit.script_method | 
 |             def method(self, x): | 
 |                 return self.foo(x, x) | 
 |  | 
 |         class ModuleTooFew(BaseModule): | 
 |             @torch.jit.script_method | 
 |             def method(self, x): | 
 |                 return self.bar(x) | 
 |  | 
 |         class ModuleTooManyAssign(BaseModule): | 
 |             @torch.jit.script_method | 
 |             def method(self, x): | 
 |                 y, z, w = self.bar(x, x) | 
 |                 return x | 
 |  | 
 |         class ModuleDefault(BaseModule): | 
 |             @torch.jit.script_method | 
 |             def method(self, x): | 
 |                 y = self.baz(x) | 
 |                 return x | 
 |  | 
 |         with self.assertRaisesRegex(RuntimeError, "Expected at most 2 arguments but found 3"): | 
 |             ModuleTooMany() | 
 |         with self.assertRaisesRegex(RuntimeError, "Argument y not provided"): | 
 |             ModuleTooFew() | 
 |         with self.assertRaisesRegex(RuntimeError, "need 3 values .* found only 2"): | 
 |             ModuleTooManyAssign() | 
 |         with self.assertRaisesRegex(RuntimeError, "Argument y not provided."): | 
 |             ModuleDefault() | 
 |  | 
 |     def test_type_inferred_from_empty_annotation(self): | 
 |         """ | 
 |         Test that the type inferred from an empty or missing annotation is Torch.Tensor wtih `inferred=true` | 
 |         """ | 
 |         @torch.jit.script | 
 |         def fn(x): | 
 |             return x | 
 |  | 
 |         graph = fn.graph | 
 |         n = next(graph.inputs()) | 
 |         self.assertTrue(n.type() == torch._C.TensorType.getInferred()) | 
 |  | 
 |         with self.assertRaisesRegex(RuntimeError, "Inferred \'x\' to be of type \'Tensor"): | 
 |             fn("1") | 
 |  | 
 |     def test_script_define_order(self): | 
 |         class M(torch.jit.ScriptModule): | 
 |  | 
 |             @torch.jit.script_method | 
 |             def call_foo(self, input): | 
 |                 return self.foo(input) | 
 |  | 
 |             @torch.jit.script_method | 
 |             def foo(self, input): | 
 |                 return input + 1 | 
 |         m = M() | 
 |         self.assertEqual(2, m.call_foo(torch.ones((), dtype=torch.int64))) | 
 |  | 
 |     def test_script_define_order_recursive_fail(self): | 
 |         class M(torch.jit.ScriptModule): | 
 |  | 
 |             @torch.jit.script_method | 
 |             def call_foo(self, input): | 
 |                 return self.foo(input) | 
 |  | 
 |             @torch.jit.script_method | 
 |             def foo(self, input): | 
 |                 self.call_foo(input) | 
 |  | 
 |         with self.assertRaisesRegex(RuntimeError, 'called recursively'): | 
 |             M() | 
 |  | 
 |     def test_script_kwargs_fn_call(self): | 
 |         class M(torch.jit.ScriptModule): | 
 |  | 
 |             @torch.jit.script_method | 
 |             def call_foo(self, input): | 
 |                 return self.foo(input=input, bar=1) | 
 |  | 
 |             @torch.jit.script_method | 
 |             def foo(self, bar, input): | 
 |                 # type: (int, Tensor) -> Tensor | 
 |                 return input + bar | 
 |         m = M() | 
 |         self.assertEqual(2, m.call_foo(torch.ones((), dtype=torch.int64))) | 
 |  | 
 |     def test_if_define(self): | 
 |         @torch.jit.script | 
 |         def foo(a): | 
 |             if bool(a == 0): | 
 |                 b = 1 | 
 |             else: | 
 |                 b = 0 | 
 |             return b + 1 | 
 |  | 
 |         @torch.jit.script | 
 |         def foo2(a): | 
 |             b = 0 | 
 |             if bool(a == 0): | 
 |                 b = 1 | 
 |             return b + 1 | 
 |  | 
 |         @torch.jit.script | 
 |         def foo3(a): | 
 |             b = 1 | 
 |             if bool(a == 0): | 
 |                 c = 4 | 
 |             else: | 
 |                 b = 0 | 
 |             return b + 1 | 
 |  | 
 |         a = torch.ones(1, dtype=torch.long) | 
 |         b = torch.zeros(1, dtype=torch.long) | 
 |         self.assertEqual(1, foo(a)) | 
 |         self.assertEqual(2, foo(b)) | 
 |         self.assertEqual(1, foo2(a)) | 
 |         self.assertEqual(2, foo2(b)) | 
 |         self.assertEqual(1, foo3(a)) | 
 |         self.assertEqual(2, foo3(b)) | 
 |  | 
 |     def test_script_module_export_submodule(self): | 
 |         class M1(torch.jit.ScriptModule): | 
 |             def __init__(self): | 
 |                 super(M1, self).__init__() | 
 |                 self.weight = nn.Parameter(torch.randn(2)) | 
 |  | 
 |             @torch.jit.script_method | 
 |             def forward(self, thing): | 
 |                 return self.weight + thing | 
 |  | 
 |         class M2(torch.jit.ScriptModule): | 
 |             def __init__(self): | 
 |                 super(M2, self).__init__() | 
 |                 # test submodule | 
 |                 self.sub = M1() | 
 |                 self.weight = nn.Parameter(torch.randn(2, 3)) | 
 |                 self.bias = nn.Parameter(torch.randn(2)) | 
 |                 self.define(""" | 
 |                     def hi(self, a): | 
 |                         return self.weight.mm(a) | 
 |                 """) | 
 |  | 
 |             @torch.jit.script_method | 
 |             def doit(self, input): | 
 |                 return self.weight.mm(input) | 
 |  | 
 |             @torch.jit.script_method | 
 |             def doit2(self, input): | 
 |                 return self.weight.mm(input) | 
 |  | 
 |             @torch.jit.script_method | 
 |             def doit3(self, input): | 
 |                 return input + torch.ones([1], dtype=torch.double) | 
 |  | 
 |             @torch.jit.script_method | 
 |             def forward(self, input): | 
 |                 a = self.doit(input) | 
 |                 b = self.doit2(input) | 
 |                 c = self.hi(input) | 
 |                 return a + b + self.bias + c | 
 |  | 
 |         with torch.jit.optimized_execution(False): | 
 |             m_orig = M2() | 
 |             m_import = self.getExportImportCopy(m_orig) | 
 |  | 
 |             input = torch.randn(3, 2) | 
 |             self.assertEqual(m_orig.doit(input), m_import.doit(input)) | 
 |             self.assertEqual(m_orig.hi(input), m_import.hi(input)) | 
 |             self.assertEqual(m_orig.doit3(input), m_import.doit3(input)) | 
 |             self.assertEqual(m_orig.forward(input), m_import.forward(input)) | 
 |  | 
 |     @slowTest | 
 |     def test_compile_module_with_constant(self): | 
 |         class Double(nn.Module): | 
 |             def __init__(self, downsample=None): | 
 |                 super(Double, self).__init__() | 
 |  | 
 |             def forward(self, input): | 
 |                 return input * 2 | 
 |  | 
 |         class Mod(nn.Module): | 
 |             __constants__ = ['downsample'] | 
 |  | 
 |             def __init__(self, downsample=None): | 
 |                 super(Mod, self).__init__() | 
 |                 self.downsample = downsample | 
 |  | 
 |             def forward(self, input): | 
 |                 if self.downsample is not None: | 
 |                     return self.downsample(input) | 
 |                 return input | 
 |  | 
 |         none_mod = torch.jit.script(Mod(None)) | 
 |         double_mod = torch.jit.script(Mod(Double())) | 
 |         self.assertEqual(none_mod(torch.tensor(1)), torch.tensor(1)) | 
 |         self.assertEqual(double_mod(torch.tensor(1)), torch.tensor(1) * 2) | 
 |  | 
 |     def test_device_kwarg(self): | 
 |         from torch import device | 
 |  | 
 |         def f(): | 
 |             return device(type='cuda'), torch.device(type='cpu') | 
 |         self.checkScript(f, ()) | 
 |  | 
 |     def test_script_module_export_tensor_type(self): | 
 |         class M(torch.jit.ScriptModule): | 
 |             def __init__(self, type): | 
 |                 super(M, self).__init__() | 
 |                 self.param = torch.nn.Parameter(torch.zeros((5, 5), dtype=type).random_()) | 
 |  | 
 |             @torch.jit.script_method | 
 |             def foo(self): | 
 |                 return self.param | 
 |  | 
 |         with torch.jit.optimized_execution(False): | 
 |             for type in [torch.float, torch.double]: | 
 |                 m_orig = M(type) | 
 |                 m_import = self.getExportImportCopy(m_orig) | 
 |                 # check to make sure the storage wasn't resized | 
 |                 self.assertTrue(m_orig.param.storage().size() == 25) | 
 |                 self.assertEqual(m_orig.foo(), m_import.foo()) | 
 |                 self.assertTrue(m_orig.foo().dtype == m_import.foo().dtype) | 
 |  | 
 |     @unittest.skipIf(not RUN_CUDA, "testing cuda tensors require CUDA") | 
 |     def test_script_module_export_tensor_cuda(self): | 
 |         class M(torch.jit.ScriptModule): | 
 |  | 
 |             def __init__(self): | 
 |                 super(M, self).__init__() | 
 |                 self.param = torch.nn.Parameter(torch.zeros((5, 5), device='cuda:0').random_()) | 
 |  | 
 |             @torch.jit.script_method | 
 |             def foo(self): | 
 |                 return self.param | 
 |  | 
 |         m_orig = M() | 
 |         m_import = self.getExportImportCopy(m_orig) | 
 |         # check to make sure the storage wasn't resized | 
 |         self.assertTrue(m_orig.param.storage().size() == 25) | 
 |         self.assertTrue(m_import.foo().device == torch.device('cuda:0')) | 
 |         self.assertEqual(m_orig.foo(), m_import.foo()) | 
 |         self.assertTrue(m_orig.foo().dtype == m_import.foo().dtype) | 
 |  | 
 |     def test_script_module_export_blocks(self): | 
 |         class M(torch.jit.ScriptModule): | 
 |             def __init__(self, n, m): | 
 |                 super(M, self).__init__() | 
 |                 self.weight = torch.nn.Parameter(torch.rand(n, m)) | 
 |  | 
 |             @torch.jit.script_method | 
 |             def forward(self, input): | 
 |                 if bool(input.sum() > 0): | 
 |                     output = self.weight.mv(input) | 
 |                 else: | 
 |                     output = self.weight + input | 
 |                 return output | 
 |  | 
 |         m_orig = M(200, 200) | 
 |         m_import = self.getExportImportCopy(m_orig) | 
 |  | 
 |         t = torch.rand(200) | 
 |         self.assertEqual(m_orig(t), m_import(t)) | 
 |  | 
 |     def test_script_module_export_shared_storage(self): | 
 |         class M(torch.jit.ScriptModule): | 
 |  | 
 |             def __init__(self): | 
 |                 super(M, self).__init__() | 
 |                 self.param1 = torch.nn.Parameter(torch.rand(5, 5)) | 
 |                 self.param2 = torch.nn.Parameter(self.param1[3]) | 
 |                 self.param3 = torch.nn.Parameter(torch.rand(5, 5)) | 
 |                 self.param4 = torch.nn.Parameter(torch.rand(11, 5)[1:6]) | 
 |  | 
 |             @torch.jit.script_method | 
 |             def foo(self): | 
 |                 return self.param1 + self.param2 + self.param3 + self.param4 | 
 |  | 
 |         with torch.jit.optimized_execution(False): | 
 |             m_orig = M() | 
 |             m_import = self.getExportImportCopy(m_orig) | 
 |  | 
 |             self.assertEqual(m_orig.foo(), m_import.foo()) | 
 |  | 
 |             self.assertTrue(m_import.param1.storage().data_ptr() == m_import.param2.storage().data_ptr()) | 
 |             self.assertTrue(m_import.param1.storage().data_ptr() != m_import.param3.storage().data_ptr()) | 
 |  | 
 |     def test_sequential_intermediary_types(self): | 
 |         class A(torch.nn.Module): | 
 |             def __init__(self): | 
 |                 super(A, self).__init__() | 
 |  | 
 |             def forward(self, x): | 
 |                 return x + 3 | 
 |  | 
 |         class B(torch.nn.Module): | 
 |             def __init__(self): | 
 |                 super(B, self).__init__() | 
 |  | 
 |             def forward(self, x): | 
 |                 return {"1": x} | 
 |  | 
 |         class C(torch.nn.Module): | 
 |             def __init__(self): | 
 |                 super(C, self).__init__() | 
 |                 self.foo = torch.nn.Sequential(A(), B()) | 
 |  | 
 |             def forward(self, x): | 
 |                 return self.foo(x) | 
 |  | 
 |         self.checkModule(C(), (torch.tensor(1),)) | 
 |  | 
 |     def test_ellipsis_const_mid(self): | 
 |         def ellipsize(x): | 
 |             # type: (Tensor) -> List[int] | 
 |             return x[2, Ellipsis, 0:4, 4:8].size() | 
 |  | 
 |         dummy = torch.zeros(8, 8, 8, 8, 8) | 
 |         self.checkScript(ellipsize, (dummy,), optimize=True) | 
 |  | 
 |     def test_ellipsis_const_mid_select(self): | 
 |         def ellipsize(x): | 
 |             # type: (Tensor) -> List[int] | 
 |             return x[2, Ellipsis, 4, 4, 4:8, 2].size() | 
 |  | 
 |         dummy = torch.zeros(8, 8, 8, 8, 8, 8, 8) | 
 |         self.checkScript(ellipsize, (dummy,), optimize=True) | 
 |  | 
 |     def test_ellipsis_const_start(self): | 
 |         def ellipsize(x): | 
 |             # type: (Tensor) -> List[int] | 
 |             return x[Ellipsis, 0:4, 4:8].size() | 
 |         dummy = torch.zeros(8, 8, 8, 8, 8) | 
 |         self.checkScript(ellipsize, (dummy,), optimize=True) | 
 |  | 
 |     def test_ellipsis_const_end(self): | 
 |         def ellipsize(x): | 
 |             # type: (Tensor) -> List[int] | 
 |             return x[0:4, 2, Ellipsis].size() | 
 |         dummy = torch.zeros(8, 8, 8, 8, 8) | 
 |         self.checkScript(ellipsize, (dummy,), optimize=True) | 
 |  | 
 |     def test_ellipsis_mid(self): | 
 |         def ellipsize(x): | 
 |             # type: (Tensor) -> List[int] | 
 |             return x[2, ..., 0:4, 4:8].size() | 
 |  | 
 |         dummy = torch.zeros(8, 8, 8, 8, 8) | 
 |         self.checkScript(ellipsize, (dummy,), optimize=True) | 
 |  | 
 |     def test_ellipsis_mid_select(self): | 
 |         def ellipsize(x): | 
 |             # type: (Tensor) -> List[int] | 
 |             return x[2, ..., 4, 4, 4:8, 2].size() | 
 |  | 
 |         dummy = torch.zeros(8, 8, 8, 8, 8, 8, 8) | 
 |         self.checkScript(ellipsize, (dummy,), optimize=True) | 
 |  | 
 |     def test_ellipsis_start(self): | 
 |         def ellipsize(x): | 
 |             # type: (Tensor) -> List[int] | 
 |             return x[..., 0:4, 4:8].size() | 
 |         dummy = torch.zeros(8, 8, 8, 8, 8) | 
 |         self.checkScript(ellipsize, (dummy,), optimize=True) | 
 |  | 
 |     def test_ellipsis_end(self): | 
 |         def ellipsize(x): | 
 |             # type: (Tensor) -> List[int] | 
 |             return x[0:4, 2, ...].size() | 
 |         dummy = torch.zeros(8, 8, 8, 8, 8) | 
 |         self.checkScript(ellipsize, (dummy,), optimize=True) | 
 |  | 
 |     def test_torch_manual_seed(self): | 
 |         with freeze_rng_state(): | 
 |             def test(): | 
 |                 torch.manual_seed(2) | 
 |                 return torch.rand(1) | 
 |  | 
 |             script = torch.jit.script(test) | 
 |             self.assertEqual(test(), script()) | 
 |             graph = script.graph_for() | 
 |             FileCheck().check("aten::manual_seed").run(graph) | 
 |  | 
 |     def test_index_select_shape_prop(self): | 
 |  | 
 |         @torch.jit.script | 
 |         def foo(x, y): | 
 |             return torch.index_select(x, index=y, dim=1) | 
 |  | 
 |         a = torch.zeros(2, 2) | 
 |         b = torch.zeros(4, dtype=torch.long) | 
 |         torch._C._jit_pass_complete_shape_analysis(foo.graph, (a, b), False) | 
 |         FileCheck().check("Double(2, 4, strides=[4, 1], requires_grad=0, device=cpu)").run(str(foo.graph)) | 
 |  | 
 |     def test_shape_analysis_loop(self): | 
 |         def foo(a, b, x): | 
 |             c = a | 
 |             # on the first iteration of the loop it appears that | 
 |             # c should have a expand to the size of b | 
 |             # but on the second+ iterations, there is no broadcast and the | 
 |             # sizes are different. | 
 |             # previously this would cause the compiler to (1) enter an infinite | 
 |             # loop trying to compute the shape, and (2) insert invalid | 
 |             # broadcasts. | 
 |             # this test ensure we don't regress on these issues | 
 |             for _ in range(2): | 
 |                 a = c + b | 
 |                 c = x | 
 |                 b = x | 
 |             return a | 
 |  | 
 |         self.checkScript(foo, (torch.zeros(1), torch.zeros(4), torch.zeros(5)), optimize=False) | 
 |  | 
 |     def test_intlist_args(self): | 
 |         def func_1(x): | 
 |             return torch.nn.functional.adaptive_avg_pool1d(x, 1) | 
 |  | 
 |         def func_2(x): | 
 |             return torch.nn.functional.adaptive_avg_pool1d(x, output_size=1) | 
 |  | 
 |         def func_3(x): | 
 |             return torch.nn.functional.adaptive_avg_pool1d(x, output_size=[1]) | 
 |  | 
 |         x = torch.randn(8, 8, 8) | 
 |         self.checkScript(func_1, [x], optimize=True) | 
 |         self.checkScript(func_2, [x], optimize=True) | 
 |         self.checkScript(func_3, [x], optimize=True) | 
 |  | 
 |     def test_wrong_implicit_expand(self): | 
 |  | 
 |         @_trace(torch.zeros(3), torch.zeros(1)) | 
 |         def foo(a, b): | 
 |             return a + b | 
 |  | 
 |         a = torch.rand(4) | 
 |         b = torch.rand(4) | 
 |         self.assertEqual(a + b, foo(a, b)) | 
 |  | 
 |     def test_builtin_args_fails(self): | 
 |  | 
 |         with self.assertRaisesRegex(RuntimeError, 'Argument self not provided'): | 
 |             @torch.jit.script | 
 |             def f1(a): | 
 |                 torch.sum(foo=4) | 
 |  | 
 |         with self.assertRaisesRegex(RuntimeError, 'specified twice'): | 
 |             @torch.jit.script | 
 |             def f2(a): | 
 |                 torch.sum(a, self=a) | 
 |  | 
 |         with self.assertRaisesRegex(RuntimeError, 'not provided'): | 
 |             @torch.jit.script | 
 |             def f3(a): | 
 |                 torch.sum(dim=4) | 
 |  | 
 |         with self.assertRaisesRegex(RuntimeError, 'for argument \'tensors\' but instead found type \'Tensor'): | 
 |             @torch.jit.script | 
 |             def f4(a): | 
 |                 torch.cat(a) | 
 |  | 
 |         with self.assertRaisesRegex(RuntimeError, r'argument \'tensors\' but instead found type \'List\[int\]'): | 
 |             @torch.jit.script | 
 |             def f5(a): | 
 |                 torch.cat([3]) | 
 |  | 
 |         with self.assertRaisesRegex(RuntimeError, r'Expected a value of' | 
 |                                     r' type \'List\[int\]\' for argument' | 
 |                                     r' \'size\' but instead found type ' | 
 |                                     r'\'List\[Union\[List\[int\], int\]\]'): | 
 |             @torch.jit.script | 
 |             def f6(a): | 
 |                 a.expand(size=[3, [4]]) | 
 |  | 
 |     def test_builtin_args(self): | 
 |  | 
 |         def t0(a): | 
 |             # default arg dim | 
 |             return torch.cat([a, a]) | 
 |  | 
 |         self.checkScript(t0, (torch.zeros(1, 1),)) | 
 |  | 
 |         def t1(a): | 
 |             # keywords out of order | 
 |             return torch.cat(dim=1, tensors=[a, a]) | 
 |  | 
 |         self.checkScript(t1, (torch.zeros(1, 1, 2),)) | 
 |  | 
 |         def t2(a): | 
 |             # mix const/non-const attributes | 
 |             if 1 == 1: | 
 |                 b = 1 | 
 |             else: | 
 |                 b = 0 | 
 |             return torch.sum(a, dim=b, keepdim=False) | 
 |  | 
 |         self.checkScript(t2, (torch.zeros(1, 1, 2),)) | 
 |  | 
 |     def test_parser_type_annotations(self): | 
 |         cu = torch.jit.CompilationUnit(''' | 
 |             def foo(x : Tensor, y : Tuple[Tuple[Tensor, Tensor], Tensor]) -> Tuple[Tensor, Tensor]: | 
 |                 return x, x | 
 |         ''') | 
 |  | 
 |         self.assertExpected(str(cu.foo.schema)) | 
 |  | 
 |     def test_parser_type_annotations_comment(self): | 
 |         cu = torch.jit.CompilationUnit(''' | 
 |             def foo(x, y): | 
 |                 # type: (Tensor, Tuple[Tuple[Tensor, Tensor], Tensor]) -> Tuple[Tensor, Tensor] | 
 |                 return x, x | 
 |         ''') | 
 |  | 
 |         self.assertExpected(str(cu.foo.schema)) | 
 |  | 
 |     def test_parser_type_annotations_unknown_type(self): | 
 |         with self.assertRaisesRegex(RuntimeError, "Unknown type name 'Foo'"): | 
 |             cu = torch.jit.CompilationUnit(''' | 
 |                 def foo(x : Tensor, y : Tuple[Tuple[Foo, Tensor], Tensor]) -> Tuple[Tensor, Tensor]: | 
 |                     return x, x | 
 |             ''') | 
 |  | 
 |     def test_parser_type_annotations_subscript_non_ident(self): | 
 |         with self.assertRaisesRegex(RuntimeError, r'Subscripted type must be a type identifier'): | 
 |             cu = torch.jit.CompilationUnit(''' | 
 |                 def foo(x : Tensor, y : Tuple[Tensor, Tensor][Tensor]) -> Tuple[Tensor, Tensor]: | 
 |                     return x, x | 
 |             ''') | 
 |  | 
 |     def test_parser_type_annotations_subscript_tensor(self): | 
 |         with self.assertRaisesRegex(RuntimeError, r'Unknown type constructor Tensor'): | 
 |             cu = torch.jit.CompilationUnit(''' | 
 |                 def foo(x : Tensor, y : Tensor[Tensor, Tensor]) -> Tuple[Tensor, Tensor]: | 
 |                     return x, x | 
 |             ''') | 
 |  | 
 |     def test_parser_type_annotations_incompatible_expression(self): | 
 |         with self.assertRaisesRegex(RuntimeError, r'Expression of type \+ cannot be used in a type expression'): | 
 |             cu = torch.jit.CompilationUnit(''' | 
 |                 def foo(x : Tensor, y : Tuple[3 + 4, Tensor]) -> Tuple[Tensor, Tensor]: | 
 |                     return x, x | 
 |             ''') | 
 |  | 
 |     def test_gather_dynamic_index(self): | 
 |         def t(x): | 
 |             gather1 = x[0] | 
 |             idx = 0 + 1 | 
 |             gather2 = x[idx] | 
 |             return gather1 + gather2 | 
 |  | 
 |         self.checkScript(t, (torch.zeros(3, 2, 3),)) | 
 |  | 
 |     def test_torch_ignore_conversion_to_none(self): | 
 |         class A(torch.nn.Module): | 
 |             def __init__(self): | 
 |                 super(A, self).__init__() | 
 |  | 
 |             @torch.jit.ignore | 
 |             def ignored(self, a: int) -> None: | 
 |                 l: int = len([2 for i in range(a) if i > 2]) | 
 |                 return | 
 |  | 
 |             def forward(self) -> int: | 
 |                 a: int = 4 | 
 |                 b: int = 5 | 
 |                 self.ignored(a) | 
 |                 return a + b | 
 |  | 
 |         class B(torch.nn.Module): | 
 |             def __init__(self): | 
 |                 super(B, self).__init__() | 
 |  | 
 |             @torch.jit.ignore | 
 |             def ignored(self, a: int): | 
 |                 l: int = len([2 for i in range(a) if i > 2]) | 
 |                 return | 
 |  | 
 |             def forward(self) -> int: | 
 |                 a: int = 4 | 
 |                 b: int = 5 | 
 |                 self.ignored(a) | 
 |                 return a + b | 
 |  | 
 |         modelA = torch.jit.script(A()) | 
 |         self.assertEqual(modelA(), 9) | 
 |  | 
 |         modelB = torch.jit.script(B()) | 
 |         self.assertEqual(modelB(), 9) | 
 |  | 
 |     def test_addmm_grad(self): | 
 |         """ This test checks several things: | 
 |             1. An expand node was inserted before the addmm operating on the | 
 |                bias term. | 
 |             2. The fused form of addmm appears in the ultimate graph that's | 
 |                executed. | 
 |             3. A sum op was emitted for accumulating gradients along the 0th | 
 |                (expanded) dimension of the bias term. | 
 |             4. The correct symbolic representation for the backward pass of the | 
 |                mm operator was emitted (x.t() -> mm) | 
 |  | 
 |             TODO: we should actually check these conditions once we have a way | 
 |             to dump the GraphExecutor state. Namely the processed forward graph | 
 |             and the backward graph. | 
 |         """ | 
 |         @torch.jit.script | 
 |         def addmm_grad_test(b, x, w): | 
 |             return torch.addmm(b, x, w) | 
 |  | 
 |         # Initialize param and input values | 
 |         w_init = torch.rand(2, 5) | 
 |         b_init = torch.rand(5) | 
 |         x = torch.rand(3, 2) | 
 |  | 
 |         # Clone trainable params | 
 |         b = b_init.clone() | 
 |         b.requires_grad_() | 
 |         w = w_init.clone() | 
 |         w.requires_grad_() | 
 |  | 
 |         # Test symbolic differentiation | 
 |         y = addmm_grad_test(b, x, w) | 
 |         y.sum().backward() | 
 |  | 
 |         # clone params for autograd reference | 
 |         b_ref = b_init.clone() | 
 |         b_ref.requires_grad_() | 
 |         w_ref = w_init.clone() | 
 |         w_ref.requires_grad_() | 
 |         y_ref = torch.addmm(b_ref, x, w_ref) | 
 |         y_ref.sum().backward() | 
 |  | 
 |         self.assertEqual(w.grad, w_ref.grad) | 
 |         self.assertEqual(b.grad, b_ref.grad) | 
 |  | 
 |     @unittest.skipIf(not RUN_CUDA, "running tests on cuda to verify cudnn fix") | 
 |     def test_batch_norm_inference_backward_cuda(self): | 
 |         with enable_profiling_mode_for_profiling_tests(): | 
 |             class MyBatchNorm(torch.nn.Module): | 
 |                 def __init__(self, num_features, affine, track_running_stats): | 
 |                     super(MyBatchNorm, self).__init__() | 
 |                     self.bn = torch.nn.BatchNorm2d( | 
 |                         num_features, 1e-5, affine=affine, track_running_stats=track_running_stats).float() | 
 |  | 
 |                 def forward(self, x: torch.Tensor): | 
 |                     o = self.bn(x) | 
 |                     o = torch.nn.functional.relu(o) | 
 |                     return o | 
 |  | 
 |             batch = 4 | 
 |             c = 2 | 
 |             hw = 3 | 
 |             # Initialize param and input values | 
 |             x_init = torch.randn(batch, c, hw, hw, dtype=torch.float).cuda() | 
 |             grad = torch.randn(batch, c, hw, hw, dtype=torch.float).cuda() | 
 |  | 
 |             training = False | 
 |             affine = True | 
 |             track_running_stats = True | 
 |  | 
 |             module = torch.jit.script(MyBatchNorm(c, affine, track_running_stats)).cuda() | 
 |             ref_module = MyBatchNorm(c, affine, track_running_stats).cuda() | 
 |             module.eval() | 
 |             ref_module.eval() | 
 |  | 
 |             jit_module = torch.jit.script(module) | 
 |             ref_module.load_state_dict(module.state_dict()) | 
 |  | 
 |             x = x_init.detach().clone() | 
 |             x.requires_grad_() | 
 |             x_ref = x_init.detach().clone() | 
 |             x_ref.requires_grad_() | 
 |  | 
 |             # Test symbolic differentiation | 
 |             # Run Forward and Backward thrice to trigger autodiff graph | 
 |             for i in range(0, 3): | 
 |                 y = jit_module(x) | 
 |                 y.backward(grad) | 
 |             x.grad.zero_() | 
 |  | 
 |             module.bn.running_mean.zero_() | 
 |             module.bn.running_var.fill_(1.0) | 
 |             ref_module.bn.running_mean.zero_() | 
 |             ref_module.bn.running_var.fill_(1.0) | 
 |  | 
 |             # run jitted module | 
 |             y = jit_module(x) | 
 |             y.backward(grad) | 
 |             # reference computation | 
 |             y_ref = ref_module(x_ref) | 
 |             y_ref.backward(grad) | 
 |  | 
 |             self.assertEqual(y_ref, y) | 
 |             self.assertEqual(x.grad, x_ref.grad) | 
 |             self.assertEqual(module.bn.running_mean, ref_module.bn.running_mean) | 
 |             self.assertEqual(module.bn.running_var, ref_module.bn.running_var) | 
 |  | 
 |     def test_zeros(self): | 
 |         class M(torch.jit.ScriptModule): | 
 |             __constants__ = ['d'] | 
 |  | 
 |             def __init__(self): | 
 |                 super(M, self).__init__() | 
 |                 self.d = torch.device('cpu') | 
 |  | 
 |             @torch.jit.script_method | 
 |             def create(self): | 
 |                 return torch.zeros([1, 1, 2], dtype=torch.float, device=self.d, layout=torch.strided) | 
 |  | 
 |         r = M().create() | 
 |         self.assertEqual(r.dtype, torch.float) | 
 |         self.assertEqual(torch.zeros([1, 1, 2], dtype=torch.float), r) | 
 |  | 
 |         def fn(): | 
 |             return torch.zeros((1, 2, 3)) | 
 |  | 
 |         self.checkScript(fn, ()) | 
 |  | 
 |     def test_vararg_zeros(self): | 
 |         def foo(): | 
 |             return torch.zeros(3, 4, 5, dtype=torch.int) | 
 |  | 
 |         self.checkScript(foo, ()) | 
 |  | 
 |     @unittest.skipIf(GRAPH_EXECUTOR != ProfilingMode.LEGACY, "the original version of test_rand") | 
 |     def test_rand(self): | 
 |         def test_rand(): | 
 |             a = torch.rand([3, 4]) | 
 |             return a + 1.0 - a | 
 |  | 
 |         self.checkScript(test_rand, ()) | 
 |         fn = torch.jit.script(test_rand) | 
 |         out = fn() | 
 |         self.assertEqual(out.dtype, torch.double) | 
 |         g = fn.graph_for() | 
 |         # Testing shape analysis correctly setting type | 
 |         if GRAPH_EXECUTOR != ProfilingMode.SIMPLE: | 
 |             FileCheck().check("Double(*, *, requires_grad=0, device=cpu)") \ | 
 |                        .check_not("Float(*, *, requires_grad=0, device=cpu)").run(g) | 
 |  | 
 |         @torch.jit.script | 
 |         def randint(): | 
 |             return torch.randint(0, 5, [1, 2]) | 
 |         out = randint() | 
 |         self.assertEqual(out.dtype, torch.int64) | 
 |         if GRAPH_EXECUTOR != ProfilingMode.SIMPLE: | 
 |             FileCheck().check("Long(*, *, requires_grad=0, device=cpu)") \ | 
 |                        .check_not("Float(*, *, requires_grad=0, device=cpu)") \ | 
 |                        .check_not("Double(*, *, requires_grad=0, device=cpu)") \ | 
 |                        .run(randint.graph_for()) | 
 |  | 
 |     @unittest.skipIf(not RUN_CUDA, "no CUDA") | 
 |     @unittest.skipIf(GRAPH_EXECUTOR != ProfilingMode.PROFILING, "skip if profiling isn't enabled") | 
 |     def test_autodiff_complex(self): | 
 |         def foo(x: torch.Tensor, y: torch.Tensor, W: torch.Tensor): | 
 |             return torch.exp(torch.mm(torch.complex(x, y), W.cfloat())) | 
 |  | 
 |         @torch.jit.script | 
 |         def jitted_foo(x: torch.Tensor, y: torch.Tensor, W: torch.Tensor): | 
 |             return torch.exp(torch.mm(torch.complex(x, y), W.cfloat())) | 
 |  | 
 |         x = torch.randn(128, 16, dtype=torch.float32, device='cuda:0') | 
 |         y = torch.randn(128, 16, dtype=torch.float32, device='cuda:0') | 
 |         W = torch.randn(16, 1, dtype=torch.float32, device='cuda:0', requires_grad=True) | 
 |         W.data /= 4 | 
 |  | 
 |         with enable_profiling_mode_for_profiling_tests(): | 
 |             for i in range(4): | 
 |                 self.assertTrue((foo(x, y, W).grad_fn is None) == (jitted_foo(x, y, W).grad_fn is None)) | 
 |  | 
 |  | 
 |     def test_linear_grad(self): | 
 |         with enable_profiling_mode_for_profiling_tests(): | 
 |             def t(x: torch.Tensor, w: torch.Tensor, b: Optional[torch.Tensor]): | 
 |                 return torch.nn.functional.linear(x, w, b) | 
 |  | 
 |             x_init = torch.randn(4, 2) | 
 |             w_init = torch.randn(3, 2) | 
 |             b_init = torch.randn(3) | 
 |             grad = torch.randn(4, 3) | 
 |  | 
 |             with disable_autodiff_subgraph_inlining(): | 
 |                 # script module | 
 |                 jit_t = torch.jit.script(t) | 
 |  | 
 |                 x = x_init.detach().requires_grad_() | 
 |                 w = w_init.detach().requires_grad_() | 
 |                 b = b_init.detach().requires_grad_() | 
 |                 x_ref = x_init.detach().requires_grad_() | 
 |                 w_ref = w_init.detach().requires_grad_() | 
 |                 b_ref = b_init.detach().requires_grad_() | 
 |  | 
 |                 # profiling/optimization runs | 
 |                 jit_o = jit_t(x, w, b) | 
 |                 jit_o.backward(grad) | 
 |                 jit_o = jit_t(x, w, b) | 
 |                 jit_o.backward(grad) | 
 |  | 
 |                 x.grad.zero_() | 
 |                 w.grad.zero_() | 
 |                 b.grad.zero_() | 
 |                 jit_o = jit_t(x, w, b) | 
 |                 jit_o.backward(grad) | 
 |                 o = t(x_ref, w_ref, b_ref) | 
 |                 o.backward(grad) | 
 |  | 
 |                 self.assertEqual(jit_o, o) | 
 |                 self.assertEqual(x.grad, x_ref.grad) | 
 |                 self.assertEqual(w.grad, w_ref.grad) | 
 |                 self.assertEqual(b.grad, b_ref.grad) | 
 |  | 
 |                 x.grad.zero_() | 
 |                 w.grad.zero_() | 
 |                 x_ref.grad.zero_() | 
 |                 w_ref.grad.zero_() | 
 |                 jit_o = jit_t(x, w, None) | 
 |                 jit_o.backward(grad) | 
 |                 o = t(x_ref, w_ref, None) | 
 |                 o.backward(grad) | 
 |  | 
 |                 self.assertEqual(jit_o, o) | 
 |                 self.assertEqual(x.grad, x_ref.grad) | 
 |                 self.assertEqual(w.grad, w_ref.grad) | 
 |  | 
 |     @unittest.skipIf(GRAPH_EXECUTOR != ProfilingMode.PROFILING, "the profiling version of test_rand") | 
 |     def test_rand_profiling(self): | 
 |         def test_rand(): | 
 |             a = torch.rand([3, 4]) | 
 |             return a + 1.0 - a | 
 |  | 
 |         # Testing shape analysis correctly setting type | 
 |         with enable_profiling_mode_for_profiling_tests(): | 
 |             with num_profiled_runs(1): | 
 |                 fn = torch.jit.script(test_rand) | 
 |                 out = fn() | 
 |                 graph_str = torch.jit.last_executed_optimized_graph() | 
 |                 self.assertEqual(out.dtype, torch.double) | 
 |                 FileCheck().check("Double(3, 4, strides=[4, 1], requires_grad=0, device=cpu)") \ | 
 |                            .check_not("Float(3, 4, strides=[4, 1], requires_grad=0, device=cpu)").run(graph_str) | 
 |  | 
 |             # fn = self.checkScript(test_rand, ()) | 
 |             # out = fn() | 
 |             # self.assertEqual(out.dtype, torch.double) | 
 |  | 
 |         @torch.jit.script | 
 |         def randint(): | 
 |             return torch.randint(0, 5, [1, 2]) | 
 |  | 
 |         with enable_profiling_mode_for_profiling_tests(): | 
 |             with num_profiled_runs(1): | 
 |                 out = randint() | 
 |                 graph_str = torch.jit.last_executed_optimized_graph() | 
 |                 self.assertEqual(out.dtype, torch.int64) | 
 |                 FileCheck().check("profiled_type=Long(1, 2, strides=[2, 1], requires_grad=0, device=cpu)").run(graph_str) | 
 |  | 
 |  | 
 |     def test_erase_number_types(self): | 
 |         def func(a): | 
 |             b = 7 + 1 + 3 | 
 |             c = a + b | 
 |             c += b | 
 |             return c | 
 |  | 
 |         graph = torch.jit.script(func).graph | 
 |         FileCheck().check("int = prim::Constant").check("aten::add_").run(str(graph)) | 
 |         self.run_pass("erase_number_types", graph) | 
 |         FileCheck().check_not("int = prim::Constant").run(str(graph)) | 
 |  | 
 |     def test_refine_tuple_types(self): | 
 |         # TupleConstruct output type is not correct here. | 
 |         graph_str = """ | 
 |         graph(%a : Float(123), %b : Float(4, 5, 6)): | 
 |           %c : (Tensor, Tensor) = prim::TupleConstruct(%a, %b) | 
 |           return (%c) | 
 |         """ | 
 |         graph = parse_ir(graph_str) | 
 |         torch._C._jit_pass_refine_tuple_types(graph) | 
 |  | 
 |         # After the pass, the output type should've been updated. | 
 |         self.assertTrue('(Float(123), Float(4, 5, 6))' in str(graph.findNode('prim::TupleConstruct').output())) | 
 |  | 
 |     # TODO(henrytu): Add test for RefineTypes for NamedTuple when it's supported by IR parser. | 
 |  | 
 |     def test_remove_dropout(self): | 
 |         weight_0_shape = (20, 5) | 
 |         weight_1_shape = (20, 20) | 
 |         input_shape = (10, 5) | 
 |  | 
 |         class M(torch.nn.Module): | 
 |             def __init__(self): | 
 |                 super(M, self).__init__() | 
 |                 self.weight_0 = torch.nn.Parameter(torch.rand(weight_0_shape)) | 
 |                 self.weight_1 = torch.nn.Parameter(torch.rand(weight_1_shape)) | 
 |  | 
 |             def forward(self, x): | 
 |                 o = F.linear(x, self.weight_0) | 
 |                 o = F.dropout(o, training=self.training) | 
 |                 o = F.linear(o, self.weight_1) | 
 |                 return o | 
 |  | 
 |         data = torch.rand(input_shape) | 
 |         m = M() | 
 |         m = torch.jit.script(m) | 
 |         with self.assertRaisesRegex(RuntimeError, r'Dropout removal module in training mode is not yet supported'): | 
 |             torch._C._jit_pass_remove_dropout(m._c) | 
 |         m.eval() | 
 |         ref_res = m(data) | 
 |         # Need to inline otherwise we see instances of Function. | 
 |         # We would have to use torch.linear/dropout to get around it otherwise. | 
 |         from torch.jit._recursive import wrap_cpp_module | 
 |         m = wrap_cpp_module(torch._C._freeze_module(m._c)) | 
 |         torch._C._jit_pass_remove_dropout(m._c) | 
 |         res = m(data) | 
 |         FileCheck().check_not("aten::dropout").run(str(m.graph)) | 
 |         torch.testing.assert_close(ref_res, res, rtol=1e-2, atol=1e-3) | 
 |  | 
 |     def test_unfold_zero_dim(self): | 
 |         def fn(x): | 
 |             return x.unfold(0, 1, 1) | 
 |  | 
 |         graph = torch.jit.script(fn).graph | 
 |         torch._C._jit_pass_complete_shape_analysis(graph, (torch.tensor(0.39),), False) | 
 |         out_dims = fn(torch.tensor(0.3923)).ndim | 
 |         self.assertEqual(graph.findNode("aten::unfold").output().type().dim(), out_dims) | 
 |  | 
 |     def test_mm_batching(self): | 
 |  | 
 |         with enable_profiling_mode_for_profiling_tests(): | 
 |             lstm_cell = torch.jit.script(LSTMCellS) | 
 |  | 
 |             def lstm(x, hx, cx, w_ih, w_hh, b_ih, b_hh): | 
 |                 for i in range(x.size(0)): | 
 |                     hx, cx = lstm_cell(x[i], hx, cx, w_ih, w_hh, b_ih, b_hh) | 
 |                 return hx | 
 |  | 
 |             slstm = torch.jit.script(lstm) | 
 |  | 
 |             inputs = get_lstm_inputs('cpu', training=True, seq_length=10) | 
 |             slstm(*inputs, profile_and_replay=True).sum().backward(retain_graph=True) | 
 |             if GRAPH_EXECUTOR == ProfilingMode.PROFILING: | 
 |                 slstm(*inputs, profile_and_replay=True).sum().backward() | 
 |  | 
 |             fw_graph = slstm.graph_for(*inputs) | 
 |             if GRAPH_EXECUTOR == ProfilingMode.LEGACY: | 
 |                 bw_graph = backward_graph(slstm, diff_graph_idx=0) | 
 |                 self.assertTrue('prim::MMBatchSide' in str(fw_graph)) | 
 |                 self.assertTrue('prim::MMTreeReduce' in str(bw_graph)) | 
 |  | 
 |             sout = slstm(*inputs) | 
 |             out = lstm(*inputs) | 
 |             self.assertEqual(sout, out) | 
 |             self.assertEqual(torch.autograd.grad(sout.sum(), inputs), | 
 |                              torch.autograd.grad(out.sum(), inputs)) | 
 |  | 
 |     def test_loop_unrolling(self): | 
 |         def fn(x): | 
 |             y = 0 | 
 |             for i in range(int(x)): | 
 |                 y -= i | 
 |             return y | 
 |  | 
 |         graph = torch.jit.script(fn).graph | 
 |         self.run_pass('loop_unrolling', graph) | 
 |         unroll_factor = 8 | 
 |         FileCheck().check("prim::Loop").check_count("aten::sub", unroll_factor) \ | 
 |             .check("prim::Loop").check("aten::sub").run(str(graph)) | 
 |         self.checkScript(fn, (torch.tensor(10),)) | 
 |  | 
 |     def test_loop_unrolling_const(self): | 
 |         def fn(): | 
 |             y = 0 | 
 |             for _ in range(10): | 
 |                 y -= 1 | 
 |             return y | 
 |  | 
 |         def fn2(): | 
 |             y = 0 | 
 |             for i in range(10): | 
 |                 y -= i | 
 |             return y | 
 |  | 
 |         def check(fn, name): | 
 |             graph = torch.jit.script(fn).graph | 
 |             self.run_pass('loop_unrolling', graph) | 
 |             # entirely unrolled | 
 |             FileCheck().check_not("prim::Loop'").run(str(graph)) | 
 |             self.checkScript(fn, ()) | 
 |  | 
 |         check(fn, 'add_const') | 
 |         check(fn2, 'add_iter') | 
 |  | 
 |     def test_loop_unrolling_nested(self): | 
 |         def fn(x): | 
 |             y = 0 | 
 |             for _ in range(10): | 
 |                 for j in range(int(x)): | 
 |                     y -= j | 
 |             return y | 
 |  | 
 |         graph = torch.jit.script(fn).graph | 
 |         self.run_pass('loop_unrolling', graph) | 
 |         # inner loop with 8 subs followed by loop epilogue | 
 |         unroll_factor = 8 | 
 |         FileCheck().check("prim::Loop").check("prim::Loop").check_count('aten::sub', unroll_factor) \ | 
 |             .check("prim::Loop").check("aten::sub").run(str(graph)) | 
 |         self.checkScript(fn, (torch.tensor(10),)) | 
 |  | 
 |     def test_loop_unroll_unused_counter(self): | 
 |         def fn(x): | 
 |             y = 0 | 
 |             for _ in range(int(x)): | 
 |                 y -= 1 | 
 |             return y | 
 |  | 
 |         graph = torch.jit.script(fn).graph | 
 |         self.run_pass('loop_unrolling', graph) | 
 |         FileCheck().check("prim::Loop").check_not("aten::add").check("return") \ | 
 |             .run(str(graph)) | 
 |  | 
 |     def test_loop_unroll_negative(self): | 
 |         def fn(x): | 
 |             y = 0 | 
 |             for _ in range(int(x)): | 
 |                 y += 1 | 
 |             return y | 
 |  | 
 |         self.checkScript(fn, (torch.tensor(-20),)) | 
 |         self.checkScript(fn, (torch.tensor(-2),)) | 
 |         self.checkScript(fn, (torch.tensor(-1),)) | 
 |         self.checkScript(fn, (torch.tensor(0),)) | 
 |         self.checkScript(fn, (torch.tensor(1),)) | 
 |         self.checkScript(fn, (torch.tensor(2),)) | 
 |  | 
 |     def test_where(self): | 
 |         def fn(x, y): | 
 |             return torch.where(x > 0.0, x, y) | 
 |  | 
 |         self.checkScript(fn, (torch.randn(3, 2, dtype=torch.float), torch.ones(3, 2, dtype=torch.float))) | 
 |  | 
 |     def test_where_method(self): | 
 |         def fn(x, y): | 
 |             return x.where(x > 0.0, y) | 
 |  | 
 |         self.checkScript(fn, (torch.randn(3, 2, dtype=torch.float), torch.ones(3, 2, dtype=torch.float))) | 
 |  | 
 |     def test_union_to_number(self): | 
 |         @torch.jit.script | 
 |         def fn(x: Union[int, complex, float], y: Union[int, complex, float]): | 
 |             return x + y | 
 |         FileCheck().check(": Scalar):").run(fn.graph) | 
 |  | 
 |     def test_reassign_module_lhs(self): | 
 |         with self.assertRaisesRegex(RuntimeError, 'Cannot re-assign \'self\''): | 
 |             class ReassignSelfLHS(torch.jit.ScriptModule): | 
 |                 @torch.jit.script_method | 
 |                 def forward(self, x): | 
 |                     for _ in range(20): | 
 |                         self = x | 
 |                     return self | 
 |  | 
 |             ReassignSelfLHS() | 
 |  | 
 |     def test_reassign_module_rhs(self): | 
 |         with self.assertRaisesRegex(RuntimeError, 'Cannot re-assign \'x\' to a value of type module'): | 
 |             class ReassignSelfRHS(torch.jit.ScriptModule): | 
 |                 @torch.jit.script_method | 
 |                 def forward(self, x): | 
 |                     for _ in range(20): | 
 |                         x = self | 
 |                     return self | 
 |  | 
 |             ReassignSelfRHS() | 
 |  | 
 |     def test_unknown_builtin(self): | 
 |         with self.assertRaisesRegex(RuntimeError, 'object has no attribute or method'): | 
 |             @torch.jit.script | 
 |             def unknown_builtin(x): | 
 |                 return x.splork(3) | 
 |  | 
 |     def test_return_tuple(self): | 
 |         def return_tuple(x): | 
 |             a = (x, x) | 
 |             return a, x | 
 |         self.checkScript(return_tuple, (torch.rand(4),)) | 
 |  | 
 |     def test_add_tuple_optional(self): | 
 |         def foo(input: Tuple[torch.Tensor, Optional[torch.Tensor], Optional[torch.Tensor]]) -> Optional[torch.Tensor]: | 
 |             changed_input = input[0] + 1 | 
 |             value: Tuple[torch.Tensor, Optional[torch.Tensor], Optional[torch.Tensor]] = (changed_input,) + input[1:] | 
 |             return value[2] | 
 |         inp: Tuple[torch.Tensor, Optional[torch.Tensor], Optional[torch.Tensor]] = (torch.rand(4), None, None) | 
 |         self.checkScript(foo, (inp,)) | 
 |  | 
 |     def test_add_tuple_non_optional(self): | 
 |         def foo(input: Tuple[torch.Tensor, torch.Tensor, torch.Tensor]) -> torch.Tensor: | 
 |             changed_input = input[0] + 1 | 
 |             value: Tuple[torch.Tensor, torch.Tensor, torch.Tensor] = (changed_input,) + input[1:] | 
 |             return torch.sum(value[2]) + 4 | 
 |         inp: Tuple[torch.Tensor, torch.Tensor, torch.Tensor] = (torch.rand(4), torch.rand(4), torch.rand(4)) | 
 |         self.checkScript(foo, (inp,)) | 
 |  | 
 |     def test_add_tuple_different_types(self): | 
 |         def foo(a: Tuple[int, float], b: Tuple[int]) -> int: | 
 |             c: Tuple[int, float, int] = a + b | 
 |             d: Tuple[int, float, int, int] = c + b | 
 |             return d[3] + 1 | 
 |         a = (1, 2.0) | 
 |         b = (3,) | 
 |         self.checkScript(foo, (a, b)) | 
 |  | 
 |     def test_add_tuple_same_types(self): | 
 |         def foo(a: Tuple[int, int], b: Tuple[int, int, int]) -> int: | 
 |             c: Tuple[int, int, int, int, int] = a + b | 
 |             d: Tuple[int, int, int, int, int, int, int, int] = c + b | 
 |             return d[6] - 2 | 
 |         a = (1, 2) | 
 |         b = (3, 4, 5) | 
 |         self.checkScript(foo, (a, b)) | 
 |  | 
 |     def test_method_no_self(self): | 
 |         with self.assertRaisesRegex(RuntimeError, 'methods must have a self argument'): | 
 |             class MethodNoSelf(torch.jit.ScriptModule): | 
 |                 @torch.jit.script_method  # noqa: B902 | 
 |                 def forward():  # noqa: B902 | 
 |                     return torch.zeros(3, 4) | 
 |  | 
 |             MethodNoSelf() | 
 |  | 
 |     def test_return_stmt_not_at_end(self): | 
 |         def return_stmt(x): | 
 |             if bool(x > 3): | 
 |                 return x + 3 | 
 |             else: | 
 |                 return x | 
 |         self.checkScript(return_stmt, (torch.rand(1),)) | 
 |  | 
 |     def test_for_in_range(self): | 
 |         def fn(): | 
 |             c = 0 | 
 |             for i in range(100): | 
 |                 c += i | 
 |             return c | 
 |         self.checkScript(fn, ()) | 
 |  | 
 |     def test_for_in_range_dynamic(self): | 
 |         def fn(): | 
 |             c = 0 | 
 |             for i in range(100): | 
 |                 acc = 0 | 
 |                 for j in range(i): | 
 |                     acc += j | 
 |                 c += acc | 
 |             return c | 
 |         self.checkScript(fn, (), optimize=False) | 
 |  | 
 |     def test_for_in_range_ast(self): | 
 |         def test_script_for_in_range_ast(): | 
 |             c = 0 | 
 |             for i in range(100): | 
 |                 acc = 0 | 
 |                 for j in range(i): | 
 |                     acc += j | 
 |                 c += acc | 
 |             return c | 
 |  | 
 |         self.checkScript(test_script_for_in_range_ast, ()) | 
 |  | 
 |     def test_for_in_range_if_ast(self): | 
 |         @torch.jit.script | 
 |         def test_script_for_in_range_if_ast(x): | 
 |             output = x | 
 |             for i in range(20): | 
 |                 if i == 0: | 
 |                     output = x.unsqueeze(0) | 
 |                 else: | 
 |                     output = torch.cat((output, x.unsqueeze(0)), dim=0) | 
 |             return output | 
 |         inputs = self._make_scalar_vars([0], torch.int64) | 
 |  | 
 |         self.assertEqual(test_script_for_in_range_if_ast(*inputs).shape[0], 20) | 
 |  | 
 |     def test_for_in_range_start_end(self): | 
 |         def fn(): | 
 |             x = 0 | 
 |             for i in range(7, 100): | 
 |                 x += i | 
 |             return x | 
 |         self.checkScript(fn, ()) | 
 |  | 
 |     def test_for_in_range_start_end_step(self): | 
 |         def fn(start, end, step): | 
 |             # type: (int, int, int) -> int | 
 |             x = 0 | 
 |             for i in range(start, end, step): | 
 |                 x += i | 
 |             return x | 
 |  | 
 |         self.checkScript(fn, (7, 100, 7)) | 
 |         self.checkScript(fn, (7, 100, -7)) | 
 |         self.checkScript(fn, (2, -11, -3)) | 
 |         self.checkScript(fn, (2, -11, 3)) | 
 |         self.checkScript(fn, (2, 10, 3)) | 
 |         self.checkScript(fn, (-2, -10, -10)) | 
 |  | 
 |     def test_for_in_range_zero_step(self): | 
 |         @torch.jit.script | 
 |         def fn(): | 
 |             x = 0 | 
 |             for i in range(2, -11, 0): | 
 |                 x += i | 
 |             return x | 
 |  | 
 |         with self.assertRaisesRegex(RuntimeError, "must not be zero"): | 
 |             fn() | 
 |  | 
 |     def test_range_args(self): | 
 |         with self.assertRaisesRegex(RuntimeError, r'range expected at least 1 arguments, got 0'): | 
 |             @torch.jit.script | 
 |             def range_no_arg(x): | 
 |                 for _ in range(): | 
 |                     x += 1 | 
 |                 return x | 
 |         with self.assertRaisesRegex(RuntimeError, r'found float'): | 
 |             @torch.jit.script | 
 |             def range_non_float(): | 
 |                 for i in range(.5): | 
 |                     print(i) | 
 |  | 
 |     def test_parse_empty_tuple_annotation(self): | 
 |         cu = torch.jit.CompilationUnit(''' | 
 |             def foo(x : Tuple[()]) -> Tuple[()]: | 
 |                 return x | 
 |         ''') | 
 |  | 
 |         foo_code = cu.find_function('foo').code | 
 |         FileCheck().check("Tuple[()]").check("Tuple[()]").run(foo_code) | 
 |  | 
 |     def test_parse_empty_tuple_annotation_element_error(self): | 
 |         with self.assertRaisesRegex( | 
 |                 RuntimeError, 'Tuple literal in Tuple type annotation must not have any elements'): | 
 |             cu = torch.jit.CompilationUnit(''' | 
 |                 def foo(x : Tuple[(int,)]) -> Tuple[(int,)]: | 
 |                     return x | 
 |             ''') | 
 |  | 
 |     def test_parse_none_type_annotation(self): | 
 |         cu = torch.jit.CompilationUnit(''' | 
 |             def foo(x : NoneType) -> NoneType: | 
 |                 return x | 
 |         ''') | 
 |  | 
 |         foo_code = cu.find_function('foo').code | 
 |         FileCheck().check(": NoneType").check("-> NoneType").run(foo_code) | 
 |  | 
 |     def test_empty_tuple_str(self): | 
 |         empty_tuple_type = torch._C.TupleType([]) | 
 |         g = {'Tuple' : typing.Tuple} | 
 |         python_type = eval(empty_tuple_type.annotation_str, g) | 
 |         assert python_type is typing.Tuple[()] | 
 |  | 
 |     def test_none_type_str(self): | 
 |         none_type = torch._C.NoneType.get() | 
 |         g = {'NoneType' : type(None)} | 
 |         python_type = eval(none_type.annotation_str, g) | 
 |         assert python_type is type(None) | 
 |  | 
 |     @skipIfTorchDynamo("TorchDynamo fails with unknown reason") | 
 |     def test_zip_enumerate_modulelist(self): | 
 |         class Sub(torch.nn.Module): | 
 |             def __init__(self): | 
 |                 super(Sub, self).__init__() | 
 |  | 
 |             def forward(self, thing): | 
 |                 return thing - 2 | 
 |  | 
 |         class Double(torch.nn.Module): | 
 |             def __init__(self): | 
 |                 super(Double, self).__init__() | 
 |  | 
 |             def forward(self, thing): | 
 |                 return thing * 2 | 
 |  | 
 |         # zipping over two | 
 |         class ZipModLists(torch.nn.Module): | 
 |             def __init__(self, mods, mods2): | 
 |                 super(ZipModLists, self).__init__() | 
 |                 self.mods = mods | 
 |                 self.mods2 = mods2 | 
 |  | 
 |             def forward(self, x): | 
 |                 iter = 0 | 
 |                 for mod1, mod2 in zip(self.mods, self.mods2): | 
 |                     x = mod2(mod1(x)) | 
 |                     iter += 1 | 
 |                 return x, iter | 
 |  | 
 |         class ZipWithValues(torch.nn.Module): | 
 |             __constants__ = ['tup_larger', 'tup_smaller'] | 
 |  | 
 |             def __init__(self, mods, mods2): | 
 |                 super(ZipWithValues, self).__init__() | 
 |                 self.mods = mods | 
 |                 self.mods2 = mods2 | 
 |                 self.tup_larger = list(range(len(mods2) + 1)) | 
 |                 self.tup_smaller = list(range(max(len(mods2) + 1, 1))) | 
 |  | 
 |             def forward(self, x): | 
 |                 iter = 0 | 
 |                 x2 = x | 
 |                 for val, mod1, mod2 in zip(self.tup_larger, self.mods, self.mods2): | 
 |                     x = mod2(mod1(x)) + val | 
 |                     iter += 1 | 
 |                 for val, mod1, mod2 in zip(self.tup_smaller, self.mods, self.mods2): | 
 |                     x2 = mod2(mod1(x2)) + val | 
 |                     iter += 1 | 
 |                 return x, iter | 
 |  | 
 |         mods = nn.ModuleList([Double()]), nn.ModuleList([Double(), Sub(), Sub()]), nn.ModuleList([Sub(), Double()]) | 
 |         for i in range(len(mods)): | 
 |             for j in range(len(mods)): | 
 |                 mod = ZipModLists(mods[i], mods[j]) | 
 |                 self.checkModule(mod, (torch.tensor(.5),)) | 
 |                 mod2 = ZipWithValues(mods[i], mods[j]) | 
 |                 self.checkModule(mod2, (torch.tensor(.5),)) | 
 |  | 
 |  | 
 |     def test_enumerate_modlist_range(self): | 
 |         class Double(torch.nn.Module): | 
 |             def forward(self, thing): | 
 |                 return thing * 2 | 
 |  | 
 |         class Mod(torch.nn.Module): | 
 |             def __init__(self): | 
 |                 super(Mod, self).__init__() | 
 |                 self.mods = nn.ModuleList([Double(), Double()]) | 
 |  | 
 |             def forward(self, x): | 
 |                 x2 = x | 
 |                 iter = 0 | 
 |                 for val, mod in enumerate(self.mods): | 
 |                     x2 = mod(x2) * val | 
 |                     iter += 1 | 
 |                 return iter, x, x2 | 
 |  | 
 |         self.checkModule(Mod(), (torch.tensor(.5),)) | 
 |  | 
 |         # variable length, modulelist | 
 |         class Mod2(Mod): | 
 |             def forward(self, x): | 
 |                 for val, mod in zip(range(int(x)), self.mods): | 
 |                     x = mod(x) * val | 
 |                 return x | 
 |  | 
 |         with self.assertRaisesRegex(Exception, "that does not have a statically determinable length"): | 
 |             torch.jit.script(Mod2()) | 
 |  | 
 |         # modulelist, variable length | 
 |         class Mod3(Mod): | 
 |             def forward(self, x): | 
 |                 for val, mod in zip(self.mods, range(int(x))): | 
 |                     x = mod(x) * val | 
 |                 return x | 
 |  | 
 |         with self.assertRaisesRegex(Exception, "that does not have a statically determinable length"): | 
 |             torch.jit.script(Mod3()) | 
 |  | 
 |     def test_for_in_enumerate(self): | 
 |         def fn(x): | 
 |             # type: (List[int]) -> int | 
 |             sum = 0 | 
 |             for (i, v) in enumerate(x): | 
 |                 sum += i * v | 
 |  | 
 |             return sum | 
 |  | 
 |         self.checkScript(fn, ([1, 2, 3, 4, 5],)) | 
 |  | 
 |         def fn_enumerate_start_arg(x): | 
 |             # type: (List[int]) -> int | 
 |             sum = 0 | 
 |             for (i, v) in enumerate(x, 1): | 
 |                 sum += i * v | 
 |  | 
 |             return sum | 
 |  | 
 |         self.checkScript(fn_enumerate_start_arg, ([1, 2, 3, 4, 5],)) | 
 |  | 
 |         def fn_enumerate_start_kwarg(x): | 
 |             # type: (List[int]) -> int | 
 |             sum = 0 | 
 |             for (i, v) in enumerate(x, start=1): | 
 |                 sum += i * v | 
 |  | 
 |             return sum | 
 |  | 
 |         self.checkScript(fn_enumerate_start_kwarg, ([1, 2, 3, 4, 5],)) | 
 |  | 
 |         def fn_nested_enumerate(x): | 
 |             # type: (List[int]) -> int | 
 |             sum = 0 | 
 |             for (i, (j, v)) in enumerate(enumerate(x)): | 
 |                 sum += i * j * v | 
 |  | 
 |             return sum | 
 |  | 
 |         self.checkScript(fn_nested_enumerate, ([1, 2, 3, 4, 5],)) | 
 |  | 
 |         with self.assertRaisesRegex(RuntimeError, r'enumerate expected at least 1 arguments, got 0'): | 
 |             @torch.jit.script | 
 |             def enumerate_no_arg(x): | 
 |                 # type: (List[int]) -> int | 
 |                 sum = 0 | 
 |                 for _ in enumerate(): | 
 |                     sum += 1 | 
 |  | 
 |                 return sum | 
 |  | 
 |         with self.assertRaisesRegex(RuntimeError, r'enumerate expected at most 2 arguments, got 3'): | 
 |             @torch.jit.script | 
 |             def enumerate_too_many_args(x): | 
 |                 # type: (List[int]) -> int | 
 |                 sum = 0 | 
 |                 for _ in enumerate(x, x, x): | 
 |                     sum += 1 | 
 |  | 
 |                 return sum | 
 |  | 
 |     def test_list_comprehension_modulelist(self): | 
 |         class Inner(torch.nn.Module): | 
 |             def forward(self, x): | 
 |                 return x + 10 | 
 |  | 
 |         class M(torch.nn.Module): | 
 |             def __init__(self, mod_list): | 
 |                 super(M, self).__init__() | 
 |                 self.module_list = mod_list | 
 |  | 
 |             def forward(self, x): | 
 |                 out = torch.jit.annotate(List[Tensor], [mod(x) for mod in self.module_list]) | 
 |                 return out | 
 |  | 
 |         mod = M(nn.ModuleList([Inner(), Inner()])) | 
 |         self.checkModule(mod, (torch.tensor(3),)) | 
 |  | 
 |         mod = M(nn.ModuleList([])) | 
 |         torch.jit.script(mod) | 
 |  | 
 |         class M2(M): | 
 |             def __init__(self, mod_list): | 
 |                 super(M2, self).__init__(mod_list) | 
 |  | 
 |             def forward(self, x): | 
 |                 out = [mod(x) for mod in self.module_list] | 
 |                 return out | 
 |  | 
 |         mod = M2(nn.ModuleList([Inner(), Inner()])) | 
 |         self.checkModule(mod, (torch.tensor(3),)) | 
 |  | 
 |         mod = M2(nn.ModuleList([])) | 
 |         # defaults to List of Tensor for empty modulelist | 
 |         self.assertEqual(torch.jit.script(mod)(torch.tensor(.5)), []) | 
 |  | 
 |         def bad_type_annotation(): | 
 |             out = torch.jit.annotate(int, [x for x in [1, 2, 3]])  # noqa: C416 | 
 |             return out | 
 |  | 
 |         with self.assertRaisesRegex(Exception, "Expected an annotation" | 
 |                                     " of type List"): | 
 |             torch.jit.script(bad_type_annotation) | 
 |  | 
 |     def test_list_comprehension_variable_write(self): | 
 |         # i in comprehension doesn't write to function scope | 
 |         def foo(): | 
 |             i = 1 | 
 |             x = [i if i != 5 else 3 for i in range(7)]  # noqa: C416 | 
 |             return i, x | 
 |  | 
 |         self.assertEqual(foo(), torch.jit.script(foo)()) | 
 |  | 
 |     def test_for_in_zip(self): | 
 |         def fn(x, y): | 
 |             # type: (List[int], List[int]) -> int | 
 |             sum = 0 | 
 |             for (i, j) in zip(x, y): | 
 |                 sum += i * j | 
 |  | 
 |             return sum | 
 |  | 
 |         self.checkScript(fn, ([1, 2, 3, 4, 5], [2, 3, 4, 5, 6])) | 
 |  | 
 |         def fn_multi_inputs(x, y, z): | 
 |             # type: (List[int], List[int], List[int]) -> int | 
 |             sum = 0 | 
 |             for (i, j, k) in zip(x, y, z): | 
 |                 sum += i * j * k | 
 |  | 
 |             return sum | 
 |  | 
 |         self.checkScript(fn_multi_inputs, ([1, 2, 3, 4], [2, 3, 4, 5], [3, 4, 5, 6])) | 
 |  | 
 |         def fn_nested_zip(x, y, z): | 
 |             # type: (List[int], List[int], List[int]) -> int | 
 |             sum = 0 | 
 |             for (i, (j, k)) in zip(x, zip(y, z)): | 
 |                 sum += i * j * k | 
 |  | 
 |             return sum | 
 |  | 
 |         self.checkScript(fn_multi_inputs, ([1, 2, 3, 4], [2, 3, 4, 5], [3, 4, 5, 6])) | 
 |  | 
 |         with self.assertRaisesRegex(RuntimeError, r'zip expected at least 1 arguments, got 0'): | 
 |             @torch.jit.script | 
 |             def zip_no_arg(x): | 
 |                 # type: (List[int]) -> int | 
 |                 sum = 0 | 
 |                 for _ in zip(): | 
 |                     sum += 1 | 
 |  | 
 |                 return sum | 
 |  | 
 |         with self.assertRaisesRegex(RuntimeError, r'too many values to unpack: need 2 but found 3'): | 
 |             @torch.jit.script | 
 |             def fn_nested_zip_wrong_target_assign(x, y, z): | 
 |                 # type: (List[int], List[int], List[int]) -> int | 
 |                 sum = 0 | 
 |                 for (i, (j, k)) in zip(x, y, z): | 
 |                     sum += i * j * k | 
 |  | 
 |                 return sum | 
 |  | 
 |     def test_for_in_zip_enumerate(self): | 
 |         def fn_zip_enumerate(x, y): | 
 |             # type: (List[int], List[int]) -> int | 
 |             sum = 0 | 
 |             for (i, (j, v), k) in zip(x, enumerate(y), range(0, 100)): | 
 |                 sum += i * j * v * k | 
 |  | 
 |             return sum | 
 |  | 
 |         self.checkScript(fn_zip_enumerate, ([1, 2, 3, 4], [2, 3, 4, 5])) | 
 |  | 
 |         def fn_enumerate_zip(x, y): | 
 |             # type: (List[int], List[int]) -> int | 
 |             sum = 0 | 
 |             for (i, (j, v)) in enumerate(zip(x, y)): | 
 |                 sum += i * j * v | 
 |  | 
 |             return sum | 
 |  | 
 |         self.checkScript(fn_enumerate_zip, ([1, 2, 3, 4], [2, 3, 4, 5])) | 
 |  | 
 |     def test_for_in_tensors(self): | 
 |         def test_sizes(x): | 
 |             sumz = 0 | 
 |             for s in x: | 
 |                 sumz += 1 | 
 |             return sumz | 
 |         self.checkScript(test_sizes, (torch.rand(5, 4, 3, 2, 1),)) | 
 |         self.checkScript(test_sizes, (torch.rand(777),)) | 
 |         self.checkScript(test_sizes, (torch.rand(0),)) | 
 |  | 
 |     def test_for_in_tensors_rank0(self): | 
 |         with self.assertRaisesRegex(RuntimeError, "of a 0-d tensor"): | 
 |             @torch.jit.script | 
 |             def test_sizes(x): | 
 |                 sumz = 0 | 
 |                 for s in x: | 
 |                     sumz += 1 | 
 |                 return sumz | 
 |  | 
 |             test_sizes(torch.tensor(1)) | 
 |  | 
 |     def test_for_in_tensors_fail_scalar(self): | 
 |         with self.assertRaisesRegex(RuntimeError, "'float' object is not iterable"): | 
 |             @torch.jit.script | 
 |             def test_sizes(x): | 
 |                 # type: (float) -> int | 
 |                 sumz = 0 | 
 |                 for s in x: | 
 |                     sumz += 1 | 
 |                 return sumz | 
 |  | 
 |             test_sizes(0.0) | 
 |  | 
 |     def test_for_in_tensors_nested(self): | 
 |         def test_sizes(x): | 
 |             sumz = 0 | 
 |             for n in x: | 
 |                 for t in n: | 
 |                     sumz += 1 | 
 |             return sumz | 
 |  | 
 |         self.checkScript(test_sizes, (torch.rand(5, 4, 3, 2, 1),)) | 
 |  | 
 |     # to avoid defining sum_list in multiple tests | 
 |     def get_sum_list_fn(self): | 
 |         def sum_list(a): | 
 |             # type: (List[int]) -> int | 
 |             sum = 0 | 
 |             for i in a: | 
 |                 sum += i | 
 |  | 
 |             return sum | 
 |  | 
 |         return sum_list | 
 |  | 
 |     def test_sum_list_diff_elms(self): | 
 |         self.checkScript(self.get_sum_list_fn(), ([1, 2, 3, 4, 5],)) | 
 |  | 
 |     def test_sum_list_empty(self): | 
 |         self.checkScript(self.get_sum_list_fn(), ([],)) | 
 |  | 
 |     def test_sum_list_one(self): | 
 |         self.checkScript(self.get_sum_list_fn(), ([1],)) | 
 |  | 
 |     def test_sum_list_literal(self): | 
 |  | 
 |         def sum_list(): | 
 |             # type: () -> int | 
 |             sum = 0 | 
 |             for i in [1, 2, 3, 4, 5]: | 
 |                 sum += i | 
 |  | 
 |             return sum | 
 |  | 
 |         self.checkScript(sum_list, ()) | 
 |  | 
 |     def test_sum_list_wrong_type(self): | 
 |  | 
 |         with self.assertRaisesRegex(RuntimeError, "'int' object is not iterable"): | 
 |             @torch.jit.script | 
 |             def sum_list(a): | 
 |                 # type: (int) -> int | 
 |                 sum = 0 | 
 |                 for i in a:  # noqa: T484 | 
 |                     sum += i | 
 |  | 
 |                 return sum | 
 |  | 
 |             sum_list(1) | 
 |  | 
 |     def test_list_iterables(self): | 
 |         with self.assertRaisesRegex(RuntimeError, 'List of iterables is not supported currently'): | 
 |             cu = torch.jit.CompilationUnit(''' | 
 |             def list_iterables(x): | 
 |                 for i, j in [2, 3, 4], [5, 6, 7]: | 
 |                     x += i | 
 |                     x += j | 
 |                 return x | 
 |             ''') | 
 |  | 
 |     def test_for_in_string(self): | 
 |         def test_strings(x): | 
 |             # type: (str) -> str | 
 |             reverse = "" | 
 |             for c in x: | 
 |                 reverse = c + reverse | 
 |             return reverse | 
 |  | 
 |         self.checkScript(test_strings, ("hello",)) | 
 |         self.checkScript(test_strings, ("",)) | 
 |  | 
 |         def test_list_strings(x): | 
 |             # type: (List[str]) -> str | 
 |             result = "" | 
 |             for sub_str in x: | 
 |                 result += sub_str | 
 |             return result | 
 |  | 
 |         self.checkScript(test_list_strings, (["hello", "world"],)) | 
 |         self.checkScript(test_list_strings, (["hello", " ", "world", ""],)) | 
 |  | 
 |     def test_for_in_dict(self): | 
 |         def test_dicts(x): | 
 |             # type: (Dict[str, int]) -> int | 
 |             sum = 0 | 
 |             for key in x: | 
 |                 sum += x[key] | 
 |             return sum | 
 |  | 
 |         self.checkScript(test_dicts, ({"a": 1, "b": 2, "c": 3},)) | 
 |  | 
 |         def test_dict_keys_values(x): | 
 |             # type: (Dict[str, int]) -> Tuple[str, int] | 
 |             key_str = "" | 
 |             sum = 0 | 
 |             for key in x.keys(): | 
 |                 key_str += key | 
 |             for val in x.values(): | 
 |                 sum += val | 
 |             return key_str, sum | 
 |  | 
 |         self.checkScript(test_dicts, ({"a": 1, "b": 2, "c": 3},)) | 
 |  | 
 |     def test_for_tuple_unpack(self): | 
 |         def for_tuple_unpack(x, y): | 
 |             for i, j in [[3, 4], [5, 6], [7, 8]]: | 
 |                 x += i | 
 |                 y += j | 
 |             return x, y | 
 |  | 
 |         self.checkScript(for_tuple_unpack, (torch.tensor(3), torch.tensor(5))) | 
 |  | 
 |         def nested_tuple_unpack(x, y): | 
 |             # type: (List[int], List[int]) -> int | 
 |             sum = 0 | 
 |             for i, (j, k), v in zip(x, enumerate(x), y): | 
 |                 sum += i + j + k + v | 
 |             return sum | 
 |  | 
 |         self.checkScript(nested_tuple_unpack, ([1, 3, 5], [2, 4, 6])) | 
 |  | 
 |     def test_for_tuple_assign(self): | 
 |         def test_simple_assign(x): | 
 |             # type: (Tuple[int, float]) -> float | 
 |             sum = 0.0 | 
 |             for a in x: | 
 |                 sum += float(a) | 
 |             return sum | 
 |  | 
 |         self.checkScript(test_simple_assign, ((1, 2.5),)) | 
 |  | 
 |         def test_tuple_assign(x): | 
 |             # type: (Tuple[Tuple[int, int], Tuple[int, int]]) -> int | 
 |             sum = 0 | 
 |             for a in x: | 
 |                 sum += a[0] | 
 |                 sum += a[1] | 
 |             return sum | 
 |  | 
 |         self.checkScript(test_tuple_assign, (((1, 2), (4, 7)), )) | 
 |  | 
 |     def test_single_starred_lhs(self): | 
 |         with self.assertRaisesRegex(RuntimeError, 'A Starred expression may only appear on the lhs within the presence' | 
 |                                                   ' of another non-starred expression'): | 
 |             cu = torch.jit.CompilationUnit(''' | 
 |             def single_starred_lhs(x): | 
 |                 a = (x, x, x) | 
 |                 *b, = a | 
 |                 return b | 
 |             ''') | 
 |  | 
 |     def test_singleton_tuple_unpack(self): | 
 |         def foo(a): | 
 |             b, = (a,) | 
 |             return b + 1 | 
 |         self.checkScript(foo, (torch.rand(3),)) | 
 |  | 
 |     def test_tuple_assignments(self): | 
 |         def var_tuple_assign(x, y): | 
 |             # type: (Tuple[Tensor, Tensor], Tensor) -> Tensor | 
 |             (a, b), c = x, y | 
 |             return a + b + c | 
 |  | 
 |         tuple_inputs = (torch.randn(1, 4), torch.randn(3, 4)) | 
 |         self.checkScript(var_tuple_assign, (tuple_inputs, torch.randn(3, 4))) | 
 |  | 
 |         def nested_tuple_assign(x, y, z): | 
 |             # type: (int, Tuple[int, Tuple[int, int]], Tuple[int, int]) -> int | 
 |             a, (b, (c, d)), (e, f) = x, y, z | 
 |             return a + b + c + d + e + f | 
 |  | 
 |         self.checkScript(nested_tuple_assign, ((1, (2, (3, 4)), (5, 6)))) | 
 |  | 
 |         def subscript_tuple_assign(a, x, i): | 
 |             # type: (List[int], Tensor, int) -> Tuple[int, Tensor, int] | 
 |             a[i], (x[i], b) = 1, (2, 3) | 
 |             return a[i] + 1, x + 5, b | 
 |  | 
 |         self.checkScript(subscript_tuple_assign, ([12, 7, 9, 11], torch.tensor((3, 13, 17)), 0)) | 
 |  | 
 |         def star_tuple_assign(): | 
 |             # type: () -> Tuple[int, int, Tuple[int, int], Tuple[int, int]] | 
 |             a, (b, *c), *d = 1, (2, 3, 4), 5, 6 | 
 |             return a, b, c, d | 
 |  | 
 |         self.checkScript(star_tuple_assign, ()) | 
 |  | 
 |         def subscript_tuple_augmented_assign(a): | 
 |             # type: (Tuple[int, int]) -> Tuple[int, int] | 
 |             a[0] += 1 | 
 |             return a | 
 |  | 
 |         with self.assertRaisesRegex(RuntimeError, 'does not support augmented assign'): | 
 |             scripted_aug_assign = torch.jit.script(subscript_tuple_augmented_assign) | 
 |  | 
 |         class AttrTupleAssignmentTestClass: | 
 |             def __init__(self, a: int, b: int): | 
 |                 self.a = a | 
 |                 self.b = b | 
 |  | 
 |             def set_ab(self, a: int, b: int): | 
 |                 self.a, self.b = (a, b) | 
 |  | 
 |             def get(self) -> Tuple[int, int]: | 
 |                 return (self.a, self.b) | 
 |  | 
 |         make_global(AttrTupleAssignmentTestClass) | 
 |  | 
 |         @torch.jit.script | 
 |         def attr_tuple_assignment(o: AttrTupleAssignmentTestClass, a: int, b: int): | 
 |             o.set_ab(a, b) | 
 |             return o | 
 |  | 
 |         o = AttrTupleAssignmentTestClass(1, 2) | 
 |         self.assertEqual(attr_tuple_assignment(o, 3, 4).get(), (3, 4)) | 
 |  | 
 |     def test_multiple_assign(self): | 
 |         def test(): | 
 |             a = b, c = d, f = (1, 1) | 
 |  | 
 |             # side effect | 
 |             ten = torch.tensor(1) | 
 |             ten1 = ten2 = ten.add_(1) | 
 |  | 
 |             # ordering | 
 |             x = 1 | 
 |             y = 3 | 
 |             x, y = y, x + y | 
 |  | 
 |             return a, b, c, d, f, ten, ten1, ten2, x, y | 
 |  | 
 |         self.checkScript(test, ()) | 
 |  | 
 |     def test_multi_reduction(self): | 
 |         with self.assertRaisesRegex( | 
 |                 RuntimeError, | 
 |                 'augmented assignment can only have one LHS expression'): | 
 |             cu = torch.jit.CompilationUnit(''' | 
 |             def multi_reduction(x): | 
 |                 a, b += x | 
 |                 return a, b | 
 |             ''') | 
 |  | 
 |     def test_invalid_call_arguments(self): | 
 |         with self.assertRaisesRegex(RuntimeError, 'but instead found type '): | 
 |             @torch.jit.script | 
 |             def invalid_call_arguments(x): | 
 |                 return torch.unsqueeze(3, 4, 5, 6, 7, 8) | 
 |  | 
 |     def test_invalid_lhs_assignment(self): | 
 |         with self.assertRaisesRegex(RuntimeError, 'unexpected expression'): | 
 |             cu = torch.jit.CompilationUnit(''' | 
 |             def invalid_lhs_assignment(x): | 
 |                 x + 1 = x | 
 |                 return x | 
 |             ''') | 
 |  | 
 |     def test_multi_starred_expr_lhs(self): | 
 |         with self.assertRaisesRegex(RuntimeError, 'Only one starred expression is allowed on the lhs'): | 
 |             cu = torch.jit.CompilationUnit(''' | 
 |             def multi_starred_expr_lhs(): | 
 |                 a, *b, *c = [1, 2, 3, 4, 5, 6] | 
 |                 return a | 
 |             ''') | 
 |  | 
 |     def test_pack_tuple_into_non_var(self): | 
 |         with self.assertRaisesRegex(RuntimeError, 'Cannot pack a tuple into a non-variable'): | 
 |             cu = torch.jit.CompilationUnit(''' | 
 |             def pack_tuple_into_non_var(x): | 
 |                 a, *1 = (3, 4, 5) | 
 |                 return x | 
 |             ''') | 
 |  | 
 |     def test_print_kwargs(self): | 
 |         with self.assertRaisesRegex(RuntimeError, 'print doesn\'t accept any keyword arguments'): | 
 |             cu = torch.jit.CompilationUnit(''' | 
 |             def print_kwargs(x): | 
 |                 print(x, flush=True) | 
 |                 return x | 
 |             ''') | 
 |  | 
 |     def test_builtin_use_as_value(self): | 
 |         with self.assertRaisesRegex(RuntimeError, 'builtin cannot be used as a value'): | 
 |             @torch.jit.script | 
 |             def builtin_use_as_value(x): | 
 |                 return x.unsqueeze | 
 |  | 
 |     def test_wrong_use_as_tuple(self): | 
 |         with self.assertRaisesRegex(RuntimeError, 'cannot be used as a tuple'): | 
 |             def test_fn(): | 
 |                 return 3 | 
 |  | 
 |             @torch.jit.script | 
 |             def wrong_use_as_tuple(self): | 
 |                 a, b = test_fn | 
 |                 return a | 
 |  | 
 |     def test_wrong_attr_lookup(self): | 
 |         with self.assertRaisesRegex(RuntimeError, 'attribute lookup is not defined on builtin'): | 
 |             @torch.jit.script | 
 |             def wrong_attr_lookup(self, x): | 
 |                 a = x.unsqueeze.myattr | 
 |                 return a | 
 |  | 
 |     def test_wrong_use_as_callable(self): | 
 |         with self.assertRaisesRegex(RuntimeError, 'cannot call a value'): | 
 |             @torch.jit.script | 
 |             def wrong_use_as_callable(x): | 
 |                 return x(3, 4, 5) | 
 |  | 
 |     def test_python_val_doesnt_have_attr(self): | 
 |         with self.assertRaisesRegex(RuntimeError, 'object has no attribute abcd'): | 
 |  | 
 |             @torch.jit.script | 
 |             def python_val_doesnt_have_attr(): | 
 |                 # this has to be a module otherwise attr lookup would not be | 
 |                 # allowed in the first place | 
 |                 return shutil.abcd | 
 |  | 
 |     def test_wrong_module_attr_lookup(self): | 
 |         with self.assertRaisesRegex(RuntimeError, 'python value of type \'type\' cannot be used as a value'): | 
 |             import io | 
 |  | 
 |             @torch.jit.script | 
 |             def wrong_module_attr_lookup(): | 
 |                 return io.BytesIO | 
 |  | 
 |     def test_wrong_method_call_inputs(self): | 
 |         with self.assertRaisesRegex(RuntimeError, 'Argument y not provided'): | 
 |             class SomeModule(torch.jit.ScriptModule): | 
 |  | 
 |                 @torch.jit.script_method | 
 |                 def foo(self, x, y): | 
 |                     return x | 
 |  | 
 |                 @torch.jit.script_method | 
 |                 def forward(self, x, y): | 
 |                     return self.foo(x) | 
 |             SomeModule() | 
 |  | 
 |     def test_single_starred_expr_for_loop(self): | 
 |         with self.assertRaisesRegex(RuntimeError, 'A Starred expression may only appear'): | 
 |             cu = torch.jit.CompilationUnit(''' | 
 |             def test(): | 
 |                 x = 0 | 
 |                 for *a in [1, 2, 3]: | 
 |                     x = x + 1 | 
 |                 return x | 
 |             ''') | 
 |  | 
 |     def test_call_ge(self): | 
 |         with self.assertRaisesRegex(RuntimeError, 'Expected at most 1 arguments but found 3'): | 
 |             @_trace(torch.zeros(1, 2, 3)) | 
 |             def foo(x): | 
 |                 return x | 
 |  | 
 |             @torch.jit.script | 
 |             def test_fn(): | 
 |                 return foo(torch.full([1], 1), torch.full([1], 2), torch.full([1], 3)) | 
 |  | 
 |     def test_wrong_return_type(self): | 
 |         with self.assertRaisesRegex(RuntimeError, 'but instead got value of type tuple'): | 
 |             @torch.jit.ignore | 
 |             def somefunc(): | 
 |                 # type: () -> Tuple[Tuple[Tensor, Tensor]] | 
 |                 return torch.zeros(3, 4), torch.zeros(4, 5)  # noqa: T484 | 
 |  | 
 |             @torch.jit.script | 
 |             def wrong_return_type(): | 
 |                 return somefunc() | 
 |             wrong_return_type() | 
 |  | 
 |     # Tests for calling between different front-end modes | 
 |     def test_call_python_fn_from_tracing_fn(self): | 
 |         def python_fn(x): | 
 |             return torch.neg(x) | 
 |  | 
 |         @_trace(torch.rand(3, 4)) | 
 |         def traced_fn(x): | 
 |             return python_fn(x) + 1 | 
 |  | 
 |         # The neg op in the python function should be properly inlined to the | 
 |         # graph | 
 |         FileCheck().check("aten::neg").run(str(traced_fn.graph)) | 
 |  | 
 |     def test_call_python_mod_from_tracing_fn(self): | 
 |         class PythonMod(torch.nn.Module): | 
 |             def __init__(self): | 
 |                 super(PythonMod, self).__init__() | 
 |                 self.param = torch.nn.Parameter(torch.rand(4, 3), requires_grad=False) | 
 |  | 
 |             def forward(self, x): | 
 |                 return torch.mm(x, self.param) | 
 |  | 
 |         pm = PythonMod() | 
 |  | 
 |         @_trace(torch.rand(3, 4)) | 
 |         def traced_fn(x): | 
 |             return pm(x) + 1.0 | 
 |  | 
 |         # Note: the parameter self.param from the Python module is inlined | 
 |         # into the graph | 
 |         self.assertTrue(len(list(traced_fn.graph.inputs())) == 1) | 
 |         FileCheck().check("aten::mm").check("aten::add").run(str(traced_fn.graph)) | 
 |  | 
 |     @_tmp_donotuse_dont_inline_everything | 
 |     def test_call_traced_fn_from_tracing_fn(self): | 
 |         @_trace(torch.rand(3, 4)) | 
 |         def traced_fn1(x): | 
 |             return torch.neg(x) | 
 |  | 
 |         @_trace(torch.rand(3, 4)) | 
 |         def traced_fn(x): | 
 |             return traced_fn1(x) + 1 | 
 |  | 
 |         FileCheck().check("traced_fn").check("prim::CallFunction").check("aten::add") \ | 
 |             .run(str(traced_fn.graph)) | 
 |  | 
 |     @unittest.skip("error in first class mode") | 
 |     def test_call_traced_mod_from_tracing_fn(self): | 
 |         class TracedModule(torch.nn.Module): | 
 |             def __init__(self): | 
 |                 super(TracedModule, self).__init__() | 
 |                 self.param = torch.nn.Parameter(torch.rand(4, 3), requires_grad=False) | 
 |  | 
 |             def forward(self, x): | 
 |                 return torch.mm(x, self.param) | 
 |  | 
 |         tm = torch.jit.trace(TracedModule(), torch.rand(3, 4)) | 
 |  | 
 |         with self.assertRaisesRegex(RuntimeError, "must be registered as submodules"): | 
 |             @_trace(torch.rand(3, 4)) | 
 |             def traced_fn(x): | 
 |                 return tm(x) + 1.0 | 
 |  | 
 |     @_tmp_donotuse_dont_inline_everything | 
 |     def test_call_script_fn_from_tracing_fn(self): | 
 |         @torch.jit.script | 
 |         def script_fn(x): | 
 |             return torch.neg(x) | 
 |  | 
 |         @_trace(torch.rand(3, 4)) | 
 |         def traced_fn(x): | 
 |             return script_fn(x) + 1 | 
 |  | 
 |         FileCheck().check("prim::CallFunction").check("aten::add").run(str(traced_fn.graph)) | 
 |  | 
 |     @unittest.skip("error in first class mode") | 
 |     def test_call_script_mod_from_tracing_fn(self): | 
 |         with self.assertRaisesRegex(RuntimeError, "must be registered as submodules"): | 
 |             class ScriptMod(torch.jit.ScriptModule): | 
 |                 def __init__(self): | 
 |                     super(ScriptMod, self).__init__() | 
 |                     self.param = torch.nn.Parameter(torch.rand(3, 4), requires_grad=False) | 
 |  | 
 |                 @torch.jit.script_method | 
 |                 def forward(self, x): | 
 |                     for _i in range(4): | 
 |                         x += self.param | 
 |                     return x | 
 |  | 
 |             sm = ScriptMod() | 
 |  | 
 |             @_trace(torch.rand(3, 4)) | 
 |             def traced_fn(x): | 
 |                 return sm(x) + 1.0 | 
 |  | 
 |  | 
 |     def test_call_python_fn_from_traced_module(self): | 
 |         def python_fn(x): | 
 |             return torch.neg(x) | 
 |  | 
 |         class TracedModule(torch.nn.Module): | 
 |             def __init__(self): | 
 |                 super(TracedModule, self).__init__() | 
 |                 self.param = torch.nn.Parameter(torch.rand(4, 3)) | 
 |  | 
 |             def forward(self, x): | 
 |                 return torch.mm(python_fn(x), self.param) | 
 |  | 
 |         tm = torch.jit.trace(TracedModule(), torch.rand(3, 4)) | 
 |  | 
 |         # Note: parameter self.param from the traced module should appear as | 
 |         # an input to the graph and the neg op from the Python function should | 
 |         # be properly inlined | 
 |         self.assertTrue(len(list(tm.graph.inputs())) == 2) | 
 |         FileCheck().check("aten::neg").check("aten::mm").run(str(tm.graph)) | 
 |  | 
 |     def test_call_python_mod_from_traced_module(self): | 
 |         class PythonModule(torch.nn.Module): | 
 |             def __init__(self): | 
 |                 super(PythonModule, self).__init__() | 
 |                 self.param = torch.nn.Parameter(torch.rand(5, 7)) | 
 |  | 
 |             def forward(self, x): | 
 |                 return torch.mm(x, self.param) | 
 |  | 
 |         class TracedModule(torch.nn.Module): | 
 |             def __init__(self): | 
 |                 super(TracedModule, self).__init__() | 
 |                 self.param = torch.nn.Parameter(torch.rand(4, 5)) | 
 |                 self.mod = PythonModule() | 
 |  | 
 |             def forward(self, x): | 
 |                 return self.mod(torch.mm(x, self.param)) + 1.0 | 
 |  | 
 |         tm = torch.jit.trace(TracedModule(), torch.rand(3, 4)) | 
 |  | 
 |         FileCheck().check_not("value=<Tensor>").check("aten::mm")\ | 
 |             .check("prim::CallMethod[name=\"forward\"]").check("aten::add") \ | 
 |             .run(str(tm.graph)) | 
 |         FileCheck().check("aten::mm").run(str(tm.mod.graph)) | 
 |  | 
 |     def test_op_dtype(self): | 
 |  | 
 |         def check_equal_and_dtype(a, b): | 
 |             self.assertEqual(a, b) | 
 |             self.assertEqual(a.dtype, b.dtype) | 
 |  | 
 |         def fn(): | 
 |             a = torch.arange(10) | 
 |             b = torch.arange(10, dtype=torch.float) | 
 |             c = torch.arange(1, 10, 2) | 
 |             d = torch.arange(1, 10, 2, dtype=torch.float) | 
 |             e = torch.arange(1, 10., 2) | 
 |             f = torch.arange(1, 10., 2, dtype=torch.float) | 
 |             return a, b, c, d, e, f | 
 |  | 
 |         scripted_fn = torch.jit.script(fn) | 
 |         eager_out = fn() | 
 |         script_out = scripted_fn() | 
 |         for a, b in zip(eager_out, script_out): | 
 |             check_equal_and_dtype(a, b) | 
 |  | 
 |     def test_floor_div(self): | 
 |         @torch.jit.script | 
 |         def foo(a, b): | 
 |             # type: (int, int) -> int | 
 |             return a // b | 
 |         for i in range(-8, 8): | 
 |             for j in range(-8, 8): | 
 |                 if j != 0: | 
 |                     self.assertEqual(foo(i, j), i // j) | 
 |  | 
 |     def test_floordiv(self): | 
 |         funcs_template = dedent(''' | 
 |         def fn(): | 
 |             ten = {a_construct} | 
 |             ten_or_scalar = {b_construct} | 
 |             return ten // ten_or_scalar, torch.floor_divide(ten, ten_or_scalar) | 
 |         ''') | 
 |  | 
 |         lhs = ["torch.tensor([5.5, 3.2])", "torch.tensor([2, 2])", "torch.tensor([3, 2])"] | 
 |         rhs = ["1.5", "2", "4", "1.1"] + lhs | 
 |         for tensor in lhs: | 
 |             for tensor_or_scalar in rhs: | 
 |                 funcs_str = funcs_template.format(a_construct=tensor, b_construct=tensor_or_scalar) | 
 |                 scope = {} | 
 |                 execWrapper(funcs_str, globals(), scope) | 
 |                 cu = torch.jit.CompilationUnit(funcs_str) | 
 |                 f_script = cu.fn | 
 |                 f = scope['fn'] | 
 |                 self.assertEqual(f_script(), f()) | 
 |  | 
 |     def test_call_python_fn_from_script_fn(self): | 
 |         @torch.jit.ignore | 
 |         def python_fn(x): | 
 |             return torch.neg(x) | 
 |  | 
 |         @torch.jit.script | 
 |         def script_fn(x): | 
 |             return python_fn(x) + 1 | 
 |  | 
 |         # Note: the call to python_fn appears as `^python_fn()` and is called | 
 |         # as a PythonOp in the interpreter | 
 |         a = torch.tensor(1) | 
 |         self.assertEqual(script_fn(a), torch.tensor(0)) | 
 |         FileCheck().check("python_fn").run(str(script_fn.graph)) | 
 |  | 
 |     def test_call_python_mod_from_script_fn(self): | 
 |         class PythonModule(torch.nn.Module): | 
 |             def __init__(self): | 
 |                 super(PythonModule, self).__init__() | 
 |                 self.param = torch.nn.Parameter(torch.rand(5, 7)) | 
 |  | 
 |             def forward(self, x): | 
 |                 return torch.mm(x, self.param) | 
 |  | 
 |         pm = PythonModule() | 
 |  | 
 |         @torch.jit.script | 
 |         def script_fn(x): | 
 |             return pm(x) + 1 | 
 |  | 
 |         # Note: call to pm(x) appears as ^<python_value>() in the trace. | 
 |         # Parameters are NOT inlined. | 
 |         FileCheck().check("python_value").check("aten::add").run(str(script_fn.graph)) | 
 |  | 
 |     @_tmp_donotuse_dont_inline_everything | 
 |     def test_call_script_fn_from_script_fn(self): | 
 |         @torch.jit.script | 
 |         def script_fn1(x): | 
 |             return torch.neg(x) | 
 |  | 
 |         @torch.jit.script | 
 |         def script_fn(x): | 
 |             return script_fn1(x) + 1 | 
 |  | 
 |         FileCheck().check("prim::CallFunction").run(str(script_fn.graph)) | 
 |  | 
 |     def test_call_script_mod_from_script_fn(self): | 
 |         with self.assertRaisesRegex(RuntimeError, "Cannot call a ScriptModule that is not a submodule of the caller"): | 
 |             class ScriptMod(torch.jit.ScriptModule): | 
 |                 def __init__(self): | 
 |                     super(ScriptMod, self).__init__() | 
 |  | 
 |                 @torch.jit.script_method | 
 |                 def forward(self, x): | 
 |                     return torch.mm(x, torch.zeros([4, 3])) | 
 |  | 
 |             sm = ScriptMod() | 
 |  | 
 |             @torch.jit.script | 
 |             def script_fn(x): | 
 |                 return sm(x) + 1 | 
 |  | 
 |     def test_call_python_fn_from_script_module(self): | 
 |         @torch.jit.ignore | 
 |         def python_fn(x): | 
 |             return torch.neg(x) | 
 |  | 
 |         class ScriptMod(torch.jit.ScriptModule): | 
 |             def __init__(self): | 
 |                 super(ScriptMod, self).__init__() | 
 |                 self.param = torch.nn.Parameter(torch.rand(4, 3)) | 
 |  | 
 |             @torch.jit.script_method | 
 |             def forward(self, x): | 
 |                 return python_fn(torch.mm(x, self.param)) | 
 |  | 
 |         sm = ScriptMod() | 
 |         FileCheck().check("aten::mm").check("python_fn") \ | 
 |             .run(str(sm.forward.graph)) | 
 |  | 
 |     def test_call_python_mod_from_script_module(self): | 
 |         class PythonMod(torch.nn.Module): | 
 |             def __init__(self): | 
 |                 super(PythonMod, self).__init__() | 
 |                 self.param = torch.nn.Parameter(torch.rand(3, 5)) | 
 |  | 
 |             @torch.jit.ignore | 
 |             def forward(self, x): | 
 |                 return torch.mm(x, self.param) | 
 |  | 
 |         class ScriptMod(torch.jit.ScriptModule): | 
 |             def __init__(self): | 
 |                 super(ScriptMod, self).__init__() | 
 |                 self.param = torch.nn.Parameter(torch.rand(4, 3)) | 
 |                 self.pm = PythonMod() | 
 |  | 
 |             @torch.jit.script_method | 
 |             def forward(self, x): | 
 |                 return self.pm(torch.mm(x, self.param)) | 
 |  | 
 |         sm = ScriptMod() | 
 |         # Note: the call into PythonMod appears as ^forward(). Parameters | 
 |         # are NOT inlined | 
 |         FileCheck().check("aten::mm").check("forward").run(str(sm.graph)) | 
 |  | 
 |     @_tmp_donotuse_dont_inline_everything | 
 |     def test_call_script_fn_from_script_module(self): | 
 |         @torch.jit.script | 
 |         def script_fn(x): | 
 |             return torch.neg(x) | 
 |  | 
 |         class ScriptMod(torch.jit.ScriptModule): | 
 |             def __init__(self): | 
 |                 super(ScriptMod, self).__init__() | 
 |                 self.param = torch.nn.Parameter(torch.rand(4, 3)) | 
 |  | 
 |             @torch.jit.script_method | 
 |             def forward(self, x): | 
 |                 return script_fn(torch.mm(x, self.param)) | 
 |  | 
 |         sm = ScriptMod() | 
 |         graph = (sm.forward.graph) | 
 |         FileCheck().check("aten::mm").check("prim::CallFunction").run(str(graph)) | 
 |  | 
 |     @_tmp_donotuse_dont_inline_everything | 
 |     def test_call_script_mod_from_script_module(self): | 
 |         class ScriptMod1(torch.jit.ScriptModule): | 
 |             def __init__(self): | 
 |                 super(ScriptMod1, self).__init__() | 
 |                 self.param = torch.nn.Parameter(torch.rand(3, 5)) | 
 |  | 
 |             @torch.jit.script_method | 
 |             def forward(self, x): | 
 |                 return torch.mm(x, self.param) | 
 |  | 
 |         class ScriptMod(torch.jit.ScriptModule): | 
 |             def __init__(self): | 
 |                 super(ScriptMod, self).__init__() | 
 |                 self.param = torch.nn.Parameter(torch.rand(4, 3)) | 
 |                 self.tm = ScriptMod1() | 
 |  | 
 |             @torch.jit.script_method | 
 |             def forward(self, x): | 
 |                 return self.tm(torch.mm(x, self.param)) | 
 |  | 
 |         sm = ScriptMod() | 
 |         # Note: the parameters from both modules should appear in the flattened | 
 |         # input list to the graph. The mm op from ScriptMod1 should be properly | 
 |         # inlined | 
 |         # 3 % values in graph input lists, two mms in body | 
 |         FileCheck().check_count('%', 3).check(":").check_count("mm", 1).check("prim::CallMethod").run(str(sm.graph)) | 
 |  | 
 |     def test_module_with_params_called_fails(self): | 
 |         with self.assertRaisesRegex(RuntimeError, "Cannot call a ScriptModule that is not a submodule of the caller"): | 
 |             class ScriptMod(torch.jit.ScriptModule): | 
 |                 def __init__(self): | 
 |                     super(ScriptMod, self).__init__() | 
 |                     self.param = torch.nn.Parameter(torch.rand(3, 3)) | 
 |  | 
 |                 @torch.jit.script_method | 
 |                 def forward(self, x): | 
 |                     return torch.mm(x, self.param) | 
 |  | 
 |             sm = ScriptMod() | 
 |  | 
 |             @torch.jit.script | 
 |             def some_func(x): | 
 |                 return sm(x) | 
 |  | 
 |     def test_tuple_index_to_list(self): | 
 |         def test_non_constant_input(a): | 
 |             # type: (bool) -> int | 
 |             if a: | 
 |                 b = 1 | 
 |             else: | 
 |                 b = 0 | 
 |             c = (0, 1) | 
 |             return c[b] | 
 |  | 
 |         self.checkScript(test_non_constant_input, (True,)) | 
 |         self.checkScript(test_non_constant_input, (False,)) | 
 |  | 
 |         with self.assertRaisesRegex(RuntimeError, "because we cannot resolve the output type"): | 
 |             @torch.jit.script | 
 |             def test_non_constant_input(a): | 
 |                 # type: (bool) -> None | 
 |                 if a: | 
 |                     b = 1 | 
 |                 else: | 
 |                     b = 0 | 
 |                 c = (0, 1.1) | 
 |                 print(c[b]) | 
 |  | 
 |     def test_tuple_indexing(self): | 
 |         def tuple_index(a): | 
 |             if bool(a): | 
 |                 b = (1, 2) | 
 |             else: | 
 |                 b = (0, 2) | 
 |             return b[-2], b[1] | 
 |  | 
 |         self.checkScript(tuple_index, (torch.tensor([0]),)) | 
 |         self.checkScript(tuple_index, (torch.tensor([1]),)) | 
 |         self.checkScript(tuple_index, (torch.tensor([1]),), optimize=True) | 
 |         tuple_comp = torch.jit.script(tuple_index) | 
 |         FileCheck().check_count("TupleIndex", 2, exactly=True).run(str(tuple_comp.graph)) | 
 |  | 
 |         with self.assertRaisesRegex(RuntimeError, "index must be an integer"): | 
 |             @torch.jit.script | 
 |             def test_indexing_float(): | 
 |                 c = (1, 2) | 
 |                 return c[0.1] | 
 |  | 
 |         def test_indexing_out_of_bounds_pos(): | 
 |             c = (1, 2) | 
 |             return c[2] | 
 |  | 
 |         self.checkScriptRaisesRegex(test_indexing_out_of_bounds_pos, (), Exception, | 
 |                                     "out of range") | 
 |  | 
 |         def test_indexing_out_of_bounds_neg(): | 
 |             c = (1, 2) | 
 |             return c[-3] | 
 |  | 
 |         self.checkScriptRaisesRegex(test_indexing_out_of_bounds_pos, (), Exception, | 
 |                                     "out of range") | 
 |  | 
 |         def negative_index(): | 
 |             tup = (1, 2, 3, 4) | 
 |             return tup[-1] | 
 |  | 
 |         self.checkScript(negative_index, []) | 
 |  | 
 |         def really_negative_index(): | 
 |             tup = (1, 2, 3, 4) | 
 |             return tup[-100] | 
 |  | 
 |         self.checkScriptRaisesRegex(really_negative_index, [], Exception, "index out of range") | 
 |  | 
 |         def negative_slice(): | 
 |             tup = (1, 2, 3, 4) | 
 |             return tup[-3:4] | 
 |  | 
 |         self.checkScript(negative_slice, []) | 
 |  | 
 |         def really_slice_out_of_bounds(): | 
 |             tup = (1, 2, 3, 4) | 
 |             return tup[-300:4000] | 
 |  | 
 |         self.checkScript(really_slice_out_of_bounds, []) | 
 |  | 
 |     def test_namedtuple_attr(self): | 
 |         def f(x): | 
 |             return x.max(dim=1).indices + torch.max(x, dim=1).indices | 
 |  | 
 |         self.checkScript(f, (torch.rand(20, 20, 20),), optimize=True) | 
 |  | 
 |         with self.assertRaisesRegex(RuntimeError, "object has no attribute or method"): | 
 |             @torch.jit.script | 
 |             def g1(x): | 
 |                 return x.max(dim=1).unknown_symbol | 
 |  | 
 |         with self.assertRaisesRegex(RuntimeError, "object has no attribute or method"): | 
 |             @torch.jit.script | 
 |             def g2(x): | 
 |                 print((x, x, x).__doc__) | 
 |                 return x | 
 |  | 
 |     def test_tuple_len(self): | 
 |         @torch.jit.script | 
 |         def foo(): | 
 |             return len((1, "str", None)) | 
 |  | 
 |         self.assertEqual(foo(), 3) | 
 |  | 
 |         @torch.jit.script | 
 |         def test_indexing_end_out_of_bounds(): | 
 |             c = (1, 2) | 
 |             return c[2:10] | 
 |  | 
 |         self.assertEqual(test_indexing_end_out_of_bounds(), ()) | 
 |  | 
 |     def test_lower_nested_tuples(self): | 
 |         @torch.jit.script | 
 |         def test(): | 
 |             return ((1, 2), 3) | 
 |  | 
 |         self.run_pass('constant_propagation', test.graph) | 
 |         FileCheck().check("prim::Constant").check_not("TupleConstruct").run(test.graph) | 
 |         # fails if a tuple can't be lowered | 
 |         self.run_pass('lower_all_tuples', test.graph) | 
 |  | 
 |     def test_unwrap_optional_builtin(self): | 
 |         def test(x): | 
 |             # type: (Optional[int]) -> int | 
 |             x = torch.jit._unwrap_optional(x) | 
 |             x = x + x  # noqa: T484 | 
 |             return x | 
 |  | 
 |         self.checkScript(test, (3,)) | 
 |  | 
 |         with self.assertRaisesRegex(AssertionError, "Unwrapping null optional"): | 
 |             test(None) | 
 |  | 
 |         test_script = torch.jit.script(test) | 
 |         with self.assertRaisesRegex(RuntimeError, "Unwrapping null optional"): | 
 |             test_script(None) | 
 |  | 
 |         @torch.jit.script | 
 |         def test_test(): | 
 |             return torch.jit._unwrap_optional(1) | 
 |  | 
 |         with self.assertRaisesRegex(RuntimeError, r"could not be inferred from actual type None"): | 
 |             @torch.jit.script | 
 |             def test_no_type(): | 
 |                 # type: () -> int | 
 |                 return torch.jit._unwrap_optional(None) | 
 |  | 
 |     def test_indexing_error(self): | 
 |         with self.assertRaisesRegex(RuntimeError, "'int' object is not subscriptable"): | 
 |             @torch.jit.script | 
 |             def test_wrong_type(): | 
 |                 a = 8 | 
 |                 return a[0] | 
 |  | 
 |     def test_unsupported_builtin_error(self): | 
 |         with self.assertRaisesRegex(RuntimeError, | 
 |                                     "Python builtin <built-in function hypot> is currently"): | 
 |             @torch.jit.script | 
 |             def test_unsupported(a): | 
 |                 return math.hypot(a, 2.0) | 
 |  | 
 |     def test_annotated_script_fn(self): | 
 |         @torch.jit.script | 
 |         def foo(x, y, z): | 
 |             # type: (Tensor, Tuple[Tensor, Tensor, Tensor], Tuple[Tensor, Tuple[Tensor, Tensor]]) -> Tensor | 
 |             return x | 
 |  | 
 |         self.assertExpected(str(foo.schema)) | 
 |  | 
 |     def test_annotated_script_method(self): | 
 |         class SM(torch.jit.ScriptModule): | 
 |             @torch.jit.script_method | 
 |             def forward(self, x, y): | 
 |                 # type: (Tuple[Tensor, Tensor], Tensor) -> Tuple[Tensor, Tensor, Tensor] | 
 |                 return y, y, y | 
 |  | 
 |         sm = SM() | 
 |  | 
 |         self.assertExpectedStripMangled(str(sm.forward.schema)) | 
 |  | 
 |     def test_annotated_script_fn_return_mismatch(self): | 
 |         with self.assertRaisesRegex(RuntimeError, "but is actually of type"): | 
 |             @torch.jit.script | 
 |             def return_tup(x): | 
 |                 # type: (Tensor) -> Tuple[Tuple[Tensor, Tensor], Tensor] | 
 |                 return x, x  # noqa: T484 | 
 |  | 
 |     def test_annotated_script_fn_arg_mismatch(self): | 
 |         with self.assertRaisesRegex(RuntimeError, r"Arguments for call are not valid"): | 
 |             @torch.jit.script | 
 |             def tuple_arg(x): | 
 |                 # type: (Tuple[Tensor, Tensor]) -> Tensor | 
 |                 return x + 1  # noqa: T484 | 
 |  | 
 |     def test_script_non_tensor_args_outputs(self): | 
 |         @torch.jit.script | 
 |         def fn(x, y): | 
 |             # type: (Tensor, float) -> float | 
 |             return float((x + y).sum()) | 
 |  | 
 |         x = torch.ones(2, 2) | 
 |         z = fn(x, 1) | 
 |         self.assertIsInstance(z, float) | 
 |         self.assertEqual(z, 8.) | 
 |  | 
 |     @unittest.skip('https://github.com/pytorch/pytorch/issues/9595') | 
 |     def test_inline_and_run_annotated_script_fn(self): | 
 |         @torch.jit.script | 
 |         def to_inline(x, y): | 
 |             # type: (Tuple[Tensor, Tensor], Tensor) -> Tensor | 
 |             return y | 
 |  | 
 |         @torch.jit.script | 
 |         def some_func(x): | 
 |             return to_inline((x, x), x) | 
 |  | 
 |         x = torch.rand(3, 4) | 
 |         self.assertEqual(some_func(x), x) | 
 |  | 
 |     def test_file_format_serialization(self): | 
 |         filename = tempfile.mktemp() | 
 |         writer = torch._C.PyTorchFileWriter(filename) | 
 |         buffers = [os.urandom(size) for size in [random.randint(1, 100) for i in range(20)]] | 
 |         offsets = [] | 
 |         for i, buf in enumerate(buffers): | 
 |             writer.write_record(str(i), buf, len(buf)) | 
 |             offsets.append(i) | 
 |         serialized_offsets = pickle.dumps(offsets) | 
 |         writer.write_record("meta", serialized_offsets, len(serialized_offsets)) | 
 |         writer.write_end_of_file() | 
 |  | 
 |         reader = torch._C.PyTorchFileReader(filename) | 
 |         serialized_offsets_read = reader.get_record("meta") | 
 |         parsed_serialized_offsets = pickle.loads(serialized_offsets) | 
 |  | 
 |         for i, offset in enumerate(parsed_serialized_offsets): | 
 |             data = reader.get_record(str(offset)) | 
 |             assert(data == buffers[i]) | 
 |  | 
 |     # for each type, the input type annotation and corresponding return type annotation | 
 |     def type_input_return_pairs(self): | 
 |         return [ | 
 |             ('Tensor', 'Tensor'), | 
 |             ('torch.Tensor', 'Tensor'), | 
 |             ('str', 'str'), | 
 |             ('int', 'int'), | 
 |             ('bool', 'bool'), | 
 |             ('BroadcastingList3[float]', 'List[float]'), | 
 |             ('BroadcastingList2[int]', 'List[int]'), | 
 |             ('List[int]', 'List[int]'), | 
 |             ('Optional[int]', 'Optional[int]'), | 
 |         ] | 
 |  | 
 |     # replacing code input & return type pair | 
 |     def format_code(self, code, pair): | 
 |         return code.format(input=pair[0], output=pair[1]) | 
 |  | 
 |     # ***** Type annotation tests **** | 
 |     # Test combinations of: | 
 |     # {String frontend, Python AST Frontend} | 
 |     # {Python 3-style type annotations, MyPy-style type comments} | 
 |     # {Script method, Script function} | 
 |  | 
 |     #  String frontend , Python 3-style type annotations , Script function | 
 |     def test_annot_string_py3_fn(self): | 
 |         code = ''' | 
 |             def foo(x : {input}, y : Tuple[Tensor, Tensor]) -> Tuple[{output}, {output}]: | 
 |                 return x, x | 
 |         ''' | 
 |         test_str = [] | 
 |         for pair in self.type_input_return_pairs(): | 
 |             cu = torch.jit.CompilationUnit(self.format_code(code, pair)) | 
 |             test_str.append(str(cu.foo.schema)) | 
 |         self.assertExpected("\n".join(test_str) + "\n") | 
 |  | 
 |     #  String frontend , Python 3-style type annotations , Script method | 
 |     def test_annot_string_py3_method(self): | 
 |         class TestModule(torch.jit.ScriptModule): | 
 |             def __init__(self): | 
 |                 super(TestModule, self).__init__() | 
 |  | 
 |         code = ''' | 
 |             def foo(self, x : {input}, y : Tuple[Tensor, Tensor]) -> Tuple[{output}, {output}]: | 
 |                 return x, x | 
 |         ''' | 
 |         test_str = [] | 
 |         for pair in self.type_input_return_pairs(): | 
 |             # clear the class registry as we will be defining foo multiple times | 
 |             jit_utils.clear_class_registry() | 
 |             tm = TestModule() | 
 |             tm.define(self.format_code(code, pair)) | 
 |             test_str.append(str(tm.foo.schema)) | 
 |         self.assertExpectedStripMangled("\n".join(test_str) + "\n") | 
 |  | 
 |     #  String frontend , MyPy-style type comments , Script function | 
 |     def test_annot_string_mypy_fn(self): | 
 |         code = ''' | 
 |             def foo(x, y): | 
 |                 # type: ({input}, Tuple[Tensor, Tensor]) -> Tuple[{output}, {output}] | 
 |                 return x, x | 
 |         ''' | 
 |         test_str = [] | 
 |         for pair in self.type_input_return_pairs(): | 
 |             cu = torch.jit.CompilationUnit(self.format_code(code, pair)) | 
 |             test_str.append(str(cu.foo.schema)) | 
 |         self.assertExpectedStripMangled("\n".join(test_str) + "\n") | 
 |  | 
 |     #  String frontend , MyPy-style type comments , Script method | 
 |     def test_annot_string_mypy_method(self): | 
 |         class TestModule(torch.jit.ScriptModule): | 
 |             def __init__(self): | 
 |                 super(TestModule, self).__init__() | 
 |  | 
 |         code = ''' | 
 |         def foo(self, x, y): | 
 |             # type: ({input}, Tuple[Tensor, Tensor]) -> Tuple[{output}, {output}] | 
 |             return x, x | 
 |         ''' | 
 |  | 
 |         test_str = [] | 
 |         for pair in self.type_input_return_pairs(): | 
 |             # clear the class registry as we will be defining foo multiple times | 
 |             jit_utils.clear_class_registry() | 
 |             tm = TestModule() | 
 |             tm.define(self.format_code(code, pair)) | 
 |             test_str.append(str(tm.foo.schema)) | 
 |         self.assertExpectedStripMangled("\n".join(test_str) + "\n") | 
 |  | 
 |     #  Python AST Frontend , Python 3-style type annotations , Script function | 
 |     def test_annot_ast_py3_fn(self): | 
 |         code = dedent(''' | 
 |             from typing import Tuple, List, Optional | 
 |             from torch import Tensor | 
 |             from torch.jit.annotations import BroadcastingList2, BroadcastingList3 | 
 |             import torch | 
 |             @torch.jit.script | 
 |             def foo(x : {input}, y : Tuple[Tensor, Tensor]) -> Tuple[{output}, {output}]: | 
 |                 return x, x | 
 |         ''') | 
 |         test_str = [] | 
 |         for pair in self.type_input_return_pairs(): | 
 |             fn = jit_utils._get_py3_code(self.format_code(code, pair), 'foo') | 
 |             test_str.append(str(fn.schema)) | 
 |         self.assertExpectedStripMangled("\n".join(test_str) + "\n") | 
 |  | 
 |     def test_multiline_annot_ast_py3_fn(self): | 
 |         code = dedent(''' | 
 |             from typing import Tuple, List, Optional | 
 |             from torch import Tensor | 
 |             from torch.jit.annotations import BroadcastingList2, BroadcastingList3 | 
 |             import torch | 
 |             @torch.jit.script | 
 |             def foo(x,  # type: {input} | 
 |                     y   # type: Tuple[Tensor, Tensor] | 
 |                     ): | 
 |                 # type: (...) -> Tuple[{output}, {output}] | 
 |                 return x, x | 
 |         ''') | 
 |         test_str = [] | 
 |  | 
 |         for pair in self.type_input_return_pairs(): | 
 |             fn = jit_utils._get_py3_code(self.format_code(code, pair), 'foo') | 
 |             args = fn.schema.arguments | 
 |             returns = fn.schema.returns | 
 |             self.assertEqual(str(args[0].type), pair[1]) | 
 |             self.assertEqual(str(args[1].type), "Tuple[Tensor, Tensor]") | 
 |             self.assertEqual(str(returns[0].type), "Tuple[{}, {}]".format(pair[1], pair[1])) | 
 |  | 
 |     def test_bad_multiline_annotations(self): | 
 |         with self.assertRaisesRegex(RuntimeError, "Return type line"): | 
 |             @torch.jit.script | 
 |             def bad_type_line(a,  # type: Tensor | 
 |                               b,  # type: Tensor | 
 |                               c   # type: Tensor | 
 |                               ): | 
 |                 # type: (int, int, int) -> Tensor | 
 |                 # type: bad type line  # noqa: F723 | 
 |  | 
 |                 return a + b + c | 
 |  | 
 |         with self.assertRaisesRegex(RuntimeError, "Return type line"): | 
 |             @torch.jit.script | 
 |             def bad_return_line(a,  # type: Tensor | 
 |                                 b, | 
 |                                 c   # type: Tensor | 
 |                                 ): | 
 |                 # type: (int, int, int) -> Tensor | 
 |                 return a + b + c | 
 |  | 
 |         # TODO: this should be supported but is difficult to parse | 
 |         with self.assertRaisesRegex(RuntimeError, "Number of type annotations"): | 
 |             @torch.jit.script | 
 |             def missing_type(a,  # type: Tensor | 
 |                              b, | 
 |                              c   # type: Tensor | 
 |                              ): | 
 |                 # type: (...) -> Tensor | 
 |                 return a + b + c | 
 |  | 
 |     #  Python AST Frontend , Python 3-style type annotations , Script method | 
 |     def test_annot_ast_py3_method(self): | 
 |         code = dedent(''' | 
 |             from typing import Tuple, List, Optional | 
 |             from torch import Tensor | 
 |             from torch.jit.annotations import BroadcastingList2, \\ | 
 |                 BroadcastingList3 | 
 |             import torch | 
 |             class FooModule(torch.jit.ScriptModule): | 
 |                 @torch.jit.script_method | 
 |                 def foo(self, x : {input}, y : Tuple[Tensor, Tensor]) -> Tuple[{output}, {output}]: | 
 |                     return x, x | 
 |             instance = FooModule() | 
 |         ''') | 
 |  | 
 |         test_str = [] | 
 |         for pair in self.type_input_return_pairs(): | 
 |             fn = jit_utils._get_py3_code(self.format_code(code, pair), 'instance') | 
 |             test_str.append(str(fn.foo.schema)) | 
 |         self.assertExpectedStripMangled("\n".join(test_str) + "\n") | 
 |  | 
 |     #  Python AST Frontend , MyPy-style type comments , Script function | 
 |     def test_annot_ast_mypy_fn(self): | 
 |         code = dedent(''' | 
 |             import torch | 
 |             @torch.jit.script | 
 |             def foo(x, y): | 
 |                 # type: ({input}, Tuple[Tensor, Tensor]) -> Tuple[{output}, {output}] | 
 |                 return x, x | 
 |         ''') | 
 |  | 
 |         test_str = [] | 
 |         for pair in self.type_input_return_pairs(): | 
 |             fn = jit_utils._get_py3_code(self.format_code(code, pair), 'foo') | 
 |             test_str.append(str(fn.schema)) | 
 |         self.assertExpected("\n".join(test_str) + "\n") | 
 |  | 
 |     #  Python AST Frontend , MyPy-style type comments , Script method | 
 |     def test_annot_ast_mypy_method(self): | 
 |         code = dedent(''' | 
 |             import torch | 
 |             class FooModule(torch.jit.ScriptModule): | 
 |                 @torch.jit.script_method | 
 |                 def foo(self, x, y): | 
 |                     # type: ({input}, Tuple[Tensor, Tensor]) -> Tuple[{output}, {output}] | 
 |                     return x, x | 
 |             instance = FooModule() | 
 |         ''') | 
 |  | 
 |         test_str = [] | 
 |         for pair in self.type_input_return_pairs(): | 
 |             fn = jit_utils._get_py3_code(self.format_code(code, pair), 'instance') | 
 |             test_str.append(str(fn.foo.schema)) | 
 |         self.assertExpectedStripMangled("\n".join(test_str) + "\n") | 
 |  | 
 |     # Tests that "# type: ignore[*]" is supported in type lines and is | 
 |     # properly ignored. | 
 |     def test_mypy_type_ignore(self): | 
 |         @torch.jit.script | 
 |         def foo(x):  # type: ignore | 
 |             return x | 
 |  | 
 |         @torch.jit.script | 
 |         def bar(x):  # type: ignore[no-redef] | 
 |             return x | 
 |  | 
 |     def test_method_casts_script(self): | 
 |         cast_types = [ | 
 |             'byte', 'char', 'double', 'float', 'int', 'long', 'short' | 
 |         ] | 
 |  | 
 |         for cast_type in cast_types: | 
 |             cu = torch.jit.CompilationUnit(''' | 
 |             def cast_to(x): | 
 |                 return x.{cast_type}() | 
 |             '''.format(cast_type=cast_type)) | 
 |  | 
 |             x = torch.rand(3, 4, 5) * 128 | 
 |             cu_result = cu.cast_to(x) | 
 |             reference = getattr(x, cast_type)() | 
 |             self.assertEqual(cu_result, reference) | 
 |  | 
 |     def test_string_frontend_elif(self): | 
 |         code = ''' | 
 |             def func(niter): | 
 |                 # type: (int) | 
 |                 rv = 0 | 
 |                 for i in range(niter): | 
 |                     if i % 3 == 0 and i % 5 == 0: | 
 |                         rv += 35 | 
 |                     elif i % 3 == 0: | 
 |                         rv += 3 | 
 |                     elif i % 5 == 0: | 
 |                         rv += 5 | 
 |                     else: | 
 |                         rv += i | 
 |                 return rv | 
 |         ''' | 
 |  | 
 |         self.checkScript(dedent(code), (101,)) | 
 |  | 
 |     def test_module_parameters_and_buffers(self): | 
 |         weights = torch.randn(10, 10) | 
 |         bias = torch.randn(10) | 
 |         weights2 = torch.randn(10, 10) | 
 |         bias2 = torch.randn(10) | 
 |  | 
 |         class TestLinear(torch.nn.Module): | 
 |             def __init__(self, in_features, out_features): | 
 |                 super(TestLinear, self).__init__() | 
 |                 self.in_features = in_features | 
 |                 self.out_features = out_features | 
 |                 self.weight = torch.nn.Parameter(torch.empty(out_features, in_features)) | 
 |                 self.bias = torch.nn.Parameter(torch.empty(out_features)) | 
 |                 self.register_buffer('counter', torch.ones(out_features)) | 
 |                 self.reset_parameters() | 
 |  | 
 |             def reset_parameters(self): | 
 |                 torch.nn.init.kaiming_uniform_(self.weight, a=math.sqrt(5)) | 
 |                 if self.bias is not None: | 
 |                     fan_in, _ = torch.nn.init._calculate_fan_in_and_fan_out(self.weight) | 
 |                     bound = 1 / math.sqrt(fan_in) | 
 |                     torch.nn.init.uniform_(self.bias, -bound, bound) | 
 |  | 
 |             def forward(self, input): | 
 |                 return F.linear(input, self.weight, self.bias) + self.counter | 
 |  | 
 |         # Initialize a ScriptModule that uses the weak module above multiple times | 
 |         class Strong(torch.jit.ScriptModule): | 
 |             def __init__(self): | 
 |                 super(Strong, self).__init__() | 
 |                 self.fc1 = TestLinear(10, 10) | 
 |                 self.fc1.weight = torch.nn.Parameter(weights) | 
 |                 self.fc1.bias = torch.nn.Parameter(bias) | 
 |                 self.fc2 = TestLinear(10, 10) | 
 |                 self.fc2.weight = torch.nn.Parameter(weights2) | 
 |                 self.fc2.bias = torch.nn.Parameter(bias2) | 
 |  | 
 |             @torch.jit.script_method | 
 |             def forward(self, x): | 
 |                 return x + self.fc1(x) + self.fc1(x) + self.fc2(x) | 
 |  | 
 |         strong_mod = Strong() | 
 |  | 
 |         # Run same calculation as module | 
 |         inp = torch.ones(10) | 
 |         lin = torch.nn.Linear(10, 10) | 
 |         lin.weight = torch.nn.Parameter(weights) | 
 |         lin.bias = torch.nn.Parameter(bias) | 
 |         lin2 = torch.nn.Linear(10, 10) | 
 |         lin2.weight = torch.nn.Parameter(weights2) | 
 |         lin2.bias = torch.nn.Parameter(bias2) | 
 |         expected_result = inp + (lin(inp) + torch.ones(10)) * 2 + lin2(inp) + torch.ones(10) | 
 |  | 
 |         self.assertEqual(strong_mod(inp), expected_result) | 
 |         self.assertExportImportModule(strong_mod, (inp,)) | 
 |  | 
 |     def test_module_copying(self): | 
 |         class Submodule(torch.nn.Module): | 
 |             def __init__(self): | 
 |                 super(Submodule, self).__init__() | 
 |  | 
 |             def forward(self, x): | 
 |                 return x + 100 | 
 |  | 
 |         class Weak(torch.nn.Module): | 
 |             def __init__(self, in_features, out_features): | 
 |                 super(Weak, self).__init__() | 
 |                 self.weight = torch.nn.Parameter(torch.ones(out_features, in_features)) | 
 |                 self.bias = torch.nn.Parameter(torch.ones(out_features)) | 
 |                 self.register_buffer("buffer", torch.ones(out_features)) | 
 |                 self.submodule = Submodule() | 
 |  | 
 |             def forward(self, x): | 
 |                 return F.linear(x, self.weight, self.bias) \ | 
 |                     + self.buffer + self.submodule(x) | 
 |  | 
 |         class Strong(torch.jit.ScriptModule): | 
 |             def __init__(self, weak): | 
 |                 super(Strong, self).__init__() | 
 |                 self.weak = weak | 
 |  | 
 |             @torch.jit.script_method | 
 |             def forward(self, x): | 
 |                 return self.weak(x) | 
 |  | 
 |         inp = torch.ones(5, 5) * 5 | 
 |         weak_mod = Weak(5, 5) | 
 |         strong_mod = Strong(weak_mod) | 
 |  | 
 |         self.assertTrue(isinstance(strong_mod.weak, torch.jit.ScriptModule)) | 
 |         self.assertFalse(isinstance(weak_mod, torch.jit.ScriptModule)) | 
 |  | 
 |         self.assertIs(strong_mod.weak.weight, weak_mod.weight) | 
 |         self.assertIs(strong_mod.weak.buffer, weak_mod.buffer) | 
 |         # strong_mod.weak.submodule has been recursively scripted | 
 |         self.assertIsNot(strong_mod.weak.submodule, weak_mod.submodule) | 
 |  | 
 |         weak_mod.weight.data += torch.ones(5, 5) * 100 | 
 |         self.assertTrue(strong_mod(inp).allclose(weak_mod(inp))) | 
 |  | 
 |         # Re-assignment is not tracked | 
 |         weak_mod.weight = torch.nn.Parameter(torch.ones(5, 5) * 100) | 
 |         self.assertFalse(strong_mod(inp).allclose(weak_mod(inp))) | 
 |  | 
 |     def test_backend_cudnn_enabled(self): | 
 |         # Only test that this compiles | 
 |         @torch.jit.script | 
 |         def fn(x): | 
 |             if torch.backends.cudnn.enabled: | 
 |                 x = x + 2 | 
 |             else: | 
 |                 x = x + 3 | 
 |             return x | 
 |  | 
 |     def test_inplace_add(self): | 
 |  | 
 |         def foo(a, b): | 
 |             c = a + b | 
 |             c.add_(b) | 
 |             return c | 
 |         self.checkScript(foo, (torch.rand(3), torch.rand(3))) | 
 |  | 
 |     def test_add_out(self): | 
 |         def foo(a, b): | 
 |             c = a + b | 
 |             e = 2 * a | 
 |             torch.add(c, b, out=e) | 
 |             return e | 
 |         self.checkScript(foo, (torch.rand(3), torch.rand(3))) | 
 |  | 
 |     def test_tuple_error_msg(self): | 
 |         def fn(t: Any): | 
 |             if isinstance(t, tuple): | 
 |                 a, b = t | 
 |             return a + b | 
 |         with self.assertRaisesRegexWithHighlight(RuntimeError, "Provided tuple is not fully defined/refined", "t"): | 
 |             s = torch.jit.script(fn) | 
 |  | 
 |     def test_augmented_assign(self): | 
 |         def foo(a, b): | 
 |             a += b | 
 |             a -= b | 
 |             a /= b | 
 |             a *= b | 
 |             return a, b | 
 |         self.checkScript(foo, (torch.rand(3), torch.rand(3))) | 
 |  | 
 |     def test_ignored_props(self): | 
 |         class A(nn.Module): | 
 |             __jit_ignored_attributes__ = ["ignored", "ignored_return_val"] | 
 |  | 
 |             def __init__(self): | 
 |                 super().__init__() | 
 |  | 
 |             @property | 
 |             def ignored(self): | 
 |                 raise ValueError("shouldn't be called") | 
 |  | 
 |             @property | 
 |             def ignored_return_val(self): | 
 |                 return 1 | 
 |  | 
 |             @torch.jit.ignore | 
 |             def call(self): | 
 |                 return self.ignored_return_val | 
 |  | 
 |         f = torch.jit.script(A()) | 
 |         # jank way to test if there is no error | 
 |         self.assertTrue(isinstance(f, torch.jit.ScriptModule)) | 
 |         self.assertTrue(isinstance(f.call(), property)) | 
 |  | 
 |  | 
 |     def test_pass(self): | 
 |         def foo(x): | 
 |             # type: (bool) -> int | 
 |             for _i in range(3): | 
 |                 pass | 
 |             if x: | 
 |                 pass | 
 |             else: | 
 |                 pass | 
 |             return 3 | 
 |  | 
 |         self.checkScript(foo, (True,)) | 
 |  | 
 |     def test_lhs_indexing(self): | 
 |         def foo(a, b): | 
 |             a = a.clone() | 
 |             a[0] = b | 
 |             return a | 
 |         self.checkScript(foo, (torch.rand(2, 3), torch.rand(3))) | 
 |  | 
 |     def test_lhs_advanced_indexing_assignment(self): | 
 |         def foo(x, y): | 
 |             a = torch.exp(x) | 
 |             b = x == 1 | 
 |             a[b] = y[b] | 
 |             return a | 
 |         self.checkScript(foo, (torch.ones(4, 3), torch.ones(4, 3))) | 
 |  | 
 |     def test_lhs_advanced_indexing_augmented_assignment(self): | 
 |         def foo(x, y): | 
 |             a = torch.exp(x) | 
 |             b = x == 1 | 
 |             a[b] += y[b] | 
 |             return a | 
 |         self.checkScript(foo, (torch.ones(4, 3), torch.ones(4, 3))) | 
 |  | 
 |     def test_lhs_indexing_list(self): | 
 |         def foo(a, b): | 
 |             ls = [a] | 
 |             ls[0] = b | 
 |             return ls | 
 |         self.checkScript(foo, (torch.rand(2, 3), torch.rand(3))) | 
 |  | 
 |     def test_inplace_copy_script(self): | 
 |         def foo(x): | 
 |             a = torch.rand(3, 4) | 
 |             a.copy_(x) | 
 |             return a | 
 |         self.checkScript(foo, (torch.rand(3, 4),)) | 
 |  | 
 |     def test_lhs_indexing_increment(self): | 
 |         def foo(a, b): | 
 |             a[0] += b | 
 |             return a | 
 |         self.checkScript(foo, (torch.rand(2, 3), torch.rand(3))) | 
 |  | 
 |     def test_lhs_indexing_increment_list(self): | 
 |         def foo(a, b): | 
 |             a = a.clone() | 
 |             ls = [a, b] | 
 |             ls[0] += b | 
 |             return ls | 
 |         self.checkScript(foo, (torch.rand(2, 3), torch.rand(3))) | 
 |  | 
 |     def test_lhs_indexing_increment_list_prim(self): | 
 |         def foo(): | 
 |             ls = [1, 2, 3] | 
 |             ls[0] += 5 | 
 |             return ls | 
 |         self.checkScript(foo, ()) | 
 |  | 
 |     def test_lhs_indexing_multi(self): | 
 |         def foo(a, b): | 
 |             a = a.clone() | 
 |             foo, a[0], bar = (1, b, 3) | 
 |             return foo, a, bar | 
 |         self.checkScript(foo, (torch.rand(2, 3), torch.rand(3))) | 
 |  | 
 |     def test_bool_dispatch(self): | 
 |         with torch._jit_internal._disable_emit_hooks():  # TODO: Python print broadcasting list | 
 |             def kwarg_false(x): | 
 |                 # type: (Tensor) -> Tensor | 
 |                 return F.max_pool1d(x, 1, 1, return_indices=False) | 
 |             self.checkScript(kwarg_false, (torch.randn(3, 3, 3),)) | 
 |  | 
 |             def kwarg_true(x): | 
 |                 # type: (Tensor) -> Tuple[Tensor, Tensor] | 
 |                 return F.max_pool1d(x, 1, 1, return_indices=True) | 
 |             self.checkScript(kwarg_true, (torch.randn(3, 3, 3),)) | 
 |  | 
 |             def full_kwarg_false(x): | 
 |                 # type: (Tensor) -> Tensor | 
 |                 return F.max_pool1d(x, 1, 1, ceil_mode=False, return_indices=False) | 
 |             self.checkScript(full_kwarg_false, (torch.randn(3, 3, 3),)) | 
 |  | 
 |             def full_kwarg_true(x): | 
 |                 # type: (Tensor) -> Tuple[Tensor, Tensor] | 
 |                 return F.max_pool1d(x, 1, 1, ceil_mode=False, return_indices=True) | 
 |             self.checkScript(full_kwarg_true, (torch.randn(3, 3, 3),)) | 
 |  | 
 |             def use_default(x): | 
 |                 # type: (Tensor) -> Tensor | 
 |                 return F.max_pool1d(x, 1, 1) | 
 |             self.checkScript(use_default, (torch.randn(3, 3, 3),)) | 
 |  | 
 |             def arg_false(x): | 
 |                 # type: (Tensor) -> Tensor | 
 |                 return F.max_pool1d(x, 1, 1, 0, 1, False, False) | 
 |             self.checkScript(arg_false, (torch.randn(3, 3, 3),)) | 
 |  | 
 |             def arg_true(x): | 
 |                 # type: (Tensor) -> Tuple[Tensor, Tensor] | 
 |                 return F.max_pool1d(x, 1, 1, 0, 1, False, True) | 
 |             self.checkScript(arg_true, (torch.randn(3, 3, 3),)) | 
 |  | 
 |     def test_infer_size(self): | 
 |         from torch._C import _infer_size | 
 |  | 
 |         def fn(x, y): | 
 |             # type: (Tensor, Tensor) -> List[int] | 
 |             return _infer_size(x.size(), y.size()) | 
 |  | 
 |         self.checkScript(fn, (torch.ones(2, 4, 2), torch.ones(2, 4, 2))) | 
 |  | 
 |     def test_hash(self): | 
 |         def tester(fn, inputs): | 
 |             for x in inputs: | 
 |                 for y in inputs: | 
 |                     if x == y: | 
 |                         self.assertEqual(fn(x), fn(y)) | 
 |                     else: | 
 |                         self.assertNotEqual(fn(x), fn(y)) | 
 |  | 
 |         @torch.jit.script | 
 |         def int_hash(x): | 
 |             # type: (int) -> int | 
 |             return hash(x) | 
 |  | 
 |         @torch.jit.script | 
 |         def float_hash(x): | 
 |             # type: (float) -> int | 
 |             return hash(x) | 
 |  | 
 |         @torch.jit.script | 
 |         def str_hash(x): | 
 |             # type: (str) -> int | 
 |             return hash(x) | 
 |  | 
 |         tester(int_hash, (20, 21, 22)) | 
 |         tester(float_hash, (20.0, 21.00001, 22.443)) | 
 |         tester(str_hash, ("", "hello", "a")) | 
 |  | 
 |     def test_id(self): | 
 |         with self.assertRaisesRegex(RuntimeError, "Expected a value"): | 
 |             @torch.jit.script | 
 |             def test_id_scalars(): | 
 |                 return id(2) == id(None) | 
 |  | 
 |         @torch.jit.script | 
 |         class FooTest(object): | 
 |             def __init__(self, x): | 
 |                 self.foo = x | 
 |  | 
 |             def getFooTest(self): | 
 |                 return self.foo | 
 |  | 
 |         @torch.jit.script | 
 |         def test_id_class_types(): | 
 |             obj1 = FooTest(torch.tensor(3)) | 
 |             obj2 = FooTest(torch.tensor(2)) | 
 |             assert obj1 is not obj2 | 
 |             assert id(obj1) != id(obj2) | 
 |             assert id(obj1) != id(None) | 
 |             return True | 
 |  | 
 |         self.assertTrue(test_id_class_types()) | 
 |  | 
 |     def test_mutable_dce(self): | 
 |         @torch.jit.script | 
 |         def foo(): | 
 |             a = torch.rand(2, 3) | 
 |             a += torch.rand(2, 3) | 
 |             b = torch.rand(2, 3) | 
 |             b += torch.rand(2, 3) | 
 |             # b should be cleaned up but not a | 
 |             return a | 
 |  | 
 |         FileCheck().check_count("aten::rand", 2, exactly=True) \ | 
 |             .check_count("aten::add", 1, exactly=True).run(str(foo.graph)) | 
 |  | 
 |     def test_mutable_dce_block(self): | 
 |         @torch.jit.script | 
 |         def foo(): | 
 |             a = torch.rand(2, 3) | 
 |             a += torch.rand(2, 3) | 
 |             b = torch.rand(2, 3) | 
 |             if bool(a > torch.zeros(2, 3)): | 
 |                 b += torch.rand(2, 3) | 
 |                 a += torch.rand(2, 3) | 
 |             # a should be cleaned up but not b | 
 |             return b | 
 |  | 
 |         FileCheck().check("prim::If").check_count("aten::rand", 1, exactly=True) \ | 
 |             .run(str(foo.graph)) | 
 |  | 
 |     def test_mutable_dce_graph_input(self): | 
 |         @torch.jit.script | 
 |         def foo(a): | 
 |             a += torch.rand(2, 3) | 
 |             # shouldn't clean up `a` even though it's not used in the output | 
 |  | 
 |         FileCheck().check("aten::rand").check("aten::add").run(str(foo.graph)) | 
 |  | 
 |     def test_mutable_dce_list(self): | 
 |         @torch.jit.script | 
 |         def foo(a): | 
 |             l = [] | 
 |             l.append(a) | 
 |             c = l[0] | 
 |             b = torch.rand(2, 3) | 
 |             c += torch.rand(2, 3) | 
 |             return b | 
 |  | 
 |         # c does not get cleaned up because there is a wildcard + mutation | 
 |         FileCheck().check_count("aten::rand", 2, exactly=True).run(str(foo.graph)) | 
 |  | 
 |     def test_mutable_dce_loop(self): | 
 |         @torch.jit.script | 
 |         def foo(a): | 
 |             l = [] | 
 |             l.append(a) | 
 |             i = 0 | 
 |             b = torch.rand(2, 3) | 
 |             while i < 1: | 
 |                 dead = torch.rand(2, 3) | 
 |                 c = l[0] | 
 |                 c += torch.rand(2, 3) | 
 |                 i += 1 | 
 |             return b | 
 |  | 
 |         FileCheck().check("prim::Loop").check_not("aten::rand").check("aten::__getitem__") \ | 
 |             .check_count("aten::rand", 1, exactly=True).run(str(foo.graph)) | 
 |  | 
 |     def test_mutable_dce_indirect_wildcards(self): | 
 |         def fn(): | 
 |             x = torch.ones(2, 3) | 
 |             x_1 = x.view(-1) | 
 |             l = [] | 
 |             l.append(x_1) | 
 |             x_view = l[0] | 
 |             x.add_(torch.ones(2, 3)) | 
 |             return x_view | 
 |         self.checkScript(fn, ()) | 
 |  | 
 |     def test_mutable_dce_indirect_wildcard_write(self): | 
 |         def fn(): | 
 |             indexes = torch.jit.annotate(List[Tensor], []) | 
 |             word_ids = torch.zeros(10, dtype=torch.int32) | 
 |             word_ids[1] = 1 | 
 |             indexes.append(word_ids) | 
 |  | 
 |             return word_ids | 
 |         self.checkScript(fn, ()) | 
 |  | 
 |     def test_mutable_dce_wildcards(self): | 
 |         def fn(): | 
 |             x = torch.ones(2, 3) | 
 |             l = [] | 
 |             l.append(x) | 
 |             x_view = l[0] | 
 |             x.add_(torch.ones(2, 3)) | 
 |             return x_view | 
 |  | 
 |         self.checkScript(fn, (), profiling=ProfilingMode.SIMPLE) | 
 |  | 
 |     def test_cpp_function_tensor_str(self): | 
 |         x = torch.randn(2, 2) | 
 |         scale = torch.randn(2, 2, requires_grad=True) | 
 |         shift = torch.randn(2, 2, requires_grad=True) | 
 |  | 
 |         @torch.jit.script | 
 |         def fn(x, scale, shift): | 
 |             return scale * x + shift | 
 |  | 
 |         with self.capture_stdout() as captured: | 
 |             print(fn(x, scale, shift)) | 
 |  | 
 |     def test_string_index(self): | 
 |         def fn(x): | 
 |             # type: (str) | 
 |             return x[2], x[-1] | 
 |  | 
 |         self.checkScript(fn, ("abcde",)) | 
 |  | 
 |     def test_ord(self): | 
 |         def fn(x): | 
 |             # type: (str) -> int | 
 |             return ord(x) | 
 |  | 
 |         self.checkScript(fn, ("h")) | 
 |         self.checkScript(fn, ("y")) | 
 |  | 
 |         def index_str_to_tensor(s): | 
 |             # type: (str) -> Tensor | 
 |             return torch.tensor(ord(s))  # noqa: T484 | 
 |  | 
 |         s = u'\u00a3'.encode('utf8')[:1] | 
 |         self.checkScript(index_str_to_tensor, (s,)) | 
 |  | 
 |     def test_chr(self): | 
 |         def fn(x): | 
 |             # type: (int) -> str | 
 |             return chr(x) | 
 |  | 
 |         self.checkScript(fn, (1,)) | 
 |         self.checkScript(fn, (97,)) | 
 |  | 
 |     def test_round(self): | 
 |         def round_float(x): | 
 |             # type: (float) -> float | 
 |             return round(x) | 
 |  | 
 |         def round_int(x): | 
 |             # type: (int) -> float | 
 |             return round(x) | 
 |  | 
 |         self.checkScript(round_float, (1.5,)) | 
 |         self.checkScript(round_int, (2,)) | 
 |  | 
 |     def test_convert_base(self): | 
 |         def test_hex(x): | 
 |             # type: (int) -> str | 
 |             return hex(x) | 
 |  | 
 |         def test_oct(x): | 
 |             # type: (int) -> str | 
 |             return oct(x) | 
 |  | 
 |         def test_bin(x): | 
 |             # type: (int) -> str | 
 |             return bin(x) | 
 |  | 
 |         numbers = [-1000, -10, 0, 1, 10, 2343] | 
 |         for n in numbers: | 
 |             self.checkScript(test_bin, (n,)) | 
 |             self.checkScript(test_oct, (n,)) | 
 |             self.checkScript(test_hex, (n,)) | 
 |  | 
 |     @unittest.skipIf(IS_WINDOWS or IS_SANDCASTLE, "NYI: TemporaryFileName support for Windows or Sandcastle") | 
 |     def test_get_set_state(self): | 
 |         class Root(torch.jit.ScriptModule): | 
 |             __constants__ = ['number'] | 
 |  | 
 |             def __init__(self, number): | 
 |                 super(Root, self).__init__() | 
 |                 self.register_buffer('buffer1', torch.ones(2, 2)) | 
 |                 self.register_buffer('buffer2', torch.ones(2, 2)) | 
 |                 self.number = number | 
 |  | 
 |             @torch.jit.script_method | 
 |             def __getstate__(self): | 
 |                 return (self.buffer1, self.buffer2, 74, self.training) | 
 |  | 
 |             @torch.jit.script_method | 
 |             def __setstate__(self, state): | 
 |                 self.buffer1 = state[0] + 10 | 
 |                 self.buffer2 = state[1] + 10 | 
 |                 self.training = state[3] | 
 |  | 
 |         class M(torch.jit.ScriptModule): | 
 |             __constants__ = ['number'] | 
 |  | 
 |             def __init__(self, number, submodule): | 
 |                 super(M, self).__init__() | 
 |                 self.register_buffer('buffer1', torch.ones(2, 2)) | 
 |                 self.register_buffer('buffer2', torch.ones(2, 2)) | 
 |                 self.number = number | 
 |                 self.submodule = submodule | 
 |  | 
 |             @torch.jit.script_method | 
 |             def __getstate__(self): | 
 |                 return (self.buffer1, self.buffer2, 74, self.submodule, self.training) | 
 |  | 
 |             @torch.jit.script_method | 
 |             def __setstate__(self, state): | 
 |                 self.buffer1 = state[0] + 10 | 
 |                 self.buffer2 = state[1] + 10 | 
 |                 self.submodule = state[3] | 
 |                 self.training = state[4] | 
 |  | 
 |         with TemporaryFileName() as fname: | 
 |             m = M(23, submodule=Root(99)) | 
 |             m.save(fname) | 
 |             loaded = torch.jit.load(fname) | 
 |  | 
 |         # Check original module | 
 |         self.assertEqual(m.buffer1, torch.ones(2, 2)) | 
 |         self.assertEqual(m.buffer2, torch.ones(2, 2)) | 
 |  | 
 |         # Check top level module | 
 |         self.assertEqual(loaded.buffer1, torch.ones(2, 2) + 10) | 
 |         self.assertEqual(loaded.buffer2, torch.ones(2, 2) + 10) | 
 |  | 
 |         # Check submodule | 
 |         self.assertEqual(loaded.submodule.buffer1, torch.ones(2, 2) + 10) | 
 |         self.assertEqual(loaded.submodule.buffer2, torch.ones(2, 2) + 10) | 
 |  | 
 |         # Check simpler module | 
 |         class NoArgState(torch.nn.Module): | 
 |             def __init__(self): | 
 |                 super(NoArgState, self).__init__() | 
 |                 self.register_buffer('buffer1', torch.ones(2, 2)) | 
 |                 self.register_buffer('buffer2', torch.ones(2, 2)) | 
 |  | 
 |             def forward(self): | 
 |                 pass | 
 |  | 
 |             @torch.jit.export | 
 |             def __getstate__(self): | 
 |                 return 5, self.training | 
 |  | 
 |             @torch.jit.export | 
 |             def __setstate__(self, state): | 
 |                 self.buffer1 = torch.ones(2, 2) + state[0] | 
 |                 self.buffer2 = torch.ones(2, 2) + 10 | 
 |                 self.training = state[1] | 
 |  | 
 |         with TemporaryFileName() as fname: | 
 |             m = torch.jit.script(NoArgState()) | 
 |             m.save(fname) | 
 |             loaded = torch.jit.load(fname) | 
 |             self.assertEqual(loaded.buffer1, torch.ones(2, 2) + 5) | 
 |             self.assertEqual(loaded.buffer2, torch.ones(2, 2) + 10) | 
 |  | 
 |  | 
 |  | 
 |     def test_string_slicing(self): | 
 |         def fn1(x): | 
 |             # type: (str) -> str | 
 |             return x[1:3] | 
 |  | 
 |         def fn2(x): | 
 |             # type: (str) -> str | 
 |             return x[-1:3] | 
 |  | 
 |         def fn3(x): | 
 |             # type: (str) -> str | 
 |             return x[3:1] | 
 |  | 
 |         def fn4(x): | 
 |             # type: (str) -> str | 
 |             return x[3:100] | 
 |  | 
 |         self.checkScript(fn1, ("abcdefghi",)) | 
 |         self.checkScript(fn2, ("abcdefghi",)) | 
 |         self.checkScript(fn3, ("abcdefghi",)) | 
 |         self.checkScript(fn4, ("abcdefghi",)) | 
 |  | 
 |     def test_early_return_closure(self): | 
 |         code = dedent(''' | 
 |             def tanh(self): | 
 |                 output = torch.tanh(self) | 
 |                 def backward(grad_output): | 
 |                     pass | 
 |                 return output, backward | 
 |         ''') | 
 |         cu = torch.jit.CompilationUnit(code) | 
 |         g = cu.tanh.graph | 
 |         FileCheck().check_count("prim::Closure_0", 2).check("NoneType = prim::Constant") \ | 
 |                    .check_next("return").run(g) | 
 |  | 
 |         code = dedent(''' | 
 |             def tanh(self): | 
 |                 output = torch.tanh(self) | 
 |                 def backward(grad_output): | 
 |                     a = 1 | 
 |                     if output: | 
 |                         return 1 | 
 |                     else: | 
 |                         a = 2 | 
 |                     return a | 
 |                 return output, backward | 
 |         ''') | 
 |         cu = torch.jit.CompilationUnit(code) | 
 |         g = cu.tanh.graph | 
 |         FileCheck().check_count("prim::Closure_0", 2).check("int = prim::If") \ | 
 |                    .run(g) | 
 |  | 
 |         code = dedent(''' | 
 |             def loop_in_closure(self): | 
 |                 output = torch.tanh(self) | 
 |                 def backward(grad_output): | 
 |                     for i in range(3): | 
 |                         return 1 | 
 |                     return 4 | 
 |                 return output, backward | 
 |         ''') | 
 |         cu = torch.jit.CompilationUnit(code) | 
 |         fc = FileCheck() | 
 |         fc.check("prim::Closure").check("(Tensor, NoneType) = prim::TupleConstruct") | 
 |         # Loop then two if's added in exit transform | 
 |         fc.check("prim::Closure").check("prim::Loop").check_count("prim::If", 2) | 
 |         fc.run(cu.loop_in_closure.graph) | 
 |  | 
 |         code = dedent(''' | 
 |             def tanh(self): | 
 |                 output = torch.tanh(self) | 
 |                 def backward(grad_output): | 
 |                     if 1 == 1: | 
 |                         return 1 | 
 |                     else: | 
 |                         return 1. | 
 |                 return output, backward | 
 |         ''') | 
 |         with self.assertRaisesRegex(RuntimeError, "returned a value of type int but"): | 
 |             cu = torch.jit.CompilationUnit(code) | 
 |  | 
 |     @_inline_everything | 
 |     def test_early_return_fork_join(self): | 
 |         @torch.jit.script | 
 |         def foo(x): | 
 |             if x.dim() == 2: | 
 |                 return torch.neg(x), x | 
 |             else: | 
 |                 return torch.neg(x), x + 1 | 
 |  | 
 |         x = torch.rand(3, 4) | 
 |  | 
 |         @torch.jit.script | 
 |         def wait_script(x): | 
 |             fut = torch.jit._fork(foo, x) | 
 |             y_hat = foo(x) | 
 |             y = torch.jit._wait(fut) | 
 |             return y, y_hat | 
 |  | 
 |         FileCheck().check("with prim::fork").check("prim::If").check("return")\ | 
 |                    .run(wait_script.graph) | 
 |  | 
 |     def test_early_return_type_refinement(self): | 
 |         @torch.jit.script | 
 |         def test(x): | 
 |             # type: (Optional[int]) -> int | 
 |             if x is None: | 
 |                 return 1 | 
 |             else: | 
 |                 return x | 
 |         self.assertEqual(test(None), 1) | 
 |         self.assertEqual(test(2), 2) | 
 |  | 
 |     def test_exceptions_with_control_flow(self): | 
 |         def test_num_ifs(func, num_ifs): | 
 |             g = torch.jit.script(func).graph | 
 |             FileCheck().check_count("prim::If", num_ifs, exactly=True).run(g) | 
 |  | 
 |         def no_guard_ifs_added(x): | 
 |             # type: (int) -> int | 
 |             if x == 1: | 
 |                 return 1 | 
 |             else: | 
 |                 if x == 2: | 
 |                     raise RuntimeError("hi") | 
 |                 else: | 
 |                     raise RuntimeError("hi") | 
 |  | 
 |         self.checkScript(no_guard_ifs_added, (1,)) | 
 |         self.checkScriptRaisesRegex(no_guard_ifs_added, (2,), Exception, "") | 
 |         test_num_ifs(no_guard_ifs_added, 2) | 
 |  | 
 |         # FUNCTION LOOKS LIKE: | 
 |         # graph(%x.1 : int): | 
 |         #   %7 : str = prim::Constant[value="Exception"]() | 
 |         #   %2 : int = prim::Constant[value=1]() | 
 |         #   %5 : int = prim::Constant[value=2]() | 
 |         #   %19 : int = prim::Uninitialized() | 
 |         #   %3 : bool = aten::eq(%x.1, %2) | 
 |         #   %20 : int = prim::If(%3) | 
 |         #     block0(): | 
 |         #       -> (%2) | 
 |         #     block1(): | 
 |         #       %6 : bool = aten::eq(%x.1, %5) | 
 |         #        = prim::If(%6) | 
 |         #         block0(): | 
 |         #            = prim::RaiseException(%7) | 
 |         #           -> () | 
 |         #         block1(): | 
 |         #            = prim::RaiseException(%7) | 
 |         #           -> () | 
 |         #       -> (%19) | 
 |         #   return (%20) | 
 |  | 
 |         def no_ifs_added(x): | 
 |             # type: (int) -> int | 
 |             if x < 0: | 
 |                 raise RuntimeError("hi") | 
 |             return x | 
 |  | 
 |         self.checkScript(no_ifs_added, (1,)) | 
 |         self.checkScriptRaisesRegex(no_ifs_added, (-2,), Exception, "") | 
 |         test_num_ifs(no_ifs_added, 1) | 
 |  | 
 |         def test_if_might(x): | 
 |             # type: (int) | 
 |             if x > 0: | 
 |                 if x == 1: | 
 |                     return 1 | 
 |                 else: | 
 |                     a = 2 | 
 |             else: | 
 |                 raise RuntimeError("hi") | 
 |             return a + 2 | 
 |  | 
 |         self.checkScript(test_if_might, (1,)) | 
 |         self.checkScript(test_if_might, (3,)) | 
 |         self.checkScriptRaisesRegex(no_ifs_added, (-2,), Exception, "") | 
 |         test_num_ifs(test_if_might, 3)  # one if added to guard a + 2 | 
 |  | 
 |         def test_loop_no_escape(x): | 
 |             # type: (int) | 
 |             if x >= 0: | 
 |                 for i in range(x): | 
 |                     raise RuntimeError("hi") | 
 |             else: | 
 |                 return 5 | 
 |             return x + 3 | 
 |  | 
 |         self.checkScript(test_loop_no_escape, (0,)) | 
 |         self.checkScript(test_loop_no_escape, (-1,)) | 
 |         self.checkScriptRaisesRegex(test_loop_no_escape, (1,), Exception, "") | 
 |  | 
 |         # if guard gets optimized away | 
 |         test_num_ifs(test_loop_no_escape, 1) | 
 |  | 
 |         def test_loop_exception_with_continue(x): | 
 |             # type: (int) | 
 |             i = 0 | 
 |             for i in range(5): | 
 |                 if i == x: | 
 |                     raise RuntimeError("hi") | 
 |                 else: | 
 |                     continue | 
 |                 print(i) | 
 |             return i + 5 | 
 |  | 
 |         self.checkScript(test_loop_exception_with_continue, (-1,)) | 
 |         self.checkScriptRaisesRegex(test_loop_exception_with_continue, (1,), Exception, "") | 
 |         test_num_ifs(test_loop_exception_with_continue, 1)  # no ifs added to guard print | 
 |  | 
 |  | 
 |     def test_exception_exits_closure(self): | 
 |         code = dedent(''' | 
 |             def no_return_func(self): | 
 |                 # type: (Tensor) -> Tensor | 
 |                 output = torch.tanh(self) | 
 |                 def backward(grad_output): | 
 |                     raise RuntimeError("Hi") | 
 |         ''') | 
 |         with self.assertRaisesRegex(RuntimeError, "does not return along all"): | 
 |             cu = torch.jit.CompilationUnit(code) | 
 |  | 
 |         code = dedent(''' | 
 |             def test_exit_pair_reset(x): | 
 |                 # type: (int) -> int | 
 |                 if x > 0: | 
 |                     a = 0 | 
 |                     def backward(grad_output): | 
 |                         raise RuntimeError("Hi") | 
 |                     a = a + 1 | 
 |                 else: | 
 |                     return x | 
 |                 return a + 1 | 
 |         ''') | 
 |         func = torch.jit.CompilationUnit(code).test_exit_pair_reset | 
 |         self.assertEqual(func(1,), 2) | 
 |         self.assertEqual(func(-1,), -1) | 
 |         # final a + 1 gets inlined into the first branch and optimized away | 
 |         FileCheck().check_count("prim::If", 1, exactly=True).run(func.graph) | 
 |  | 
 |     def test_non_final_return(self): | 
 |         def simple(x): | 
 |             if bool(x > 3): | 
 |                 return x + 1 | 
 |             else: | 
 |                 return x + 2 | 
 |             raise RuntimeError("nope") | 
 |  | 
 |         def nest(x): | 
 |             x = x + 1 | 
 |             if bool(x > 3): | 
 |                 if bool(x > 4): | 
 |                     x += 1 | 
 |                 return x + 1 | 
 |             else: | 
 |                 return x + 2 | 
 |  | 
 |         def early_ret(x): | 
 |             x = x + 1 | 
 |             if bool(x > 3): | 
 |                 return x + 1 | 
 |             x = x + 1 | 
 |             return x + 2 | 
 |  | 
 |         def nest_early_ret(x): | 
 |             x = x + 1 | 
 |             if bool(x > 3): | 
 |                 if bool(x > 4): | 
 |                     return x + 2 | 
 |                 return x + 1 | 
 |             x = x + 1 | 
 |             return x + 2 | 
 |  | 
 |         def not_early_ret(x): | 
 |             s = "" | 
 |             if bool(x > 3): | 
 |                 if bool(x > 4): | 
 |                     return 1, s | 
 |                 s += "foo" | 
 |             else: | 
 |                 s += "5" | 
 |             s += "hi" | 
 |             return 7, s | 
 |  | 
 |         def not_total_ret(x): | 
 |             s = "" | 
 |             if bool(x > 3): | 
 |                 if bool(x > 4): | 
 |                     return 1, s | 
 |                 else: | 
 |                     return 2, s | 
 |             else: | 
 |                 s += "5" | 
 |             return 7, s | 
 |  | 
 |         for i in range(3): | 
 |             for func in [simple, nest, early_ret, nest_early_ret, not_early_ret, | 
 |                          not_total_ret]: | 
 |                 self.checkScript(func, (torch.tensor(2.5 + i),)) | 
 |  | 
 |         def vars_used_after_ret(x): | 
 |             # type: (int) -> int | 
 |             if x == 0: | 
 |                 return x | 
 |             else: | 
 |                 y = 2 | 
 |                 z = 3 | 
 |             return x + y * z | 
 |  | 
 |         self.checkScript(vars_used_after_ret, (1,)) | 
 |         self.checkScript(vars_used_after_ret, (0,)) | 
 |  | 
 |         def complicated(x): | 
 |             # type: (int) -> int | 
 |             if x: | 
 |                 if x == 2: | 
 |                     return 1 | 
 |                     assert 1 == 2 | 
 |                 else: | 
 |                     if x == 3: | 
 |                         return 2 | 
 |                         assert 1 == 2 | 
 |                     else: | 
 |                         a = 2 | 
 |                         b = 3 | 
 |             else: | 
 |                 a = 4 | 
 |                 b = 1 | 
 |             return a + b | 
 |             assert 1 == 2 | 
 |  | 
 |         for i in range(4): | 
 |             self.checkScript(complicated, (i,)) | 
 |  | 
 |     def test_partial_returns(self): | 
 |         with self.assertRaisesRegex(RuntimeError, "does not return along all"): | 
 |             @torch.jit.script | 
 |             def no_ret(): | 
 |                 # type: () -> int | 
 |                 pass | 
 |  | 
 |         with self.assertRaisesRegex(RuntimeError, "does not return along all"): | 
 |             @torch.jit.script | 
 |             def partial(x): | 
 |                 # type: (Tensor) -> int | 
 |                 if x: | 
 |                     return 1 | 
 |  | 
 |         with self.assertRaisesRegex(RuntimeError, "does not return along all"): | 
 |             @torch.jit.script | 
 |             def typed_none(): | 
 |                 # type: () -> Optional[int] | 
 |                 pass | 
 |  | 
 |         @torch.jit.script | 
 |         def none_ret(): | 
 |             pass | 
 |  | 
 |         self.assertIs(none_ret(), None) | 
 |         FileCheck().check(": None").run(none_ret.graph) | 
 |  | 
 |     def test_early_returns_loops(self): | 
 |         def nest_while_ret(x): | 
 |             # type: (int) -> int | 
 |             y = 4 | 
 |             while x < 4: | 
 |                 if x < 3: | 
 |                     return y | 
 |                 else: | 
 |                     y = y + 1 | 
 |                     break | 
 |                 y = y + 2 | 
 |             y = y + 1 | 
 |             return y | 
 |  | 
 |         self.checkScript(nest_while_ret, (2,)) | 
 |         self.checkScript(nest_while_ret, (3,)) | 
 |         self.checkScript(nest_while_ret, (4,)) | 
 |  | 
 |         def loop_ret(x, y): | 
 |             # type: (int, int) -> (int) | 
 |             i = 0 | 
 |             for i in range(x): | 
 |                 if x == y: | 
 |                     return x + y | 
 |                 i = i + y | 
 |             i = i - 1 | 
 |             return i | 
 |  | 
 |         self.checkScript(loop_ret, (3, 3)) | 
 |         self.checkScript(loop_ret, (2, 3)) | 
 |         self.checkScript(loop_ret, (3, 1)) | 
 |  | 
 |         def test_will_ret(y): | 
 |             # type: (int) -> int | 
 |             for i in range(y): | 
 |                 return 2 | 
 |             return 1 | 
 |  | 
 |         self.checkScript(test_will_ret, (0,)) | 
 |         self.checkScript(test_will_ret, (1,)) | 
 |  | 
 |         def test_loop_nest_ret(y): | 
 |             # type: (int) -> int | 
 |             for i in range(y): | 
 |                 for i in range(y - 2): | 
 |                     return 10 | 
 |                 return 5 | 
 |             return 0 | 
 |  | 
 |         self.checkScript(test_loop_nest_ret, (0,)) | 
 |         self.checkScript(test_loop_nest_ret, (1,)) | 
 |         self.checkScript(test_loop_nest_ret, (2,)) | 
 |  | 
 |     def test_nn_init(self): | 
 |         tests = ( | 
 |             ('constant_', (lambda: (torch.ones(2, 2), 2.5)), "Tensor, float"), | 
 |             ('ones_', (lambda: (torch.ones(2, 2),)), "Tensor"), | 
 |             ('zeros_', (lambda: (torch.ones(2, 2),)), "Tensor"), | 
 |             ('uniform_', (lambda: (torch.ones(2, 2),)), "Tensor"), | 
 |             ('normal_', (lambda: (torch.ones(2, 2),)), "Tensor"), | 
 |             ('xavier_normal_', (lambda: (torch.ones(2, 2),)), "Tensor"), | 
 |             ('xavier_uniform_', (lambda: (torch.ones(2, 2),)), "Tensor"), | 
 |         ) | 
 |  | 
 |         for name, args_fn, type_str in tests: | 
 |             # Build test code | 
 |             arg_str = ', '.join([chr(i + ord('a')) for i in range(len(args_fn()))]) | 
 |  | 
 |             code = dedent(''' | 
 |                 def test({arg_str}): | 
 |                     # type: ({type_str}) | 
 |                     return torch.nn.init.{name}({arg_str}) | 
 |             ''').format(arg_str=arg_str, type_str=type_str, name=name) | 
 |             cu = torch.jit.CompilationUnit(code) | 
 |  | 
 |             # Compare functions | 
 |             init_fn = getattr(torch.nn.init, name) | 
 |             script_out = self.runAndSaveRNG(cu.test, args_fn()) | 
 |             eager_out = self.runAndSaveRNG(init_fn, args_fn()) | 
 |             self.assertEqual(script_out, eager_out) | 
 |  | 
 |             FileCheck().check_not("prim::PythonOp").run(cu.test.graph) | 
 |  | 
 |     def test_early_return_rewrite(self): | 
 |         def test_foo(x: bool): | 
 |             if x: | 
 |                 return 1 | 
 |             return 2 | 
 |  | 
 |         self.checkScript(test_foo, (True,)) | 
 |         self.checkScript(test_foo, (False,)) | 
 |         FileCheck().check_count("prim::If", 1, exactly=True).run(torch.jit.script(test_foo).graph) | 
 |  | 
 |         def test_multiple(x: int): | 
 |             if x == 5: | 
 |                 return x * x | 
 |             else: | 
 |                 y = 2 * x | 
 |  | 
 |             z = y * 2 | 
 |             if z == 8: | 
 |                 return 1 | 
 |  | 
 |             if z != 16: | 
 |                 z = z - 2 | 
 |                 abc = 4 | 
 |             else: | 
 |                 return 3 | 
 |  | 
 |             z = z * abc | 
 |             return z * z * z | 
 |  | 
 |         self.checkScript(test_multiple, (5,)) | 
 |         self.checkScript(test_multiple, (2,)) | 
 |         self.checkScript(test_multiple, (4,)) | 
 |         self.checkScript(test_multiple, (3,)) | 
 |         self.checkScript(test_multiple, (10,)) | 
 |  | 
 |         graph = torch.jit.script(test_multiple).graph | 
 |         FileCheck().check_count("prim::If", 3, exactly=True).run(graph) | 
 |  | 
 |     def test_is_scripting_metacompile(self): | 
 |         @torch.jit.script | 
 |         def foo(): | 
 |             if torch.jit.is_scripting(): | 
 |                 return 1 | 
 |             else: | 
 |                 print("hello") + 2  # will not be compiled | 
 |  | 
 |         self.assertEqual(foo(), 1) | 
 |  | 
 |     def test_boolean_literal_constant_metacompile(self): | 
 |         class Mod(torch.nn.Module): | 
 |             __constants__ = ['val'] | 
 |  | 
 |             def __init__(self, val): | 
 |                 super(Mod, self).__init__() | 
 |                 self.val = val | 
 |  | 
 |             def forward(self): | 
 |                 if self.val: | 
 |                     return 1 | 
 |                 else: | 
 |                     return "2" | 
 |  | 
 |         self.checkModule(Mod(True), ()) | 
 |         self.checkModule(Mod(False), ()) | 
 |  | 
 |         @torch.jit.script | 
 |         def foo(): | 
 |             if True: | 
 |                 return 1 | 
 |             else: | 
 |                 return "2" | 
 |  | 
 |         self.assertEqual(foo(), 1) | 
 |  | 
 |     def test_assert_is_scripting_metacompile(self): | 
 |         def foo(): | 
 |             assert not torch.jit.is_scripting(), "TestErrorMsg" | 
 |             print("hello") + 2  # will not be compiled | 
 |  | 
 |         f = torch.jit.script(foo) | 
 |         with self.assertRaisesRegex(torch.jit.Error, "TestErrorMsg"): | 
 |             f() | 
 |  | 
 |     def test_isinstance_metacompile(self): | 
 |         @torch.jit.script | 
 |         def test_primitive_type(x): | 
 |             # type: (int) -> int | 
 |             if isinstance(x, int): | 
 |                 return x + 1 | 
 |             else: | 
 |                 return x - 1 | 
 |  | 
 |         self.assertEqual(test_primitive_type(1), 2) | 
 |         with self.assertRaisesRegex(Exception, "Expected a value of type"): | 
 |             test_primitive_type(1.5) | 
 |  | 
 |         _MyNamedTuple = namedtuple('_MyNamedTuple', ['value']) | 
 |  | 
 |         @torch.jit.script | 
 |         def test_non_primitive_types(x): | 
 |             # type: (_MyNamedTuple) -> Tensor | 
 |             if isinstance(1, _MyNamedTuple): | 
 |                 return 10 | 
 |  | 
 |             if isinstance(x, _MyNamedTuple): | 
 |                 return x.value + 1 | 
 |             else: | 
 |                 return 1 | 
 |  | 
 |         out = test_non_primitive_types(_MyNamedTuple(value=torch.tensor(5.0))) | 
 |         self.assertEqual(out, torch.tensor(6.0)) | 
 |  | 
 |     def test_namedtuple_type_inference(self): | 
 |         _AnnotatedNamedTuple = NamedTuple('_NamedTupleAnnotated', [('value', int)]) | 
 |         _UnannotatedNamedTuple = namedtuple('_NamedTupleUnAnnotated', ['value']) | 
 |  | 
 |         def test_check_named_tuple_value(): | 
 |             named_tuple = _AnnotatedNamedTuple(1) | 
 |             return named_tuple.value | 
 |  | 
 |         self.checkScript(test_check_named_tuple_value, ()) | 
 |  | 
 |         def test_error(): | 
 |             return _UnannotatedNamedTuple(1) | 
 |  | 
 |         with self.assertRaisesRegex(RuntimeError, r"Expected a value of type \'Tensor \(inferred\)\' " | 
 |                                                   r"for argument \'value\' but instead found type \'int\'."): | 
 |             torch.jit.script(test_error) | 
 |  | 
 |     def test_namedtuple_default_values_simple_type(self): | 
 |  | 
 |         class Point(NamedTuple): | 
 |             x: Optional[int] = None | 
 |             y: int = 2 | 
 |  | 
 |         make_global(Point) | 
 |  | 
 |         class M(torch.nn.Module): | 
 |             def __init__(self): | 
 |                 super(M, self).__init__() | 
 |  | 
 |             def forward(self, point: Point): | 
 |                 return point | 
 |  | 
 |         p = Point(x=3, y=2) | 
 |  | 
 |         self.checkModule(M(), (p,)) | 
 |         self.checkModule(M(), (Point(),)) | 
 |  | 
 |         m = torch.jit.script(M()) | 
 |  | 
 |         FileCheck().check(r"NamedTuple(x : int? = None, y : int = 2))")   \ | 
 |                    .run(m.graph) | 
 |  | 
 |     def test_namedtuple_default_values_missing(self): | 
 |  | 
 |         class Point(NamedTuple): | 
 |             x: Optional[int] | 
 |             y: int | 
 |             z: int = 3 | 
 |  | 
 |         make_global(Point) | 
 |  | 
 |         class M(torch.nn.Module): | 
 |             def __init__(self): | 
 |                 super(M, self).__init__() | 
 |  | 
 |             def forward(self, point: Point): | 
 |                 return point | 
 |  | 
 |         p1 = Point(x=3, y=2) | 
 |         p2 = Point(x=3, y=2, z=1) | 
 |  | 
 |         self.checkModule(M(), (p1,)) | 
 |         self.checkModule(M(), (p2,)) | 
 |  | 
 |         m = torch.jit.script(M()) | 
 |  | 
 |         FileCheck().check(r"NamedTuple(x : int?, y : int, z : int = 3))")   \ | 
 |                    .run(m.graph) | 
 |  | 
 |     def test_namedtuple_default_values_container_type(self): | 
 |  | 
 |         class Point(NamedTuple): | 
 |             x: Optional[List[int]] = None | 
 |             y: List[int] = [1, 2, 3] | 
 |             z: Optional[Dict[str, int]] = {"a": 1} | 
 |  | 
 |         make_global(Point) | 
 |  | 
 |         class M(torch.nn.Module): | 
 |             def __init__(self): | 
 |                 super(M, self).__init__() | 
 |  | 
 |             def forward(self, point: Point): | 
 |                 return point | 
 |  | 
 |         p = Point(x=[4, 5, 6], y=[3, 2, 1], z={"b": 2}) | 
 |  | 
 |         self.checkModule(M(), (p,)) | 
 |         self.checkModule(M(), (Point(),)) | 
 |  | 
 |         m = torch.jit.script(M()) | 
 |  | 
 |         first_line = r"NamedTuple(x : int[]? = None, y : int[] = "    \ | 
 |                      r"[1, 2, 3], z : Dict(str, int)? = {a: 1}))" | 
 |  | 
 |         FileCheck().check(first_line)   \ | 
 |                    .run(m.graph) | 
 |  | 
 |     def test_namedtuple_default_values_Tensor_type(self): | 
 |  | 
 |         class Point(NamedTuple): | 
 |             x: torch.Tensor = torch.rand(2, 3) | 
 |  | 
 |         make_global(Point) | 
 |  | 
 |         class M(torch.nn.Module): | 
 |             def __init__(self): | 
 |                 super(M, self).__init__() | 
 |  | 
 |             def forward(self, point: Point): | 
 |                 return point | 
 |  | 
 |         p = Point(x=torch.rand(2, 3)) | 
 |  | 
 |         with self.assertRaisesRegex(RuntimeError, "Tensors are not " | 
 |                                     "supported as default NamedTuple " | 
 |                                     "fields"): | 
 |             m = torch.jit.script(M()) | 
 |             m(p) | 
 |  | 
 |     @unittest.skipIf(sys.version_info < (3, 7, 0), "defaults keyword added in Python 3.8") | 
 |     def test_namedtuple_default_values_using_factory_constructor(self): | 
 |         Pair = namedtuple("Pair", ["x", "y"], defaults=(1, 2)) | 
 |  | 
 |         make_global(Pair) | 
 |  | 
 |         @torch.jit.script | 
 |         def fn(x: Pair) -> Pair: | 
 |             return x | 
 |  | 
 |         # TODO: We can't use `checkScript` with the NamedTuple factory | 
 |         # constructor. Using the factory constructor with TorchScript | 
 |         # TorchScript creates an anonymous `NamedTuple` class instead of | 
 |         # preserving the actual name. For example, the actual generated | 
 |         # signature in this case is: | 
 |         #   graph(%x.1 : NamedTuple(x : Tensor, y : Tensor)) | 
 |         # It looks like similar test cases have had this issue as well | 
 |         # (see: `test_namedtuple_python`). | 
 |         FileCheck().check(r"NamedTuple(x : Tensor = 1, y : Tensor = 2))")   \ | 
 |                    .check_next(r"return (%x.1)")    \ | 
 |                    .run(fn.graph) | 
 |  | 
 |     def test_isinstance_dynamic(self): | 
 |         @torch.jit.script | 
 |         def foo(a): | 
 |             # type: (Optional[List[int]]) -> int | 
 |             b = 0 | 
 |             if isinstance(a, (int, (float,), list, str)): | 
 |                 b += 1 | 
 |             if isinstance(a, (int, str)): | 
 |                 b += 1 | 
 |             if isinstance(a, List[int]): | 
 |                 b += 1 | 
 |             return b | 
 |         self.assertEqual(foo([3, 4]), 2) | 
 |         self.assertEqual(foo(None), 0) | 
 |  | 
 |     def test_function_overloads(self): | 
 |         # TODO: pyflakes currently does not compose @overload annotation with other | 
 |         # decorators. This is fixed on master but not on version 2.1.1. | 
 |         # Next version update remove noqa and add @typing.overload annotation | 
 |  | 
 |         @torch.jit._overload  # noqa: F811 | 
 |         def test_simple(x1):  # noqa: F811 | 
 |             # type: (int) -> int | 
 |             pass | 
 |  | 
 |         @torch.jit._overload  # noqa: F811 | 
 |         def test_simple(x1):  # noqa: F811 | 
 |             # type: (float) -> float | 
 |             pass | 
 |  | 
 |         def test_simple(x1):  # noqa: F811 | 
 |             return x1 | 
 |  | 
 |         def invoke_function(): | 
 |             return test_simple(1.0), test_simple(.5) | 
 |  | 
 |         self.checkScript(invoke_function, ()) | 
 |  | 
 |         # testing that the functions are cached | 
 |         compiled_fns_1 = torch.jit._script._get_overloads(test_simple) | 
 |         compiled_fns_2 = torch.jit._script._get_overloads(test_simple) | 
 |         for a, b in zip(compiled_fns_1, compiled_fns_2): | 
 |             self.assertIs(a.graph, b.graph) | 
 |  | 
 |         old_func = test_simple | 
 |  | 
 |         # testing that new functions added work with caching | 
 |         @torch.jit._overload  # noqa: F811 | 
 |         def test_simple(x1):  # noqa: F811 | 
 |             # type: (str) -> str | 
 |             pass | 
 |  | 
 |         @torch.jit.script | 
 |         def my_func(): | 
 |             return old_func("hi") | 
 |  | 
 |         # testing new function same qualified name | 
 |         @torch.jit._overload  # noqa: F811 | 
 |         def test_simple(a, b):  # noqa: F811 | 
 |             # type: (int, int) -> int | 
 |             pass | 
 |  | 
 |         def test_simple(a, b): | 
 |             return a + b | 
 |  | 
 |         @torch.jit.script | 
 |         def fn(): | 
 |             return test_simple(3, 4) | 
 |  | 
 |         self.assertEqual(fn(), 7) | 
 |  | 
 |         # currently we take the default values have to be specified in the | 
 |         # overload as well - TODO take them from implementation and apply | 
 |         # where the type is valid. | 
 |         @torch.jit._overload  # noqa: F811 | 
 |         def identity(x1):  # noqa: F811 | 
 |             # type: (str) -> str | 
 |             pass | 
 |  | 
 |         @torch.jit._overload  # noqa: F811 | 
 |         def identity(x1):  # noqa: F811 | 
 |             # type: (float) -> float | 
 |             pass | 
 |  | 
 |         def identity(x1=1.0):  # noqa: F811 | 
 |             return x1 | 
 |  | 
 |         def invoke(): | 
 |             return identity(), identity(.5), identity("hi") | 
 |  | 
 |         self.checkScript(invoke, ()) | 
 |  | 
 |         def schema_match_failure(): | 
 |             return identity((1, 2)) | 
 |  | 
 |         thrown = False | 
 |         try: | 
 |             torch.jit.script(schema_match_failure) | 
 |         except Exception as e: | 
 |             thrown = True | 
 |             self.assertTrue(r"of type 'str'" in str(e) and r"of type 'float" in str(e)) | 
 |         self.assertTrue(thrown) | 
 |  | 
 |         with self.assertRaisesRegex(Exception, "cannot be directly compiled"): | 
 |             torch.jit.script(identity) | 
 |  | 
 |         @torch.jit._overload  # noqa: F811 | 
 |         def impl_compile_failure(x, y):  # noqa: F811 | 
 |             # type: (str, str) -> (str) | 
 |             pass | 
 |  | 
 |         @torch.jit._overload  # noqa: F811 | 
 |         def impl_compile_failure(x, y):  # noqa: F811 | 
 |             # type: (int, int) -> (int) | 
 |             pass | 
 |  | 
 |         def impl_compile_failure(x, y):  # noqa: F811 | 
 |             return x - y | 
 |  | 
 |         def test(): | 
 |             impl_compile_failure("one", "two") | 
 |  | 
 |  | 
 |         with self.assertRaisesRegex(Exception, "Arguments for call are not valid"): | 
 |             torch.jit.script(test) | 
 |  | 
 |         @torch.jit._overload  # noqa: F811 | 
 |         def good_overload(x=1):  # noqa: F811 | 
 |             # type: (int) -> (int) | 
 |             pass | 
 |  | 
 |         def good_overload(x=1):  # noqa: F811 | 
 |             return x | 
 |  | 
 |         @torch.jit.script | 
 |         def foo(): | 
 |             return good_overload() | 
 |  | 
 |         self.assertEqual(foo(), 1) | 
 |  | 
 |  | 
 |         with self.assertRaisesRegex(Exception, "must equal to the default parameter"): | 
 |             @torch.jit._overload  # noqa: F811 | 
 |             def bad_default_on_overload(x, y=2):  # noqa: F811 | 
 |                 # type: (int, int) -> (int) | 
 |                 pass | 
 |  | 
 |             def bad_default_on_overload(x, y=1):  # noqa: F811 | 
 |                 # type: (int, int) -> (int) | 
 |                 pass | 
 |  | 
 |             @torch.jit.script | 
 |             def test(): | 
 |                 return bad_default_on_overload(1, 2) | 
 |  | 
 |         @torch.jit._overload  # noqa: F811 | 
 |         def diff_default(x):  # noqa: F811 | 
 |             # type: (int) -> int | 
 |             pass | 
 |  | 
 |         @torch.jit._overload  # noqa: F811 | 
 |         def diff_default(x):  # noqa: F811 | 
 |             # type: (str) -> str | 
 |             pass | 
 |  | 
 |         def diff_default(x="hi"):  # noqa: F811 | 
 |             return x | 
 |  | 
 |         def test(): | 
 |             return diff_default(), diff_default(2), diff_default("abc") | 
 |  | 
 |         self.assertEqual(test(), torch.jit.script(test)()) | 
 |  | 
 |         @torch.jit._overload  # noqa: F811 | 
 |         def diff_num_params(x):  # noqa: F811 | 
 |             # type: (float) -> float | 
 |             pass | 
 |  | 
 |         @torch.jit._overload  # noqa: F811 | 
 |         def diff_num_params(x, y):  # noqa: F811 | 
 |             # type: (int, int) -> int | 
 |             pass | 
 |  | 
 |         def diff_num_params(x, y=2, z=3):  # noqa: F811 | 
 |             # type: (Union[float, int], int, int) | 
 |             return x + y + z | 
 |  | 
 |         def test(): | 
 |             return diff_num_params(1.0), diff_num_params(1, 2), diff_num_params(1), diff_num_params(1, 2, 3) | 
 |  | 
 |         self.assertEqual(test(), torch.jit.script(test)()) | 
 |  | 
 |         @torch.jit._overload  # noqa: F811 | 
 |         def diff_num_params_no_annot(): | 
 |             # type: () -> int | 
 |             pass | 
 |  | 
 |         def diff_num_params_no_annot(x=1):    # noqa: F811 | 
 |             return x | 
 |  | 
 |         def test(): | 
 |             return diff_num_params_no_annot(1.0) | 
 |  | 
 |         with self.assertRaisesRegex(Exception, "Parameters not specified"): | 
 |             torch.jit.script(test) | 
 |  | 
 |     def test_function_overload_misuse(self): | 
 |         with self.assertRaisesRegex(RuntimeError, "Only `pass` statement or `...` can be the body"): | 
 |             @torch.jit._overload | 
 |             def wrong_decl_body(x: str) -> str: | 
 |                 return x + "0" | 
 |  | 
 |         with self.assertRaisesRegex(RuntimeError, "Only `pass` statement or `...` can be the body"): | 
 |             class MyClass: | 
 |                 @torch.jit._overload_method | 
 |                 def method(self): | 
 |                     return 0 | 
 |  | 
 |         @torch.jit._overload | 
 |         def null_overload(x: int) -> int: ...  # noqa: E704 | 
 |  | 
 |         @torch.jit._overload  # noqa: F811 | 
 |         def null_overload(x: str) -> str:  # noqa: F811 | 
 |             pass | 
 |  | 
 |         def null_overload_driver(): | 
 |             return null_overload(0) | 
 |  | 
 |         with self.assertRaisesRegex(RuntimeError, 'Implementation for the function ".+" is missing.'): | 
 |             torch.jit.script(null_overload_driver) | 
 |  | 
 |         class OverloadMisuse(torch.nn.Module): | 
 |             def __init__(self): | 
 |                 super().__init__() | 
 |  | 
 |             @torch.jit._overload_method | 
 |             def forward(self, x: int): | 
 |                 pass | 
 |  | 
 |             @torch.jit._overload_method  # noqa: F811 | 
 |             def forward(self, x: Tensor):  # noqa: F811 | 
 |                 pass | 
 |  | 
 |         with self.assertRaisesRegex(RuntimeError, 'Implementation for the method ".+" is missing.'): | 
 |             m = torch.jit.script(OverloadMisuse()) | 
 |  | 
 |  | 
 |     def test_script_method_torch_function_overload(self): | 
 |         class MyCustomTensor(torch.Tensor): | 
 |             pass | 
 |  | 
 |         class MyCustomModule(torch.nn.Module): | 
 |             def forward(self, x): | 
 |                 return torch.relu(x) | 
 |  | 
 |         scripted_mod = torch.jit.script(MyCustomModule()) | 
 |         t = torch.tensor([3.0]) | 
 |         ref_out = scripted_mod(t) | 
 |  | 
 |         t_custom = MyCustomTensor([3.0]) | 
 |         out1 = scripted_mod(t_custom) | 
 |         self.assertEqual(out1, ref_out) | 
 |  | 
 |         out2 = scripted_mod.forward(t_custom) | 
 |         self.assertEqual(out2, ref_out) | 
 |  | 
 |     def test_function_overloading_isinstance(self): | 
 |         @torch.jit._overload  # noqa: F811 | 
 |         def my_conv(x, y):  # noqa: F811 | 
 |             # type: (float, str) -> (float) | 
 |             pass | 
 |  | 
 |         @torch.jit._overload  # noqa: F811 | 
 |         def my_conv(x, y):  # noqa: F811 | 
 |             # type: (float, float) -> (float) | 
 |             pass | 
 |  | 
 |         def my_conv(x, y=2.0):  # noqa: F811 | 
 |             if isinstance(y, str): | 
 |                 if y == "hi": | 
 |                     return 4.0 - x | 
 |                 else: | 
 |                     return 5.0 - x | 
 |             else: | 
 |                 return 2.0 + x | 
 |  | 
 |         def test_uses(): | 
 |             return my_conv(1.5), my_conv(1.5, "hi"), my_conv(1.5, 5.0) | 
 |  | 
 |         self.checkScript(test_uses, ()) | 
 |  | 
 |     def test_method_overloading(self): | 
 |         class Over(torch.nn.Module): | 
 |             def __init__(self): | 
 |                 super(Over, self).__init__() | 
 |  | 
 |             @torch.jit._overload_method  # noqa: F811 | 
 |             def forward(self, x):  # noqa: F811 | 
 |                 # type: (Tuple[Tensor, Tensor]) -> Tensor | 
 |                 pass | 
 |  | 
 |             @torch.jit._overload_method  # noqa: F811 | 
 |             def forward(self, x):  # noqa: F811 | 
 |                 # type: (Tensor) -> Tensor | 
 |                 pass | 
 |  | 
 |             def forward(self, x):  # noqa: F811 | 
 |                 if isinstance(x, Tensor): | 
 |                     return x + 20 | 
 |                 else: | 
 |                     return x[0] + 5 | 
 |  | 
 |         class S(torch.jit.ScriptModule): | 
 |             def __init__(self): | 
 |                 super(S, self).__init__() | 
 |                 self.weak = Over() | 
 |  | 
 |             @torch.jit.script_method | 
 |             def forward(self, x): | 
 |                 return self.weak(x) + self.weak((x, x)) | 
 |  | 
 |         s_mod = S() | 
 |         x = torch.ones(1) | 
 |         self.assertEqual(s_mod(x), x + 20 + 5 + x) | 
 |  | 
 |         over = Over() | 
 |         self.assertEqual(over((x, x)), x + 5) | 
 |         self.assertEqual(over((x)), x + 20) | 
 |  | 
 |         class Unannotated(torch.nn.Module): | 
 |             def __init__(self): | 
 |                 super(Unannotated, self).__init__() | 
 |  | 
 |             @torch.jit._overload_method  # noqa: F811 | 
 |             def hello(self, x):  # noqa: F811 | 
 |                 pass | 
 |  | 
 |             @torch.jit._overload_method  # noqa: F811 | 
 |             def hello(self, x):  # noqa: F811 | 
 |                 # type: (int) -> (int) | 
 |                 pass | 
 |  | 
 |             def hello(self, x):  # noqa: F811 | 
 |                 return x + 3 | 
 |  | 
 |             def forward(self): | 
 |                 return self.hello(1), self.hello(.5) | 
 |  | 
 |         w = Unannotated() | 
 |         with self.assertRaisesRegex(Exception, "explicitly add type annotations to overloaded functions"): | 
 |             torch.jit.script(w) | 
 |  | 
 |         class CompileOverloadError(torch.nn.Module): | 
 |             def __init__(self): | 
 |                 super(CompileOverloadError, self).__init__() | 
 |  | 
 |             @torch.jit._overload_method  # noqa: F811 | 
 |             def hello(self, x):  # noqa: F811 | 
 |                 # type: (str) -> (int) | 
 |                 pass | 
 |  | 
 |             @torch.jit._overload_method  # noqa: F811 | 
 |             def hello(self, x):  # noqa: F811 | 
 |                 # type: (int) -> (int) | 
 |                 pass | 
 |  | 
 |             def hello(self, x):  # noqa: F811 | 
 |                 return x + 1 | 
 |  | 
 |             def forward(self): | 
 |                 return self.hello("hi"), self.hello(.5) | 
 |  | 
 |         w = CompileOverloadError() | 
 |         with self.assertRaisesRegex(Exception, "but instead found type \'str\'"): | 
 |             torch.jit.script(w) | 
 |  | 
 |         # testing overload declared first, then non-overload | 
 |         with self.assertRaisesRegex(Exception, "Overloads are not useable when a module"): | 
 |             class W3(torch.nn.Module): | 
 |                 def __init__(self): | 
 |                     super(W3, self).__init__() | 
 |  | 
 |                 @torch.jit._overload_method  # noqa: F811 | 
 |                 def forward(self, x):  # noqa: F811 | 
 |                     # type: (int) -> int | 
 |                     pass | 
 |  | 
 |                 @torch.jit._overload_method  # noqa: F811 | 
 |                 def forward(self, x):  # noqa: F811 | 
 |                     # type: (Tensor) -> Tensor | 
 |                     pass | 
 |  | 
 |                 def forward(self, x):  # noqa: F811 | 
 |                     return x + 5 | 
 |  | 
 |             a = W3() | 
 |             b = torch.jit.script(a) | 
 |  | 
 |             class W3(torch.nn.Module): | 
 |                 def __init__(self): | 
 |                     super(W3, self).__init__() | 
 |  | 
 |                 def forward(self, x):  # noqa: F811 | 
 |                     return x + 5 + 10 | 
 |  | 
 |             a = W3() | 
 |             b = torch.jit.script(a) | 
 |  | 
 |         # testing non-overload declared first, then overload | 
 |         class W2(torch.nn.Module): | 
 |             def __init__(self): | 
 |                 super(W2, self).__init__() | 
 |  | 
 |             def hello(self, x1, x2): | 
 |                 return x1 + x2 | 
 |  | 
 |             def forward(self, x): | 
 |                 return self.hello(x, x) | 
 |  | 
 |         a = torch.jit.script(W2()) | 
 |         self.assertEqual(a(torch.tensor(1)), torch.tensor(2)) | 
 |  | 
 |         class W2(torch.nn.Module): | 
 |             def __init__(self): | 
 |                 super(W2, self).__init__() | 
 |  | 
 |             @torch.jit._overload_method  # noqa: F811 | 
 |             def hello(self, x):  # noqa: F811 | 
 |                 pass | 
 |  | 
 |             @torch.jit._overload_method  # noqa: F811 | 
 |             def hello(self, x):  # noqa: F811 | 
 |                 # type: (int) -> (int) | 
 |                 pass | 
 |  | 
 |             def hello(self, x):  # noqa: F811 | 
 |                 return x + 5 + 10 | 
 |  | 
 |             def forward(self, x): | 
 |                 return self.hello(1), self.hello(x) | 
 |  | 
 |         with self.assertRaisesRegex(Exception, "Overloads are not useable when a module"): | 
 |             a = torch.jit.script(W2()) | 
 |  | 
 |     def test_narrow_copy(self): | 
 |         def foo(a): | 
 |             return a.narrow_copy(0, 0, 5) | 
 |  | 
 |         self.checkScript(foo, [torch.rand(10)]) | 
 |  | 
 |     def test_select_after_chunk(self): | 
 |         def foo(x): | 
 |             chunked = torch.chunk(x, 1) | 
 |             foo = chunked[0] | 
 |             foo.add_(5) | 
 |             return x | 
 |  | 
 |         self.checkScript(foo, [torch.rand(2, 3)]) | 
 |  | 
 |     def test_nn_LSTM_with_layers(self): | 
 |         class M(torch.jit.ScriptModule): | 
 |             def __init__(self): | 
 |                 super(M, self).__init__() | 
 |                 self.rnn = nn.LSTM(2, 3, 2, dropout=0) | 
 |  | 
 |             @torch.jit.script_method | 
 |             def forward(self, x, lengths, h0, c0): | 
 |                 return self.rnn(x, (h0, c0))[0] | 
 |  | 
 |         class Eager(torch.nn.Module): | 
 |             def __init__(self): | 
 |                 super(Eager, self).__init__() | 
 |                 self.rnn = nn.LSTM(2, 3, 2, dropout=0) | 
 |  | 
 |             def forward(self, x, lengths, h0, c0): | 
 |                 return self.rnn(x, (h0, c0))[0] | 
 |  | 
 |         inputs = (torch.randn(1, 1, 2), torch.LongTensor([7]), torch.randn(2, 1, 3), torch.randn(2, 1, 3)) | 
 |         eager_out = self.runAndSaveRNG(lambda: Eager()(*inputs), ())[0] | 
 |         script_out = self.runAndSaveRNG(lambda: M()(*inputs), ())[0] | 
 |  | 
 |         self.assertEqual(eager_out, script_out) | 
 |  | 
 |     def test_nn_LSTM(self): | 
 |         input = torch.nn.utils.rnn.pack_sequence([torch.randn(5, 5)]) | 
 |  | 
 |         class S(torch.jit.ScriptModule): | 
 |             def __init__(self): | 
 |                 super(S, self).__init__() | 
 |                 self.x = torch.nn.LSTM(5, 5) | 
 |  | 
 |             @torch.jit.script_method | 
 |             def forward(self, input: PackedSequence) -> Tuple[PackedSequence, Tuple[torch.Tensor, torch.Tensor]]: | 
 |                 return self.x(input) | 
 |  | 
 |         eager_out = self.runAndSaveRNG(lambda x: torch.nn.LSTM(5, 5)(x), (input,))[0] | 
 |         script_out = self.runAndSaveRNG(lambda x: S()(x), (input,))[0] | 
 |  | 
 |         self.assertEqual(eager_out, script_out) | 
 |  | 
 |     def test_nn_GRU(self): | 
 |         seq_input = torch.nn.utils.rnn.pack_sequence([torch.randn(5, 5)]) | 
 |         tensor_input = torch.randn(5, 5, 5) | 
 |  | 
 |         class SeqLengthGRU(torch.jit.ScriptModule): | 
 |             def __init__(self): | 
 |                 super(SeqLengthGRU, self).__init__() | 
 |                 self.x = torch.nn.GRU(5, 5) | 
 |  | 
 |             @torch.jit.script_method | 
 |             def forward(self, input: PackedSequence) -> Tuple[PackedSequence, torch.Tensor]: | 
 |                 return self.x(input) | 
 |  | 
 |         class TensorGRU(torch.jit.ScriptModule): | 
 |             def __init__(self): | 
 |                 super(TensorGRU, self).__init__() | 
 |                 self.x = torch.nn.GRU(5, 5) | 
 |  | 
 |             @torch.jit.script_method | 
 |             def forward(self, input: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: | 
 |                 return self.x(input) | 
 |  | 
 |         seq_eager_out = self.runAndSaveRNG(lambda x: torch.nn.GRU(5, 5)(x), (seq_input,))[0] | 
 |         seq_script_out = self.runAndSaveRNG(lambda x: SeqLengthGRU()(x), (seq_input,))[0] | 
 |         tensor_eager_out = self.runAndSaveRNG(lambda x: torch.nn.GRU(5, 5)(x), (tensor_input,))[0] | 
 |         tensor_script_out = self.runAndSaveRNG(lambda x: TensorGRU()(x), (tensor_input,))[0] | 
 |  | 
 |         self.assertEqual(seq_eager_out, seq_script_out) | 
 |         self.assertEqual(tensor_eager_out, tensor_script_out) | 
 |  | 
 |     def test_torchscript_memoryformat(self): | 
 |         @torch.jit.script | 
 |         def fn(x): | 
 |             return x.contiguous(memory_format=torch.channels_last) | 
 |         x = torch.randn(4, 3, 6, 6) | 
 |         y = fn(x) | 
 |         self.assertTrue(y.is_contiguous(memory_format=torch.channels_last)) | 
 |  | 
 |     def test_torchscript_multi_head_attn(self): | 
 |         @torch.jit.script | 
 |         def jit_multihead_attn_forward(query,                   # type: Tensor | 
 |                                        key,                     # type: Tensor | 
 |                                        value,                   # type: Tensor | 
 |                                        embed_dim_to_check,      # type: int | 
 |                                        num_heads,               # type: int | 
 |                                        in_proj_weight,          # type: Tensor | 
 |                                        in_proj_bias,            # type: Tensor | 
 |                                        bias_k,                  # type: Optional[Tensor] | 
 |                                        bias_v,                  # type: Optional[Tensor] | 
 |                                        add_zero_attn,           # type: bool | 
 |                                        dropout,                 # type: float | 
 |                                        out_proj_weight,         # type: Tensor | 
 |                                        out_proj_bias,           # type: Tensor | 
 |                                        training=True,           # type: bool | 
 |                                        key_padding_mask=None,   # type: Optional[Tensor] | 
 |                                        need_weights=True,       # type: bool | 
 |                                        attn_mask=None           # type: Optional[Tensor] | 
 |                                        ): | 
 |             # type: (...) -> Tuple[Tensor, Optional[Tensor]] | 
 |             return torch.nn.functional.multi_head_attention_forward(query, key, value, | 
 |                                                                     embed_dim_to_check, num_heads, | 
 |                                                                     in_proj_weight, in_proj_bias, | 
 |                                                                     bias_k, bias_v, | 
 |                                                                     add_zero_attn, dropout, | 
 |                                                                     out_proj_weight, out_proj_bias, | 
 |                                                                     training, key_padding_mask, | 
 |                                                                     need_weights, attn_mask) | 
 |  | 
 |         src_l = 3 | 
 |         bsz = 5 | 
 |         embed_size = 8 | 
 |         nhead = 2 | 
 |         multi_head_attn = torch.nn.MultiheadAttention(embed_size, nhead) | 
 |         query = torch.rand((src_l, bsz, embed_size)) | 
 |         key = torch.rand((src_l, bsz, embed_size)) | 
 |         value = torch.rand((src_l, bsz, embed_size)) | 
 |  | 
 |         mask = (torch.triu(torch.ones(src_l, src_l)) == 1).transpose(0, 1) | 
 |         mask = mask.float().masked_fill(mask == 0, float('-inf')).masked_fill(mask == 1, float(0.0)).double() | 
 |  | 
 |         jit_out = jit_multihead_attn_forward(query, key, value, | 
 |                                              embed_size, nhead, | 
 |                                              multi_head_attn.in_proj_weight, | 
 |                                              multi_head_attn.in_proj_bias, | 
 |                                              multi_head_attn.bias_k, multi_head_attn.bias_v, | 
 |                                              multi_head_attn.add_zero_attn, multi_head_attn.dropout, | 
 |                                              multi_head_attn.out_proj.weight, | 
 |                                              multi_head_attn.out_proj.bias, attn_mask=mask)[0] | 
 |  | 
 |         py_out = torch.nn.functional.multi_head_attention_forward(query, key, value, | 
 |                                                                   embed_size, nhead, | 
 |                                                                   multi_head_attn.in_proj_weight, | 
 |                                                                   multi_head_attn.in_proj_bias, | 
 |                                                                   multi_head_attn.bias_k, | 
 |                                                                   multi_head_attn.bias_v, | 
 |                                                                   multi_head_attn.add_zero_attn, | 
 |                                                                   multi_head_attn.dropout, | 
 |                                                                   multi_head_attn.out_proj.weight, | 
 |                                                                   multi_head_attn.out_proj.bias, | 
 |                                                                   attn_mask=mask)[0] | 
 |         # print("rel. error: ") | 
 |         # print(jit_out / py_out - 1) | 
 |         self.assertEqual(jit_out, py_out, atol=5e-4, rtol=1e-4) | 
 |  | 
 |     def test_torchscript_multi_head_attn_fast_path(self): | 
 |         src_l = 3 | 
 |         bsz = 5 | 
 |         embed_size = 8 | 
 |         nhead = 2 | 
 |         multi_head_attn = torch.nn.MultiheadAttention(embed_size, nhead, batch_first=True) | 
 |         multi_head_attn = multi_head_attn.eval() | 
 |  | 
 |         query = key = value = torch.rand((bsz, src_l, embed_size)) | 
 |  | 
 |         with torch.no_grad(): | 
 |             py_out = multi_head_attn(query, key, value) | 
 |             mha = torch.jit.script(multi_head_attn) | 
 |             jit_out = mha(query, key, value) | 
 |         torch.testing.assert_close(jit_out, py_out) | 
 |  | 
 |     @unittest.skipIf(not RUN_CUDA, "no CUDA") | 
 |     def test_scriptmodule_multi_head_attn_cuda(self): | 
 |  | 
 |         class MyModule(torch.jit.ScriptModule): | 
 |             def __init__(self, embed_dim, num_heads): | 
 |                 super(MyModule, self).__init__() | 
 |                 sample_q = torch.randn(3, 2, embed_dim) | 
 |                 sample_kv = torch.randn(3, 2, embed_dim) | 
 |                 attention = nn.MultiheadAttention(embed_dim, num_heads) | 
 |                 attention.eval() | 
 |  | 
 |                 self.mod = torch.jit.trace(attention, | 
 |                                            (sample_q, sample_kv, sample_kv)) | 
 |  | 
 |             @torch.jit.script_method | 
 |             def forward(self, q, k, v): | 
 |                 return self.mod(q, k, v) | 
 |  | 
 |         embed_dim = 8 | 
 |         num_heads = 2 | 
 |         sl = 3 | 
 |         bs = 2 | 
 |         model = MyModule(embed_dim, num_heads).cuda() | 
 |         q = torch.randn(sl, bs, embed_dim, device="cuda") | 
 |         kv = torch.randn(sl, bs, embed_dim, device="cuda") | 
 |  | 
 |         jit_out = model(q, kv, kv)[0] | 
 |         py_out = torch.nn.functional.multi_head_attention_forward(q, kv, kv, | 
 |                                                                   embed_dim, num_heads, | 
 |                                                                   model.mod.in_proj_weight, | 
 |                                                                   model.mod.in_proj_bias, | 
 |                                                                   None, None, None, 0.0, | 
 |                                                                   model.mod.out_proj.weight, | 
 |                                                                   model.mod.out_proj.bias)[0] | 
 |         self.assertEqual(jit_out, py_out, atol=5e-4, rtol=1e-4) | 
 |  | 
 |     @unittest.skipIf(not RUN_CUDA, "no CUDA") | 
 |     def test_scriptmodule_transformer_cuda(self): | 
 |  | 
 |         class MyModule(torch.jit.ScriptModule): | 
 |             def __init__(self, transformer, sample_q, sample_kv): | 
 |                 super(MyModule, self).__init__() | 
 |                 transformer.eval() | 
 |  | 
 |                 self.mod = torch.jit.trace(transformer, | 
 |                                            (sample_q, sample_kv)) | 
 |  | 
 |             @torch.jit.script_method | 
 |             def forward(self, q, k): | 
 |                 return self.mod(q, k) | 
 |  | 
 |         d_model = 8 | 
 |         nhead = 2 | 
 |         num_encoder_layers = 2 | 
 |         num_decoder_layers = 2 | 
 |         dim_feedforward = 16 | 
 |         bsz = 2 | 
 |         seq_length = 5 | 
 |         tgt_length = 3 | 
 |  | 
 |         src = torch.randn(seq_length, bsz, d_model) | 
 |         tgt = torch.randn(tgt_length, bsz, d_model) | 
 |         transformer = nn.Transformer(d_model, nhead, num_encoder_layers, | 
 |                                      num_decoder_layers, dim_feedforward, dropout=0.0) | 
 |         model = MyModule(transformer, tgt, src) | 
 |  | 
 |         src = torch.randn(seq_length, bsz, d_model) | 
 |         tgt = torch.randn(tgt_length, bsz, d_model) | 
 |         jit_out = model(tgt, src) | 
 |         py_out = transformer(tgt, src) | 
 |  | 
 |         # print(jit_out/py_out-1) | 
 |         # print(torch.allclose(jit_out, py_out, atol=5e-4, rtol=1e-4)) | 
 |         self.assertEqual(jit_out, py_out, atol=5e-4, rtol=1e-4) | 
 |  | 
 |     def test_list_python_op(self): | 
 |         def python_list_op(lst): | 
 |             # type: (List[Tensor]) -> Tensor | 
 |             return lst[0] | 
 |  | 
 |         def fn(lst): | 
 |             # type: (List[Tensor]) -> Tensor | 
 |             return python_list_op(lst) | 
 |  | 
 |         self.checkScript(fn, ([torch.ones(2) + 2, torch.ones(2)],)) | 
 |  | 
 |     @unittest.skipIf(not RUN_CUDA, "no CUDA") | 
 |     def test_weak_cuda(self): | 
 |         class M(torch.jit.ScriptModule): | 
 |             def __init__(self): | 
 |                 super(M, self).__init__() | 
 |                 self.lstm = torch.nn.LSTM(5, 5) | 
 |                 self.lstm.cuda() | 
 |  | 
 |             @torch.jit.script_method | 
 |             def forward(self, x): | 
 |                 return self.lstm(x) | 
 |  | 
 |         m = M() | 
 |         m.cuda() | 
 |         out = m(torch.ones(5, 5, 5).cuda()) | 
 |         self.assertTrue(out[0].is_cuda) | 
 |  | 
 |     def test_ignore_decorator(self): | 
 |         with warnings.catch_warnings(record=True) as warns: | 
 |             class M(torch.jit.ScriptModule): | 
 |                 def __init__(self): | 
 |                     super(M, self).__init__() | 
 |                     tensor = torch.zeros(1, requires_grad=False) | 
 |                     self.register_buffer('some_state', torch.nn.Parameter(tensor)) | 
 |  | 
 |                 @torch.jit.script_method | 
 |                 def forward(self, x): | 
 |                     self.ignored_code(x) | 
 |                     return x | 
 |  | 
 |                 @torch.jit.ignore(drop_on_export=True) | 
 |                 def ignored_code(self, x): | 
 |                     self.some_state = torch.tensor((100,)) | 
 |  | 
 |         FileCheck().check("TorchScript will now drop the function").run(str(warns[0])) | 
 |  | 
 |         # Assert ignored code is run | 
 |         m = M() | 
 |  | 
 |         m2 = self.getExportImportCopy(m) | 
 |         pp = str(m2.forward.code) | 
 |         self.assertNotIn('ignored_code', pp) | 
 |  | 
 |         with self.assertRaisesRegex(torch.jit.Error, "annotated to be ignored and cannot be run"): | 
 |             m2.forward(torch.ones(1)) | 
 |  | 
 |     def test_ignored_as_value(self): | 
 |         class Model(nn.Module): | 
 |             def __init__(self): | 
 |                 super(Model, self).__init__() | 
 |  | 
 |             @torch.jit.unused | 
 |             def tuple_ignored(self, x): | 
 |                 # type: (Tensor) -> Tuple[Tensor, Tensor] | 
 |                 return x, x | 
 |  | 
 |             @torch.jit.unused | 
 |             def single_val_ignored(self, x, y): | 
 |                 # type: (Tensor, Tensor) -> Tensor | 
 |                 return x | 
 |  | 
 |             def forward(self, x, use_ignore_path): | 
 |                 # type: (Tensor, bool) -> Tuple[Tensor, Tensor] | 
 |                 if 1 == 2: | 
 |                     return self.tuple_ignored(x) | 
 |                 if use_ignore_path: | 
 |                     return self.single_val_ignored(x, x), self.single_val_ignored(x, x) | 
 |                 return x, x | 
 |  | 
 |         original = Model() | 
 |         scripted = torch.jit.script(original) | 
 |         self.assertEqual(scripted(torch.tensor(.5), False), (torch.tensor(.5), torch.tensor(.5))) | 
 |  | 
 |         buffer = io.BytesIO() | 
 |         torch.jit.save(scripted, buffer) | 
 |         buffer.seek(0) | 
 |         loaded = torch.jit.load(buffer) | 
 |  | 
 |         with self.assertRaisesRegex(torch.jit.Error, "annotated to be ignored and cannot be run"): | 
 |             loaded(torch.tensor(.5), True) | 
 |  | 
 |     def test_module_error(self): | 
 |         class MyModule(torch.nn.Module): | 
 |             def __init__(self): | 
 |                 super(MyModule, self).__init__() | 
 |  | 
 |             def forward(self, foo): | 
 |                 return foo | 
 |  | 
 |         with self.assertRaisesRegex(RuntimeError, "cannot be compiled since it inherits from nn.Module"): | 
 |             torch.jit.script(MyModule) | 
 |  | 
 |     def test_view_write(self): | 
 |         def fn(x, y): | 
 |             l = [] | 
 |             l.append(x) | 
 |             x_view = l[0] | 
 |             a = x + x | 
 |             x_view.add_(y) | 
 |             b = x + x | 
 |             return a == b | 
 |         self.checkScript(fn, (torch.rand(2, 3), torch.rand(2, 3))) | 
 |  | 
 |     def test_module_attrs(self): | 
 |         class M(torch.jit.ScriptModule): | 
 |             def __init__(self, table): | 
 |                 super(M, self).__init__() | 
 |                 self.table = torch.jit.Attribute(table, Dict[str, torch.Tensor]) | 
 |                 self.x = torch.nn.Parameter(torch.tensor([100.0])) | 
 |  | 
 |             @torch.jit.script_method | 
 |             def forward(self, key): | 
 |                 # type: (str) -> Tensor | 
 |                 return self.table[key] + self.x | 
 |  | 
 |         with torch._jit_internal._disable_emit_hooks(): | 
 |             # TODO: re-enable module hook when Python printing of attributes is | 
 |             # supported | 
 |             m = M({char : torch.ones(1) + ord(char) - ord("a") for char in "abcdefg"}) | 
 |             self.assertEqual(m("c"), torch.tensor([103.])) | 
 |  | 
 |     def test_module_none_attrs(self): | 
 |         class MyMod(torch.jit.ScriptModule): | 
 |             def __init__(self): | 
 |                 super(MyMod, self).__init__() | 
 |                 self.optional_value = None | 
 |  | 
 |             @torch.jit.script_method | 
 |             def forward(self): | 
 |                 return self.optional_value | 
 |  | 
 |         graph = MyMod().forward.graph | 
 |         FileCheck().check("prim::GetAttr").run(graph) | 
 |         self.run_pass('peephole', graph) | 
 |         FileCheck().check_not("prim::GetAttr").run(graph) | 
 |  | 
 |     def test_tensor_import_export(self): | 
 |         @torch.jit.script | 
 |         def foo(x): | 
 |             a = torch.tensor(1) | 
 |             b = torch.tensor([1, 2]) | 
 |             c = [a, b] | 
 |             return c | 
 |  | 
 |         self.run_pass('constant_propagation', foo.graph) | 
 |         m = self.createFunctionFromGraph(foo.graph) | 
 |         self.getExportImportCopy(m) | 
 |  | 
 |     def get_pickle_values(self): | 
 |         return (('dict', {"I": "am", "a test": "test"}, Dict[str, str]), | 
 |                 ('float', 2.3, float), | 
 |                 ('int', 99, int), | 
 |                 ('bool', False, bool), | 
 |                 ('tuple', (1, 2, 3, 4), Tuple[int, int, int, int]), | 
 |                 ('list', [(1, 2), (3, 4)], List[Tuple[int, int]]), | 
 |                 ('tensor', torch.randn(2, 2), torch.Tensor), | 
 |                 ('int_list', [1, 2, 3, 4], List[int]), | 
 |                 ('tensor_list', [torch.ones(2, 2) + i for i in range(4)], List[torch.Tensor]), | 
 |                 ('bool_list', [True, True, False, True], List[bool]), | 
 |                 ('float_list', [1., 2., 3., 4.], List[float]), | 
 |                 ('str_list', ['hello', 'bye'], List[str]), | 
 |                 ('none', None, Optional[int]), | 
 |                 ('a_device', torch.device('cpu'), torch.device), | 
 |                 ('another_device', torch.device('cuda:1'), torch.device)) | 
 |  | 
 |     def test_attribute_serialization(self): | 
 |         tester = self | 
 |  | 
 |         class M(torch.jit.ScriptModule): | 
 |             def __init__(self): | 
 |                 super(M, self).__init__() | 
 |                 for name, value, the_type in tester.get_pickle_values(): | 
 |                     setattr(self, name, torch.jit.Attribute(value, the_type)) | 
 |  | 
 |             @torch.jit.script_method | 
 |             def forward(self): | 
 |                 return (self.dict, self.float, self.int, self.bool, self.tuple, | 
 |                         self.list, self.int_list, self.tensor_list, self.bool_list, | 
 |                         self.float_list, self.str_list, self.none) | 
 |  | 
 |         m = M() | 
 |         imported_m = self.getExportImportCopy(m) | 
 |         self.assertEqual(m(), imported_m()) | 
 |  | 
 |     def test_string_len(self): | 
 |         def fn(x): | 
 |             # type: (str) -> int | 
 |             return len(x) | 
 |  | 
 |         self.checkScript(fn, ("",)) | 
 |         self.checkScript(fn, ("h",)) | 
 |         self.checkScript(fn, ("hello",)) | 
 |  | 
 |     def test_multiline_optional_future_refinement(self): | 
 |         @torch.jit.script | 
 |         def fun() -> int: | 
 |             future: Optional[ | 
 |                 torch.jit.Future[Tuple[torch.Tensor]] | 
 |             ] = None | 
 |  | 
 |             return 1 | 
 |         self.assertEqual(fun(), 1) | 
 |  | 
 |     @unittest.skipIf(IS_WINDOWS or IS_SANDCASTLE, "NYI: TemporaryFileName support for Windows or Sandcastle") | 
 |     def test_attribute_unpickling(self): | 
 |         tensor = torch.randn(2, 2) | 
 |         tester = self | 
 |  | 
 |         class M(torch.jit.ScriptModule): | 
 |             def __init__(self): | 
 |                 super(M, self).__init__() | 
 |                 for name, value, the_type in tester.get_pickle_values(): | 
 |                     setattr(self, "_" + name, torch.jit.Attribute(value, the_type)) | 
 |  | 
 |             @torch.jit.script_method | 
 |             def forward(self): | 
 |                 return (self._dict, self._float, self._int, self._bool, self._tuple, | 
 |                         self._list, self._int_list, self._tensor_list, self._bool_list, | 
 |                         self._float_list, self._str_list, self._none) | 
 |  | 
 |         with TemporaryFileName() as fname: | 
 |             M().save(fname) | 
 |             loaded = torch.jit.load(fname) | 
 |  | 
 |             def is_tensor_value(item): | 
 |                 if isinstance(item, torch.Tensor): | 
 |                     return True | 
 |                 if isinstance(item, list): | 
 |                     return is_tensor_value(item[0]) | 
 |                 return False | 
 |             for name, value, the_type in self.get_pickle_values(): | 
 |                 if is_tensor_value(value): | 
 |                     continue | 
 |                 self.assertEqual(value, getattr(loaded, "_" + name)) | 
 |  | 
 |     @unittest.skipIf(IS_WINDOWS or IS_SANDCASTLE, "NYI: TemporaryFileName support for Windows or Sandcastle") | 
 |     @unittest.skipIf(not BUILD_WITH_CAFFE2, "PyTorch is build without Caffe2 support") | 
 |     def test_old_models_bc(self): | 
 |         model = { | 
 |             'archive/version': b'1', | 
 |             'archive/code/archive.py': | 
 |                 b''' | 
 |                 op_version_set = 0 | 
 |                 def forward(self, | 
 |                     _0: Tensor) -> Tensor: | 
 |                   _1 = torch.zeros([10], dtype=6, layout=0, device=torch.device("cpu")) | 
 |                   result = torch.to(torch.fill_(_1, 5), dtype=6, layout=0, device=torch.device("cpu"), | 
 |                                     non_blocking=False, copy=False) | 
 |                   result2 = torch.rand([10], dtype=6, layout=0, device=torch.device("cpu")) | 
 |                   result3 = torch.rand_like(result2, dtype=6, layout=0, device=torch.device("cpu")) | 
 |                   _2 = torch.add(torch.add(result, result2, alpha=1), result3, alpha=1) | 
 |                   return _2 | 
 |                 ''', | 
 |             'archive/attributes.pkl': b'\x80\x02](e.', | 
 |             'archive/libs.py': b'op_version_set = 0\n', | 
 |             'archive/model.json': | 
 |                 b''' | 
 |                 { | 
 |                    "protoVersion":"2", | 
 |                    "mainModule":{ | 
 |                       "torchscriptArena":{ | 
 |                          "key":"code/archive.py" | 
 |                       }, | 
 |                       "name":"archive", | 
 |                       "optimize":true | 
 |                    }, | 
 |                    "producerName":"pytorch", | 
 |                    "producerVersion":"1.0", | 
 |                    "libs":{ | 
 |                       "torchscriptArena":{ | 
 |                          "key":"libs.py" | 
 |                       } | 
 |                    } | 
 |                 }'''} | 
 |         with TemporaryFileName() as fname: | 
 |             archive_name = os.path.basename(os.path.normpath(fname)) | 
 |             with zipfile.ZipFile(fname, 'w') as archive: | 
 |                 for k, v in model.items(): | 
 |                     archive.writestr(k, v) | 
 |  | 
 |             with open(fname, "rb") as f: | 
 |                 fn = torch.jit.load(f) | 
 |  | 
 |         x = torch.zeros(10) | 
 |         fn(x) | 
 |  | 
 |     def test_submodule_attribute_serialization(self): | 
 |         class S(torch.jit.ScriptModule): | 
 |             def __init__(self, list_data): | 
 |                 super(S, self).__init__() | 
 |                 self.table = torch.jit.Attribute({"I": "am", "a test": "test"}, Dict[str, str]) | 
 |                 self.list = torch.jit.Attribute(list_data, List[Tuple[int, int]]) | 
 |  | 
 |             @torch.jit.script_method | 
 |             def forward(self): | 
 |                 return (self.table, self.list) | 
 |  | 
 |         class M(torch.jit.ScriptModule): | 
 |             def __init__(self): | 
 |                 super(M, self).__init__() | 
 |                 self.table = torch.jit.Attribute({"this": "is", "a different": "dict"}, Dict[str, str]) | 
 |                 self.tensor = torch.jit.Attribute(torch.randn(2, 2), torch.Tensor) | 
 |                 self.s1 = S([(1, 2)]) | 
 |                 self.s2 = S([(4, 5)]) | 
 |  | 
 |             @torch.jit.script_method | 
 |             def forward(self): | 
 |                 return (self.table, self.tensor, self.s1.table, self.s2.list, self.s1.list) | 
 |  | 
 |         m = M() | 
 |         imported_m = self.getExportImportCopy(m) | 
 |         self.assertEqual(m(), imported_m()) | 
 |  | 
 |     def test_serialization_big_ints(self): | 
 |         class M(torch.jit.ScriptModule): | 
 |             def __init__(self): | 
 |                 super(M, self).__init__() | 
 |                 self.int32_max = torch.jit.Attribute(2**31 - 1, int) | 
 |                 self.int32_min = torch.jit.Attribute(-2**31, int) | 
 |                 self.uint32_max = torch.jit.Attribute(2**32, int) | 
 |  | 
 |                 self.int64_max = torch.jit.Attribute(2**63 - 1, int) | 
 |                 self.int64_min = torch.jit.Attribute(-2**63, int) | 
 |  | 
 |                 self.tensor = torch.nn.Parameter(torch.ones(2, 2)) | 
 |  | 
 |             @torch.jit.script_method | 
 |             def forward(self, x): | 
 |                 # type: (int) -> (int) | 
 |                 return x + (self.int32_max + self.int32_min) + (self.int64_max + self.int64_min) | 
 |  | 
 |         m = M() | 
 |         imported = self.getExportImportCopy(m) | 
 |         self.assertEqual(m(10), imported(10)) | 
 |  | 
 |         self.assertEqual(m.int32_max, imported.int32_max) | 
 |         self.assertEqual(m.int32_min, imported.int32_min) | 
 |         self.assertEqual(m.uint32_max, imported.uint32_max) | 
 |         self.assertEqual(m.int64_max, imported.int64_max) | 
 |         self.assertEqual(m.int64_min, imported.int64_min) | 
 |  | 
 |     def test_script_scope(self): | 
 |         scripted = torch.jit.script(torch.nn.functional.triplet_margin_loss) | 
 |  | 
 |     @unittest.skipIf(IS_WINDOWS, "NYI: TemporaryFileName on Windows") | 
 |     def test_serialization_sharing(self): | 
 |         class M(torch.jit.ScriptModule): | 
 |             def __init__(self): | 
 |                 super(M, self).__init__() | 
 |                 self.list = torch.jit.Attribute([], List[str]) | 
 |  | 
 |             @torch.jit.script_method | 
 |             def forward(self, key): | 
 |                 # type: (str) -> List[str] | 
 |                 self.list.append(key) | 
 |                 self.list.append(key) | 
 |                 self.list.append(key) | 
 |                 return self.list | 
 |  | 
 |         # the text of the string should only appear once in the pickling | 
 |         m = M() | 
 |         s1 = "a long string" | 
 |         s2 = "a different, even longer string" | 
 |         self.assertEqual(m(s1), [s1] * 3) | 
 |         self.assertEqual(m(s2), [s1] * 3 + [s2] * 3) | 
 |         with TemporaryFileName() as fname: | 
 |             m.save(fname) | 
 |             archive_name = os.path.basename(os.path.normpath(fname)) | 
 |             archive = zipfile.ZipFile(fname, 'r') | 
 |             pickled_data = archive.read(os.path.join(archive_name, 'data.pkl')) | 
 |  | 
 |             out = io.StringIO() | 
 |             pickletools.dis(pickled_data, out=out) | 
 |             disassembled = out.getvalue() | 
 |  | 
 |             FileCheck().check_count(s1, 1, exactly=True) \ | 
 |                 .check_count("BINGET", 2, exactly=True) \ | 
 |                 .check_count(s2, 1, exactly=True) \ | 
 |                 .check_count("BINGET", 2, exactly=True).run(out.getvalue()) | 
 |  | 
 |     def test_sys_stdout_override(self): | 
 |         @torch.jit.script | 
 |         def foo(): | 
 |             print('foo') | 
 |  | 
 |         class Redirect(object): | 
 |             def __init__(self): | 
 |                 self.s = '' | 
 |  | 
 |             def write(self, s): | 
 |                 self.s += s | 
 |  | 
 |         old_stdout = sys.stdout | 
 |         redirect = Redirect() | 
 |         try: | 
 |             sys.stdout = redirect | 
 |             foo() | 
 |         finally: | 
 |             sys.stdout = old_stdout | 
 |  | 
 |         FileCheck().check('foo').run(redirect.s) | 
 |  | 
 |     def test_dtype_attr(self): | 
 |         class Foo(torch.nn.Module): | 
 |             def __init__(self): | 
 |                 super(Foo, self).__init__() | 
 |                 self.dtype = torch.zeros([]).dtype | 
 |  | 
 |             def forward(self): | 
 |                 return torch.zeros(3, 4, dtype=self.dtype) | 
 |  | 
 |         f = Foo() | 
 |         torch.jit.script(f) | 
 |  | 
 |  | 
 |     def test_named_buffers_are_iterable(self): | 
 |         class MyMod(torch.nn.Module): | 
 |             def __init__(self): | 
 |                 super(MyMod, self).__init__() | 
 |                 self.mod = (torch.nn.ReLU()) | 
 |                 self.mod2 = (torch.nn.ReLU()) | 
 |                 self.mod3 = torch.nn.Sequential(torch.nn.Sequential(torch.nn.ReLU())) | 
 |                 self.register_buffer('x', torch.zeros(3)) | 
 |                 self.register_buffer('y', torch.zeros(3)) | 
 |                 self.z = torch.zeros(3) | 
 |  | 
 |             def bleh(self): | 
 |                 return self.z + 4 | 
 |  | 
 |             @torch.jit.export | 
 |             def method(self): | 
 |                 names = [""] | 
 |                 vals = [] | 
 |                 for name, buffer in self.named_buffers(): | 
 |                     names.append(name) | 
 |                     vals.append(buffer + 2) | 
 |  | 
 |                 return names, vals | 
 |  | 
 |             def forward(self, x): | 
 |                 return x | 
 |  | 
 |         model = MyMod() | 
 |         x = torch.jit.script(model) | 
 |         z = self.getExportImportCopy(x) | 
 |  | 
 |         self.assertEqual(z.method(), x.method()) | 
 |         self.assertEqual(z.method(), model.method()) | 
 |         self.assertEqual(x.method(), model.method()) | 
 |         names = x.method() | 
 |         for name in names: | 
 |             self.assertNotEqual('z', name) | 
 |  | 
 |  | 
 |     def test_static_if_prop(self): | 
 |         class MaybeHasAttr(torch.nn.Module): | 
 |             def __init__(self, add_attr): | 
 |                 super(MaybeHasAttr, self).__init__() | 
 |                 if add_attr: | 
 |                     self.maybe_attr = 1 | 
 |  | 
 |             def forward(self): | 
 |                 if hasattr(self, "maybe_attr") and True: | 
 |                     return self.maybe_attr | 
 |                 else: | 
 |                     return 0 | 
 |  | 
 |         class MaybeHasAttr2(torch.nn.Module): | 
 |             def __init__(self, add_attr): | 
 |                 super(MaybeHasAttr2, self).__init__() | 
 |                 if add_attr: | 
 |                     self.maybe_attr = 1 | 
 |  | 
 |             def forward(self): | 
 |                 if not hasattr(self, "maybe_attr") or False: | 
 |                     return 0 | 
 |                 else: | 
 |                     return self.maybe_attr | 
 |  | 
 |         torch.jit.script(MaybeHasAttr(True)) | 
 |         torch.jit.script(MaybeHasAttr(False)) | 
 |         torch.jit.script(MaybeHasAttr2(True)) | 
 |         torch.jit.script(MaybeHasAttr2(False)) | 
 |  | 
 |         class MyMod(torch.nn.Module): | 
 |             def forward(self): | 
 |                 if hasattr(self, "foo"): | 
 |                     return 1 | 
 |                 else: | 
 |                     return 0 | 
 |  | 
 |             @torch.jit.export | 
 |             def fee(self): | 
 |                 return 1 | 
 |  | 
 |         self.checkModule(MyMod(), ()) | 
 |  | 
 |         class HasAttrMod(torch.nn.Module): | 
 |             __constants__ = ["fee"] | 
 |  | 
 |             def __init__(self): | 
 |                 super().__init__() | 
 |                 self.fee = 3 | 
 |  | 
 |             def forward(self): | 
 |                 a = hasattr(self, "fee") | 
 |                 b = hasattr(self, "foo") | 
 |                 c = hasattr(self, "hi") | 
 |                 d = hasattr(self, "nonexistant") | 
 |                 return (a, b, c, d) | 
 |  | 
 |             def foo(self): | 
 |                 return 1 | 
 |  | 
 |             @torch.jit._overload_method | 
 |             def hi(self, x: Tensor): ...  # noqa: E704 | 
 |  | 
 |             def hi(self, x):  # noqa: F811 | 
 |                 return 2 | 
 |  | 
 |         self.checkModule(HasAttrMod(), ()) | 
 |  | 
 |         @torch.jit.script | 
 |         class FooTest(object): | 
 |             def __init__(self): | 
 |                 self.x = 1 | 
 |  | 
 |             def foo(self, y): | 
 |                 return self.x + y | 
 |  | 
 |         def foo(): | 
 |             a = FooTest() | 
 |             val1 = hasattr(a, "foo"), hasattr(a, "x"), hasattr(a, "bla") | 
 |             val2 = hasattr(FooTest, "foo"), hasattr(FooTest, "a") | 
 |             return val1, val2 | 
 |  | 
 |         self.assertEqual(foo(), torch.jit.script(foo)()) | 
 |  | 
 |     def _test_pickle_checkpoint(self, device): | 
 |         with TemporaryFileName() as fname: | 
 |             class M(torch.jit.ScriptModule): | 
 |                 __constants__ = ['fname'] | 
 |  | 
 |                 def __init__(self, tensor): | 
 |                     super(M, self).__init__() | 
 |                     self.fname = fname | 
 |                     self.tensor = torch.nn.Parameter(tensor) | 
 |  | 
 |                 @torch.jit.script_method | 
 |                 def forward(self, x): | 
 |                     y = self.tensor + x | 
 |                     torch.save(y, self.fname) | 
 |                     return y | 
 |  | 
 |             param = torch.randn(2, 2).to(device) | 
 |             input = torch.randn(2, 2).to(device) | 
 |             m = M(param) | 
 |             m(input) | 
 |             with open(fname, "rb") as handle: | 
 |                 loaded_tensor = torch.load(fname) | 
 |                 self.assertEqual(loaded_tensor, input + param) | 
 |  | 
 |     def _test_pickle_checkpoint_views(self, device): | 
 |         with TemporaryFileName() as fname: | 
 |             class M(torch.jit.ScriptModule): | 
 |                 __constants__ = ['fname'] | 
 |  | 
 |                 def __init__(self, tensor): | 
 |                     super(M, self).__init__() | 
 |                     self.fname = fname | 
 |                     self.tensor = torch.nn.Parameter(tensor) | 
 |  | 
 |                 @torch.jit.script_method | 
 |                 def forward(self, x): | 
 |                     y = self.tensor + x | 
 |                     y_view = y.view(4) | 
 |                     torch.save((y, y_view, y), self.fname) | 
 |                     return y | 
 |  | 
 |             param = torch.randn(2, 2).to(device) | 
 |             input = torch.randn(2, 2).to(device) | 
 |             m = M(param) | 
 |             m(input) | 
 |             with open(fname, "rb") as handle: | 
 |                 loaded_y, loaded_y_view, loaded_y_2 = torch.load(fname) | 
 |                 self.assertEqual(loaded_y, input + param) | 
 |                 with torch.no_grad(): | 
 |                     loaded_y_view[1] += 20 | 
 |                     # assert that loaded_y changed as well | 
 |                     self.assertEqual(loaded_y.view(4), loaded_y_view) | 
 |                     self.assertEqual(loaded_y_2.view(4), loaded_y_view) | 
 |  | 
 |     @unittest.skipIf(not RUN_CUDA, "no CUDA") | 
 |     def test_pickle_checkpoint_cuda(self): | 
 |         self._test_pickle_checkpoint('cuda') | 
 |         self._test_pickle_checkpoint_views('cuda') | 
 |  | 
 |     def test_pickle_checkpoint(self): | 
 |         self._test_pickle_checkpoint('cpu') | 
 |         self._test_pickle_checkpoint_views('cpu') | 
 |  | 
 |     def test_pickle_checkpoint_tup(self): | 
 |         @torch.jit.script | 
 |         def foo(fname): | 
 |             # type: (str) -> None | 
 |             torch.save((3, 4), fname) | 
 |         with TemporaryFileName() as name: | 
 |             foo(name) | 
 |             self.assertEqual(torch.load(name), (3, 4)) | 
 |  | 
 |     def test_string_list(self): | 
 |         def fn(string): | 
 |             # type: (str) -> List[str] | 
 |             return list(string) | 
 |  | 
 |         self.checkScript(fn, ("abcdefgh",)) | 
 |  | 
 |     def test_unicode_comments(self): | 
 |         @torch.jit.script | 
 |         def test(self, a): | 
 |             # 🤷🤷🤷🤷 | 
 |             return torch.nn.functional.relu(a) | 
 |  | 
 |     def test_get_set_state_with_tensors(self): | 
 |         class M(torch.nn.Module): | 
 |             def __init__(self): | 
 |                 super(M, self).__init__() | 
 |                 self.tensor = torch.randn(2, 2) | 
 |  | 
 |             @torch.jit.export | 
 |             def __getstate__(self): | 
 |                 return (self.tensor, self.training) | 
 |  | 
 |             @torch.jit.export | 
 |             def __setstate__(self, state): | 
 |                 self.tensor = state[0] | 
 |                 self.training = state[1] | 
 |  | 
 |             def forward(self, x): | 
 |                 return x + self.tensor | 
 |  | 
 |         with TemporaryFileName() as fname: | 
 |             m = torch.jit.script(M()) | 
 |             m.save(fname) | 
 |             loaded = torch.jit.load(fname) | 
 |             self.assertEqual(loaded.tensor, m.tensor) | 
 |  | 
 |     def test_in_for_and_comp_expr(self): | 
 |         def fn(d): | 
 |             # type: (Dict[str, int]) -> List[int] | 
 |             out = [1] | 
 |             for i in range(d["hi"] if "hi" in d else 6): | 
 |                 out.append(i) | 
 |             return out | 
 |  | 
 |         self.checkScript(fn, ({'hi': 2, 'bye': 3},)) | 
 |         self.checkScript(fn, ({'bye': 3},)) | 
 |  | 
 |     def test_for_else(self): | 
 |         def fn(): | 
 |             c = 0 | 
 |             for i in range(4): | 
 |                 c += 10 | 
 |             else: | 
 |                 print("In else block of for...else") | 
 |  | 
 |         with self.assertRaisesRegex(torch.jit.frontend.NotSupportedError, "else branches of for loops aren't supported"): | 
 |             torch.jit.script(fn) | 
 |  | 
 |     def test_split(self): | 
 |         def split_two(tensor): | 
 |             a, b, c = torch.split(tensor, 2, dim=1) | 
 |             return a, b, c | 
 |         x = torch.randn(3, 6) | 
 |         y = torch.randn(3, 6) | 
 |         self.checkScript(split_two, [(x + y)]) | 
 |  | 
 |     def test_conv_error(self): | 
 |         @torch.jit.script | 
 |         def fn(x, y): | 
 |             return F.conv2d(x, y) | 
 |  | 
 |         try: | 
 |             fn(torch.ones(2, 2), torch.ones(4, 4)) | 
 |         except RuntimeError as e: | 
 |             self.assertFalse('frame' in str(e)) | 
 |  | 
 |     def test_python_op_name(self): | 
 |         import random | 
 |  | 
 |         with self.assertRaisesRegex(RuntimeError, "randint"): | 
 |             @torch.jit.script | 
 |             def fn(): | 
 |                 return random.randint() | 
 |  | 
 |     def test_dir(self): | 
 |         class M(torch.jit.ScriptModule): | 
 |             def forward(self, t): | 
 |                 return t | 
 |  | 
 |         self.assertTrue('forward' in dir(M())) | 
 |  | 
 |     def test_kwarg_expansion_error(self): | 
 |         @torch.jit.ignore | 
 |         def something_else(h, i): | 
 |             pass | 
 |  | 
 |         def fn(x): | 
 |             something_else(**x) | 
 |  | 
 |         with self.assertRaisesRegex(torch.jit.frontend.NotSupportedError, "keyword-arg expansion is not supported"): | 
 |             torch.jit.script(fn) | 
 |  | 
 |     def test_kwargs_error_msg(self): | 
 |         def other(**kwargs): | 
 |             print(kwargs) | 
 |  | 
 |         def fn(): | 
 |             return other() | 
 |  | 
 |         with self.assertRaisesRegex(torch.jit.frontend.NotSupportedError, 'variable number'): | 
 |             torch.jit.script(fn) | 
 |  | 
 |         def another_other(*args): | 
 |             print(args) | 
 |  | 
 |         def another_fn(): | 
 |             return another_other() | 
 |  | 
 |         with self.assertRaisesRegex(torch.jit.frontend.NotSupportedError, 'variable number'): | 
 |             torch.jit.script(another_fn) | 
 |  | 
 |     def test_inferred_error_msg(self): | 
 |         """ | 
 |         Test that when we get a type mismatch on a function where we inferred | 
 |         the type to be tensor, a good error message is given. | 
 |         """ | 
 |         @torch.jit.script | 
 |         def foo(a): | 
 |             return a | 
 |  | 
 |         with self.assertRaisesRegex(RuntimeError, (r"Expected a value of type \'Tensor \(inferred\)\'" | 
 |                                                    r"[\S\s]*Inferred \'a\' to be of type \'Tensor\'")): | 
 |             foo("1") | 
 |  | 
 |     def test_type_comments_in_body(self): | 
 |         @torch.jit.script | 
 |         def foo(a,  # type: int | 
 |                 b,  # type: int | 
 |                 ): | 
 |             # type: (...) -> int | 
 |             # type: int | 
 |             return a + b | 
 |  | 
 |         class M(torch.nn.Module): | 
 |             def __init__(self, | 
 |                          a,  # type: int | 
 |                          b   # type: int | 
 |                          ): | 
 |                 # type: (...) -> None | 
 |                 super(M, self).__init__() | 
 |                 self.a = a  # type: int | 
 |                 self.b = b  # type: int | 
 |  | 
 |         torch.jit.script(M(2, 3)) | 
 |  | 
 |     def test_input_keyword_in_schema(self): | 
 |         def f(x): | 
 |             return torch.ceil(input=x) | 
 |  | 
 |         inp = torch.randn(10) | 
 |         self.checkScript(f, (inp, )) | 
 |  | 
 |     def test_module_method_reassignment(self): | 
 |         class Foo(torch.nn.Module): | 
 |             def __init__(self): | 
 |                 super().__init__() | 
 |  | 
 |             def _forward(self, x): | 
 |                 return x | 
 |  | 
 |             forward = _forward | 
 |  | 
 |         sm = torch.jit.script(Foo()) | 
 |         input = torch.ones(2, 2) | 
 |         self.assertEqual(input, sm(input)) | 
 |  | 
 |     # Tests the case where a torch.Tensor subclass (like Parameter) is used as | 
 |     # input. | 
 |     def test_script_module_tensor_subclass_argument(self): | 
 |         @torch.jit.script | 
 |         def parameter_script(x: torch.nn.Parameter): | 
 |             return x | 
 |  | 
 |         input = torch.ones(2, 2) | 
 |         self.assertEqual(input, parameter_script(input)) | 
 |  | 
 |     def test_save_load_attr_error(self): | 
 |         class Inner(nn.Module): | 
 |             def __init__(self): | 
 |                 super().__init__() | 
 |  | 
 |             def forward(self, x): | 
 |                 return x | 
 |  | 
 |         class Wrapper(nn.Module): | 
 |             def __init__(self, inner): | 
 |                 super().__init__() | 
 |                 self.inner = inner | 
 |  | 
 |             def forward(self, x): | 
 |                 # this attribute doesn't exist on `Inner` | 
 |                 return self.inner.b(x) | 
 |  | 
 |         inner_module = torch.jit.script(Inner()) | 
 |         inner_module = self.getExportImportCopy(inner_module) | 
 |         wrapped = Wrapper(inner_module) | 
 |         # This should properly complain that `self.inner` doesn't have the attribute `b` | 
 |         with self.assertRaisesRegex(RuntimeError, 'has no attribute'): | 
 |             torch.jit.script(wrapped) | 
 |  | 
 |     def test_rescripting_loaded_modules(self): | 
 |         class InnerSubmod(nn.Module): | 
 |             __constants__ = ['my_constant'] | 
 |  | 
 |             def __init__(self): | 
 |                 super().__init__() | 
 |                 self.register_buffer("foo", torch.ones(1)) | 
 |                 self.register_parameter("bar", torch.nn.Parameter(torch.ones(1))) | 
 |                 self.baz = torch.ones(1) | 
 |                 self.my_constant = 1 | 
 |  | 
 |             def forward(self, x): | 
 |                 return x + x | 
 |  | 
 |         class Inner(nn.Module): | 
 |             def __init__(self): | 
 |                 super().__init__() | 
 |                 self.submod = InnerSubmod() | 
 |  | 
 |             def forward(self, x): | 
 |                 return self.submod(x) | 
 |  | 
 |         class Wrapper(nn.Module): | 
 |             def __init__(self, inner): | 
 |                 super().__init__() | 
 |                 self.inner = inner | 
 |  | 
 |             def forward(self, x): | 
 |                 # access inner elements | 
 |                 ret = self.inner.submod(x) + self.inner.submod.foo + self.inner.submod.bar + self.inner.submod.baz | 
 |                 ret = ret + self.inner.submod.my_constant | 
 |                 return ret | 
 |  | 
 |         inner_module = torch.jit.script(Inner()) | 
 |         wrapped = Wrapper(inner_module) | 
 |         self.checkModule(wrapped, torch.ones(1)) | 
 |  | 
 |         inner_module_loaded = self.getExportImportCopy(inner_module) | 
 |         wrapped_loaded = Wrapper(inner_module_loaded) | 
 |         self.assertEqual(wrapped(torch.ones(1)), wrapped_loaded(torch.ones(1))) | 
 |  | 
 |     def test_interpret_graph(self): | 
 |         def fn(x): | 
 |             return x.unfold(0, 1, 1) | 
 |  | 
 |         graph_str = """ | 
 |         graph(%a : Tensor, %b : Tensor): | 
 |           %c : Tensor = aten::mul(%a, %b) | 
 |           return (%c) | 
 |         """ | 
 |         graph = parse_ir(graph_str) | 
 |         a = torch.rand(10) | 
 |         b = torch.rand(10) | 
 |         test = torch._C._jit_interpret_graph(graph, (a, b)) | 
 |         ref = a * b | 
 |         self.assertEqual(test, ref) | 
 |  | 
 |     def test_signed_float_zero(self): | 
 |  | 
 |         class MyModule(torch.nn.Module): | 
 |             def __init__(self): | 
 |                 super(MyModule, self).__init__() | 
 |  | 
 |             def forward(self, x): | 
 |                 return torch.div(x, -0.) | 
 |  | 
 |         inp = torch.ones(1) | 
 |         self.checkModule(MyModule(), inp) | 
 |  | 
 |     def test_index_with_tuple(self): | 
 |         class MyModule(torch.nn.Module): | 
 |             def __init__(self): | 
 |                 super(MyModule, self).__init__() | 
 |  | 
 |             def forward(self, x): | 
 |                 return x[(1,)] | 
 |  | 
 |         self.checkModule(MyModule(), (torch.ones(2, 3),)) | 
 |  | 
 |     def test_context_manager(self): | 
 |         class MyModule(torch.nn.Module): | 
 |             def __init__(self): | 
 |                 super(MyModule, self).__init__() | 
 |  | 
 |             def forward(self, x, y): | 
 |                 p = x + y | 
 |                 q = p + 2.0 | 
 |                 return q | 
 |  | 
 |         x = torch.randn(3, 2, dtype=torch.float) | 
 |         y = torch.randn(3, 2, dtype=torch.float) | 
 |         for fuser_name in ['fuser0', 'fuser1', 'none']: | 
 |             with torch.jit.fuser(fuser_name): | 
 |                 self.checkModule(MyModule(), (x, y)) | 
 |  | 
 | # known to be failing in tracer | 
 | EXCLUDE_TRACED = { | 
 |     # The following fail due to #12024. | 
 |     # A prim::ListConstruct is involved and the indices get traced as TensorType, | 
 |     # which always require_grad. This causes a crash in autodiff. | 
 |     'test___getitem___adv_index', | 
 |     'test___getitem___adv_index_beg', | 
 |     'test___getitem___adv_index_comb', | 
 |     'test___getitem___adv_index_dup', | 
 |     'test___getitem___adv_index_sub', | 
 |     'test___getitem___adv_index_sub_2', | 
 |     'test___getitem___adv_index_sub_3', | 
 |     'test___getitem___adv_index_var', | 
 |  | 
 |     # jit doesn't support sparse tensors. | 
 |     'test_to_sparse', | 
 |     'test_to_sparse_dim', | 
 | } | 
 |  | 
 | EXCLUDE_TYPE_CHECK = { | 
 |     # slogdet tests use itemgetter to select its only differentiable output, | 
 |     # but this happens outside of the graph we handle, so there are fewer | 
 |     # reference outputs than graph outputs. | 
 |     'test_slogdet_1x1_neg_det', | 
 |     'test_slogdet_1x1_pos_det', | 
 |     'test_slogdet_distinct_singular_values', | 
 |     'test_slogdet_neg_det', | 
 |     'test_slogdet_pos_det', | 
 |     'test_slogdet_symmetric', | 
 |     'test_slogdet_symmetric_pd', | 
 |     'test_slogdet_batched_1x1_neg_det', | 
 |     'test_slogdet_batched_pos_det', | 
 |     'test_slogdet_batched_symmetric', | 
 |     'test_slogdet_batched_symmetric_pd', | 
 |     'test_slogdet_batched_distinct_singular_values' | 
 | } | 
 |  | 
 | # chunk returns a list in scripting and we don't unpack the list, | 
 | # Thus it won't be replaced by ConstantChunk and run AD. | 
 | # It's explicitly checked in test_chunk_constant_script_ad | 
 | # Similary for split, it's replaced by split_with_sizes in tracing, | 
 | # but we don't have AD formula for aten::split(Tensor, int[], int), | 
 | # an op registered in JIT so AD is not triggered in scripting. | 
 | EXCLUDE_SCRIPT_AD_CHECK = { | 
 |     'test_chunk', | 
 |     'test_chunk_dim', | 
 |     'test_chunk_dim_neg0', | 
 |     'test_split_size_list', | 
 |     'test_split_size_list_dim', | 
 |     'test_split_size_list_dim_neg0', | 
 |     'test_tensor_indices_sections', | 
 |     'test_tensor_indices_sections_dim', | 
 |     'test_tensor_indices_sections_dim_neg0', | 
 |     'test_tensor_split_sections', | 
 |     'test_tensor_split_sections_dim', | 
 |     'test_tensor_split_sections_dim_neg0' | 
 | } | 
 |  | 
 | EXCLUDE_PYTHON_PRINT = { | 
 |     # no support for BroadcastingList in python printer | 
 |     'test_nn_max_unpool1d', | 
 |     'test_nn_max_unpool2d', | 
 |     'test_nn_max_unpool3d', | 
 |     'test_nn_max_pool1d', | 
 |     'test_nn_max_pool2d', | 
 |     'test_nn_max_pool3d', | 
 |     'test_nn_max_pool1d_with_indices', | 
 | } | 
 |  | 
 | EXCLUDE_ALIAS = { | 
 |     # aliases, which may appear in method_tests but are tested elsewhere | 
 |     'true_divide', | 
 |  | 
 |     # Disable tests for lu from common_methods_invocations.py | 
 |     # TODO(@nikitaved) Enable jit tests once autograd.Function does support scripting | 
 |     'lu' | 
 | } | 
 |  | 
 |  | 
 | class TestJitGeneratedModule(JitTestCase): | 
 |     pass | 
 |  | 
 |  | 
 | class TestJitGeneratedFunctional(JitTestCase): | 
 |     pass | 
 |  | 
 | # UBSAN per-function exclusions don't seem to work with OpenMP pragmas, | 
 | # and we have to disable the failing tests here instead. | 
 | UBSAN_DISABLED_TESTS = [ | 
 |     "test___rdiv___constant", | 
 |     "test___rdiv___scalar_constant", | 
 |     "test_addcdiv", | 
 |     "test_addcdiv_broadcast_all", | 
 |     "test_addcdiv_broadcast_rhs", | 
 |     "test_addcdiv_scalar", | 
 |     "test_addcdiv_scalar_broadcast_lhs", | 
 |     "test_addcdiv_scalar_broadcast_rhs", | 
 |     "test_addcdiv_scalar_scale", | 
 |     "test_addcdiv_scalar_scale_broadcast_lhs", | 
 |     "test_addcdiv_scalar_scale_broadcast_rhs", | 
 |     "test_addcdiv_scale", | 
 |     "test_addcdiv_scale_broadcast_all", | 
 |     "test_addcdiv_scale_broadcast_rhs", | 
 |     "test_add_broadcast_all", | 
 |     "test_add_broadcast_lhs", | 
 |     "test_add_broadcast_rhs", | 
 |     "test_add_constant", | 
 |     "test_add_scalar", | 
 |     "test_add_scalar_broadcast_lhs", | 
 |     "test_add_scalar_broadcast_rhs", | 
 |     "test_div", | 
 |     "test_div_broadcast_all", | 
 |     "test_div_broadcast_lhs", | 
 |     "test_div_broadcast_rhs", | 
 |     "test_div_scalar", | 
 |     "test_div_scalar_broadcast_lhs", | 
 |     "test_div_scalar_broadcast_rhs", | 
 |     "test_rsqrt", | 
 |     "test_rsqrt_scalar", | 
 |     "test_add", | 
 |     "test_reciprocal", | 
 |     "test_reciprocal_scalar", | 
 | ] | 
 |  | 
 | L = 20 | 
 | M = 10 | 
 | S = 5 | 
 |  | 
 | def add_nn_module_test(*args, **kwargs): | 
 |     no_grad = False if 'no_grad' not in kwargs else kwargs['no_grad'] | 
 |  | 
 |     if 'desc' in kwargs and 'eval' in kwargs['desc']: | 
 |         # eval() is not supported, so skip these tests | 
 |         return | 
 |  | 
 |     test_name = get_nn_mod_test_name(**kwargs) | 
 |  | 
 |     @suppress_warnings | 
 |     def do_test(self): | 
 |         if test_name in EXCLUDE_SCRIPT_MODULES: | 
 |             return | 
 |         if not kwargs.get('check_jit', True): | 
 |             raise unittest.SkipTest('module test skipped on JIT') | 
 |  | 
 |         module_name = get_nn_module_name_from_kwargs(**kwargs) | 
 |  | 
 |         if 'constructor' in kwargs: | 
 |             nn_module = kwargs['constructor'] | 
 |         else: | 
 |             nn_module = getattr(torch.nn, module_name) | 
 |  | 
 |         if "FunctionalModule" in str(nn_module): | 
 |             return | 
 |  | 
 |         if 'constructor_args_fn' in kwargs: | 
 |             constructor_args = kwargs['constructor_args_fn']() | 
 |         else: | 
 |             constructor_args = kwargs.get('constructor_args', ()) | 
 |  | 
 |         def create_script_module(*args, **kwargs): | 
 |             """Construct a script module that passes arguments through to self.submodule""" | 
 |             formals, tensors, actuals = get_script_args(args) | 
 |  | 
 |             method_args = ', '.join(['self'] + actuals) | 
 |             call_args_str = ', '.join(actuals) | 
 |             call = "self.submodule({})".format(call_args_str) | 
 |             script = script_method_template.format(method_args, call) | 
 |  | 
 |             submodule_constants = [] | 
 |             if kwargs.get('is_constant'): | 
 |                 submodule_constants = ['submodule'] | 
 |  | 
 |             # Create module to use the script method | 
 |             class TheModule(torch.jit.ScriptModule): | 
 |                 __constants__ = submodule_constants | 
 |  | 
 |                 def __init__(self): | 
 |                     super(TheModule, self).__init__() | 
 |                     self.submodule = nn_module(*constructor_args) | 
 |  | 
 |             def make_module(script): | 
 |                 module = TheModule() | 
 |                 # check __repr__ | 
 |                 str(module) | 
 |                 module.define(script) | 
 |                 return module | 
 |  | 
 |             module = make_module(script) | 
 |             self.assertExportImportModule(module, tensors) | 
 |             create_script_module.last_graph = module.graph | 
 |             mod = module(*args) | 
 |             return mod | 
 |  | 
 |         # Construct a normal nn module to stay consistent with create_script_module | 
 |         # and make use of a single global rng_state in module initialization | 
 |         def create_nn_module(*args, **kwargs): | 
 |             module = nn_module(*constructor_args) | 
 |             return module(*args) | 
 |  | 
 |         # Set up inputs from tuple of sizes or constructor fn | 
 |         dtype = torch.double | 
 |         if 'input_fn' in kwargs: | 
 |             input = kwargs['input_fn']() | 
 |             if isinstance(input, Tensor): | 
 |                 input = (input,) | 
 |  | 
 |             if all(tensor.is_complex() for tensor in input): | 
 |                 dtype = torch.cdouble | 
 |         else: | 
 |             input = (kwargs['input_size'],) | 
 |  | 
 |         if 'target_size' in kwargs: | 
 |             input = input + (kwargs['target_size'],) | 
 |         elif 'target_fn' in kwargs: | 
 |             if torch.is_tensor(input): | 
 |                 input = (input,) | 
 |             input = input + (kwargs['target_fn'](),) | 
 |         elif 'target' in kwargs: | 
 |             input = input + (kwargs['target'],) | 
 |  | 
 |         # Extra parameters to forward() | 
 |         if 'extra_args' in kwargs: | 
 |             input = input + kwargs['extra_args'] | 
 |  | 
 |         args_variable, kwargs_variable = create_input(input, dtype=dtype) | 
 |         f_args_variable = deepcopy(unpack_variables(args_variable)) | 
 |  | 
 |         # TODO(issue#52052) Neither this nor no_grad should be required | 
 |         # if check_against_reference() is updated to check gradients | 
 |         # w.r.t. weights and then only check w.r.t. inputs if any | 
 |         # inputs require it. | 
 |         any_requires_grad = any(input.requires_grad for input in f_args_variable) | 
 |  | 
 |         # Check against Python module as reference | 
 |         check_against_reference(self, create_script_module, create_nn_module, | 
 |                                 lambda x: x, f_args_variable, | 
 |                                 no_grad=no_grad or not any_requires_grad) | 
 |  | 
 |     if 'slowTest' in kwargs: | 
 |         do_test = slowTest(do_test) | 
 |  | 
 |     post_add_test(test_name, (), do_test, TestJitGeneratedModule) | 
 |  | 
 |  | 
 | def post_add_test(test_name, skipTestIf, do_test, test_class): | 
 |     assert not hasattr(test_class, test_name), 'Two tests have the same name: ' + test_name | 
 |  | 
 |     for skip in skipTestIf: | 
 |         do_test = skip(do_test) | 
 |  | 
 |     if not (TEST_WITH_UBSAN and test_name in UBSAN_DISABLED_TESTS): | 
 |         setattr(test_class, test_name, do_test) | 
 |  | 
 |  | 
 | def normalize_check_ad(check_ad, name): | 
 |     # normalized check_ad is 3-element tuple: (bool, List[str], List[str]) | 
 |     if len(check_ad) == 0: | 
 |         check_ad = [False, ['aten::' + name], []] | 
 |     elif len(check_ad) == 1: | 
 |         check_ad = [check_ad[0], ['aten::' + name], []] | 
 |     elif len(check_ad) == 2: | 
 |         check_ad = [check_ad[0], check_ad[1], []] | 
 |     elif len(check_ad) == 3: | 
 |         check_ad = list(check_ad) | 
 |     else: | 
 |         raise Exception('Invalid check_ad, requires (bool, str|List[str], str|List[str])') | 
 |  | 
 |     check_ad = [[t] if isinstance(t, str) else t for t in check_ad] | 
 |  | 
 |     return check_ad | 
 |  | 
 |  | 
 | class TestProducerVersion(TestCase): | 
 |  | 
 |     def test_version(self): | 
 |         # issue gh-32561 | 
 |         self.assertTrue(torch.__version__.startswith(torch.onnx.producer_version)) | 
 |  | 
 | for test in module_tests + new_module_tests + additional_module_tests: | 
 |     add_nn_module_test(**test) | 
 |  | 
 | for test in criterion_tests: | 
 |     test['no_grad'] = True | 
 |     add_nn_module_test(**test) | 
 |  | 
 | if __name__ == '__main__': | 
 |     run_tests() | 
 |     import jit.test_module_interface | 
 |     suite = unittest.findTestCases(jit.test_module_interface) | 
 |     unittest.TextTestRunner().run(suite) |