blob: 33947687453611930194fcb9cd09c65cb6c989c3 [file] [log] [blame]
Hong Xub16dab82020-01-31 09:59:02 -08001# -*- coding: utf-8 -*-
Jane Xu09c77712021-10-25 07:48:17 -07002# Owner(s): ["oncall: jit"]
3
Shen Li10224432021-08-12 11:39:31 -07004import torch
5
6# This is how we include tests located in test/jit/...
7# They are included here so that they are invoked when you call `test_jit.py`,
8# do not run these test files directly.
9from jit.test_tracer import TestTracer, TestMixTracingScripting # noqa: F401
10from jit.test_recursive_script import TestRecursiveScript # noqa: F401
11from jit.test_type_sharing import TestTypeSharing # noqa: F401
12from jit.test_logging import TestLogging # noqa: F401
13from jit.test_backends import TestBackends, TestBackendsWithCompiler # noqa: F401
14from jit.test_backend_nnapi import TestNnapiBackend # noqa: F401
15from jit.test_list_dict import TestList, TestDict, TestNamedTuple, TestScriptDict, TestScriptList # noqa: F401
16from jit.test_async import TestAsync # noqa: F401
Ivan Kobzarev2fc73622023-01-27 11:04:26 -080017from jit.test_await import TestAwait # noqa: F401
Shen Li10224432021-08-12 11:39:31 -070018from jit.test_data_parallel import TestDataParallel # noqa: F401
19from jit.test_models import TestModels # noqa: F401
20from jit.test_modules import TestModules # noqa: F401
David Berardad07b7c2022-04-05 14:08:23 -070021from jit.test_autodiff import TestAutodiffJit # noqa: F401
Shen Li10224432021-08-12 11:39:31 -070022from jit.test_autodiff_subgraph_slicing import TestAutodiffSubgraphSlicing # noqa: F401
23from jit.test_custom_operators import TestCustomOperators # noqa: F401
Shen Li10224432021-08-12 11:39:31 -070024from jit.test_graph_rewrite_passes import TestGraphRewritePasses # noqa: F401
25from jit.test_class_type import TestClassType # noqa: F401
26from jit.test_builtins import TestBuiltins, TestTensorBuiltins # noqa: F401
27from jit.test_ignore_context_manager import TestIgnoreContextManager # noqa: F401
28from jit.test_symbolic_shape_analysis import TestSymbolicShapeAnalysis # noqa: F401
Elias Ellisonaacdf292022-03-29 11:32:31 -070029from jit.test_op_decompositions import TestOpDecompositions # noqa: F401
Shen Li10224432021-08-12 11:39:31 -070030from jit.test_unsupported_ops import TestUnsupportedOps # noqa: F401
31from jit.test_freezing import TestFreezing, TestFrozenOptimizations, TestMKLDNNReinplacing # noqa: F401
32from jit.test_peephole import TestPeephole # noqa: F401
Elias Ellison6b44e752021-11-09 18:31:17 -080033from jit.test_alias_analysis import TestAliasAnalysis # noqa: F401
Han Qi75d6cbe2022-03-24 16:24:32 -070034from jit.test_save_load import TestSaveLoad, TestSaveLoadFlatbuffer # noqa: F401
Chen Laic321d4c2021-12-21 13:06:18 -080035from jit.test_save_load_for_op_version import TestSaveLoadForOpVersion # noqa: F401
Shen Li10224432021-08-12 11:39:31 -070036from jit.test_module_containers import TestModuleContainers # noqa: F401
37from jit.test_python_bindings import TestPythonBindings # noqa: F401
38from jit.test_python_ir import TestPythonIr # noqa: F401
39from jit.test_functional_blocks import TestFunctionalBlocks # noqa: F401
40from jit.test_remove_mutation import TestRemoveMutation # noqa: F401
41from jit.test_torchbind import TestTorchbind # noqa: F401
titaiwangc19cf342022-06-08 19:21:42 +000042from jit.test_module_interface import TestModuleInterface # noqa: F401 # noqa: F401
Shen Li10224432021-08-12 11:39:31 -070043from jit.test_with import TestWith # noqa: F401
44from jit.test_enum import TestEnum # noqa: F401
45from jit.test_string_formatting import TestStringFormatting # noqa: F401
46from jit.test_profiler import TestProfiler # noqa: F401
47from jit.test_slice import TestSlice # noqa: F401
48from jit.test_ignorable_args import TestIgnorableArgs # noqa: F401
49from jit.test_hooks import TestHooks # noqa: F401
50from jit.test_warn import TestWarn # noqa: F401
51from jit.test_isinstance import TestIsinstance # noqa: F401
52from jit.test_cuda import TestCUDA # noqa: F401
53from jit.test_python_builtins import TestPythonBuiltinOP # noqa: F401
54from jit.test_typing import TestTyping # noqa: F401
55from jit.test_hash import TestHash # noqa: F401
56from jit.test_complex import TestComplex # noqa: F401
57from jit.test_jit_utils import TestJitUtils # noqa: F401
58from jit.test_scriptmod_ann import TestScriptModuleInstanceAttributeTypeAnnotation # noqa: F401
59from jit.test_types import TestTypesAndAnnotation # noqa: F401
60from jit.test_misc import TestMisc # noqa: F401
Tugsbayasgalan (Tugsuu) Manlaibaatar20f7c892021-12-11 13:42:29 -080061from jit.test_upgraders import TestUpgraders # noqa: F401
Shen Li10224432021-08-12 11:39:31 -070062from jit.test_pdt import TestPDT # noqa: F401
63from jit.test_tensor_creation_ops import TestTensorCreationOps # noqa: F401
64from jit.test_module_apis import TestModuleAPIs # noqa: F401
65from jit.test_script_profile import TestScriptProfile # noqa: F401
66from jit.test_convert_activation import TestFunctionalToInplaceActivation, TestInplaceToFunctionalActivation # noqa: F401
67from jit.test_parametrization import TestParametrization # noqa: F401
68from jit.test_attr import TestGetDefaultAttr # noqa: F401
69from jit.test_aten_pow import TestAtenPow # noqa: F401
Salil Desai86c96542021-09-01 14:08:02 -070070from jit.test_optimize_for_mobile_preserve_debug_info import TestOptimizeForMobilePreserveDebugInfo # noqa: F401
Ansley Ussery6831d8e2021-09-03 06:10:37 -070071from jit.test_union import TestUnion # noqa: F401
David Berard8eb21482021-09-16 10:44:33 -070072from jit.test_batch_mm import TestBatchMM # noqa: F401
John Clowadb619a2021-12-08 11:49:09 -080073from jit.test_dtype_analysis import TestDtypeAnalysis, TestDtypeCustomRulesCPU # noqa: F401
John Clowdabcbb22022-01-13 13:55:26 -080074from jit.test_device_analysis import TestDeviceAnalysis # noqa: F401
David Berardbf60c6e2021-11-16 08:37:50 -080075from jit.test_dce import TestDCE # noqa: F401
David Berardebc35a72021-12-14 15:42:20 -080076from jit.test_sparse import TestSparse # noqa: F401
gmagogsfmfdd12a92022-03-10 17:40:30 -080077from jit.test_tensor_methods import TestTensorMethods # noqa: F401
Han Qi13dff3b2022-06-07 21:44:55 +000078from jit.test_dataclasses import TestDataclasses # noqa: F401
Shen Li10224432021-08-12 11:39:31 -070079
80# Torch
81from torch import Tensor
82from torch._C import TensorType, BoolType, parse_ir, _propagate_shapes
Shen Li10224432021-08-12 11:39:31 -070083from torch.autograd import Variable
84from torch.jit.annotations import BroadcastingList2, BroadcastingList3, Any # noqa: F401
85from torch.nn.utils.rnn import PackedSequence
Kushashwa Ravi Shrimalid3763692021-08-30 12:16:23 -070086from torch.testing import FileCheck, make_tensor
Shen Li10224432021-08-12 11:39:31 -070087import torch.autograd.profiler
88import torch.cuda
89import torch.jit
90import torch.jit._logging
91import torch.jit.frontend
92import torch.nn as nn
93import torch.nn.functional as F
94
95# Testing utils
96from torch.testing._internal import jit_utils
97from torch.testing._internal.common_jit import check_against_reference
98from torch.testing._internal.common_utils import run_tests, IS_WINDOWS, TEST_WITH_UBSAN, \
Nikita Shulga77becca2021-10-21 20:31:01 -070099 suppress_warnings, BUILD_WITH_CAFFE2, IS_SANDCASTLE, GRAPH_EXECUTOR, ProfilingMode, TestCase, \
AllenTiTaiWangbdb14232022-10-31 23:44:23 +0000100 freeze_rng_state, slowTest, TemporaryFileName, \
Edward Z. Yangee955b82022-04-19 19:56:43 -0700101 enable_profiling_mode_for_profiling_tests, TEST_MKL, set_default_dtype, num_profiled_runs, \
Animesh Jain1d90d6e2022-07-07 18:57:31 +0000102 skipIfCrossRef, IS_MACOS, skipIfTorchDynamo
Shen Li10224432021-08-12 11:39:31 -0700103from torch.testing._internal.jit_utils import JitTestCase, enable_cpu_fuser, disable_autodiff_subgraph_inlining, \
104 _trace, do_input_map, get_execution_plan, make_global, \
105 execWrapper, _inline_everything, _tmp_donotuse_dont_inline_everything, \
106 RUN_CUDA
Mike Ruberryde949a02022-04-18 21:55:32 +0000107from torch.testing._internal.jit_metaprogramming_utils import (
108 get_script_args,
109 create_input, unpack_variables,
110 additional_module_tests, EXCLUDE_SCRIPT_MODULES,
111 get_nn_module_name_from_kwargs, get_nn_mod_test_name, script_method_template)
Shen Li10224432021-08-12 11:39:31 -0700112
113from torch.testing._internal.common_nn import module_tests, new_module_tests, criterion_tests
Shen Li10224432021-08-12 11:39:31 -0700114
115# For testing truediv in python 2
116from torch.testing._internal.test_module.future_div import div_int_future, div_float_future
117from torch.testing._internal.test_module.no_future_div import div_int_nofuture, div_float_nofuture
118
119# Standard library
120from collections import defaultdict, namedtuple, OrderedDict
121from copy import deepcopy
122from itertools import product
123from textwrap import dedent
124from typing import List, Dict, NamedTuple, Optional, Tuple, Union
Richard Barnesec6d29d2021-01-07 12:07:49 -0800125import copy
Elias Ellison2285a2f2020-07-31 15:09:46 -0700126import functools
Richard Barnesec6d29d2021-01-07 12:07:49 -0800127import inspect
davidriazati37ab35c2019-06-06 15:46:44 -0700128import io
Richard Barnesec6d29d2021-01-07 12:07:49 -0800129import itertools
130import math
Shen Li10224432021-08-12 11:39:31 -0700131import numpy as np
davidriazati37ab35c2019-06-06 15:46:44 -0700132import os
133import pickle
134import pickletools
135import random
Richard Barnesec6d29d2021-01-07 12:07:49 -0800136import re
davidriazati37ab35c2019-06-06 15:46:44 -0700137import shutil
Richard Barnesec6d29d2021-01-07 12:07:49 -0800138import string
davidriazati37ab35c2019-06-06 15:46:44 -0700139import sys
140import tempfile
davidriazati37ab35c2019-06-06 15:46:44 -0700141import types
James Reeda3c06e62021-04-12 17:21:52 -0700142import typing
davidriazati37ab35c2019-06-06 15:46:44 -0700143import unittest
144import warnings
145import zipfile
Adam Paszkee3d50c42020-03-09 10:21:50 -0700146
David Riazati9e93a022018-11-28 23:28:59 -0800147
Michael Suocc457ca2019-11-06 13:17:23 -0800148def canonical(graph):
149 return torch._C._jit_pass_canonicalize(graph).str(False)
150
Adam Paszkeb45f2ff2018-05-16 20:03:04 +0200151def LSTMCellF(input, hx, cx, *params):
152 return LSTMCell(input, (hx, cx), *params)
153
Nikolay Korovaiko47faee22019-10-29 11:40:04 -0700154def doAutodiffCheck(testname):
Elias Ellison0a64f992020-07-29 10:14:41 -0700155 # TODO: setting false on test itself is not working
156 if "test_t_" in testname or testname == "test_t":
157 return False
Nikolay Korovaiko47faee22019-10-29 11:40:04 -0700158
Nikolay Korovaiko5b702ab2019-11-11 13:39:03 -0800159 if GRAPH_EXECUTOR == ProfilingMode.SIMPLE:
160 return False
161
162 if GRAPH_EXECUTOR == ProfilingMode.LEGACY:
Nikolay Korovaiko47faee22019-10-29 11:40:04 -0700163 return True
164
Shen Li10224432021-08-12 11:39:31 -0700165
Nikolay Korovaiko47faee22019-10-29 11:40:04 -0700166 # these tests are disabled because BailOut nodes
167 # inserted by ProfilingExecutor interfere with
168 # subgraph slicing of Differentiable Graphs
169 test_exceptions = [
170 # functional
Shen Li10224432021-08-12 11:39:31 -0700171 'test_nn_dropout',
172 'test_nn_log_softmax',
173 'test_nn_relu',
174 'test_nn_softmax',
175 'test_nn_threshold',
176 'test_nn_lp_pool2d',
177 'test_nn_lp_pool1d',
178 'test_nn_gumbel_softmax_hard',
179 'test_nn_gumbel_softmax',
180 'test_nn_multilabel_soft_margin_loss',
181 'test_nn_batch_norm',
182 'test_nn_max_pool2d_with_indices',
Nikolay Korovaiko47faee22019-10-29 11:40:04 -0700183 # AutogradJitGenerated
Shen Li10224432021-08-12 11:39:31 -0700184 'test___rdiv___constant',
185 'test___rdiv___scalar_constant',
186 'test_split',
187 'test_split_dim',
188 'test_split_dim_neg0',
189 'test_split_size_list',
190 'test_split_size_list_dim',
191 'test_split_size_list_dim_neg0',
192 'test_split_with_sizes',
193 'test_split_with_sizes_dim',
194 'test_split_with_sizes_dim_neg0',
195 'test_split_with_sizes_size_0',
196 'test_nn_max_pool2d_with_indices',
Nikolay Korovaiko47faee22019-10-29 11:40:04 -0700197 ]
198
199 if testname in test_exceptions:
Elias Ellisonfdeef452019-11-04 09:18:09 -0800200 return False
Nikolay Korovaiko47faee22019-10-29 11:40:04 -0700201 return True
202
Nikolay Korovaikof91bdbe2020-09-03 14:40:43 -0700203
Nikolay Korovaikofe261022020-09-13 15:56:30 -0700204# TODO: enable TE in PE when all tests are fixed
205torch._C._jit_set_texpr_fuser_enabled(GRAPH_EXECUTOR == ProfilingMode.PROFILING)
Nikolay Korovaiko5b702ab2019-11-11 13:39:03 -0800206torch._C._jit_set_profiling_executor(GRAPH_EXECUTOR != ProfilingMode.LEGACY)
Adam Paszkeb45f2ff2018-05-16 20:03:04 +0200207
Edward Z. Yangf7091992017-10-13 19:57:51 -0700208def LSTMCell(input, hidden, w_ih, w_hh, b_ih=None, b_hh=None):
209 hx, cx = hidden
210 gates = F.linear(input, w_ih, b_ih) + F.linear(hx, w_hh, b_hh)
211
212 ingate, forgetgate, cellgate, outgate = gates.chunk(4, 1)
Wei Yangcb1bfe92018-07-02 15:44:12 -0700213 ingate = torch.sigmoid(ingate)
214 forgetgate = torch.sigmoid(forgetgate)
Adam Paszkeb9f575f2018-07-11 10:22:00 -0700215 cellgate = torch.tanh(cellgate)
Wei Yangcb1bfe92018-07-02 15:44:12 -0700216 outgate = torch.sigmoid(outgate)
Edward Z. Yangf7091992017-10-13 19:57:51 -0700217
218 cy = (forgetgate * cx) + (ingate * cellgate)
Adam Paszkeb9f575f2018-07-11 10:22:00 -0700219 hy = outgate * torch.tanh(cy)
Edward Z. Yangf7091992017-10-13 19:57:51 -0700220 return hy, cy
221
222
Zachary DeVito8cc30e42017-10-31 10:44:13 -0700223def LSTMCellC(*args, **kwargs):
Adam Paszkeb45f2ff2018-05-16 20:03:04 +0200224 hy, cy = LSTMCellF(*args, **kwargs)
Zachary DeVito8cc30e42017-10-31 10:44:13 -0700225 return torch.cat((hy, cy))
226
227
Richard Zoub4462512018-08-14 11:27:21 -0700228def LSTMCellS(x, hx, cx, w_ih, w_hh, b_ih, b_hh):
229 gates = x.mm(w_ih.t()) + hx.mm(w_hh.t()) + b_ih + b_hh
230 ingate, forgetgate, cellgate, outgate = gates.chunk(4, 1)
Adam Paszkeb7038f72018-09-05 06:28:44 -0700231 ingate = torch.sigmoid(ingate)
232 forgetgate = torch.sigmoid(forgetgate)
233 cellgate = torch.tanh(cellgate)
234 outgate = torch.sigmoid(outgate)
Richard Zoub4462512018-08-14 11:27:21 -0700235 cy = (forgetgate * cx) + (ingate * cellgate)
Adam Paszkeb7038f72018-09-05 06:28:44 -0700236 hy = outgate * torch.tanh(cy)
Richard Zoub4462512018-08-14 11:27:21 -0700237 return hy, cy
238
239
240# Code reference: https://github.com/pytorch/translate/blob/master/pytorch_translate/rnn_cell.py#L27:44
241def MiLSTMCell(x, hx, cx, w_ih, w_hh, alpha, beta_i, beta_h, bias):
242 Wx = x.mm(w_ih.t())
243 Uz = hx.mm(w_hh.t())
244 # Section 2.1 in https://arxiv.org/pdf/1606.06630.pdf
245 gates = alpha * Wx * Uz + beta_i * Wx + beta_h * Uz + bias
246 # Same as LSTMCell after this point
247 ingate, forgetgate, cellgate, outgate = gates.chunk(4, 1)
248 ingate = ingate.sigmoid()
249 forgetgate = forgetgate.sigmoid()
250 cellgate = cellgate.tanh()
251 outgate = outgate.sigmoid()
252 cy = (forgetgate * cx) + (ingate * cellgate)
253 hy = outgate * cy.tanh()
254 return hy, cy
255
256
Shen Li10224432021-08-12 11:39:31 -0700257
Adam Paszkea6036892018-11-26 09:18:43 -0800258def get_lstm_inputs(device, training=False, seq_length=None):
259 input_shape = (3, 10) if seq_length is None else (seq_length, 3, 10)
Shen Li10224432021-08-12 11:39:31 -0700260 input = torch.randn(*input_shape, dtype=torch.float, device=device, requires_grad=training)
Adam Paszke98e04db2018-09-13 19:23:06 -0700261 hx = torch.randn(3, 20, dtype=torch.float, device=device, requires_grad=training)
262 cx = torch.randn(3, 20, dtype=torch.float, device=device, requires_grad=training)
Shen Li10224432021-08-12 11:39:31 -0700263 module = nn.LSTMCell(10, 20).to(device, torch.float) # Just to allocate weights with correct sizes
Richard Zoub4462512018-08-14 11:27:21 -0700264 if training:
265 params = tuple(module.parameters())
266 else:
267 params = tuple(p.requires_grad_(False) for p in module.parameters())
268 return (input, hx, cx) + params
269
270
271def get_milstm_inputs(device, training=False):
272 minibatch = 3
273 input_size = 10
274 hidden_size = 20
275 x = torch.randn(minibatch, input_size, device=device, dtype=torch.float)
276 hx = torch.randn(minibatch, hidden_size, device=device, dtype=torch.float)
277 cx = torch.randn(minibatch, hidden_size, device=device, dtype=torch.float)
278
Shen Li10224432021-08-12 11:39:31 -0700279 ih = torch.randn(4 * hidden_size, input_size, device=device, dtype=torch.float, requires_grad=training)
280 hh = torch.randn(4 * hidden_size, hidden_size, device=device, dtype=torch.float, requires_grad=training)
281 alpha = torch.randn(4 * hidden_size, dtype=torch.float, device=device, requires_grad=training)
282 ibeta = torch.randn(4 * hidden_size, dtype=torch.float, device=device, requires_grad=training)
283 hbeta = torch.randn(4 * hidden_size, dtype=torch.float, device=device, requires_grad=training)
284 bias = torch.randn(4 * hidden_size, dtype=torch.float, device=device, requires_grad=training)
Richard Zoub4462512018-08-14 11:27:21 -0700285 return x, hx, cx, ih, hh, alpha, ibeta, hbeta, bias
Adam Paszkeb45f2ff2018-05-16 20:03:04 +0200286
287
Chunliec71c682018-05-18 15:24:20 -0700288def get_fn(file_name, script_path):
289 import importlib.util
290 spec = importlib.util.spec_from_file_location(file_name, script_path)
291 module = importlib.util.module_from_spec(spec)
292 spec.loader.exec_module(module)
293 fn = module.fn
294 return fn
295
Nikolay Korovaiko47faee22019-10-29 11:40:04 -0700296def get_grad_executor(plan_state, diff_graph_idx=None, skip_check=False):
Zachary DeVito056cfaf2018-12-18 10:27:26 -0800297 if diff_graph_idx is None:
298 nodes = list(plan_state.graph.nodes())
Nikolay Korovaiko47faee22019-10-29 11:40:04 -0700299
300 if not skip_check:
Shen Li10224432021-08-12 11:39:31 -0700301 nodes = list(filter(lambda n : n.kind() != "prim::BailOut" and n.kind() != "prim::BailoutTemplate", nodes))
302 if len(nodes) == 1 or (len(nodes) == 2 and nodes[1].kind() == "prim::TupleConstruct"):
Nikolay Korovaiko47faee22019-10-29 11:40:04 -0700303 pass
Shen Li10224432021-08-12 11:39:31 -0700304 elif len(nodes) == 2 and nodes[0].kind() == "prim::RequiresGradCheck" and nodes[1].kind() == "prim::If":
Thomas Viehmannea087e22021-01-08 19:59:58 -0800305 pass
Nikolay Korovaiko47faee22019-10-29 11:40:04 -0700306 else:
Shen Li10224432021-08-12 11:39:31 -0700307 raise RuntimeError("Can't get a grad_executor for a non-differentiable graph")
Zachary DeVitoe958ceb2019-04-13 08:28:13 -0700308 grad_executors = list(plan_state.code.grad_executor_states())
Adam Paszkea6036892018-11-26 09:18:43 -0800309 return grad_executors[diff_graph_idx or 0]
Richard Zoub4462512018-08-14 11:27:21 -0700310
311
Thomas Viehmann17941f92019-05-24 11:13:07 -0700312def all_backward_graphs(script_module, diff_graph_idx=None):
313 # Note: for Python 2 the order seems to be unstable
314 ge_state = script_module.get_debug_state()
315 fwd_plan = get_execution_plan(ge_state)
316 grad_executor_state = get_grad_executor(fwd_plan, diff_graph_idx=diff_graph_idx)
317 bwd_plans = list(grad_executor_state.execution_plans.values())
318 return [p.graph.copy() for p in bwd_plans]
319
320
Nikolay Korovaiko47faee22019-10-29 11:40:04 -0700321def backward_graph(script_module, diff_graph_idx=None, skip_check=False):
Richard Zoub4462512018-08-14 11:27:21 -0700322 ge_state = script_module.get_debug_state()
323 fwd_plan = get_execution_plan(ge_state)
Shen Li10224432021-08-12 11:39:31 -0700324 grad_executor_state = get_grad_executor(fwd_plan, diff_graph_idx=diff_graph_idx, skip_check=skip_check)
Zachary DeVitoe958ceb2019-04-13 08:28:13 -0700325 bwd_plan = get_execution_plan(grad_executor_state)
Richard Zoub4462512018-08-14 11:27:21 -0700326 # Running JIT passes requires that we own the graph (with a shared_ptr).
327 # The debug state struct does not own its graph so we make a copy of it.
328 return bwd_plan.graph.copy()
329
330
Zachary DeVito1abbee02019-04-10 18:12:38 -0700331# helper function to get sum of List[Tensor]
332def _sum_of_list(tensorlist):
333 s = 0
334 for t in tensorlist:
335 s += t.sum()
336 return s
337
Zachary DeVitode31f672019-06-08 20:54:17 -0700338
Zachary DeVito3a984622018-12-17 21:11:30 -0800339# has to be at top level or Pickle complains
Sam Estepe3900d22021-04-19 13:14:27 -0700340class FooToPickle(torch.nn.Module):
Zachary DeVito3a984622018-12-17 21:11:30 -0800341 def __init__(self):
Xuehai Pan046e88a2023-02-12 22:20:50 +0000342 super().__init__()
Zachary DeVito3a984622018-12-17 21:11:30 -0800343 self.bar = torch.jit.ScriptModule()
344
Jason Anselae57bd62023-02-14 19:06:50 +0000345
346@skipIfTorchDynamo()
Elias Ellison2158f4a2018-09-10 11:29:40 -0700347class TestJit(JitTestCase):
Zachary DeVito170ff772018-11-30 19:15:09 -0800348 @unittest.skip("Requires a lot of RAM")
349 def test_big(self):
350 m = torch.jit.ScriptModule()
351 gig = int(1024 * 1024 * 1024 / 4)
352 # a small tensor in the first 4GB
353 m.v0 = nn.Parameter(torch.full((2,), 1, dtype=torch.float))
354 # a large tensor in the first 4GB that ends outside of it
355 m.v1 = nn.Parameter(torch.full((5, gig), 2, dtype=torch.float))
356 # a small tensor in >4GB space
357 m.v2 = nn.Parameter(torch.full((2,), 3, dtype=torch.float))
358 # s large tensor in the > 4GB space
359 m.v3 = nn.Parameter(torch.full((5, gig), 4, dtype=torch.float))
360
361 m2 = self.getExportImportCopy(m)
362
363 self.assertEqual(tuple(m.parameters()), tuple(m2.parameters()))
364
xuewenc7836eac2020-05-29 10:35:08 -0700365 def test_inferred_as_tensor(self):
Shen Li10224432021-08-12 11:39:31 -0700366 with self.assertRaisesRegex(RuntimeError, "Inferred the value for argument 'dim' to be of type 'Tensor' "
367 "because it was not annotated with an explicit type"):
xuewenc7836eac2020-05-29 10:35:08 -0700368 @torch.jit.script
369 def dot(points, query, dim):
370 return (points * query).sum(dim)
371
David Riazati8c6f0c02019-11-22 12:28:49 -0800372 def test_constants_pkl(self):
373 # This test asserts that the serialization archive includes a `constants.pkl`
374 # file. This file is used by `torch.load` to determine whether a zip file
375 # is a normal eager-mode serialization zip or a jit serialization zip. If
376 # you are deleting `constants.pkl`, make sure to update `torch.serialization.load`
377 # so it is still able to figure out which is which.
378 @torch.jit.script
379 def fn(x):
380 return x
381
382 buf = io.BytesIO()
383 torch.jit.save(fn, buf)
384 buf.seek(0)
385
386 files = zipfile.ZipFile(buf).filelist
Shen Li10224432021-08-12 11:39:31 -0700387 self.assertTrue(any(['archive/constants.pkl' == f.filename for f in files]))
David Riazati8c6f0c02019-11-22 12:28:49 -0800388
Nik B2d5b3102021-12-09 09:42:55 -0800389 def test_script_fn_pkl(self):
390 with self.assertRaisesRegex(pickle.PickleError, "ScriptFunction cannot be pickled"):
391
392 @torch.jit.script
393 def fn(x: torch.Tensor) -> torch.Tensor:
394 return x
395
396 pkl_fn = pickle.dumps(fn, protocol=0)
397
Lu Fange0f68672018-12-03 14:07:50 -0800398 def test_restore_device(self):
Michael Suo34126272019-10-12 09:49:56 -0700399 class M(torch.jit.ScriptModule):
400 def __init__(self, cpu_device_str):
Xuehai Pan046e88a2023-02-12 22:20:50 +0000401 super().__init__()
Shen Li10224432021-08-12 11:39:31 -0700402 self.p0 = nn.Parameter(torch.tensor([0.3], dtype=torch.float,
403 device=cpu_device_str))
404 self.b0 = torch.tensor([0.9], dtype=torch.float,
405 device=cpu_device_str)
Michael Suo34126272019-10-12 09:49:56 -0700406
Lu Fange0f68672018-12-03 14:07:50 -0800407 # main purpose is checking map_location works
Michael Suo34126272019-10-12 09:49:56 -0700408 m = M("cpu")
Lu Fange0f68672018-12-03 14:07:50 -0800409 m2 = self.getExportImportCopy(m)
410 self.assertEqual(tuple(m.parameters()), tuple(m2.parameters()))
411 self.assertEqual(tuple(m.buffers()), tuple(m2.buffers()))
412 self.assertFalse(m2.p0.is_cuda)
413 self.assertFalse(m2.b0.is_cuda)
414
415 @unittest.skipIf(not RUN_CUDA, "restore device requires CUDA")
416 def test_restore_device_cuda(self):
Lu Fangc7f93662018-12-04 00:30:46 -0800417 class MyModule(torch.jit.ScriptModule):
418 def __init__(self):
Xuehai Pan046e88a2023-02-12 22:20:50 +0000419 super().__init__()
Shen Li10224432021-08-12 11:39:31 -0700420 self.register_buffer('b0', torch.randn(1, 3))
Lu Fangc7f93662018-12-04 00:30:46 -0800421 self.p0 = nn.Parameter(torch.randn(2, 3))
422
423 @torch.jit.script_method
424 def forward(self, x):
425 return x + self.b0 + self.p0
426
427 m = MyModule()
428 m.cuda(torch.cuda.device_count() - 1)
Shen Li10224432021-08-12 11:39:31 -0700429 cuda_device_str = 'cuda:' + str(torch.cuda.device_count() - 1)
Lu Fangc7f93662018-12-04 00:30:46 -0800430
Lu Fange0f68672018-12-03 14:07:50 -0800431 self.assertTrue(m.p0.is_cuda)
432 self.assertTrue(m.b0.is_cuda)
433
434 # restore to the saved devices
435 m2 = self.getExportImportCopy(m)
436 self.assertEqual(tuple(m.parameters()), tuple(m2.parameters()))
437 self.assertEqual(tuple(m.buffers()), tuple(m2.buffers()))
438 self.assertEqual(str(m2.p0.device), cuda_device_str)
439 self.assertEqual(str(m2.b0.device), cuda_device_str)
440
441 # restore all to cpu using string
Shen Li10224432021-08-12 11:39:31 -0700442 cpu_device_str = 'cpu'
Lu Fange0f68672018-12-03 14:07:50 -0800443 m3 = self.getExportImportCopy(m, map_location=cpu_device_str)
444 self.assertEqual(str(m3.p0.device), cpu_device_str)
445 self.assertEqual(str(m3.b0.device), cpu_device_str)
446
447 # restore all to first gpu using device
Shen Li10224432021-08-12 11:39:31 -0700448 m4 = self.getExportImportCopy(
449 m3, map_location=torch.device('cuda:0'))
450 self.assertEqual(str(m4.p0.device), 'cuda:0')
451 self.assertEqual(str(m4.b0.device), 'cuda:0')
Lu Fange0f68672018-12-03 14:07:50 -0800452
Lu Fangc7f93662018-12-04 00:30:46 -0800453 # compute and compare the results
454 input = torch.rand(2, 3).cuda(torch.cuda.device_count() - 1)
455 origin_result = m(input)
456 self.assertEqual(origin_result, m2(input))
457 self.assertEqual(origin_result, m3(input.cpu()))
458 self.assertEqual(origin_result, m4(input.cuda(0)))
459
Zachary DeVitoecfa7a22020-11-02 17:26:41 -0800460 def test_trace_retains_train(self):
461 class M(torch.nn.Module):
462 def forward(self, x):
463 return x
464 m = M()
465 m.eval()
466 tm = torch.jit.trace(m, (torch.rand(3)))
467 self.assertEqual(tm.training, m.training)
468
Lu Fang8ab4d342019-01-23 21:32:57 -0800469 @unittest.skipIf(not RUN_CUDA, "restore device requires CUDA")
470 def test_restore_shared_storage_on_cuda(self):
Michael Suo34126272019-10-12 09:49:56 -0700471 class Foo(torch.jit.ScriptModule):
472 def __init__(self):
Xuehai Pan046e88a2023-02-12 22:20:50 +0000473 super().__init__()
Shen Li10224432021-08-12 11:39:31 -0700474 whole_tensor = torch.randn(4, 5, dtype=torch.float, device='cpu')
Michael Suo34126272019-10-12 09:49:56 -0700475 self.p0 = nn.Parameter(whole_tensor.narrow(0, 0, 1))
Shen Li10224432021-08-12 11:39:31 -0700476 self.register_buffer('b0', whole_tensor.narrow(0, 3, 1))
Michael Suo34126272019-10-12 09:49:56 -0700477
478 m = Foo()
Shen Li10224432021-08-12 11:39:31 -0700479 m2 = self.getExportImportCopy(m, map_location=torch.device('cuda:0'))
Lu Fang8ab4d342019-01-23 21:32:57 -0800480 self.assertEqual(tuple(m.parameters()), tuple(m2.parameters()))
481 self.assertEqual(tuple(m.buffers()), tuple(m2.buffers()))
482 self.assertTrue(m2.p0.is_cuda)
483 self.assertTrue(m2.b0.is_cuda)
484 self.assertTrue(m2.p0.is_shared())
485 self.assertTrue(m2.b0.is_shared())
486 self.assertEqual(m2.b0.storage().data_ptr(), m2.p0.storage().data_ptr())
487
Kimish Patelc5dcf052020-07-09 16:20:22 -0700488 def test_add_relu_fusion(self):
489 class M(torch.nn.Module):
490 def __init__(self, relu_op):
Xuehai Pan046e88a2023-02-12 22:20:50 +0000491 super().__init__()
Kimish Patelc5dcf052020-07-09 16:20:22 -0700492 self.relu_op = relu_op
493
494 def forward(self, a, b, c):
495 tmp = torch.add(a, b)
496 x = self.relu_op(tmp)
497 d = torch.add(a, c)
498 return x + d
499 a = torch.rand((7, 11))
500 a = a * -10
501 a = a + 5
502 b = torch.rand((7, 11))
503 c = torch.rand((7, 11))
504 m = torch.jit.script(M(torch.relu))
505 orig_res = m(a, b, c)
506 torch._C._jit_pass_fuse_add_relu(m.graph)
507 buffer = io.BytesIO()
508 torch.jit.save(m, buffer)
509 buffer.seek(0)
510 m = torch.jit.load(buffer)
511 new_res = m(a, b, c)
Shen Li10224432021-08-12 11:39:31 -0700512 FileCheck().check_not("aten::relu(") \
513 .check("aten::_add_relu(") \
514 .run(m.graph)
Philip Meier99203582021-08-19 12:45:32 -0700515 torch.testing.assert_close(orig_res, new_res)
Kimish Patelc5dcf052020-07-09 16:20:22 -0700516
517 # add, relu_
518 a = torch.rand((7, 11))
519 a = a * -10
520 a = a + 5
521 b = torch.rand((7, 11))
522 c = torch.rand((7, 11))
523 m = torch.jit.script(M(torch.relu_))
524 orig_res = m(a, b, c)
525 torch._C._jit_pass_fuse_add_relu(m.graph)
526 buffer = io.BytesIO()
527 torch.jit.save(m, buffer)
528 buffer.seek(0)
529 m = torch.jit.load(buffer)
530 new_res = m(a, b, c)
Shen Li10224432021-08-12 11:39:31 -0700531 FileCheck().check_not("aten::relu_(") \
532 .check("aten::_add_relu(") \
533 .run(m.graph)
Philip Meier99203582021-08-19 12:45:32 -0700534 torch.testing.assert_close(orig_res, new_res)
Kimish Patelc5dcf052020-07-09 16:20:22 -0700535
536 class Madd_(torch.nn.Module):
537 def __init__(self, relu_op):
Xuehai Pan046e88a2023-02-12 22:20:50 +0000538 super().__init__()
Kimish Patelc5dcf052020-07-09 16:20:22 -0700539 self.relu_op = relu_op
540
541 def forward(self, a, b):
542 x = a.add_(b)
543 x = self.relu_op(x)
544 return x
545
546 # add_, relu_
547 a = torch.rand((7, 11))
548 a = a * -10
549 a = a + 5
550 b = torch.rand((7, 11))
551 # Because in place add_ will overwrite a
552 a_copy = a.clone()
553 m = torch.jit.script(Madd_(torch.relu_))
554 orig_res = m(a, b)
555 torch._C._jit_pass_fuse_add_relu(m.graph)
556 buffer = io.BytesIO()
557 torch.jit.save(m, buffer)
558 buffer.seek(0)
559 m = torch.jit.load(buffer)
560 new_res = m(a_copy, b)
Shen Li10224432021-08-12 11:39:31 -0700561 FileCheck().check_not("aten::add_(") \
562 .check_not("aten::relu_(") \
563 .check("aten::_add_relu_(") \
564 .run(m.graph)
Philip Meier99203582021-08-19 12:45:32 -0700565 torch.testing.assert_close(orig_res, new_res)
albanD27e2ea42020-10-22 17:57:22 -0700566 # Since _add_relu_ does inplace mutation ensure
Kimish Patelc5dcf052020-07-09 16:20:22 -0700567 # a_copy is modified
Philip Meier99203582021-08-19 12:45:32 -0700568 torch.testing.assert_close(orig_res, a_copy)
Kimish Patelc5dcf052020-07-09 16:20:22 -0700569
570 class Madd_out(torch.nn.Module):
571 def __init__(self, relu_op):
Xuehai Pan046e88a2023-02-12 22:20:50 +0000572 super().__init__()
Kimish Patelc5dcf052020-07-09 16:20:22 -0700573 self.relu_op = relu_op
574
575 def forward(self, a, b):
576 x = torch.add(a, b, out=a)
577 x = self.relu_op(x)
578 return x
579 a = torch.rand((7, 11))
580 a = a * -10
581 a = a + 5
582 b = torch.rand((7, 11))
583
584 # add_out, relu_
585 a = torch.rand((7, 11))
586 a = a * -10
587 a = a + 5
588 b = torch.rand((7, 11))
589 # Because in place add_ will overwrite a
590 a_copy = a.clone()
591 m = torch.jit.script(Madd_out(torch.relu_))
592 orig_res = m(a, b)
593 torch._C._jit_pass_fuse_add_relu(m.graph)
594 buffer = io.BytesIO()
595 torch.jit.save(m, buffer)
596 buffer.seek(0)
597 m = torch.jit.load(buffer)
598 new_res = m(a_copy, b)
Shen Li10224432021-08-12 11:39:31 -0700599 FileCheck().check_not("aten::add(") \
600 .check_not("aten::relu_(") \
601 .check("aten::_add_relu(") \
602 .run(m.graph)
Philip Meier99203582021-08-19 12:45:32 -0700603 torch.testing.assert_close(orig_res, new_res)
albanD27e2ea42020-10-22 17:57:22 -0700604 # Since _add_relu_ with out=a does inplace mutation ensure
Kimish Patelc5dcf052020-07-09 16:20:22 -0700605 # a_copy is modified
Philip Meier99203582021-08-19 12:45:32 -0700606 torch.testing.assert_close(orig_res, a_copy)
Kimish Patelc5dcf052020-07-09 16:20:22 -0700607
yuguo68c1b831f2022-06-07 18:56:47 -0700608 def test_repeat_interleave_script(self):
609 def fn(input: torch.Tensor, repeats: torch.Tensor) -> torch.Tensor:
610 output = input.repeat_interleave(repeats)
611 return output
612 fn_scripted = torch.jit.script(fn)
613
614 input = torch.tensor([5, 7], dtype=torch.int64)
615 repeats = torch.tensor([3, 6], dtype=torch.int64)
616
617 output = fn(input, repeats)
618 output_scripted = fn_scripted(input, repeats)
619 self.assertEqual(output_scripted, output)
620
Shen Li10224432021-08-12 11:39:31 -0700621 @unittest.skipIf(GRAPH_EXECUTOR != ProfilingMode.LEGACY, "Simple executor doesn't have shape information")
Elias Ellison20b63aa2019-04-01 15:33:35 -0700622 def test_peephole_optimize_shape_ops(self):
623 def test_input(func, input, result):
Elias Ellisonfdeef452019-11-04 09:18:09 -0800624 # if result == 2 we will trigger a bailout and
Nikolay Korovaiko47faee22019-10-29 11:40:04 -0700625 # the unprofiled graph should return the correct result
626 self.assertEqual(func(input, profile_and_replay=True), result)
Elias Ellison20b63aa2019-04-01 15:33:35 -0700627 gre = func.graph_for(input)
Nikolay Korovaikofe261022020-09-13 15:56:30 -0700628 FileCheck().check_not("prim::If").run(gre)
Elias Ellison20b63aa2019-04-01 15:33:35 -0700629
630 def test_dim():
631 @torch.jit.script
632 def func(x):
633 if x.dim() == 1:
634 return 1
635 else:
636 return 2
637
638 test_input(func, torch.tensor([0.5]), 1)
639 test_input(func, torch.tensor([[0.5]]), 2)
640 test_dim()
641
Elias Ellison92129952020-04-28 23:18:29 -0700642 def test_size_index():
643 @torch.jit.script
644 def func(x):
645 if x.size(0) == 1:
646 return 1
647 else:
648 return 2
649
650 test_input(func, torch.rand([1, 2]), 1)
651 test_input(func, torch.rand([1, 3]), 1)
652
653 @torch.jit.script
654 def neg_index(x):
655 if x.size(-2) == 1:
656 return 1
657 else:
658 return 2
659
660 test_input(neg_index, torch.rand([1, 2]), 1)
661 test_input(neg_index, torch.rand([1, 3]), 1)
662
663 if GRAPH_EXECUTOR == ProfilingMode.PROFILING:
664 test_size_index()
665
Elias Ellison20b63aa2019-04-01 15:33:35 -0700666 def test_dtype():
667 @torch.jit.script
668 def func(x):
669 if x.dtype == torch.float32:
670 return 1
671 else:
672 return 2
673
674 test_input(func, torch.tensor(0.5, dtype=torch.float32), 1)
675 test_input(func, torch.tensor(0.5, dtype=torch.int64), 2)
676 test_dtype()
677
Elias Ellison45e84512020-04-28 23:18:29 -0700678 def test_is_floating_poiint():
679 @torch.jit.script
680 def func(x):
681 if x.is_floating_point():
682 return 1
683 else:
684 return 2
685
686 test_input(func, torch.tensor(0.5, dtype=torch.float32), 1)
687 test_input(func, torch.tensor(0.5, dtype=torch.int64), 2)
688 test_is_floating_poiint()
689
Elias Ellison20b63aa2019-04-01 15:33:35 -0700690 def test_device():
691 @torch.jit.script
692 def func_1(x):
Shen Li10224432021-08-12 11:39:31 -0700693 if x.device == torch.device('cuda:0'):
Elias Ellison20b63aa2019-04-01 15:33:35 -0700694 a = 0
695 else:
696 a = 1
697 return a
698
699 @torch.jit.script
700 def func_2(x):
701 if x.is_cuda:
702 a = 0
703 else:
704 a = 1
705 return a
706
707 test_input(func_1, torch.tensor(0.5), 1)
708 test_input(func_2, torch.tensor(0.5), 1)
709
710 if RUN_CUDA:
711 test_input(func_1, torch.tensor(0.5, device="cuda:0"), 0)
712 test_input(func_2, torch.tensor(0.5, device="cuda:0"), 0)
713
714 test_device()
715
James Reed98439932019-08-09 17:22:46 -0700716 def test_attrs(self):
717 def foo(x):
718 return (
719 # x.dtype, TODO: dtype long -> instance conversion
720 x.device,
721 x.shape,
722 x.is_cuda,
723 x.is_mkldnn,
724 x.is_quantized,
James Reed6a4ca9a2019-09-30 19:28:58 -0700725 x.requires_grad,
lezcano82a216c2021-10-13 07:43:30 -0700726 x.T,
727 x.mT,
728 x.H,
729 x.mH
James Reed6a4ca9a2019-09-30 19:28:58 -0700730 # x.layout TODO: layout long -> instance conversion
James Reed98439932019-08-09 17:22:46 -0700731 )
732
733 scripted = torch.jit.script(foo)
734 x = torch.rand(3, 4)
735 self.assertEqual(scripted(x), foo(x))
736
James Reed6a4ca9a2019-09-30 19:28:58 -0700737 def test_layout(self):
738 @torch.jit.script
739 def check(x, y):
740 return x.layout == y.layout
741
742 x = torch.rand(3, 4)
743 y = torch.rand(3, 4)
744
745 self.assertTrue(check(x, y))
746
lezcano82a216c2021-10-13 07:43:30 -0700747 def test_matrix_transpose(self):
748 @torch.jit.script
749 def check(x):
PyTorch MergeBotcba96362022-12-02 21:36:13 +0000750 return torch.equal(x.mT, x.transpose(-2, -1))
lezcano82a216c2021-10-13 07:43:30 -0700751
752 x = torch.rand(3, 4)
753 self.assertTrue(check(x))
754
755 def test_transpose(self):
756 @torch.jit.script
757 def check(x):
PyTorch MergeBotcba96362022-12-02 21:36:13 +0000758 return torch.equal(x.T, x.t())
lezcano82a216c2021-10-13 07:43:30 -0700759
760 x = torch.rand(3, 4)
761 self.assertTrue(check(x))
762
763 def test_matrix_conj_transpose(self):
764 @torch.jit.script
765 def check(x):
PyTorch MergeBotcba96362022-12-02 21:36:13 +0000766 return torch.equal(x.mH, x.transpose(-2, -1).conj())
lezcano82a216c2021-10-13 07:43:30 -0700767
768 x = torch.rand(3, 4)
769 self.assertTrue(check(x))
770
771 x = make_tensor((3, 4), device="cpu", dtype=torch.complex64)
772 self.assertTrue(check(x))
773
774 def test_conj_transpose(self):
775 @torch.jit.script
776 def check(x):
PyTorch MergeBotcba96362022-12-02 21:36:13 +0000777 return torch.equal(x.H, x.t().conj())
lezcano82a216c2021-10-13 07:43:30 -0700778
779 x = torch.rand(3, 4)
780 self.assertTrue(check(x))
781
782 x = make_tensor((3, 4), device="cpu", dtype=torch.complex64)
783 self.assertTrue(check(x))
784
785 def test_T_mT_H_mH(self):
786 def T(x):
787 return x.mT
788
789 def mT(x):
790 return x.mT
791
792 def H(x):
793 return x.H
794
795 def mH(x):
796 return x.mH
797
798 x = torch.rand(3, 4)
799 y = make_tensor((3, 4), device="cpu", dtype=torch.complex64)
800
801 self.checkScript(T, (x, ))
802 self.checkScript(mT, (x, ))
803 self.checkScript(H, (x, ))
804 self.checkScript(mH, (x, ))
805 self.checkScript(T, (y, ))
806 self.checkScript(mT, (y, ))
807 self.checkScript(H, (y, ))
808 self.checkScript(mH, (y, ))
809
Guilherme Leobas0d981ee2021-01-15 11:14:13 -0800810 def test_nn_conv(self):
811 class Mod(nn.Module):
812 def __init__(self, conv):
813 super().__init__()
814 self.conv = conv
815
816 def forward(self, input):
817 return self.conv(input)
818
819 inputs = [
820 # Conv
821 (Mod(nn.Conv1d(16, 33, 3, stride=2)), torch.randn(20, 16, 5)),
822 (Mod(nn.Conv2d(16, 33, 3, stride=2)), torch.randn(20, 16, 5, 10)),
823 (Mod(nn.Conv3d(16, 33, 3, stride=2)), torch.randn(20, 16, 3, 5, 4)),
824 # ConvTransposed
825 (Mod(nn.ConvTranspose1d(16, 33, 3, stride=2)), torch.randn(20, 16, 5)),
826 (Mod(nn.ConvTranspose2d(16, 33, 3, stride=2)), torch.randn(20, 16, 5, 10)),
Shen Li10224432021-08-12 11:39:31 -0700827 (Mod(nn.ConvTranspose3d(16, 33, 3, stride=2)), torch.randn(20, 16, 3, 5, 4)),
Guilherme Leobas0d981ee2021-01-15 11:14:13 -0800828 ]
829
830 for m, inp in inputs:
831 self.checkModule(m, (inp,))
832
Shen Li10224432021-08-12 11:39:31 -0700833 @unittest.skipIf(GRAPH_EXECUTOR != ProfilingMode.PROFILING, 'Not implemented for Simple or Legacy')
Nikolay Korovaiko0019a202021-02-16 10:46:30 -0800834 def test_debug_flush_compilation_cache(self):
835 def foo(x):
836 return x + 2
Nikolay Korovaiko847d1d42021-02-22 12:21:50 -0800837
838 class Mod(nn.Module):
Nikolay Korovaiko847d1d42021-02-22 12:21:50 -0800839 def forward(self, t):
840 return t + 2
841
842 m = torch.jit.script(Mod())
Nikolay Korovaiko0019a202021-02-16 10:46:30 -0800843 x = torch.rand(1, 10)
Nikolay Korovaiko847d1d42021-02-22 12:21:50 -0800844
Nikolay Korovaiko0019a202021-02-16 10:46:30 -0800845 with enable_profiling_mode_for_profiling_tests():
846 jitted = self.checkScript(foo, (x,))
847 # shouldn't throw
848 states = jitted.get_debug_state()
Nikolay Korovaiko847d1d42021-02-22 12:21:50 -0800849
Nikolay Korovaiko0019a202021-02-16 10:46:30 -0800850 # after flushing there shouldn't be
851 # no opt plan
Nikolay Korovaiko0019a202021-02-16 10:46:30 -0800852 jitted._debug_flush_compilation_cache()
853 with self.assertRaisesRegex(RuntimeError, "INTERNAL ASSERT FAILED"):
Sam Estepe3900d22021-04-19 13:14:27 -0700854 states = jitted.get_debug_state()
Nikolay Korovaiko0019a202021-02-16 10:46:30 -0800855
Nikolay Korovaiko847d1d42021-02-22 12:21:50 -0800856 NUM_RUNS = 1
857 with num_profiled_runs(NUM_RUNS):
858 m(x)
859 m(x)
860 fwd = m._c._get_method("forward")
861 states = m.get_debug_state()
862
863 # after flushing there shouldn't be
864 # no opt plan
865 fwd._debug_flush_compilation_cache()
866 with self.assertRaisesRegex(RuntimeError, "INTERNAL ASSERT FAILED"):
Sam Estepe3900d22021-04-19 13:14:27 -0700867 states = m.get_debug_state()
Nikolay Korovaiko847d1d42021-02-22 12:21:50 -0800868
Shihao Xu97337112020-02-26 22:22:52 -0800869 def test_numel(self):
870 @torch.jit.script
871 def get_numel_script(x):
872 return x.numel()
873
874 x = torch.rand(3, 4)
875 numel = get_numel_script(x)
876 self.assertEqual(numel, x.numel())
877
878 def test_element_size(self):
879 @torch.jit.script
880 def get_element_size_script(x):
881 return x.element_size()
882
883 x = torch.rand(3, 4)
884 element_size = get_element_size_script(x)
885 self.assertEqual(element_size, x.element_size())
886
Guilherme Leobasa9e46f12021-01-19 15:04:32 -0800887 def test_Sequential(self):
888 class Seq(nn.Module):
889 def __init__(self):
Xuehai Pan046e88a2023-02-12 22:20:50 +0000890 super().__init__()
Guilherme Leobasa9e46f12021-01-19 15:04:32 -0800891 self.seq = nn.Sequential(nn.Linear(10, 20), nn.Linear(20, 30))
892
893 @torch.jit.script_method
894 def forward(self, x):
895 for l in self.seq:
896 x = l(x)
897 return x
898
899 m = torch.jit.script(Seq())
900 assert m.graph # ensure jit was able to compile
901
902 def test_ModuleList(self):
903 class Mod(nn.Module):
904 def __init__(self):
Xuehai Pan046e88a2023-02-12 22:20:50 +0000905 super().__init__()
Guilherme Leobasa9e46f12021-01-19 15:04:32 -0800906 self.model = nn.ModuleList([nn.Linear(10, 10) for _ in range(10)])
907 self.model += (nn.Linear(10, 20),)
Ansley Ussery09e48db2021-01-29 18:43:56 -0800908 self.model.append(nn.Linear(20, 30))
Guilherme Leobasa9e46f12021-01-19 15:04:32 -0800909 self.model.extend([nn.Linear(30, 40), nn.Linear(40, 50)])
910
911 def forward(self, v):
912 for m in self.model:
913 v = m(v)
914 return v
915
916 m = torch.jit.script(Mod())
917 assert m.graph # ensure jit was able to compile
918
Adam Paszked9b74f62018-08-29 14:58:13 -0700919 def test_disabled(self):
Michael Suo300a3aa2020-07-05 21:59:08 -0700920 torch.jit._state.disable()
Adam Paszked9b74f62018-08-29 14:58:13 -0700921 try:
922 def f(x, y):
923 return x + y
924
Zachary DeVito93bd2912018-08-30 13:51:45 -0700925 self.assertIs(torch.jit.trace(f, (torch.randn(2, 2), torch.randn(2, 2))), f)
Adam Paszked9b74f62018-08-29 14:58:13 -0700926 self.assertIs(torch.jit.script(f), f)
927
928 class MyModule(torch.jit.ScriptModule):
929 @torch.jit.script_method
930 def method(self, x):
931 return x
932
933 # XXX: Unfortunately ScriptModule won't simply become Module now,
934 # because that requires disabling the JIT at startup time, which
935 # we can't do in here.
936 # We need to or those two conditions to make it work with all versions of Python
Shen Li10224432021-08-12 11:39:31 -0700937 self.assertTrue(inspect.ismethod(MyModule.method) or inspect.isfunction(MyModule.method))
Adam Paszked9b74f62018-08-29 14:58:13 -0700938 finally:
Michael Suo300a3aa2020-07-05 21:59:08 -0700939 torch.jit._state.enable()
Adam Paszked9b74f62018-08-29 14:58:13 -0700940
Zachary DeVito289a8c92018-09-11 15:01:48 -0700941 def test_train_eval(self):
942 class Sub(nn.Module):
943 def forward(self, input):
944 if self.training:
945 return input
946 else:
947 return -input
948
949 class MyModule(torch.jit.ScriptModule):
Wanchao Liangd872af92018-12-04 18:15:14 -0800950 def __init__(self, module):
Xuehai Pan046e88a2023-02-12 22:20:50 +0000951 super().__init__()
Wanchao Liangd872af92018-12-04 18:15:14 -0800952 self.module = module
Zachary DeVito289a8c92018-09-11 15:01:48 -0700953
954 @torch.jit.script_method
955 def forward(self, input):
Wanchao Liangd872af92018-12-04 18:15:14 -0800956 return self.module(input) + 1
Zachary DeVito289a8c92018-09-11 15:01:48 -0700957
Wanchao Liangd872af92018-12-04 18:15:14 -0800958 m = MyModule(Sub())
Zachary DeVito289a8c92018-09-11 15:01:48 -0700959 input = torch.rand(3, 4)
960 self.assertEqual(input + 1, m(input))
961 m.eval()
962 self.assertEqual(-input + 1, m(input))
963
Wanchao Liangd872af92018-12-04 18:15:14 -0800964 # test batchnorm and dropout train/eval
965 input = torch.randn(6, 10)
966 batchnorm = nn.BatchNorm1d(10)
967 dropout = nn.Dropout(p=0.2)
Zachary DeVito289a8c92018-09-11 15:01:48 -0700968
Wanchao Liangd872af92018-12-04 18:15:14 -0800969 m_batchnorm = MyModule(batchnorm)
970 self.assertEqual(batchnorm(input) + 1, m_batchnorm(input))
971 batchnorm.eval()
972 m_batchnorm.eval()
973 self.assertEqual(batchnorm(input) + 1, m_batchnorm(input))
Zachary DeVito289a8c92018-09-11 15:01:48 -0700974
Wanchao Liangd872af92018-12-04 18:15:14 -0800975 m_dropout = MyModule(dropout)
976 dropout.eval()
977 m_dropout.eval()
978 self.assertEqual(dropout(input) + 1, m_dropout(input))
Zachary DeVito289a8c92018-09-11 15:01:48 -0700979
Davit Kobaladze8e12d2b2022-03-28 16:09:51 -0700980 def test_nn_lp_pool2d(self):
981 class Mod(torch.nn.Module):
982 def __init__(self):
983 super().__init__()
984 self.l = torch.nn.LPPool2d(2, 3)
985 self.n = torch.nn.LPPool2d(2, (7, 1))
986
987 def forward(self, x):
988 return (self.l(x),
989 self.n(x),
990 torch.nn.functional.lp_pool2d(x, float(2), 3),
991 torch.nn.functional.lp_pool2d(x, 2, 3),
992 torch.nn.functional.lp_pool2d(x, float(2), (7, 1)))
993
994 self.checkModule(Mod(), (torch.rand(1, 3, 7, 7),))
995
996 def test_nn_lp_pool1d(self):
997 class Mod(torch.nn.Module):
998 def __init__(self):
999 super().__init__()
1000 self.l = torch.nn.LPPool1d(2, 3)
1001 self.n = torch.nn.LPPool1d(2, 7)
1002
1003 def forward(self, x):
1004 return (self.l(x),
1005 self.n(x),
1006 torch.nn.functional.lp_pool1d(x, float(2), 3),
1007 torch.nn.functional.lp_pool1d(x, 2, 3),
1008 torch.nn.functional.lp_pool1d(x, float(2), 7))
1009
1010 self.checkModule(Mod(), (torch.rand(1, 3, 7),))
1011
1012 def test_nn_padding_functional(self):
1013 class Mod(nn.Module):
1014 def __init__(self, *pad):
1015 super().__init__()
1016 self.pad = pad
1017
1018 def forward(self, x):
1019 return F.pad(x, self.pad, mode='constant', value=3.5)
1020
1021 inputs = [
1022 (Mod(1, 2), torch.randn(1, 3, 4)), # 1D
1023 (Mod(1, 2, 3, 4), torch.randn(1, 3, 4)), # 2D
1024 (Mod(1, 2, 3, 4, 5, 6), torch.randn(1, 3, 4)), # 3D
1025 ]
1026
1027 for m, inp in inputs:
1028 self.checkModule(m, (inp,))
1029
Guilherme Leobas374951d2021-01-12 15:28:58 -08001030 def test_nn_padding(self):
1031 class Mod(nn.Module):
1032 def __init__(self, padding):
1033 super().__init__()
1034 self.padding = padding
1035
1036 def forward(self, input):
1037 return self.padding(input)
1038
1039 inputs = [
1040 (Mod(nn.ConstantPad1d(2, 3.5)), torch.randn(1, 2, 4)),
1041 (Mod(nn.ConstantPad2d(2, 3.5)), torch.randn(1, 2, 2)),
1042 (Mod(nn.ConstantPad3d(3, 3.5)), torch.randn(16, 3, 10, 20, 30)),
Shen Li10224432021-08-12 11:39:31 -07001043 (Mod(nn.ReflectionPad1d(2)), torch.arange(8, dtype=torch.float).reshape(1, 2, 4)),
1044 (Mod(nn.ReflectionPad2d(2)), torch.arange(9, dtype=torch.float).reshape(1, 1, 3, 3)),
Thomas J. Fanc16f8792021-06-21 10:51:49 -07001045 (Mod(nn.ReflectionPad3d(3)), torch.randn(16, 3, 8, 32, 48)),
Shen Li10224432021-08-12 11:39:31 -07001046 (Mod(nn.ReplicationPad1d(2)), torch.arange(8, dtype=torch.float).reshape(1, 2, 4)),
1047 (Mod(nn.ReplicationPad2d(2)), torch.arange(9, dtype=torch.float).reshape(1, 1, 3, 3)),
Guilherme Leobas374951d2021-01-12 15:28:58 -08001048 (Mod(nn.ReplicationPad3d(3)), torch.randn(16, 3, 8, 32, 48)),
Shen Li10224432021-08-12 11:39:31 -07001049 (Mod(nn.ZeroPad2d(2)), torch.randn(1, 1, 3, 3))
Guilherme Leobas374951d2021-01-12 15:28:58 -08001050 ]
1051
1052 for m, inp in inputs:
1053 self.checkModule(m, (inp,))
1054
Wanchao Liangf6daab52019-08-21 11:18:18 -07001055 def test_script_autograd_grad(self):
1056 def test_simple_grad(x, y):
Wanchao Liangf7b12a92019-11-20 22:16:51 -08001057 # type: (Tensor, Tensor) -> List[Optional[Tensor]]
Wanchao Liangf6daab52019-08-21 11:18:18 -07001058 z = x + 2 * y + x * y
Shen Li10224432021-08-12 11:39:31 -07001059 return torch.autograd.grad((z.sum(), ), (x, y))
Wanchao Liangf6daab52019-08-21 11:18:18 -07001060
1061 def test_simple_grad_with_grad_outputs(x, y):
Wanchao Liangf7b12a92019-11-20 22:16:51 -08001062 # type: (Tensor, Tensor) -> List[Optional[Tensor]]
Wanchao Liangf6daab52019-08-21 11:18:18 -07001063 z = x + 2 * y + x * y
Shen Li10224432021-08-12 11:39:31 -07001064 grad_outputs = torch.jit.annotate(List[Optional[torch.Tensor]], [torch.ones((2, 2)), ])
1065 return torch.autograd.grad((z, ), (x, y), grad_outputs)
Wanchao Liangf6daab52019-08-21 11:18:18 -07001066
Wanchao Liangf7b12a92019-11-20 22:16:51 -08001067 def test_one_output_not_requires_grad(x, y):
1068 # type: (Tensor, Tensor) -> List[Optional[Tensor]]
1069 z = 2 * y + y
1070 return torch.autograd.grad((z.sum(),), (x, y), allow_unused=True)
1071
1072 def test_retain_graph(x, y):
1073 # type: (Tensor, Tensor) -> None
1074 z = x + 2 * y + x * y
Shen Li10224432021-08-12 11:39:31 -07001075 torch.autograd.grad((z.sum(), ), (x, y), retain_graph=True)
1076 torch.autograd.grad((z.sum(), ), (x, y))
Wanchao Liangf7b12a92019-11-20 22:16:51 -08001077
Wanchao Liangf6daab52019-08-21 11:18:18 -07001078 x = torch.randn(2, 2, requires_grad=True)
1079 y = torch.randn(2, 2, requires_grad=True)
1080 self.checkScript(test_simple_grad, (x, y), inputs_requires_grad=True)
Shen Li10224432021-08-12 11:39:31 -07001081 self.checkScript(test_simple_grad_with_grad_outputs, (x, y), inputs_requires_grad=True)
1082 self.checkScript(test_one_output_not_requires_grad, (x, y), inputs_requires_grad=True)
Wanchao Liangf7b12a92019-11-20 22:16:51 -08001083 self.checkScript(test_retain_graph, (x, y), inputs_requires_grad=True)
Wanchao Liangf6daab52019-08-21 11:18:18 -07001084
Wanchao Liang969c9182019-08-26 12:09:23 -07001085 def test_script_backward(self):
Wanchao Lianga252aee2019-09-30 10:26:23 -07001086 def checkBackwardScript(fn, inputs):
Wanchao Liang969c9182019-08-26 12:09:23 -07001087 scripted_fn = torch.jit.script(fn)
Wanchao Lianga252aee2019-09-30 10:26:23 -07001088 FileCheck().check("torch.autograd.backward").run(scripted_fn.code)
Shen Li10224432021-08-12 11:39:31 -07001089 recording_inputs = do_input_map(lambda t: t.detach().requires_grad_(), inputs)
Wanchao Liang969c9182019-08-26 12:09:23 -07001090
1091 fn(*inputs)
1092 scripted_fn(*recording_inputs)
1093
1094 for inp1, inp2 in zip(inputs, recording_inputs):
1095 self.assertEqual(inp1.grad, inp2.grad)
1096
1097 def test_tensor_backward(input):
1098 # type: (Tensor) -> None
1099 output = torch.relu(input)
1100 output = output.softmax(0)
1101 sum_out = output.sum()
1102 sum_out.backward()
1103
1104 def test_torch_autograd_backward(input):
1105 # type: (Tensor) -> None
1106 output = torch.relu(input)
1107 output = output.softmax(0)
1108 torch.autograd.backward(output.sum())
1109
1110 def test_torch_autograd_backward_with_grad_tensors(input):
1111 # type: (Tensor) -> None
1112 output = torch.relu(input)
1113 output = output.softmax(0)
Shen Li10224432021-08-12 11:39:31 -07001114 grad_outputs = torch.jit.annotate(List[Optional[torch.Tensor]], [torch.ones((2, 2)), ])
Wanchao Liang969c9182019-08-26 12:09:23 -07001115 torch.autograd.backward((output,), grad_outputs)
1116
1117 inp = torch.randn(2, 2, requires_grad=True)
Wanchao Lianga252aee2019-09-30 10:26:23 -07001118 checkBackwardScript(test_tensor_backward, (inp,))
1119 checkBackwardScript(test_torch_autograd_backward, (inp,))
1120 checkBackwardScript(test_torch_autograd_backward_with_grad_tensors, (inp,))
Wanchao Liang969c9182019-08-26 12:09:23 -07001121
Sujoy Saraswati54931eb2020-09-08 14:35:06 -07001122 def test_script_backward_twice(self):
1123 def checkBackwardTwiceScript(fn, inputs, retain_graph_=False):
1124 torch._C._jit_set_profiling_executor(False)
1125
1126 with torch.jit.optimized_execution(True):
1127 scripted_fn = torch.jit.script(fn, inputs)
Shen Li10224432021-08-12 11:39:31 -07001128 FileCheck().check("prim::DifferentiableGraph").run(scripted_fn.graph_for(*inputs))
Sujoy Saraswati54931eb2020-09-08 14:35:06 -07001129
1130 result = scripted_fn(*inputs)
1131 result.sum().backward(retain_graph=retain_graph_)
1132 if not retain_graph_:
Shen Li10224432021-08-12 11:39:31 -07001133 self.assertRaisesRegex(RuntimeError, 'Specify retain_graph=True',
1134 lambda: result.sum().backward())
Sujoy Saraswati54931eb2020-09-08 14:35:06 -07001135 else:
1136 result.sum().backward()
1137
1138 def test_script_backward_twice_with_saved_values(input1, input2):
1139 # type: (Tensor, Tensor) -> Tensor
1140 tmp1 = torch.mul(input1, input2)
1141 tmp2 = torch.abs(tmp1)
PyTorch MergeBotcba96362022-12-02 21:36:13 +00001142 if torch.equal(input1, input2):
Sujoy Saraswati54931eb2020-09-08 14:35:06 -07001143 tmp2 = torch.acos(tmp2)
1144 else:
1145 tmp2 = torch.atan(tmp2)
1146 result = torch.add(tmp2, input2)
1147 return result
1148
1149 inp1 = torch.randn(2, 2, requires_grad=True)
1150 inp2 = torch.randn(2, 2, requires_grad=True)
Shen Li10224432021-08-12 11:39:31 -07001151 checkBackwardTwiceScript(test_script_backward_twice_with_saved_values, (inp1, inp2), False)
1152 checkBackwardTwiceScript(test_script_backward_twice_with_saved_values, (inp1, inp2), True)
Sujoy Saraswati54931eb2020-09-08 14:35:06 -07001153
Adam Paszke51414822018-09-24 13:16:13 -07001154 def test_diff_subgraph_clones_constants(self):
1155 @torch.jit.script
1156 def f(x, y):
1157 return x + x + y + x + y + x + y + x + y + x
1158
1159 def count_constants(graph):
Shen Li10224432021-08-12 11:39:31 -07001160 return sum(node.kind() == 'prim::Constant' for node in graph.nodes())
Adam Paszke51414822018-09-24 13:16:13 -07001161
1162 graph = f.graph.copy()
Shen Li10224432021-08-12 11:39:31 -07001163 self.run_pass('cse', graph)
1164 self.run_pass('create_autodiff_subgraphs', graph)
Adam Paszke51414822018-09-24 13:16:13 -07001165 nodes = list(graph.nodes())
1166 self.assertEqual(count_constants(graph), 1)
Shen Li10224432021-08-12 11:39:31 -07001167 self.assertEqual(count_constants(nodes[1].g('Subgraph')), 1)
Adam Paszke51414822018-09-24 13:16:13 -07001168
Adam Paszkeb45f2ff2018-05-16 20:03:04 +02001169 # TODO: adapt this test to check that GraphExecutor treats them differently
1170 @unittest.skip("Need to be adjusted to Graph Executor")
Adam Paszked1fb8fd2017-11-11 12:48:46 -05001171 def test_arg_configurations(self):
1172 """Different arg configurations should trigger different traces"""
1173 x = Variable(torch.FloatTensor(4, 4).uniform_())
1174 x_double = Variable(x.data.double())
Adam Paszked1fb8fd2017-11-11 12:48:46 -05001175 x_grad = Variable(x.data.clone(), requires_grad=True)
1176 y = Variable(torch.randn(4))
1177
1178 configurations = [
1179 (x,),
1180 (x_double,),
Adam Paszked1fb8fd2017-11-11 12:48:46 -05001181 (x_grad,),
1182 (y,),
1183 ([x, x],),
1184 ([x, y],),
1185 ]
1186 if torch.cuda.is_available():
1187 x_cuda = Variable(x.data.cuda())
1188 configurations += [
1189 (x_cuda,),
1190 ([x, x_cuda],),
1191 ([x_cuda, x],),
1192 ([[x_cuda, x]],),
1193 ]
1194 if torch.cuda.device_count() > 1:
1195 x_cuda_1 = Variable(x.data.cuda(1))
1196 configurations += [
1197 (x_cuda_1,),
1198 ([x_cuda, x_cuda_1],),
1199 ]
1200
1201 @torch.jit.compile(nderivs=0)
1202 def fn(*args):
Adam Paszkee6cbe842017-12-21 10:54:49 -05001203 in_vars, _ = torch._C._jit_flatten(args)
Adam Paszked1fb8fd2017-11-11 12:48:46 -05001204 return in_vars[0] + 1
1205
1206 for i, config in enumerate(configurations):
1207 self.assertFalse(fn.has_trace_for(*config))
1208 fn(*config)
1209 self.assertTrue(fn.has_trace_for(*config))
Shen Li10224432021-08-12 11:39:31 -07001210 for unk_config in configurations[i + 1:]:
Adam Paszked1fb8fd2017-11-11 12:48:46 -05001211 self.assertFalse(fn.has_trace_for(*unk_config))
Adam Paszke669a99b2017-11-22 05:16:39 -05001212 self.assertEqual(fn.hits, 0)
Adam Paszked1fb8fd2017-11-11 12:48:46 -05001213
nikithamalgic4237332021-02-25 21:07:34 -08001214 def test_torch_sum(self):
1215 def fn(x):
1216 return torch.sum(x)
1217
1218 def fn1(x, dim: int):
1219 return torch.sum(x, dim)
1220
1221 x = torch.randn(3, 4)
Shen Li10224432021-08-12 11:39:31 -07001222 self.checkScript(fn, (x, ))
1223 self.checkScript(fn1, (x, 1, ))
1224 self.checkScript(fn1, (x, 0, ))
nikithamalgic4237332021-02-25 21:07:34 -08001225
Lu Fang0a1ac8b2017-09-14 22:07:59 -07001226 def test_cse(self):
Adam Paszkeb45f2ff2018-05-16 20:03:04 +02001227 x = torch.tensor([0.4, 0.3], requires_grad=True)
1228 y = torch.tensor([0.7, 0.5], requires_grad=True)
Adam Paszkee6cbe842017-12-21 10:54:49 -05001229
1230 def fn(x, y):
1231 w = (x + y) * (x + y) * (x + y)
1232 t = torch.tanh(w) + torch.tanh(w)
1233 z = (x + y) * (x + y) * (x + y) + t
1234 return z
Lu Fang0a1ac8b2017-09-14 22:07:59 -07001235
James Reed6e38c3b2019-11-05 17:02:40 -08001236 g, _ = torch.jit._get_trace_graph(fn, (x, y))
Shen Li10224432021-08-12 11:39:31 -07001237 self.run_pass('cse', g)
eellisond8d83712019-02-22 17:54:09 -08001238 do_exactly = True
Shen Li10224432021-08-12 11:39:31 -07001239 FileCheck().check_count("add", 1).check_count("mul", 2, do_exactly) \
1240 .check_count("tanh", 1, do_exactly).check_count("add", 2, do_exactly).check_next("return") \
1241 .run(str(g))
eellisond8d83712019-02-22 17:54:09 -08001242
James Reedf7825002019-10-26 18:39:23 -07001243 self.assertExportImport(g, (x, y))
Lu Fang0a1ac8b2017-09-14 22:07:59 -07001244
eellisond9027742019-04-23 20:31:36 -07001245 def test_cse_not_introduce_aliasing(self):
1246 @torch.jit.script
1247 def tensor_alias_outputs(x):
1248 return x + x, x + x
1249
Shen Li10224432021-08-12 11:39:31 -07001250 self.run_pass('cse', tensor_alias_outputs.graph)
eellisond9027742019-04-23 20:31:36 -07001251 FileCheck().check_count("aten::add", 2).run(tensor_alias_outputs.graph)
1252
1253 @torch.jit.script
1254 def ints_alias_outputs(x):
1255 # type: (int) -> Tuple[int, int]
1256 return x + x, x + x
1257
1258 # non-aliasing types can be CSEd
Shen Li10224432021-08-12 11:39:31 -07001259 self.run_pass('cse', ints_alias_outputs.graph)
1260 FileCheck().check_count("aten::add", 1, exactly=True).run(ints_alias_outputs.graph)
eellisond9027742019-04-23 20:31:36 -07001261
Owen Andersonabf85bf2018-08-16 00:10:34 -07001262 def test_recursive_cse(self):
Mikhail Zolotukhin943f7122019-04-08 12:22:52 -07001263 input_str = """
1264graph(%x : Tensor,
eellisond9027742019-04-23 20:31:36 -07001265 %y : Tensor,
1266 %20 : int):
Mikhail Zolotukhin943f7122019-04-08 12:22:52 -07001267 %2 : int = prim::Constant[value=1]()
1268 %3 : Tensor = aten::add(%x, %y, %2)
eellisond9027742019-04-23 20:31:36 -07001269 %4 : int = aten::add(%2, %20)
Wanchao Liang799633e2019-07-03 22:14:14 -07001270 %5 : bool = aten::Bool(%4)
eellisond9027742019-04-23 20:31:36 -07001271 %z : int = prim::If(%5)
Mikhail Zolotukhin943f7122019-04-08 12:22:52 -07001272 # CHECK: block
1273 block0():
1274 # CHECK-NOT: aten::add
eellisond9027742019-04-23 20:31:36 -07001275 %z.1 : int = aten::add(%2, %20)
Mikhail Zolotukhin943f7122019-04-08 12:22:52 -07001276 -> (%z.1)
1277 block1():
eellisond9027742019-04-23 20:31:36 -07001278 -> (%2)
Mikhail Zolotukhin943f7122019-04-08 12:22:52 -07001279 return (%z)
1280"""
1281 graph = parse_ir(input_str)
Shen Li10224432021-08-12 11:39:31 -07001282 self.run_pass('cse', graph)
Mikhail Zolotukhin943f7122019-04-08 12:22:52 -07001283 FileCheck().run(input_str, graph)
Edward Z. Yang711e5a62018-06-15 17:52:21 -04001284
Mikhail Zolotukhin8a6072c2019-05-08 11:18:52 -07001285 def test_pattern_based_rewrite(self):
Mikhail Zolotukhinb3324d02019-05-08 11:45:05 -07001286 # mul(mul(mul(mul(x,y),z),x),y) --> mul(mul(mulmul(x,y,z), x), y) -->
1287 # --> mulmul(mulmul(x,y,z), x, y)
1288 input_str = """
1289graph(%x, %y, %z):
1290 # CHECK-NOT: aten::mul
1291 # CHECK: my::fused_mulmul
1292 %t = aten::mul(%x, %y)
1293 %p = aten::mul(%t, %z)
1294 # CHECK: my::fused_mulmul
1295 %u = aten::mul(%p, %x)
1296 %o = aten::mul(%u, %y)
1297 return (%o)"""
1298 graph = parse_ir(input_str)
Shen Li10224432021-08-12 11:39:31 -07001299 torch._C._jit_pass_custom_pattern_based_rewrite_graph("""
Mikhail Zolotukhin2a95cf62019-04-29 19:14:02 -07001300graph(%a, %b, %c):
1301 %q = aten::mul(%a, %b)
1302 %r = aten::mul(%q, %c)
Shen Li10224432021-08-12 11:39:31 -07001303 return (%r)""", """
Mikhail Zolotukhinc931d7e2019-05-08 11:45:05 -07001304graph(%a, %b, %c):
1305 %r = my::fused_mulmul(%a, %b, %c)
Shen Li10224432021-08-12 11:39:31 -07001306 return (%r)""", graph)
Mikhail Zolotukhinb3324d02019-05-08 11:45:05 -07001307 FileCheck().run(input_str, graph)
Mikhail Zolotukhin2a95cf62019-04-29 19:14:02 -07001308
1309 # Check that overlapping matches are handled correctly
Mikhail Zolotukhinb3324d02019-05-08 11:45:05 -07001310 # mul(mul(mul(x,y),z),x) --> mul(mulmul(x,y,z), x)
1311 input_str = """
1312graph(%x, %y, %z):
1313 # CHECK-NOT: aten::mul
1314 # CHECK: my::fused_mulmul
1315 %t = aten::mul(%x, %y)
1316 %p = aten::mul(%t, %z)
1317 # CHECK-NEXT: aten::mul
1318 %u = aten::mul(%p, %x)
1319 return (%u)"""
1320 graph = parse_ir(input_str)
Shen Li10224432021-08-12 11:39:31 -07001321 torch._C._jit_pass_custom_pattern_based_rewrite_graph("""
Mikhail Zolotukhin2a95cf62019-04-29 19:14:02 -07001322graph(%a, %b, %c):
1323 %q = aten::mul(%a, %b)
1324 %r = aten::mul(%q, %c)
Shen Li10224432021-08-12 11:39:31 -07001325 return (%r)""", """
Mikhail Zolotukhinc931d7e2019-05-08 11:45:05 -07001326graph(%a, %b, %c):
1327 %r = my::fused_mulmul(%a, %b, %c)
Shen Li10224432021-08-12 11:39:31 -07001328 return (%r)""", graph)
Mikhail Zolotukhinb3324d02019-05-08 11:45:05 -07001329 FileCheck().run(input_str, graph)
Mikhail Zolotukhin2a95cf62019-04-29 19:14:02 -07001330
Mikhail Zolotukhinb3324d02019-05-08 11:45:05 -07001331 # Check add(mul(x,y),z) --> muladd(x,y,z) replacement
1332 input_str = """
1333graph(%x, %y, %z):
1334 # CHECK-NOT: aten::mul
1335 # CHECK-NOT: aten::add
1336 %c = prim::Const[value=1]()
1337 %t = aten::mul(%x, %y)
1338 %p = aten::add(%t, %z, %c)
1339 # CHECK: my::muladd
1340 # CHECK-NEXT: return
1341 return (%p)"""
1342 graph = parse_ir(input_str)
Shen Li10224432021-08-12 11:39:31 -07001343 torch._C._jit_pass_custom_pattern_based_rewrite_graph("""
Mikhail Zolotukhin2a95cf62019-04-29 19:14:02 -07001344graph(%a, %b, %c, %d):
1345 %q = aten::mul(%a, %b)
1346 %r = aten::add(%q, %c, %d)
Shen Li10224432021-08-12 11:39:31 -07001347 return (%r)""", """
Mikhail Zolotukhinc931d7e2019-05-08 11:45:05 -07001348graph(%a, %b, %c, %d):
1349 %r = my::muladd(%a, %b, %c, %d)
Shen Li10224432021-08-12 11:39:31 -07001350 return (%r)""", graph)
Mikhail Zolotukhinc931d7e2019-05-08 11:45:05 -07001351 FileCheck().run(input_str, graph)
1352
1353 # Check add(mul(x,y),z) --> sub(add(x,y),z) replacement
1354 input_str = """
1355graph(%x, %y, %z):
1356 # CHECK-NOT: aten::mul
1357 %c = prim::Const[value=1]()
1358 # CHECK: aten::add
1359 %t = aten::mul(%x, %y)
1360 # CHECK-NEXT: aten::sub
1361 %p = aten::add(%t, %z, %c)
1362 # CHECK-NOT: aten::add
1363 # CHECK-NEXT: return
1364 return (%p)"""
1365 graph = parse_ir(input_str)
Shen Li10224432021-08-12 11:39:31 -07001366 torch._C._jit_pass_custom_pattern_based_rewrite_graph("""
Mikhail Zolotukhinc931d7e2019-05-08 11:45:05 -07001367graph(%a, %b, %c, %d):
1368 %q = aten::mul(%a, %b)
1369 %r = aten::add(%q, %c, %d)
Shen Li10224432021-08-12 11:39:31 -07001370 return (%r)""", """
Mikhail Zolotukhinc931d7e2019-05-08 11:45:05 -07001371graph(%a, %b, %c, %d):
1372 %q = aten::add(%a, %b, %d)
1373 %r = aten::sub(%q, %c, %d)
Shen Li10224432021-08-12 11:39:31 -07001374 return (%r)""", graph)
Mikhail Zolotukhinc931d7e2019-05-08 11:45:05 -07001375 FileCheck().run(input_str, graph)
1376
1377 # Check mul(x,y) --> x replacement
1378 input_str = """
1379graph(%x, %y, %z):
1380 %c = prim::Const[value=1]()
1381 # CHECK-NOT: aten::mul
1382 %t = aten::mul(%x, %y)
1383 # CHECK: aten::add(%x, %z
1384 %p = aten::add(%t, %z, %c)
1385 # CHECK-NEXT: return
1386 return (%p)"""
1387 graph = parse_ir(input_str)
Shen Li10224432021-08-12 11:39:31 -07001388 torch._C._jit_pass_custom_pattern_based_rewrite_graph("""
Mikhail Zolotukhinc931d7e2019-05-08 11:45:05 -07001389graph(%Pa, %Pb):
1390 %Pq = aten::mul(%Pa, %Pb)
Shen Li10224432021-08-12 11:39:31 -07001391 return (%Pq)""", """
Mikhail Zolotukhinc931d7e2019-05-08 11:45:05 -07001392graph(%Ra, %Rb):
Shen Li10224432021-08-12 11:39:31 -07001393 return (%Ra)""", graph)
Mikhail Zolotukhinb3324d02019-05-08 11:45:05 -07001394 FileCheck().run(input_str, graph)
Mikhail Zolotukhin2a95cf62019-04-29 19:14:02 -07001395
Mikhail Zolotukhin85bca162019-08-23 21:13:34 -07001396 @_tmp_donotuse_dont_inline_everything
1397 def test_pattern_based_module_rewrite(self):
1398 # Check match::module behavior
1399 class Test(torch.nn.Module):
1400 def __init__(self):
Xuehai Pan046e88a2023-02-12 22:20:50 +00001401 super().__init__()
Mikhail Zolotukhin85bca162019-08-23 21:13:34 -07001402 self.conv = torch.nn.Conv2d(1, 20, 5, 1)
1403 self.bn = torch.nn.BatchNorm2d(num_features=20)
1404
1405 def forward(self, x):
1406 x = self.conv(x)
1407 x = self.bn(x)
1408 return x
1409 m = torch.jit.script(Test())
Shen Li10224432021-08-12 11:39:31 -07001410 torch._C._jit_pass_custom_pattern_based_rewrite_graph("""
Mikhail Zolotukhin85bca162019-08-23 21:13:34 -07001411 graph(%self, %x):
1412 %conv = match::module[name="Conv2d"](%self)
1413 %y = prim::CallMethod[name="forward"](%conv, %x)
1414 %bn = match::module[name="BatchNorm2d"](%self)
1415 %z = prim::CallMethod[name="forward"](%bn, %y)
Shen Li10224432021-08-12 11:39:31 -07001416 return (%z)""", """
Mikhail Zolotukhin85bca162019-08-23 21:13:34 -07001417 graph(%self, %x):
1418 %z = my::matched_conv_bn(%self, %x)
Shen Li10224432021-08-12 11:39:31 -07001419 return (%z)""", m._c._get_method("forward").graph)
Mikhail Zolotukhin85bca162019-08-23 21:13:34 -07001420
1421 FileCheck().check("my::matched_conv_bn").run(m._c._get_method("forward").graph)
1422
Kimish Patele0676752021-05-25 09:17:27 -07001423 def test_pattern_based_rewrite_with_source_range_preserved(self):
1424 class TestModule1(torch.nn.Module):
Kimish Patele0676752021-05-25 09:17:27 -07001425 def forward(self, x, y, z, w):
1426 x = x + y
1427 x = x * z
1428 return w - x
1429
1430 input_pattern = """
1431 graph(%x, %y, %z, %const):
1432 %t = aten::add(%x, %y, %const)
1433 %o = aten::mul(%t, %z)
1434 return (%o)"""
1435 replacement_pattern = """
1436 graph(%x, %y, %z, %const):
1437 %o = my::add_mul(%x, %y, %z, %const)
1438 return (%o)"""
1439 scripted_model = torch.jit.script(TestModule1())
1440 graph = scripted_model.graph
1441 value_mappings = [("o", "t")]
1442 for node in graph.nodes():
1443 if node.kind() == "aten::add":
1444 source_range_1 = node.sourceRange()
1445 torch._C._jit_pass_custom_pattern_based_rewrite_graph(
Shen Li10224432021-08-12 11:39:31 -07001446 input_pattern, replacement_pattern, scripted_model.graph, value_name_pairs=value_mappings)
Kimish Patele0676752021-05-25 09:17:27 -07001447 graph = scripted_model.graph
1448 for node in graph.nodes():
1449 if node.kind() == "my::add_mul":
1450 source_range_2 = node.sourceRange()
1451 self.assertTrue(source_range_1 == source_range_2)
1452
1453 class TestModule2(torch.nn.Module):
Kimish Patele0676752021-05-25 09:17:27 -07001454 def forward(self, x, y, z, w):
1455 x = x + y
1456 x = x + z
1457 x = x * z
1458 x = x * w
1459 return x - 2
1460
1461 # Check source range preservation for two node transforms add -> my_add
1462 input_pattern = """
1463 graph(%x, %y, %const):
1464 %o = aten::add(%x, %y, %const)
1465 return (%o)"""
1466 replacement_pattern = """
1467 graph(%x, %y, %const):
1468 %o = my::add(%x, %y, %const)
1469 return (%o)"""
1470 scripted_model = copy.deepcopy(torch.jit.script(TestModule2()))
1471 graph_copy = scripted_model.graph.copy()
1472 value_mappings = [("o", "o")]
1473 source_range_add_1 = None
1474 for node in graph_copy.nodes():
1475 if source_range_add_1 is None and node.kind() == "aten::add":
1476 source_range_add_1 = node.sourceRange()
1477 if source_range_add_1 is not None and node.kind() == "aten::add":
1478 source_range_add_2 = node.sourceRange()
1479 torch._C._jit_pass_custom_pattern_based_rewrite_graph(
Shen Li10224432021-08-12 11:39:31 -07001480 input_pattern, replacement_pattern, graph_copy, value_name_pairs=value_mappings)
Kimish Patele0676752021-05-25 09:17:27 -07001481 source_range_my_add_1 = None
1482 for node in graph_copy.nodes():
1483 if source_range_my_add_1 is None and node.kind() == "my::add":
1484 source_range_my_add_1 = node.sourceRange()
1485 if source_range_my_add_1 is not None and node.kind() == "my::add":
1486 source_range_my_add_2 = node.sourceRange()
1487 self.assertTrue(source_range_add_1 == source_range_my_add_1)
1488 self.assertTrue(source_range_add_2 == source_range_my_add_2)
1489
1490 # Check source range preservation for add-add -> double_add transform
1491 # fuse nodes
1492 input_pattern = """
1493 graph(%x, %y, %z, %const):
1494 %t = aten::add(%x, %y, %const)
1495 %o = aten::add(%t, %z, %const)
1496 return (%o)"""
1497 replacement_pattern = """
1498 graph(%x, %y, %z, %const):
1499 %o = my::double_add(%x, %y, %z, %const)
1500 return (%o)"""
1501 scripted_model = torch.jit.script(TestModule2())
1502 graph_copy = scripted_model.graph.copy()
1503 value_mappings = [("o", "t")]
1504 source_range_1 = None
1505 source_range_2 = None
1506 for node in graph_copy.nodes():
1507 if node.kind() == "aten::add":
1508 source_range_1 = node.sourceRange()
1509 break
1510 torch._C._jit_pass_custom_pattern_based_rewrite_graph(
Shen Li10224432021-08-12 11:39:31 -07001511 input_pattern, replacement_pattern, graph_copy, value_name_pairs=value_mappings)
Kimish Patele0676752021-05-25 09:17:27 -07001512 for node in graph_copy.nodes():
1513 if node.kind() == "my::double_add":
1514 source_range_2 = node.sourceRange()
1515 self.assertTrue(source_range_1 == source_range_2)
1516
1517 # Check source range preservation for mul -> add + add transform
1518 # split node
1519 input_pattern = """
1520 graph(%x, %y):
1521 %t = aten::mul(%x, %y)
1522 return (%t)"""
1523 replacement_pattern = """
1524 graph(%x, %y):
1525 %t = my::add(%x, %y)
1526 %o = my::add(%t, %y)
1527 return (%o)"""
1528 scripted_model = torch.jit.script(TestModule2())
1529 graph_copy = scripted_model.graph.copy()
1530 value_mappings = [("t", "t"), ("o", "t")]
1531 source_range_mul_1 = None
1532 for node in graph_copy.nodes():
1533 if source_range_mul_1 is None and node.kind() == "aten::mul":
1534 source_range_mul_1 = node.sourceRange()
1535 if source_range_mul_1 is not None and node.kind() == "aten::mul":
1536 source_range_mul_2 = node.sourceRange()
1537 torch._C._jit_pass_custom_pattern_based_rewrite_graph(
Shen Li10224432021-08-12 11:39:31 -07001538 input_pattern, replacement_pattern, graph_copy, value_name_pairs=value_mappings)
Kimish Patele0676752021-05-25 09:17:27 -07001539 source_range_add_1 = None
1540 for node in graph_copy.nodes():
1541 if source_range_add_1 is None and node.kind() == "my::add":
1542 source_range_add_1 = node.sourceRange()
1543 if source_range_add_1 is not None and node.kind() == "my::add":
1544 source_range_add_2 = node.sourceRange()
1545 self.assertTrue(source_range_mul_1 == source_range_add_1)
1546 self.assertTrue(source_range_mul_2 == source_range_add_2)
1547
1548 # Check lack of source range preservation for mul-mul-> double_mul transform
1549 input_pattern = """
1550 graph(%x, %y, %z):
1551 %t = aten::mul(%x, %y)
1552 %o = aten::mul(%t, %z)
1553 return (%o)"""
1554 replacement_pattern = """
1555 graph(%x, %y, %z):
1556 %o = my::double_mul(%x, %y, %z)
1557 return (%o)"""
1558 scripted_model = torch.jit.script(TestModule2())
1559 graph_copy = scripted_model.graph.copy()
1560 for node in graph_copy.nodes():
1561 if node.kind() == "aten::mul":
1562 source_range_1 = node.sourceRange()
Shen Li10224432021-08-12 11:39:31 -07001563 torch._C._jit_pass_custom_pattern_based_rewrite_graph(input_pattern, replacement_pattern, graph_copy)
Kimish Patele0676752021-05-25 09:17:27 -07001564 for node in graph_copy.nodes():
1565 if node.kind() == "my::double_mul":
1566 source_range_2 = node.sourceRange()
1567 self.assertFalse(source_range_1 == source_range_2)
1568
Mikhail Zolotukhin13b95ea2019-03-25 17:39:01 -07001569 def test_expand_quantlint(self):
1570 pass
1571
1572 def test_expand_fold_quant_inputs(self):
1573 pass
1574
Adam Paszkeb45f2ff2018-05-16 20:03:04 +02001575 def test_shape_analysis_broadcast(self):
1576 def broadcast(a, b):
1577 return a + b
Zach DeVitoe91966a2017-08-11 17:53:33 -07001578
Adam Paszkeb45f2ff2018-05-16 20:03:04 +02001579 x = torch.randn(3, 1, 5, requires_grad=True)
1580 y = torch.randn(4, 1, 8, 5, requires_grad=True)
Zach DeVitoe91966a2017-08-11 17:53:33 -07001581
James Reed0b16b032018-07-25 16:55:09 -07001582 graph = torch.jit.script(broadcast).graph
Adam Paszkec8b246a2018-08-26 09:40:58 -07001583 torch._C._jit_pass_complete_shape_analysis(graph, (x, y), False)
Shen Li10224432021-08-12 11:39:31 -07001584 FileCheck().check("Double(4, 3, 8, 5, strides=[120, 40, 5, 1], device=cpu)").run(str(graph))
Zachary DeVito8cc30e42017-10-31 10:44:13 -07001585
Mikhail Zolotukhin50d55b92020-06-30 10:59:30 -07001586 def test_shape_analysis_unsqueeze_in_loop(self):
1587 input_str = """graph(%x.1 : Tensor):
1588 %4 : bool = prim::Constant[value=1]()
1589 %1 : int = prim::Constant[value=2]()
1590 %7 : int = prim::Constant[value=0]()
Thomas Viehmannea087e22021-01-08 19:59:58 -08001591 # CHECK: FloatTensor(requires_grad=0, device=cpu) = prim::Loop
Mikhail Zolotukhin50d55b92020-06-30 10:59:30 -07001592 %x : Tensor = prim::Loop(%1, %4, %x.1)
Thomas Viehmannea087e22021-01-08 19:59:58 -08001593 # CHECK: : FloatTensor(requires_grad=0, device=cpu)):
Mikhail Zolotukhin50d55b92020-06-30 10:59:30 -07001594 block0(%i : int, %x.6 : Tensor):
Thomas Viehmannea087e22021-01-08 19:59:58 -08001595 # CHECK: FloatTensor(requires_grad=0, device=cpu) = aten::unsqueeze
Mikhail Zolotukhin50d55b92020-06-30 10:59:30 -07001596 %x.3 : Tensor = aten::unsqueeze(%x.6, %7)
1597 -> (%4, %x.3)
1598 return (%x)"""
1599 graph = parse_ir(input_str)
Shen Li10224432021-08-12 11:39:31 -07001600 torch._C._jit_pass_complete_shape_analysis(graph, (torch.zeros(2, 2, dtype=torch.float32),), False)
Mikhail Zolotukhin50d55b92020-06-30 10:59:30 -07001601 FileCheck().run(input_str, graph)
1602
nikithamalgi141f6152021-02-09 11:35:12 -08001603 def test_script_tensor_type(self):
1604 def foo(x, t: torch.dtype):
1605 return x.type(t)
1606 scr = torch.jit.script(foo)
1607 x = torch.rand(3, 4)
Shen Li10224432021-08-12 11:39:31 -07001608 for t in [torch.int8, torch.float64, torch.float32,
1609 torch.bfloat16, torch.complex64, torch.complex128, torch.bool]:
nikithamalgi141f6152021-02-09 11:35:12 -08001610 self.assertEqual(scr(x, t), foo(x, t))
1611
Mikhail Zolotukhin871bfaa2020-06-30 10:59:30 -07001612 def test_shape_analysis_masked_select(self):
1613 input_str = """graph(%0 : Float(),
1614 %1 : Bool()):
Mikhail Zolotukhin5d704652020-07-17 10:20:46 -07001615 # CHECK: Float(*, requires_grad=0, device=cpu) = aten::masked_select
Mikhail Zolotukhin871bfaa2020-06-30 10:59:30 -07001616 %2 : Tensor = aten::masked_select(%0, %1) # test/test_jit.py:15261:0
1617 return (%2)"""
1618 graph = parse_ir(input_str)
1619 x = torch.ones(1, dtype=torch.float32)[0]
1620 mask = x.ge(0.5)
1621 torch._C._jit_pass_complete_shape_analysis(graph, (x, mask), False)
1622 FileCheck().run(input_str, graph)
1623
Adam Paszkeb45f2ff2018-05-16 20:03:04 +02001624 # TODO: update verify to work with GraphExecutors
1625 @unittest.skip("verify needs to be updated to work with GraphExecutors")
Edward Z. Yangf7f37302017-10-03 09:11:31 -07001626 def test_verify(self):
Adam Paszkeb45f2ff2018-05-16 20:03:04 +02001627 x = torch.tensor([0.4], requires_grad=True)
1628 y = torch.tensor([0.7], requires_grad=True)
Edward Z. Yangf7f37302017-10-03 09:11:31 -07001629
1630 @torch.jit.compile
1631 def f(x, y):
1632 z = torch.sigmoid(x * (x + y))
1633 w = torch.abs(x * x * x + y) + Variable(torch.ones(1))
1634 return z, w
1635
1636 torch.jit.verify(f, (x, y), loss_fn=lambda z, w: z * w, devices=[])
1637
Adam Paszkeb45f2ff2018-05-16 20:03:04 +02001638 # TODO: adapt to a GraphExecutor test
1639 @unittest.skip("Need to instrument GraphExecutors a bit more")
Adam Paszkeea888c12017-09-01 11:13:37 -07001640 def test_flags(self):
Adam Paszkeb45f2ff2018-05-16 20:03:04 +02001641 x, y = torch.randn(2, 2)
Adam Paszkeea888c12017-09-01 11:13:37 -07001642 y = Variable(torch.randn(2, 2))
1643
Edward Z. Yang0c403052017-09-22 07:42:04 -07001644 @torch.jit.compile
Adam Paszkeea888c12017-09-01 11:13:37 -07001645 def fn(x, y):
1646 return (x * x + y * y + x * y).sum()
1647
1648 grads = {}
1649 for rx, ry in product((True, False), repeat=2):
1650 x.requires_grad = rx
1651 y.requires_grad = ry
1652
1653 self.assertFalse(fn.has_trace_for(x, y))
1654 out = fn(x, y)
1655
1656 self.assertFalse(fn.has_trace_for(x, y))
Shen Li10224432021-08-12 11:39:31 -07001657 for v, name, compute in [(x, 'x', rx), (y, 'y', ry)]:
Adam Paszkeea888c12017-09-01 11:13:37 -07001658 if not compute:
1659 continue
Shen Li10224432021-08-12 11:39:31 -07001660 grad_v, = torch.autograd.grad(out, v, retain_graph=True)
Adam Paszkeea888c12017-09-01 11:13:37 -07001661 expected_grad = grads.setdefault(name, grad_v)
1662 self.assertEqual(grad_v, expected_grad)
1663 self.assertEqual(fn.has_trace_for(x, y), rx or ry)
1664
Zach DeVitoa3fdb282017-08-18 16:56:34 -07001665 def test_python_ir(self):
Adam Paszkeb45f2ff2018-05-16 20:03:04 +02001666 x = torch.tensor([0.4], requires_grad=True)
1667 y = torch.tensor([0.7], requires_grad=True)
Zach DeVitoa3fdb282017-08-18 16:56:34 -07001668
1669 def doit(x, y):
1670 return torch.sigmoid(torch.tanh(x * (x + y)))
1671
James Reed6e38c3b2019-11-05 17:02:40 -08001672 g, _ = torch.jit._get_trace_graph(doit, (x, y))
Shen Li10224432021-08-12 11:39:31 -07001673 self.run_pass('dce', g)
1674 self.run_pass('canonicalize', g)
Zach DeVitoa60d9bd2017-08-19 17:41:27 -07001675 g2 = torch._C.Graph()
1676 g_to_g2 = {}
1677 for node in g.inputs():
1678 g_to_g2[node] = g2.addInput()
1679 for node in g.nodes():
Zach DeVitoef4b19f2017-11-14 03:05:32 -08001680 n_ = g2.createClone(node, lambda x: g_to_g2[x])
1681 g2.appendNode(n_)
1682 for o, no in zip(node.outputs(), n_.outputs()):
1683 g_to_g2[o] = no
Zach DeVitoa60d9bd2017-08-19 17:41:27 -07001684
1685 for node in g.outputs():
1686 g2.registerOutput(g_to_g2[node])
1687
Edward Z. Yangacc40932018-03-16 13:36:11 -04001688 t_node = g2.create("prim::TensorTest").t_("a", torch.ones([2, 2]))
Adam Paszkeb45f2ff2018-05-16 20:03:04 +02001689 self.assertEqual(t_node.attributeNames(), ["a"])
Zach DeVitoa60d9bd2017-08-19 17:41:27 -07001690 g2.appendNode(t_node)
PyTorch MergeBotcba96362022-12-02 21:36:13 +00001691 self.assertTrue(torch.equal(torch.ones(2, 2), t_node.t("a")))
eellisonbd7fcce2019-03-06 13:41:13 -08001692 for node in g.nodes():
1693 self.assertTrue(g2.findNode(node.kind()) is not None)
Zach DeVitoa3fdb282017-08-18 16:56:34 -07001694
max25a6aab2022-05-13 22:18:49 +00001695 def test_permute_inputs_binding(self):
1696 @torch.jit.script
1697 def foo(i, j, k):
1698 pass
1699
1700 g = foo.graph
1701
1702 idxs = []
1703 for i, inp in enumerate(g.inputs()):
1704 inp.setDebugName(f"inp{i}")
1705 idxs.append(i)
1706
1707 permuted_idxs = list(np.random.permutation(idxs))
1708 g.permuteInputs(permuted_idxs)
1709 for i, inp in enumerate(g.inputs()):
1710 self.assertEqual(f"inp{permuted_idxs[i]}", inp.debugName())
1711
Mike Ruberrybb8baea2022-05-09 11:24:55 +00001712 @unittest.skipIf(IS_MACOS, "Failing on MacOS only")
Elias Ellison8bc28e92022-02-24 16:40:59 -08001713 def test_python_ir_utils(self):
1714 @torch.jit.script
1715 def foo(inp):
1716 x = inp + 1
1717 y = x / 2
1718 z = y * y
1719 return z
1720
1721 add_node = foo.graph.findNode("aten::add")
1722 div_node = foo.graph.findNode("aten::div")
1723
1724 with foo.graph.insert_point_guard(add_node):
1725 with foo.graph.insert_point_guard(div_node):
1726 foo.graph.insertConstant("goodbye")
1727 foo.graph.insertConstant("hello")
1728 with foo.graph.insert_point_guard(foo.graph.findNode("aten::mul")):
1729 foo.graph.insertConstant("hello")
1730 FileCheck().check("hello").check("goodbye").check("hello").run(foo.graph)
1731
1732 self.assertTrue(add_node.matches(add_node.schema()))
1733 self.assertFalse(add_node.matches(div_node.schema()))
1734
1735 def test_python_ir_utils_graph(self):
1736 @torch.jit.script
1737 def unrolled_mul(x: torch.Tensor, y: int):
1738 out = x
1739 for _ in range(y - 1):
1740 out = out + x
1741 return out
1742
1743 @torch.jit.script
1744 def foo(x):
1745 return x * 4
1746
1747 g = foo.graph
1748 muls = g.findAllNodes("aten::mul")
1749 scalar_muls = filter(lambda x: x.matches("aten::mul(Tensor self, Scalar other) -> Tensor"), muls)
1750 mul_constant_int = filter(lambda x: isinstance(list(x.inputs())[1].toIValue(), int), scalar_muls)
1751 for mul in mul_constant_int:
1752 with g.insert_point_guard(mul):
1753 outputs = g.insertGraph(unrolled_mul.graph, list(mul.inputs()))
1754 assert len(outputs) == len(list(mul.outputs()))
1755 for new_out, old_out in zip(outputs, g.outputs()):
1756 old_out.replaceAllUsesWith(new_out)
1757 mul.destroy()
1758
1759 FileCheck().check_not("aten::mul").check("aten::add").run(foo.graph)
1760 self.assertEqual(foo(torch.ones([2, 2])), torch.ones([2, 2]) * 4)
1761
Michael Suo0a4117a2019-04-03 22:18:09 -07001762 @unittest.skipIf(IS_SANDCASTLE, "gtest runs these in sandcastle")
1763 @unittest.skipIf(RUN_CUDA, "covered by test_cpp_cuda")
Shen Li10224432021-08-12 11:39:31 -07001764 @unittest.skipIf(not torch._C._jit_has_cpp_tests(), "Tests were not built, use BUILD_TEST=1")
Michael Suo0a4117a2019-04-03 22:18:09 -07001765 def test_cpp(self):
1766 from cpp.jit import tests_setup
1767 tests_setup.setup()
Michael Suo374e9372020-09-18 13:54:03 -07001768 torch._C._jit_run_cpp_tests()
Michael Suo0a4117a2019-04-03 22:18:09 -07001769 tests_setup.shutdown()
1770
Edward Z. Yangf062e062017-08-21 20:46:54 -07001771 def test_batchnorm(self):
Adam Paszkeb45f2ff2018-05-16 20:03:04 +02001772 x = torch.ones(2, 2, 2, 2)
Shen Li10224432021-08-12 11:39:31 -07001773 g, outputs, inputs = torch.jit._get_trace_graph(nn.BatchNorm2d(2), x,
1774 _force_outplace=True, return_inputs=True)
James Reedf7825002019-10-26 18:39:23 -07001775 m = self.createFunctionFromGraph(g)
Elias Ellisone5b4baa2019-02-26 08:11:47 -08001776 self.assertEqual(outputs, m(*inputs))
Edward Z. Yangf062e062017-08-21 20:46:54 -07001777
Edward Z. Yang247d50e2017-10-29 20:24:58 -07001778 def test_dropout(self):
Adam Paszkeb45f2ff2018-05-16 20:03:04 +02001779 x = torch.ones(2, 2)
Elias Ellisone5b4baa2019-02-26 08:11:47 -08001780 with torch.random.fork_rng(devices=[]):
Shen Li10224432021-08-12 11:39:31 -07001781 g, outputs, inputs = torch.jit._get_trace_graph(nn.Dropout(0.6), x, return_inputs=True)
Elias Ellisone5b4baa2019-02-26 08:11:47 -08001782 with torch.random.fork_rng(devices=[]):
James Reedf7825002019-10-26 18:39:23 -07001783 m = self.createFunctionFromGraph(g)
Elias Ellisone5b4baa2019-02-26 08:11:47 -08001784 self.assertEqual(outputs, m(*inputs))
Edward Z. Yangdb3349f2017-09-27 11:16:44 -07001785
jiejca921112021-11-18 19:39:53 -08001786 @unittest.skipIf(not RUN_CUDA, "test requires CUDA")
1787 @unittest.skipIf(GRAPH_EXECUTOR != ProfilingMode.PROFILING, "skip if profiling isn't enabled")
1788 def test_native_dropout_corner_case(self):
1789 with disable_autodiff_subgraph_inlining():
1790 def t(x, p: float, t: bool):
1791 o = torch.dropout(x, p, t)
1792 return o
1793
1794 jit_t = torch.jit.script(t)
1795 x = torch.randn(5).requires_grad_()
1796 FileCheck().check("prim::DifferentiableGraph").run(jit_t.graph_for(x, 1.0, True, profile_and_replay=True))
1797
1798 for train in [True, False]:
1799 for p in [0.0, 1.0]:
1800 for device in ["cuda", "cpu"]:
1801 x = torch.randn(5).to(device=device).requires_grad_()
1802 x_ref = x.detach().requires_grad_()
1803 o = jit_t(x, p, train)
1804 o_ref = t(x_ref, p, train)
1805 o.sum().backward()
1806 o_ref.sum().backward()
1807 assert(o.equal(o_ref))
1808 assert(x.grad.equal(x_ref.grad))
1809
davidriazati60642232019-12-30 11:43:04 -08001810 @slowTest
Shen Li10224432021-08-12 11:39:31 -07001811 @unittest.skipIf(GRAPH_EXECUTOR != ProfilingMode.LEGACY, 'Testing differentiable graph')
Ailing Zhang88751202019-11-12 16:29:35 -08001812 def test_dropout_module_requires_grad(self):
Elias Ellison0e3a05e2020-05-06 11:27:59 -07001813 with enable_profiling_mode_for_profiling_tests():
Ailing Zhang88751202019-11-12 16:29:35 -08001814 class MyModule(torch.nn.Module):
1815 def __init__(self, M):
Xuehai Pan046e88a2023-02-12 22:20:50 +00001816 super().__init__()
Ailing Zhang88751202019-11-12 16:29:35 -08001817 self.dropout = torch.nn.Dropout(0.5)
1818 self.linear = torch.nn.Linear(M, M)
1819
1820 def forward(self, input):
1821 input = self.dropout(input)
1822 output = self.linear(input)
1823 return output
1824
1825 def profile(func, X):
1826 with torch.autograd.profiler.profile() as prof:
1827 func(X)
1828 return [e.name for e in prof.function_events]
1829
1830 M = 1000
1831 scripted = torch.jit.script(MyModule(M))
1832 # To reduce confusion about expected behaviors:
1833 # requires_grad controls whether dropout is symbolically differentiated.
1834 # training controls whether bernoulli_ is called inside symbolic differentiation of dropout.
1835 # * When requires_grad == training, the expected behaviors are obvious.
1836 # * When requires_grad=True and training=False, bernoulli_ might still show up in the graph.
1837 # But it's in a branch that's not called. That's why we have separate checks for autograd
1838 # profiler to make sure it's not run.
1839 # * When requires_grad=False and training=True, bernoulli_ must be run since it's the expected
1840 # behavior for the dropout layer in training mode. It's independent of whether graph requires
1841 # gradient. In fact bernoulli_ comes from autograd instead of autodiff in this case.
1842 for training in (True, False):
1843 if training:
1844 scripted.train()
1845 else:
1846 scripted.eval()
1847 for requires_grad in (True, False):
1848 X = torch.randn(M, M, requires_grad=requires_grad)
1849 if requires_grad:
jiejca921112021-11-18 19:39:53 -08001850 FileCheck().check("aten::native_dropout").run(scripted.graph_for(X, profile_and_replay=True))
Shen Li10224432021-08-12 11:39:31 -07001851 self.assertEqual(training, 'aten::bernoulli_' in profile(scripted, X))
Ailing Zhang88751202019-11-12 16:29:35 -08001852
Natalia Gimelshein7d9bbd32021-10-13 19:50:13 -07001853 @unittest.skipIf(GRAPH_EXECUTOR == ProfilingMode.SIMPLE, 'Testing differentiable graph')
Animesh Jain6a586032022-08-26 20:49:43 +00001854 @skipIfTorchDynamo("Torchdynamo cannot correctly handle profiler.profile calls")
Ailing Zhang88751202019-11-12 16:29:35 -08001855 def test_dropout_func_requires_grad(self):
1856 def dropout_training(input):
1857 return F.dropout(input, 0.5, training=True)
1858
1859 def dropout_eval(input):
1860 return F.dropout(input, 0.5, training=False)
1861
1862 def profile(func, X):
1863 with torch.autograd.profiler.profile() as prof:
1864 func(X)
1865 return [e.name for e in prof.function_events]
1866
1867 M = 1000
1868 scripted_training = torch.jit.script(dropout_training)
1869 scripted_eval = torch.jit.script(dropout_eval)
1870 # See comments in test_dropout_module_requires_grad.
Elias Ellisonda3ff5e2020-07-23 14:46:28 -07001871 with disable_autodiff_subgraph_inlining():
1872 for requires_grad in (True, False):
1873 X = torch.randn(M, M, requires_grad=requires_grad)
1874 if requires_grad:
jiejca921112021-11-18 19:39:53 -08001875 FileCheck().check("aten::native_dropout").run(scripted_training.graph_for(X, profile_and_replay=True))
Natalia Gimelshein7d9bbd32021-10-13 19:50:13 -07001876 self.assertIn('aten::bernoulli_', profile(scripted_training, X))
1877 self.assertNotIn('aten::bernoulli_', profile(scripted_eval, X))
Ailing Zhang88751202019-11-12 16:29:35 -08001878
Ailing Zhanga50ba7e2019-03-19 10:20:06 -07001879 @unittest.skipIf(not RUN_CUDA, "test_dropout_cuda require CUDA")
1880 def test_dropout_cuda(self):
1881 # Dropout AD is dispatched to _fused_dropout in CUDA case,
1882 # which is not included in TestJitGeneratedFunctional
Ailing Zhange75f12a2020-07-06 13:07:51 -07001883 def _zero_rate(t):
1884 return torch.true_divide((t == 0).sum(), t.numel())
1885
1886 x = torch.ones(1000, 1000).cuda().requires_grad_()
Ailing Zhanga50ba7e2019-03-19 10:20:06 -07001887
Elias Ellison0e3a05e2020-05-06 11:27:59 -07001888 with enable_profiling_mode_for_profiling_tests():
Nikolay Korovaiko47faee22019-10-29 11:40:04 -07001889 @torch.jit.script
1890 def func(x):
1891 return torch.nn.functional.dropout(x)
Ailing Zhanga50ba7e2019-03-19 10:20:06 -07001892
Nikolay Korovaiko47faee22019-10-29 11:40:04 -07001893 with freeze_rng_state():
1894 out_ref = torch.nn.functional.dropout(x)
1895 grad_ref = torch.autograd.grad(out_ref.sum(), x)
Ailing Zhanga50ba7e2019-03-19 10:20:06 -07001896
Nikolay Korovaiko47faee22019-10-29 11:40:04 -07001897 with freeze_rng_state():
1898 out = func(x)
1899 grad = torch.autograd.grad(out.sum(), x)
Ailing Zhanga50ba7e2019-03-19 10:20:06 -07001900
Ailing Zhange75f12a2020-07-06 13:07:51 -07001901 # TODO(#40882): previously we assert exact matches between eager and JIT result:
1902 # self.assertEqual(out, out_ref)
1903 # self.assertEqual(grad, grad_ref)
1904 # This test was disabled during legacy -> profiling executor transition.
1905 # Currently JIT fused results doesn't match eager result exactly due to some changes merged in between.
1906 # We temporarily only check statstical difference but it should be reverted once the issue is fixed.
1907 self.assertEqual(_zero_rate(out), _zero_rate(out_ref), rtol=1e-3, atol=1e-4)
Shen Li10224432021-08-12 11:39:31 -07001908 self.assertEqual(_zero_rate(grad[0]), _zero_rate(grad_ref[0]), rtol=1e-3, atol=1e-4)
Ailing Zhanga50ba7e2019-03-19 10:20:06 -07001909
Zachary DeVito52005b52020-03-13 18:43:41 -07001910 def test_torch_ops_overloaded(self):
1911 with self.assertRaisesRegex(RuntimeError, "failed to many any schema"):
1912 torch.ops.aten.add("a", 1)
1913 self.assertEqual("ab", torch.ops.aten.add("a", "b"))
1914 a, b = torch.rand(3, 4), torch.rand(3, 4)
1915 self.assertEqual(a + b, torch.ops.aten.add(a, b))
1916 self.assertEqual(a + 1, torch.ops.aten.add(a, 1))
1917
Edward Yangcdf702b2021-08-10 07:13:24 -07001918 def test_torch_ops_kwonly(self):
1919 a, b = torch.rand(3, 4), torch.rand(3, 4)
1920 with self.assertRaisesRegex(RuntimeError, "positional argument"):
1921 torch.ops.aten.add(a, b, 2)
1922 # h/t Chillee for this ambiguous case
1923 self.assertEqual(a.prod(1), torch.ops.aten.prod(a, 1))
1924
nikithamalgi5cd73df2021-04-12 20:33:51 -07001925 def test_torch_complex(self):
1926 def fn(real, img):
1927 return torch.complex(real, img)
1928
1929 def fn_out(real, img, out):
1930 return torch.complex(real, img, out=out)
Shen Li10224432021-08-12 11:39:31 -07001931 self.checkScript(fn, (torch.rand(3, 4), torch.rand(3, 4), ))
1932 self.checkScript(fn, (torch.ones(5, 1, 4), torch.ones(5, 1, 4), ))
1933 self.checkScript(fn, (torch.zeros(1, 6), torch.ones(6, 1), ))
1934 self.checkScript(fn, (torch.zeros(1, 6), torch.zeros(6, 1), ))
1935 self.checkScript(fn, (torch.empty(3, 4), torch.empty(3, 4), ))
nikithamalgi5cd73df2021-04-12 20:33:51 -07001936
1937 real = torch.tensor([1, 2], dtype=torch.float32)
1938 img = torch.tensor([3, 4], dtype=torch.float32)
1939 out = torch.empty([3, 4], dtype=torch.complex64)
Shen Li10224432021-08-12 11:39:31 -07001940 self.checkScript(fn_out, (real, img, out, ))
nikithamalgi5cd73df2021-04-12 20:33:51 -07001941
1942 real = torch.tensor([5, 2], dtype=torch.float64)
1943 img = torch.tensor([3, 4], dtype=torch.float64)
1944 out = torch.empty([5, 2], dtype=torch.complex128)
Shen Li10224432021-08-12 11:39:31 -07001945 self.checkScript(fn_out, (real, img, out, ))
nikithamalgi5cd73df2021-04-12 20:33:51 -07001946
1947 real = torch.ones([1, 2])
1948 img = torch.ones([1, 2])
1949 out = torch.empty([1, 2], dtype=torch.complex128)
Shen Li10224432021-08-12 11:39:31 -07001950 self.checkScript(fn_out, (real, img, out, ))
nikithamalgi5cd73df2021-04-12 20:33:51 -07001951
1952 real = torch.ones([3, 8, 7])
1953 img = torch.ones([3, 8, 7])
1954 out = torch.empty([3, 8, 7], dtype=torch.complex128)
Shen Li10224432021-08-12 11:39:31 -07001955 self.checkScript(fn_out, (real, img, out, ))
nikithamalgi5cd73df2021-04-12 20:33:51 -07001956
1957 real = torch.empty([3, 2, 6])
1958 img = torch.empty([3, 2, 6])
1959 out = torch.empty([3, 2, 6], dtype=torch.complex128)
Shen Li10224432021-08-12 11:39:31 -07001960 self.checkScript(fn_out, (real, img, out, ))
nikithamalgi5cd73df2021-04-12 20:33:51 -07001961
1962 real = torch.zeros([1, 3])
1963 img = torch.empty([3, 1])
1964 out = torch.empty([3, 3], dtype=torch.complex128)
Shen Li10224432021-08-12 11:39:31 -07001965 self.checkScript(fn_out, (real, img, out, ))
nikithamalgi5cd73df2021-04-12 20:33:51 -07001966
1967 real = torch.ones([2, 5])
1968 img = torch.empty([2, 1])
1969 out = torch.empty([2, 5], dtype=torch.complex128)
Shen Li10224432021-08-12 11:39:31 -07001970 self.checkScript(fn_out, (real, img, out, ))
nikithamalgi5cd73df2021-04-12 20:33:51 -07001971
1972 real = torch.ones([2, 5])
1973 img = torch.zeros([2, 1])
1974 out = torch.empty([2, 5], dtype=torch.complex128)
Shen Li10224432021-08-12 11:39:31 -07001975 self.checkScript(fn_out, (real, img, out, ))
nikithamalgi5cd73df2021-04-12 20:33:51 -07001976
Adam Paszke120d7692018-09-11 05:58:11 -07001977 def test_einsum(self):
Heitor Schueroff72ae9242021-05-21 08:35:25 -07001978 def check(fn, jitted, *args):
Shen Li10224432021-08-12 11:39:31 -07001979 self.assertGraphContains(jitted.graph, kind='aten::einsum')
Heitor Schueroff72ae9242021-05-21 08:35:25 -07001980 self.assertEqual(fn(*args), jitted(*args))
1981
1982 def equation_format(x, y):
Shen Li10224432021-08-12 11:39:31 -07001983 return torch.einsum('i,j->ij', (x, y))
Adam Paszke120d7692018-09-11 05:58:11 -07001984
Heitor Schueroff8f658d52021-06-29 13:59:02 -07001985 def equation_format_varargs(x, y):
Shen Li10224432021-08-12 11:39:31 -07001986 return torch.einsum('i,j->ij', x, y)
Heitor Schueroff8f658d52021-06-29 13:59:02 -07001987
Heitor Schueroff72ae9242021-05-21 08:35:25 -07001988 def sublist_format(x, y):
1989 return torch.einsum(x, [0], y, [1], [0, 1])
1990
Philip Meier0973c5a2022-02-24 21:47:38 -08001991 x = make_tensor((5,), dtype=torch.float32, device="cpu")
1992 y = make_tensor((10,), dtype=torch.float32, device="cpu")
Heitor Schueroff72ae9242021-05-21 08:35:25 -07001993
Heitor Schueroff8f658d52021-06-29 13:59:02 -07001994 for fn in [equation_format, equation_format_varargs, sublist_format]:
1995 check(fn, torch.jit.script(fn), x, y)
1996 check(fn, torch.jit.trace(fn, (x, y)), x, y)
Adam Paszke120d7692018-09-11 05:58:11 -07001997
Animesh Jain1d90d6e2022-07-07 18:57:31 +00001998 @skipIfTorchDynamo("TorchDynamo fails with unknown reason")
Yanli Zhao193ac312020-01-22 15:46:32 -08001999 def test_python_ivalue(self):
2000 # Test if pure python object can be hold as IValue and conversion
2001 # between IValue and PyObject are correct
2002 # test for numpy object
2003 py_array = np.arange(15)
2004 ret_py_obj = torch._C._ivalue_debug_python_object(py_array)
2005 self.assertEqual(py_array, ret_py_obj)
2006
2007 # test for function object
2008 ret_py_obj = torch._C._ivalue_debug_python_object(F.relu)
2009 self.assertEqual(F.relu, ret_py_obj)
2010
2011 # test for memory management
2012 # we need to ensure IValue correctly call incref/decref to avoid
2013 # dangling behavior and potential memory leaks during conversions
2014 def test_func_scope_helper(inp):
2015 # create a scope and do the conversion -> ivalue -> pyobject
2016 # this func return a new pyobject that refcount + 1
2017 inp_refcount = sys.getrefcount(inp)
2018 ivalue_holder = torch._C._ivalue_debug_python_object(inp)
2019 self.assertEqual(inp_refcount + 1, sys.getrefcount(ivalue_holder))
2020 return ivalue_holder + 1
2021
2022 test_input = 2200
2023 before_count = sys.getrefcount(test_input)
2024 test_func_scope_helper(test_input)
2025 after_count = sys.getrefcount(test_input)
2026
2027 # after the test_func_scope_helper_call, the refcount of
2028 # test_input should be equal to the original refcount
2029 # otherwise we get either dangling pointer or memory leak!
2030 self.assertEqual(before_count, after_count)
2031
James Reed1f94a6e2018-05-30 15:06:58 -07002032 def test_decompose_addmm(self):
eellisonbd7fcce2019-03-06 13:41:13 -08002033 def does_decompose():
2034 @torch.jit.script
Wanchao Liang4d676d52019-05-08 17:18:47 -07002035 def addmm(mat, mat1, mat2):
2036 a = mat.addmm(mat1, mat2)
2037 b = mat.addmm(mat1, mat2, alpha=1.0, beta=1.0)
eellisonbd7fcce2019-03-06 13:41:13 -08002038 return a + b
James Reed1f94a6e2018-05-30 15:06:58 -07002039
eellisonbd7fcce2019-03-06 13:41:13 -08002040 mat = torch.randn(2, 2)
2041 mat1 = torch.randn(2, 4)
2042 mat2 = torch.randn(4, 2)
James Reed1f94a6e2018-05-30 15:06:58 -07002043
Wanchao Liang4d676d52019-05-08 17:18:47 -07002044 out_ref = addmm(mat, mat1, mat2)
Shen Li10224432021-08-12 11:39:31 -07002045 self.run_pass('decompose_ops', addmm.graph)
Wanchao Liang4d676d52019-05-08 17:18:47 -07002046 out_test = addmm(mat, mat1, mat2)
eellisonbd7fcce2019-03-06 13:41:13 -08002047 self.assertEqual(out_ref, out_test)
2048 FileCheck().check_not("addmm").run(str(addmm.graph))
2049
2050 def doesnt_decompose():
2051 @torch.jit.script
2052 def addmm(mat, mat1, mat2, alpha, beta):
Wanchao Liang4d676d52019-05-08 17:18:47 -07002053 a = mat.addmm(mat1, mat2, alpha=4.20, beta=2.0)
2054 b = mat.addmm(mat1, mat2, alpha=int(alpha), beta=int(beta))
eellisonbd7fcce2019-03-06 13:41:13 -08002055
Wanchao Liang4d676d52019-05-08 17:18:47 -07002056 return a + b
2057
2058 orig = str(addmm.graph)
Shen Li10224432021-08-12 11:39:31 -07002059 self.run_pass('decompose_ops', addmm.graph)
eellisonbd7fcce2019-03-06 13:41:13 -08002060 self.assertTrue(orig == str(addmm.graph))
James Reed1f94a6e2018-05-30 15:06:58 -07002061
Wanchao Liang4d676d52019-05-08 17:18:47 -07002062 does_decompose()
2063 doesnt_decompose()
2064
Wanchao Liang27904392019-08-09 16:42:34 -07002065 @suppress_warnings
2066 def test_sparse_tensors(self):
davidriazati7a370db2019-07-16 12:50:02 -07002067 @torch.jit.ignore
Elias Ellison77f69982018-09-13 08:43:38 -07002068 def get_sparse():
2069 return torch.sparse.FloatTensor(2, 3)
2070
2071 @torch.jit.script
Wanchao Liang27904392019-08-09 16:42:34 -07002072 def test_is_sparse(input):
2073 # type: (Tensor) -> bool
2074 return input.is_sparse
2075
2076 script_out_is_sparse = test_is_sparse(get_sparse())
2077 script_out_is_dense = test_is_sparse(torch.randn(2, 3))
2078 self.assertEqual(script_out_is_sparse, True)
2079 self.assertEqual(script_out_is_dense, False)
2080
2081 def test_basic_sparse(input):
Elias Ellison77f69982018-09-13 08:43:38 -07002082 output = get_sparse()
2083 return output, input
2084
Wanchao Liang27904392019-08-09 16:42:34 -07002085 self.checkScript(test_basic_sparse, (get_sparse(),))
2086 self.checkScript(test_basic_sparse, (torch.tensor([1]),))
Elias Ellison77f69982018-09-13 08:43:38 -07002087
Wanchao Liang5fd32512019-08-23 15:45:53 -07002088 def test_sparse_sum(input):
2089 return torch.sparse.sum(input)
2090
2091 self.checkScript(test_sparse_sum, (get_sparse(),))
2092
2093 def test_sparse_mm(input1, input2):
2094 return torch.sparse.mm(input1, input2)
2095
2096 self.checkScript(test_sparse_mm, (get_sparse(), torch.randn(3, 4)))
2097
2098 def test_sparse_addmm(input, input1, input2):
2099 return torch.sparse.addmm(input, input1, input2)
2100
2101 def test_sparse_addmm_alpha_beta(input, input1, input2):
Christian Puhrsch75955e42021-11-19 19:45:55 -08002102 return torch.sparse.addmm(input, input1, input2, alpha=1.3, beta=1.5)
Wanchao Liang5fd32512019-08-23 15:45:53 -07002103
Shen Li10224432021-08-12 11:39:31 -07002104 self.checkScript(test_sparse_addmm, (torch.randn(2, 4), get_sparse(), torch.randn(3, 4)))
2105 self.checkScript(test_sparse_addmm_alpha_beta, (torch.randn(2, 4), get_sparse(), torch.randn(3, 4)))
Wanchao Liang5fd32512019-08-23 15:45:53 -07002106
Sameer Deshmukh5fb11422021-04-12 10:07:56 -07002107 @suppress_warnings
2108 def test_sparse_csr_tensors(self):
2109 @torch.jit.ignore
2110 def get_sparse_csr():
2111 return torch.randn(3, 3).to_sparse_csr()
2112
2113 @torch.jit.script
2114 def test_is_sparse_csr(input):
2115 # type: (Tensor) -> bool
2116 return input.is_sparse_csr
2117
2118 script_out_is_sparse_csr = test_is_sparse_csr(get_sparse_csr())
2119 script_out_is_dense_csr = test_is_sparse_csr(torch.randn(3, 3))
2120
2121 self.assertEqual(script_out_is_sparse_csr, True)
2122 self.assertEqual(script_out_is_dense_csr, False)
2123
Nikitha Malgi7a60b7d2021-03-01 21:02:38 -08002124 @unittest.skipIf(not RUN_CUDA, "requires CUDA")
2125 def test_device_not_equal(self):
Shen Li10224432021-08-12 11:39:31 -07002126
Nikitha Malgi7a60b7d2021-03-01 21:02:38 -08002127 def compare_device(x: torch.device):
2128 return x != torch.device("cuda:0")
2129
2130 def compare_two_device(x: torch.device, y: torch.device):
2131 return x != y
2132
2133 self.checkScript(compare_device, (torch.device("cuda:0"),))
Shen Li10224432021-08-12 11:39:31 -07002134 self.checkScript(compare_two_device, (torch.device("cuda:0"), torch.device("cuda:1"), ))
Nikitha Malgi7a60b7d2021-03-01 21:02:38 -08002135
Elias Ellisone57cb4a2018-07-30 15:43:29 -07002136 def test_constant_prop_simple(self):
2137 @torch.jit.script
Elias Ellisona386c282019-01-31 15:37:52 -08002138 def constant_prop(input_int):
2139 # type: (int) -> int
Elias Ellisone57cb4a2018-07-30 15:43:29 -07002140 a = 2 * 3
2141 b = a + 2
Elias Ellisona386c282019-01-31 15:37:52 -08002142 return b - input_int
Elias Ellisone57cb4a2018-07-30 15:43:29 -07002143
Elias Ellisona386c282019-01-31 15:37:52 -08002144 out_ref = constant_prop(2)
Shen Li10224432021-08-12 11:39:31 -07002145 self.run_pass('constant_propagation', constant_prop.graph)
Elias Ellisona386c282019-01-31 15:37:52 -08002146 out_test = constant_prop(2)
Elias Ellisone57cb4a2018-07-30 15:43:29 -07002147 self.assertEqual(out_ref, out_test)
Elias Ellisona386c282019-01-31 15:37:52 -08002148 graph_str = str(constant_prop.graph)
2149 self.assertTrue("aten::add" not in graph_str and "aten::mul" not in graph_str)
2150 const = constant_prop.graph.findNode("prim::Constant").output().toIValue()
2151 self.assertEqual(const, 8)
Elias Ellisone57cb4a2018-07-30 15:43:29 -07002152
2153 def test_constant_prop_nested(self):
2154 @torch.jit.script
2155 def constant_prop(a):
2156 b = 2 + 1
David Riazati6f53b4e2018-09-13 11:10:00 -07002157 if bool(a < 2):
Elias Ellisone57cb4a2018-07-30 15:43:29 -07002158 c = b + 2
2159 else:
2160 c = b - 2
2161 return c
Elias Ellisone57cb4a2018-07-30 15:43:29 -07002162 out_ref = constant_prop(torch.tensor(2))
Shen Li10224432021-08-12 11:39:31 -07002163 self.run_pass('constant_propagation', constant_prop.graph)
Elias Ellisone57cb4a2018-07-30 15:43:29 -07002164 out_test = constant_prop(torch.tensor(2))
2165 self.assertEqual(out_ref, out_test)
Elias Ellisona386c282019-01-31 15:37:52 -08002166 if_node = constant_prop.graph.findNode("prim::If")
2167 for block in if_node.blocks():
2168 for node in block.nodes():
2169 self.assertTrue(node.kind() == "prim::Constant")
Elias Ellisone57cb4a2018-07-30 15:43:29 -07002170
2171 def test_constant_prop_print(self):
2172 @torch.jit.script
2173 def constant_prop(input_tensor):
Richard Zou67f6f932018-08-27 08:53:56 -07002174 a = 2 * 3
Elias Ellisone57cb4a2018-07-30 15:43:29 -07002175 print(a)
2176 b = a + 2
2177 return b + input_tensor
2178
Shen Li10224432021-08-12 11:39:31 -07002179 self.run_pass('constant_propagation', constant_prop.graph)
Elias Ellisona386c282019-01-31 15:37:52 -08002180 graph = constant_prop.graph
2181 print_node = graph.findNode("prim::Print")
2182 self.assertTrue(print_node.input().toIValue() == 6)
Elias Ellisone57cb4a2018-07-30 15:43:29 -07002183
2184 def test_constant_prop_rand(self):
2185 @torch.jit.script
2186 def constant_prop():
2187 a = torch.randn([3])
2188 b = a + 2
2189 return b
2190
Shen Li10224432021-08-12 11:39:31 -07002191 self.run_pass('constant_propagation', constant_prop.graph)
Elias Ellisona386c282019-01-31 15:37:52 -08002192 self.assertTrue("aten::randn" in str(constant_prop.graph))
Elias Ellisone57cb4a2018-07-30 15:43:29 -07002193
Elias Ellison7d601712019-01-15 10:56:17 -08002194 def test_constant_prop_none(self):
2195 @torch.jit.script
2196 def typed_none():
2197 # type: () -> Optional[int]
2198 return None
2199
2200 @torch.jit.script
2201 def constant_prop():
2202 a = typed_none()
2203 b = typed_none()
Shen Li10224432021-08-12 11:39:31 -07002204 if (a is None and b is None):
Elias Ellison7d601712019-01-15 10:56:17 -08002205 a = 2
2206 else:
2207 a = 1
2208 return a
2209
Shen Li10224432021-08-12 11:39:31 -07002210 self.run_pass('constant_propagation', constant_prop.graph)
Michael Suo755f91b2019-08-19 18:41:08 -07002211 FileCheck().check("prim::Constant").run(constant_prop.graph)
Elias Ellison7d601712019-01-15 10:56:17 -08002212
Elias Ellison87101842019-01-23 17:47:29 -08002213 def test_constant_prop_if_inline(self):
2214 @torch.jit.script
2215 def constant_prop():
2216 cond = True
2217 a = 1
2218 if cond:
2219 a = 1 * 2
2220 else:
2221 a = 1 // 0
2222 return a
2223
2224 # testing that 1 // 0 error is not thrownn
Shen Li10224432021-08-12 11:39:31 -07002225 self.run_pass('constant_propagation', constant_prop.graph)
Elias Ellison87101842019-01-23 17:47:29 -08002226
Elias Ellison44bd63c2019-08-28 15:32:36 -07002227 def test_constant_prop_exception(self):
2228 # checking y = a[4] does not error in constant propagation
2229 def bad_index(x):
2230 # type: (bool)
2231 y = 0
2232 if x:
2233 a = [1, 2, 3]
2234 y = a[4]
2235 return y
2236
2237 self.checkScript(bad_index, (False,))
2238
Elias Ellison3eefc062019-12-09 14:17:58 -08002239 def test_constant_prop_aliasing_type(self):
2240 @torch.jit.script
2241 def foo():
2242 return len([1]), len(torch.tensor([2]))
2243
2244 FileCheck().check_dag("aten::tensor").check_dag("aten::len").run(foo.graph)
2245
2246 @torch.jit.script
2247 def fn():
Elias Ellisond1b8da72020-11-20 11:14:59 -08002248 if 1 == 1:
Elias Ellison3eefc062019-12-09 14:17:58 -08002249 return 1
2250 else:
2251 return 2
2252
2253 FileCheck().check_not("prim::If").run(fn.graph)
2254
Elias Ellisone7bc1662020-01-17 12:27:35 -08002255 def test_unchecked_cast(self):
2256 def test(cond):
2257 # type: (bool)
2258 a = torch.tensor([10])
2259 if cond:
2260 b = None
2261 else:
2262 b = a
2263 if b is not None:
2264 b[0] = 5
2265 return a.int()
2266
2267 self.checkScript(test, (True,))
2268 self.checkScript(test, (False,))
2269
Elias Ellisone57cb4a2018-07-30 15:43:29 -07002270 def test_constant_prop_if_constant(self):
2271 @torch.jit.script
Elias Ellisone1905052018-08-17 09:44:22 -07002272 def constant_prop(a, b):
2273 c0 = 1
2274 c1 = 1
2275 c2 = 1
David Riazati6f53b4e2018-09-13 11:10:00 -07002276 if bool(a): # -> c0, c1
2277 if bool(b): # -> c0
Elias Ellisond1b8da72020-11-20 11:14:59 -08002278 if 1 == 1: # -> c0
Elias Ellisone1905052018-08-17 09:44:22 -07002279 c0 = c0 + 1
Elias Ellisond1b8da72020-11-20 11:14:59 -08002280 if 1 == 2:
Elias Ellisone1905052018-08-17 09:44:22 -07002281 c1 = c1 + 1
2282 c2 = c2 + 1
2283 else: # -> c0, c1
2284 c1 = c1 + 1
2285
Elias Ellisond1b8da72020-11-20 11:14:59 -08002286 if 1 == 1: # inlined
Elias Ellisone1905052018-08-17 09:44:22 -07002287 c0 = c0 + 1 # dynamic
2288 c2 = c2 + 4 # set to 5
2289 return a + c0 + c1 + c2
Elias Ellisone57cb4a2018-07-30 15:43:29 -07002290
Elias Ellisona386c282019-01-31 15:37:52 -08002291 graph = constant_prop.graph
Shen Li10224432021-08-12 11:39:31 -07002292 self.run_pass('constant_propagation', graph)
Elias Ellisona386c282019-01-31 15:37:52 -08002293 ifs = graph.findAllNodes("prim::If", recurse=False)
2294 snd_if_inlined = len(ifs) == 1
2295 self.assertTrue(snd_if_inlined)
2296 first_if = ifs[0]
2297 self.assertTrue(first_if.outputsSize() == 2)
2298 second_if = first_if.findNode("prim::If", recurse=False)
2299 self.assertTrue(second_if.outputsSize() == 1)
2300 self.assertTrue(second_if.findNode("prim::If") is None)
Elias Ellisone57cb4a2018-07-30 15:43:29 -07002301
Elias Ellisone57cb4a2018-07-30 15:43:29 -07002302 def test_constant_prop_loop_constant(self):
2303 @torch.jit.script
Elias Ellison87101842019-01-23 17:47:29 -08002304 def constant_prop(cond, iter):
2305 # type: (bool, int) -> int
Elias Ellisone57cb4a2018-07-30 15:43:29 -07002306 b = 0
2307 while True:
Elias Ellison87101842019-01-23 17:47:29 -08002308 print("stays")
2309 for _ in range(2):
2310 print("stays")
2311 for _ in range(iter):
2312 print("stays")
2313 while cond:
2314 print("stays")
Elias Ellisone57cb4a2018-07-30 15:43:29 -07002315 while False:
Elias Ellison87101842019-01-23 17:47:29 -08002316 print("removed")
2317 for _i in range(0):
2318 print("removed")
2319 for _i in range(-4):
2320 print("removed")
Elias Ellisone57cb4a2018-07-30 15:43:29 -07002321 return b
2322
Shen Li10224432021-08-12 11:39:31 -07002323 self.run_pass('constant_propagation', constant_prop.graph)
Elias Ellison87101842019-01-23 17:47:29 -08002324 graph = canonical(constant_prop.graph)
2325 self.assertTrue(graph.count("removed") == 0)
2326 self.assertTrue(graph.count("stays") == 1) # constant gets pooled
2327 self.assertTrue(graph.count("prim::Print") == 4)
2328
2329 def test_constant_prop_remove_output(self):
2330 @torch.jit.script
2331 def constant_prop(iter):
2332 # type: (int) -> None
2333 a = 1
2334 b = 1
2335 c = 1
2336 for i in range(iter):
Elias Ellisond1b8da72020-11-20 11:14:59 -08002337 if 1 == 2:
Elias Ellison87101842019-01-23 17:47:29 -08002338 a = 10
2339 if i == 5:
2340 b = 2
2341 c = 3
2342 print(a, b, c)
2343
2344 graph = constant_prop.graph
Shen Li10224432021-08-12 11:39:31 -07002345 self.run_pass('constant_propagation', graph)
Elias Ellison87101842019-01-23 17:47:29 -08002346 self.assertTrue(graph.findNode("prim::Loop").outputsSize() == 2)
Elias Ellisone57cb4a2018-07-30 15:43:29 -07002347
Yanan Cao890b52e2020-07-24 11:37:11 -07002348 # TODO(gmagogsfm): Refactor this test to reduce complexity.
Elias Ellisonadf09162020-01-22 12:09:46 -08002349 def test_constant_insertion(self):
Shen Li10224432021-08-12 11:39:31 -07002350 funcs_template = dedent('''
Elias Ellison38d122e2020-01-22 12:09:46 -08002351 def func():
2352 return {constant_constructor}
Shen Li10224432021-08-12 11:39:31 -07002353 ''')
Elias Ellisonadf09162020-01-22 12:09:46 -08002354
Elias Ellison38d122e2020-01-22 12:09:46 -08002355 # constants: primitives: int, double, bool, str, lists of primitives,
2356 # and tuples
2357 def check_constant(constant_constructor):
2358 scope = {}
2359 funcs_str = funcs_template.format(constant_constructor=constant_constructor)
2360 execWrapper(funcs_str, globals(), scope)
2361 cu = torch.jit.CompilationUnit(funcs_str)
2362 f_script = cu.func
Shen Li10224432021-08-12 11:39:31 -07002363 self.run_pass('constant_propagation', f_script.graph)
2364 FileCheck().check_count("prim::Constant", 1, exactly=True).run(f_script.graph)
2365 self.assertEqual(scope['func'](), f_script())
Elias Ellison38d122e2020-01-22 12:09:46 -08002366 imported = self.getExportImportCopy(f_script)
2367 self.assertEqual(imported(), f_script())
Elias Ellisonadf09162020-01-22 12:09:46 -08002368
Shen Li10224432021-08-12 11:39:31 -07002369 constants = ["None", "-.5", "0", "1", "True", "False", "''", "'a'", "'b'", "torch.tensor(1)",
2370 "[True, False]", "[0., .5]", "[torch.tensor(4), torch.tensor(2)]", "[0, 1]", "['0', '1']",
2371 "[True, None]", "[.5, None, .2]"]
Elias Ellison38d122e2020-01-22 12:09:46 -08002372
2373 for type in ["Tensor", "str", "int", "float", "bool"]:
2374 constants.append("torch.jit.annotate(List[ " + type + "], [])")
2375
2376 for constant in constants:
2377 check_constant(constant)
2378
Nikolay Korovaiko0a4a5582020-03-09 16:04:54 -07002379 for key_type in ["str", "int", "float"]:
2380 for value_type in ["Tensor", "bool", "str", "int", "float"]:
Shen Li10224432021-08-12 11:39:31 -07002381 check_constant("torch.jit.annotate(Dict[ " + key_type + ", " + value_type + "], {})")
2382 check_constant("torch.jit.annotate(Dict[ " + key_type + ", Optional[" + value_type + "]], {})")
Nikolay Korovaiko0a4a5582020-03-09 16:04:54 -07002383
Elias Ellison38d122e2020-01-22 12:09:46 -08002384 for i in range(len(constants)):
2385 for j in range(i + 1, len(constants)):
2386 tup_constant = constants[i] + ", " + constants[j]
2387 check_constant(tup_constant)
2388
Elias Ellisonc86655a2020-10-07 17:33:55 -07002389 dict_constants = []
Nikolay Korovaiko0a4a5582020-03-09 16:04:54 -07002390 for i in range(len(constants)):
2391 # check_constant constructs the second dict with another Tensor
2392 # which fails the comparison
Elias Ellisonc86655a2020-10-07 17:33:55 -07002393 if not isinstance(eval(constants[i]), (str, int, float)):
Nikolay Korovaiko0a4a5582020-03-09 16:04:54 -07002394 continue
2395 for j in range(len(constants)):
2396 dict_constant = "{ " + constants[i] + ": " + constants[j] + "}"
2397 check_constant(dict_constant)
Elias Ellisonc86655a2020-10-07 17:33:55 -07002398 dict_constants.append(dict_constant)
2399 constants = constants + dict_constants
Nikolay Korovaiko0a4a5582020-03-09 16:04:54 -07002400
Elias Ellison38d122e2020-01-22 12:09:46 -08002401 # testing node hashing
Shen Li10224432021-08-12 11:39:31 -07002402 funcs_template = dedent('''
Elias Ellison38d122e2020-01-22 12:09:46 -08002403 def func():
2404 print({constant_constructor})
Shen Li10224432021-08-12 11:39:31 -07002405 ''')
Alexander Grund93719442020-10-22 09:42:34 -07002406 single_elem_tuples = ("(" + x + ",)" for x in constants)
Elias Ellison38d122e2020-01-22 12:09:46 -08002407 input_arg = ", ".join(single_elem_tuples)
2408 scope = {}
2409 funcs_str = funcs_template.format(constant_constructor=input_arg)
2410 execWrapper(funcs_str, globals(), scope)
2411 cu = torch.jit.CompilationUnit(funcs_str)
2412 f_script = cu.func
Shen Li10224432021-08-12 11:39:31 -07002413 self.run_pass('constant_propagation', f_script.graph)
Elias Ellison38d122e2020-01-22 12:09:46 -08002414 # prim::None return adds one constant
Shen Li10224432021-08-12 11:39:31 -07002415 self.assertEqual(len(constants) + 1, str(f_script.graph).count("prim::Constant"))
2416 self.run_pass('cse', f_script.graph)
Elias Ellison38d122e2020-01-22 12:09:46 -08002417 # node hashing correctly working, no CSE occurs
Shen Li10224432021-08-12 11:39:31 -07002418 self.assertEqual(len(constants) + 1, str(f_script.graph).count("prim::Constant"))
Elias Ellison38d122e2020-01-22 12:09:46 -08002419
Shen Li10224432021-08-12 11:39:31 -07002420 funcs_template = dedent('''
Elias Ellison38d122e2020-01-22 12:09:46 -08002421 def func():
2422 a = {constant_constructor}
2423 print(a)
2424 b = {constant_constructor}
2425 print(b)
Shen Li10224432021-08-12 11:39:31 -07002426 ''')
Elias Ellison38d122e2020-01-22 12:09:46 -08002427
Nikolay Korovaiko0a4a5582020-03-09 16:04:54 -07002428 # generate dicts with built-in types (excluding torch.Tensor)
2429 xprod = itertools.product(constants, constants)
2430
Nikolay Korovaiko0a4a5582020-03-09 16:04:54 -07002431 # test that equal tuples and dicts correctly work with node hashing
Alexander Grund93719442020-10-22 09:42:34 -07002432 for tup in ("(" + x + ",)" for x in constants):
Elias Ellison38d122e2020-01-22 12:09:46 -08002433 funcs_str = funcs_template.format(constant_constructor=tup)
2434 scope = {}
2435 execWrapper(funcs_str, globals(), scope)
2436 cu = torch.jit.CompilationUnit(funcs_str)
2437 f_script = cu.func
Shen Li10224432021-08-12 11:39:31 -07002438 self.run_pass('constant_propagation_immutable_types', f_script.graph)
Yanan Cao890b52e2020-07-24 11:37:11 -07002439 num_constants = str(f_script.graph).count("prim::Constant")
Shen Li10224432021-08-12 11:39:31 -07002440 self.run_pass('cse', f_script.graph)
2441 FileCheck().check_count("prim::Constant", num_constants, exactly=True).run(f_script.graph)
Elias Ellisonadf09162020-01-22 12:09:46 -08002442
Zachary DeVitoad7936e2018-09-12 12:21:20 -07002443 @unittest.skipIf(not RUN_CUDA, "requires CUDA")
2444 def test_cuda_export_restore(self):
2445 class Sub(torch.jit.ScriptModule):
2446 def __init__(self):
Xuehai Pan046e88a2023-02-12 22:20:50 +00002447 super().__init__()
Zachary DeVitoad7936e2018-09-12 12:21:20 -07002448 self.weight = nn.Parameter(torch.randn(3, 4))
2449
2450 @torch.jit.script_method
2451 def forward(self, thing):
2452 return self.weight + thing
2453
2454 class M(torch.jit.ScriptModule):
2455 def __init__(self):
Xuehai Pan046e88a2023-02-12 22:20:50 +00002456 super().__init__()
Zachary DeVitoad7936e2018-09-12 12:21:20 -07002457 self.mod = Sub()
2458
2459 @torch.jit.script_method
2460 def forward(self, v):
2461 return self.mod(v)
2462 m = M()
2463 m.cuda()
2464 m2 = self.getExportImportCopy(m)
2465 m2.cuda()
2466 input = torch.rand(3, 4).cuda()
2467 self.assertEqual(m(input), m2(input))
2468
davidriazati60642232019-12-30 11:43:04 -08002469 @slowTest
James Reed47c1de22018-09-07 19:38:44 -07002470 def test_export_batchnorm(self):
Shen Li10224432021-08-12 11:39:31 -07002471 for mode in ['eval', 'train']:
James Reed47c1de22018-09-07 19:38:44 -07002472 for clazz in [
Shen Li10224432021-08-12 11:39:31 -07002473 torch.nn.BatchNorm1d(100),
2474 torch.nn.BatchNorm1d(100, affine=False),
2475 torch.nn.BatchNorm2d(100),
2476 torch.nn.BatchNorm2d(100, affine=False)]:
James Reed47c1de22018-09-07 19:38:44 -07002477 getattr(clazz, mode)()
Shen Li10224432021-08-12 11:39:31 -07002478 input = torch.randn(20, 100) if isinstance(clazz, torch.nn.BatchNorm1d) else \
2479 torch.randn(20, 100, 35, 45)
James Reed47c1de22018-09-07 19:38:44 -07002480 traced = torch.jit.trace(clazz, (input,))
2481 imported = self.getExportImportCopy(traced)
Shen Li10224432021-08-12 11:39:31 -07002482 x = torch.randn(20, 100) if isinstance(clazz, torch.nn.BatchNorm1d) else \
2483 torch.randn(20, 100, 35, 45)
James Reed47c1de22018-09-07 19:38:44 -07002484 self.assertEqual(traced(x), imported(x))
2485
2486 def test_export_rnn(self):
2487 for clazz in [nn.RNN(10, 20, 2), nn.GRU(10, 20, 2)]:
2488 class RNNTest(torch.nn.Module):
2489 def __init__(self):
Xuehai Pan046e88a2023-02-12 22:20:50 +00002490 super().__init__()
James Reed47c1de22018-09-07 19:38:44 -07002491 self.rnn = clazz
2492
2493 def forward(self, x, lengths, h0):
2494 packed = torch.nn.utils.rnn.pack_padded_sequence(x, lengths)
2495 out, h = self.rnn(packed, h0)
2496 padded_outs, _ = torch.nn.utils.rnn.pad_packed_sequence(out)
2497 return padded_outs
2498
2499 test = RNNTest()
2500
Shen Li10224432021-08-12 11:39:31 -07002501 traced = torch.jit.trace(test, (torch.randn(5, 3, 10), torch.LongTensor([3, 2, 1]), torch.randn(2, 3, 20)))
James Reed47c1de22018-09-07 19:38:44 -07002502 imported = self.getExportImportCopy(traced)
James Reeddb0b5c72018-10-29 13:53:46 -07002503 # NB: We make sure to pass in a batch with a different max sequence
2504 # length to ensure that the argument stashing for pad_packed works
2505 # properly.
Shen Li10224432021-08-12 11:39:31 -07002506 x, lengths, h0 = torch.randn(7, 4, 10), torch.LongTensor([7, 3, 2, 1]), torch.randn(2, 4, 20)
James Reed47c1de22018-09-07 19:38:44 -07002507 self.assertEqual(traced(x, lengths, h0), imported(x, lengths, h0))
2508
2509 def test_export_lstm(self):
2510 class LSTMTest(torch.nn.Module):
2511 def __init__(self):
Xuehai Pan046e88a2023-02-12 22:20:50 +00002512 super().__init__()
James Reed47c1de22018-09-07 19:38:44 -07002513 self.rnn = nn.LSTM(10, 20, 2)
2514
2515 def forward(self, x, lengths, hiddens):
2516 h0, c0 = hiddens
2517 packed = torch.nn.utils.rnn.pack_padded_sequence(x, lengths)
2518 out, (h, c) = self.rnn(packed, (h0, c0))
2519 padded_outs, _ = torch.nn.utils.rnn.pad_packed_sequence(out)
2520 return padded_outs
2521
2522 test = LSTMTest()
2523
Shen Li10224432021-08-12 11:39:31 -07002524 traced = torch.jit.trace(test, (torch.randn(5, 3, 10),
2525 torch.LongTensor([3, 2, 1]),
2526 (torch.randn(2, 3, 20), torch.randn(2, 3, 20))))
James Reed47c1de22018-09-07 19:38:44 -07002527 imported = self.getExportImportCopy(traced)
Shen Li10224432021-08-12 11:39:31 -07002528 x, lengths, h0, c0 = \
2529 torch.randn(7, 3, 10), torch.LongTensor([7, 5, 2]), torch.randn(2, 3, 20), torch.randn(2, 3, 20)
James Reed47c1de22018-09-07 19:38:44 -07002530 self.assertEqual(traced(x, lengths, (h0, c0)), imported(x, lengths, (h0, c0)))
2531
Lu Fang0c237f12019-04-03 23:14:07 -07002532 def test_unique_state_dict(self):
2533 class MyModule(torch.nn.Module):
2534 def __init__(self):
Xuehai Pan046e88a2023-02-12 22:20:50 +00002535 super().__init__()
Lu Fang0c237f12019-04-03 23:14:07 -07002536 shared_param = torch.nn.Parameter(torch.ones(1))
Shen Li10224432021-08-12 11:39:31 -07002537 self.register_parameter('w1', shared_param)
2538 self.register_parameter('w2', shared_param)
Lu Fang0c237f12019-04-03 23:14:07 -07002539
2540 def forward(self, input):
2541 return input + self.w1 + self.w2
2542
2543 model = MyModule()
2544 unittest.TestCase.assertEqual(
Shen Li10224432021-08-12 11:39:31 -07002545 self, len(torch.jit._unique_state_dict(model, keep_vars=False)), 1)
Lu Fang0c237f12019-04-03 23:14:07 -07002546 unittest.TestCase.assertEqual(
Shen Li10224432021-08-12 11:39:31 -07002547 self, len(torch.jit._unique_state_dict(model, keep_vars=True)), 1)
Lu Fang0c237f12019-04-03 23:14:07 -07002548
Wanchao Liang27d78952020-06-17 17:24:52 -07002549 def test_export_dropout(self):
2550 test = torch.nn.Dropout()
2551 test.eval()
Michael Suobd75fba2019-02-01 14:36:02 -08002552
Wanchao Liang27d78952020-06-17 17:24:52 -07002553 traced = torch.jit.trace(test, (torch.rand(3, 4),), check_trace=False)
2554 imported = self.getExportImportCopy(traced)
2555 x = torch.randn(3, 4)
2556 self.assertEqual(traced(x), imported(x))
Elias Ellison9cbeb0f2020-04-15 17:32:25 -07002557
David Riazatia79f5d72018-09-18 17:36:07 -07002558 def test_pretty_printer(self):
2559 @torch.jit.script
2560 def if_test(a, b):
2561 # FIXME: use 0 instead of a.
2562 # c = 0
2563 c = a
2564 if bool(a < b):
2565 c = b
2566 else:
2567 c = a
2568 return c
2569
2570 @torch.jit.script
2571 def if_one(a, b):
2572 c = b
2573 if bool(a < b):
2574 c = a
2575 return c
2576
2577 @torch.jit.script
2578 def while_test(a, i):
2579 while bool(i < 3):
2580 a *= a
2581 i += 1
2582 return a
2583
2584 @torch.jit.script
2585 def while_if_test(a, b):
2586 c = 0
2587 while bool(a < 10):
2588 a = a + 1
2589 b = b + 1
2590 if bool(a > b):
2591 c = 2
2592 else:
2593 c = 3
David Riazatif0b73ff2018-10-04 15:00:33 -07002594 return a + 1 + c
David Riazatia79f5d72018-09-18 17:36:07 -07002595
2596 @torch.jit.script
2597 def loop_use_test(y):
2598 x = y + 1
2599 z = x + 5
2600 while bool(y < 8):
2601 y += 1
2602 z = x
2603 return x, z
2604
davidriazati7a370db2019-07-16 12:50:02 -07002605 @torch.jit.ignore
David Riazatif0b73ff2018-10-04 15:00:33 -07002606 def python_fn(x):
2607 return x + 10
2608
2609 @torch.jit.script
2610 def python_op_name_test(y):
2611 return python_fn(y)
2612
Zachary DeVito30676bd2018-11-13 16:33:51 -08002613 @torch.jit.script
2614 def empty_int_list_test(y):
2615 x = torch.jit.annotate(List[int], [])
2616 return x[0]
2617
2618 @torch.jit.script
2619 def empty_float_list_test(y):
2620 return [1.0, 2.0, 3.0]
2621
2622 @torch.jit.script
2623 def print_weird_test(y):
2624 print("hi\016")
2625
Zachary DeVito330990d2019-04-25 15:43:53 -07002626 self.assertExpected(if_test.code, "if_test")
2627 self.assertExpected(if_one.code, "if_one")
2628 self.assertExpected(while_test.code, "while_test")
2629 self.assertExpected(while_if_test.code, "while_if_test")
2630 self.assertExpected(loop_use_test.code, "loop_use_test")
2631 self.assertExpected(python_op_name_test.code, "python_op_name_test")
2632 self.assertExpected(empty_int_list_test.code, "empty_int_list_test")
2633 self.assertExpected(empty_float_list_test.code, "empty_float_list_test")
2634 self.assertExpected(print_weird_test.code, "print_weird_test")
Zachary DeVito30676bd2018-11-13 16:33:51 -08002635
2636 def test_cu_escaped_number(self):
Shen Li10224432021-08-12 11:39:31 -07002637 cu = torch.jit.CompilationUnit('''
Zachary DeVito30676bd2018-11-13 16:33:51 -08002638 def foo(a):
2639 print("hi\016")
Shen Li10224432021-08-12 11:39:31 -07002640 ''')
Zachary DeVito330990d2019-04-25 15:43:53 -07002641 self.assertExpected(cu.foo.code)
David Riazatia79f5d72018-09-18 17:36:07 -07002642
Zachary DeVito05731692018-11-15 15:28:56 -08002643 def test_import_method(self):
Michael Suoca1b8eb2020-07-13 16:57:41 -07002644 with torch._jit_internal._disable_emit_hooks():
Michael Suo16aa2352019-07-15 13:05:03 -07002645 class Foo(torch.jit.ScriptModule):
Michael Suo16aa2352019-07-15 13:05:03 -07002646 @torch.jit.script_method
2647 def forward(self, x, y):
2648 return 2 * x + y
2649
2650 foo = Foo()
2651 buffer = io.BytesIO()
2652 torch.jit.save(foo, buffer)
2653
2654 buffer.seek(0)
2655 foo_loaded = torch.jit.load(buffer)
2656 self.assertExpected(foo_loaded.forward.code)
Zachary DeVito05731692018-11-15 15:28:56 -08002657
Zino Benaissa690946c2020-07-09 09:07:35 -07002658 @unittest.skip("temporarily disable the test for fwd compatibility")
2659 def test_non_ascii_string(self):
2660 class Foo(torch.jit.ScriptModule):
2661 def __init__(self):
Xuehai Pan046e88a2023-02-12 22:20:50 +00002662 super().__init__()
Zino Benaissa690946c2020-07-09 09:07:35 -07002663 self.a = "Over \u0e55\u0e57 57"
2664
2665 @torch.jit.script_method
2666 def forward(self, x, y):
2667 return self.a + "hi\xA1"
2668
2669 foo = Foo()
2670 buffer = io.BytesIO()
2671 torch.jit.save(foo, buffer)
2672
2673 buffer.seek(0)
2674 foo_loaded = torch.jit.load(buffer)
2675 self.assertExpected(foo_loaded.forward.code)
2676
David Riazatieb5fdc52018-10-11 20:47:00 -07002677 def test_function_default_values(self):
2678 outer_var = torch.tensor(20)
2679 outer_var2 = torch.tensor(30)
2680 a = torch.tensor(0.5)
2681 b = torch.tensor(10)
2682
2683 @torch.jit.script
2684 def simple_fn(x, a=a, b=b, c=outer_var + outer_var2):
2685 return x + a + b + c
2686
Shen Li10224432021-08-12 11:39:31 -07002687 self.assertEqual(
2688 simple_fn(torch.ones(1)),
2689 torch.ones(1) + 0.5 + 10 + (20 + 30))
David Riazatieb5fdc52018-10-11 20:47:00 -07002690 self.assertEqual(
2691 simple_fn(torch.ones(1), torch.tensor(1), torch.tensor(3), torch.tensor(4)),
Shen Li10224432021-08-12 11:39:31 -07002692 torch.ones(1) + 1 + 3 + 4)
David Riazatieb5fdc52018-10-11 20:47:00 -07002693
2694 outer_c = torch.tensor(9)
2695 outer_flag = torch.tensor(False)
2696
2697 @torch.jit.script
2698 def bool_fn(x, a=outer_c, flag=outer_flag):
2699 if bool(flag):
2700 result = x
2701 else:
2702 result = x + a
2703 return result
2704
David Riazatieb5fdc52018-10-11 20:47:00 -07002705 self.assertEqual(bool_fn(torch.ones(1)), torch.ones(1) + 9)
2706 self.assertEqual(
Shen Li10224432021-08-12 11:39:31 -07002707 bool_fn(torch.ones(1), torch.tensor(1), torch.tensor(True)),
2708 torch.ones(1))
David Riazatieb5fdc52018-10-11 20:47:00 -07002709
2710 @torch.jit.script
Wanchao Liang7ca995c2018-10-26 11:35:12 -07002711 def none_fn(x=None):
2712 # type: (Optional[int]) -> Optional[int]
2713 return x
2714
Wanchao Liang7ca995c2018-10-26 11:35:12 -07002715 self.assertEqual(none_fn(), None)
2716 self.assertEqual(none_fn(1), 1)
2717
2718 @torch.jit.script
David Riazati1e8064d2018-10-21 14:03:48 -07002719 def hints(x, a=0.5, b=10):
David Riazatieb5fdc52018-10-11 20:47:00 -07002720 # type: (Tensor, float, int) -> Tensor
2721 return x + a + b
2722
David Riazatieb5fdc52018-10-11 20:47:00 -07002723 self.assertEqual(hints(torch.ones(1)), torch.ones(1) + 0.5 + 10)
2724
2725 with self.assertRaisesRegex(RuntimeError, "Expected a default value"):
2726
2727 @torch.jit.script
Elias Ellison561037a2019-03-07 09:12:35 -08002728 def hints_bad_types(x, a=10, b=0.5): # noqa: T484
David Riazatieb5fdc52018-10-11 20:47:00 -07002729 # type: (Tensor, float, int) -> Tensor
2730 return x + a + b
Elias Ellisonfb24f7c2019-12-18 20:25:55 -08002731 with self.assertRaisesRegex(RuntimeError, "Expected a default value"):
2732 @torch.jit.script
2733 def bad_no_optional(x=None):
2734 # type: (Dict[str, int]) -> Dict[str, int]
2735 return x
2736
Shen Li10224432021-08-12 11:39:31 -07002737
David Riazatieb5fdc52018-10-11 20:47:00 -07002738 def test_module_default_values(self):
2739 four = torch.tensor(4)
2740
2741 class Test(torch.jit.ScriptModule):
David Riazatieb5fdc52018-10-11 20:47:00 -07002742 @torch.jit.script_method
2743 def forward(self, input, other=four):
2744 return input + other
2745
2746 t = Test()
2747 self.assertEqual(t(torch.ones(1)), torch.ones(1) + 4)
2748
Elias Ellison902c1f92019-11-14 18:26:24 -08002749 def test_mutable_default_values(self):
2750 with self.assertRaisesRegex(Exception, "Mutable default parameters"):
2751 @torch.jit.script
2752 def foo(x=(1, [])):
2753 # type: (Tuple[int, List[Tensor]])
2754 return x
2755
2756 class Test(torch.nn.Module):
Michael Suo62b10722019-12-06 17:48:20 -08002757 def forward(self, input=[]): # noqa: B006
Elias Ellison902c1f92019-11-14 18:26:24 -08002758 return input
2759
2760 with self.assertRaisesRegex(Exception, "Mutable default parameters"):
2761 torch.jit.script(Test())
2762
Animesh Jain1d90d6e2022-07-07 18:57:31 +00002763 @skipIfTorchDynamo("TorchDynamo fails with unknown reason")
David Riazati67271332018-10-24 16:36:56 -07002764 def test_warnings(self):
2765 import warnings
2766
David Riazati67271332018-10-24 16:36:56 -07002767 def fn(x):
2768 if bool(x < 2):
2769 warnings.warn("x is less than 2")
2770 return x
2771
Thomas Viehmann481e7f22020-02-24 10:26:18 -08002772 class M(torch.nn.Module):
2773 def forward(self, x):
2774 if bool(x < 2):
2775 warnings.warn("x is less than 2")
2776 return x
2777
Shen Li10224432021-08-12 11:39:31 -07002778
Thomas Viehmann481e7f22020-02-24 10:26:18 -08002779 scripted_mod = torch.jit.script(M())
davidriazatif172fad2019-06-05 11:03:22 -07002780 scripted_fn = torch.jit.script(fn)
2781
2782 with warnings.catch_warnings(record=True) as warns:
2783 fn(torch.ones(1))
2784
2785 with warnings.catch_warnings(record=True) as script_warns:
2786 scripted_fn(torch.ones(1))
2787
Thomas Viehmann481e7f22020-02-24 10:26:18 -08002788 with warnings.catch_warnings(record=True) as script_mod_warns:
2789 scripted_mod(torch.ones(1))
2790
Alban Desmaisonb14c5942019-11-07 08:32:51 -08002791 self.assertEqual(str(warns[0]), str(script_warns[0]))
Thomas Viehmann481e7f22020-02-24 10:26:18 -08002792 self.assertEqual(len(script_mod_warns), 1)
2793 self.assertEqual(str(warns[0].message), str(script_mod_warns[0].message))
David Riazati67271332018-10-24 16:36:56 -07002794
Richard Zoub34ab432018-12-12 11:32:05 -08002795 def test_no_erroneous_warnings(self):
2796 import warnings
2797
2798 def fn(x):
2799 if bool(x > 0):
Shen Li10224432021-08-12 11:39:31 -07002800 warnings.warn('This should NOT be printed')
Richard Zoub34ab432018-12-12 11:32:05 -08002801 x += 1
2802 return x
2803
2804 with warnings.catch_warnings(record=True) as warns:
2805 fn_script = torch.jit.script(fn)
2806 fn_script(torch.tensor(0))
2807 warns = [str(w.message) for w in warns]
2808 self.assertEqual(len(warns), 0)
2809
Shen Li10224432021-08-12 11:39:31 -07002810 @unittest.skipIf(True, "TODO: re-enable with https://github.com/pytorch/pytorch/pull/29339")
David Riazati692898f2018-12-28 13:52:01 -08002811 def test_torch_load_error(self):
2812 class J(torch.jit.ScriptModule):
David Riazati692898f2018-12-28 13:52:01 -08002813 @torch.jit.script_method
2814 def forward(self, input):
2815 return input + 100
2816
2817 j = J()
peter8d7338e2020-12-29 09:56:01 -08002818 with TemporaryFileName() as fname:
2819 j.save(fname)
David Riazati692898f2018-12-28 13:52:01 -08002820 with self.assertRaisesRegex(RuntimeError, "is a zip"):
peter8d7338e2020-12-29 09:56:01 -08002821 torch.load(fname)
David Riazati692898f2018-12-28 13:52:01 -08002822
davidriazati7a921ba2019-08-30 16:43:45 -07002823 def test_torch_load_zipfile_check(self):
2824 @torch.jit.script
2825 def fn(x):
2826 return x + 10
2827
peter8d7338e2020-12-29 09:56:01 -08002828 with TemporaryFileName() as fname:
2829 fn.save(fname)
Shen Li10224432021-08-12 11:39:31 -07002830 with io.open(fname, 'rb') as f:
peter8d7338e2020-12-29 09:56:01 -08002831 self.assertTrue(torch.serialization._is_zipfile(f))
davidriazati7a921ba2019-08-30 16:43:45 -07002832
Andras Tantosf3a860b2019-03-12 08:46:16 -07002833 def test_python_bindings(self):
2834 lstm_cell = torch.jit.script(LSTMCellS)
2835
2836 def lstm(x, hx, cx, w_ih, w_hh, b_ih, b_hh):
2837 for i in range(x.size(0)):
2838 hx, cx = lstm_cell(x[i], hx, cx, w_ih, w_hh, b_ih, b_hh)
2839 return hx
2840
2841 slstm = torch.jit.script(lstm)
2842
Shen Li10224432021-08-12 11:39:31 -07002843 inputs = get_lstm_inputs('cpu', training=True, seq_length=10)
Andras Tantosf3a860b2019-03-12 08:46:16 -07002844 slstm(*inputs).sum().backward()
2845 global fw_graph
2846 fw_graph = slstm.graph_for(*inputs)
Hong Xua6a72ac2020-02-21 08:29:32 -08002847 nodes = list(fw_graph.nodes())
Andras Tantosf3a860b2019-03-12 08:46:16 -07002848 tested_blocks = False
2849 for node in nodes:
Hong Xua6a72ac2020-02-21 08:29:32 -08002850 for output in node.outputs():
Shen Li10224432021-08-12 11:39:31 -07002851 self.assertTrue(hasattr(output, 'type'))
Andras Tantosf3a860b2019-03-12 08:46:16 -07002852 self.assertTrue(output.type() is not None)
Hong Xua6a72ac2020-02-21 08:29:32 -08002853 for input in node.inputs():
Shen Li10224432021-08-12 11:39:31 -07002854 self.assertTrue(hasattr(input, 'type'))
Andras Tantosf3a860b2019-03-12 08:46:16 -07002855 self.assertTrue(input.type() is not None)
Hong Xua6a72ac2020-02-21 08:29:32 -08002856 for block in node.blocks():
Andras Tantosf3a860b2019-03-12 08:46:16 -07002857 tested_blocks = True
Shen Li10224432021-08-12 11:39:31 -07002858 self.assertTrue(hasattr(block, 'inputs'))
2859 self.assertTrue(hasattr(block, 'outputs'))
Hong Xua6a72ac2020-02-21 08:29:32 -08002860 for output in block.outputs():
Shen Li10224432021-08-12 11:39:31 -07002861 self.assertTrue(hasattr(output, 'type'))
Andras Tantosf3a860b2019-03-12 08:46:16 -07002862 self.assertTrue(output.type() is not None)
Hong Xua6a72ac2020-02-21 08:29:32 -08002863 for input in block.inputs():
Shen Li10224432021-08-12 11:39:31 -07002864 self.assertTrue(hasattr(input, 'type'))
Andras Tantosf3a860b2019-03-12 08:46:16 -07002865 self.assertTrue(input.type() is not None)
Shen Li10224432021-08-12 11:39:31 -07002866 self.assertTrue(hasattr(block, 'returnNode'))
Andras Tantosf3a860b2019-03-12 08:46:16 -07002867 self.assertTrue(type(block.returnNode()) == torch._C.Node)
Shen Li10224432021-08-12 11:39:31 -07002868 self.assertTrue(hasattr(block, 'paramNode'))
Andras Tantosf3a860b2019-03-12 08:46:16 -07002869 self.assertTrue(type(block.paramNode()) == torch._C.Node)
2870 self.assertTrue(tested_blocks)
2871
Martin Yuan11854bc2019-12-20 13:37:14 -08002872 def test_export_opnames(self):
2873 class Foo(torch.jit.ScriptModule):
Martin Yuan11854bc2019-12-20 13:37:14 -08002874 def one(self, x, y):
2875 # type: (Tensor, Tensor) -> Tensor
2876 return x + y
2877
2878 def two(self, x):
2879 # type: (Tensor) -> Tensor
2880 return 2 * x
2881
2882 @torch.jit.script_method
2883 def forward(self, x):
2884 # type: (Tensor) -> Tensor
2885 return self.one(self.two(x), x)
2886
2887 class Bar(torch.jit.ScriptModule):
2888 def __init__(self):
Xuehai Pan046e88a2023-02-12 22:20:50 +00002889 super().__init__()
Martin Yuan11854bc2019-12-20 13:37:14 -08002890 self.sub = Foo()
2891
Martin Yuanda4e68f2020-03-26 22:43:27 -07002892 @torch.jit.script_method
Martin Yuan11854bc2019-12-20 13:37:14 -08002893 def forward(self, x):
2894 # type: (Tensor) -> Tensor
2895 return self.sub.forward(x)
2896
2897 bar = Bar()
2898 ops = torch.jit.export_opnames(bar)
Shen Li10224432021-08-12 11:39:31 -07002899 expected = ['aten::add.Tensor', 'aten::mul.Scalar']
Martin Yuanda4e68f2020-03-26 22:43:27 -07002900 self.assertTrue(set(expected).issubset(set(ops)))
Martin Yuan11854bc2019-12-20 13:37:14 -08002901
davidriazati877b7c12019-05-06 17:37:16 -07002902 def test_pytorch_jit_env_off(self):
2903 import subprocess
2904 env = os.environ.copy()
Shen Li10224432021-08-12 11:39:31 -07002905 env['PYTORCH_JIT'] = '0'
davidriazati877b7c12019-05-06 17:37:16 -07002906 try:
Shen Li10224432021-08-12 11:39:31 -07002907 subprocess.check_output([sys.executable, '-c', 'import torch'], env=env)
davidriazati877b7c12019-05-06 17:37:16 -07002908 except subprocess.CalledProcessError as e:
Akihiro Nitta84949672020-09-14 14:15:37 -07002909 raise RuntimeError("Could not 'import torch' with PYTORCH_JIT=0") from e
davidriazati877b7c12019-05-06 17:37:16 -07002910
Ailing Zhangf363a332019-06-18 08:43:29 -07002911 def test_print_op_module(self):
2912 # Issue #19351: python2 and python3 go through different paths.
2913 # python2 returns '<module 'torch.ops' (built-in)>'
2914 # python3 uses __file__ and return
2915 # '<module 'torch.ops' from '/scratch/ailzhang/pytorch/torch/_ops.py'>'
2916 s = str(torch.ops)
Shen Li10224432021-08-12 11:39:31 -07002917 self.assertRegex(s, r'ops')
Adam Paszkea58f2d22018-03-22 16:58:36 +01002918
Yida Wang4ea6a3a2021-08-11 09:36:49 -07002919 def test_print_classes_module(self):
2920 s = str(torch.classes)
Shen Li10224432021-08-12 11:39:31 -07002921 self.assertRegex(s, r'classes')
Yida Wang4ea6a3a2021-08-11 09:36:49 -07002922
2923 def test_print_torch_ops_modules(self):
2924 s = str(torch._ops.ops.quantized)
Shen Li10224432021-08-12 11:39:31 -07002925 self.assertRegex(s, r'torch.ops')
Yida Wang4ea6a3a2021-08-11 09:36:49 -07002926 s = str(torch._ops.ops.atan)
Shen Li10224432021-08-12 11:39:31 -07002927 self.assertRegex(s, r'torch.ops')
Yida Wang4ea6a3a2021-08-11 09:36:49 -07002928
Shen Li10224432021-08-12 11:39:31 -07002929 @unittest.skipIf(IS_WINDOWS, 'TODO: fix occasional windows failure')
Ilia Cherniavskii235f6242020-05-19 15:46:56 -07002930 def test_profiler(self):
2931 prev_opt = torch._C._get_graph_executor_optimize()
2932 torch._C._set_graph_executor_optimize(False)
2933
2934 def other_fn(x):
2935 return x * 2
2936
2937 x = torch.rand(3, 4)
2938 traced_other_fn = torch.jit.trace(other_fn, x)
2939
2940 def fn(x):
2941 y = traced_other_fn(x)
2942 fut = torch.jit._fork(traced_other_fn, x)
2943 y = torch.jit._wait(fut)
2944 return y
2945
2946 traced_fn = torch.jit.trace(fn, x)
2947 with torch.autograd.profiler.profile() as prof:
2948 traced_fn(x)
2949
2950 # expecting to see other_fn TS function call
2951 # with cpu time >= mul cpu time and
2952 # a forked other_fn
2953
2954 mul_events = defaultdict(int)
2955 other_fn_events = defaultdict(int)
2956 for e in prof.function_events:
Ilia Cherniavskiie7a09b42020-07-17 22:18:35 -07002957 if e.name == "aten::mul":
Ilia Cherniavskii235f6242020-05-19 15:46:56 -07002958 self.assertTrue(e.thread not in mul_events)
Ilia Cherniavskiif7a8bf22020-11-25 04:30:15 -08002959 mul_events[e.thread] = e.time_range.elapsed_us()
Ilia Cherniavskii235f6242020-05-19 15:46:56 -07002960 elif e.name == "other_fn":
2961 self.assertTrue(e.thread not in other_fn_events)
Ilia Cherniavskiif7a8bf22020-11-25 04:30:15 -08002962 other_fn_events[e.thread] = e.time_range.elapsed_us()
Ilia Cherniavskii235f6242020-05-19 15:46:56 -07002963
2964 self.assertTrue(len(mul_events) == 2)
2965 self.assertTrue(len(other_fn_events) == 2)
2966
2967 for thread, mul_time in mul_events.items():
2968 self.assertTrue(thread in other_fn_events)
2969 self.assertTrue(other_fn_events[thread] >= mul_time)
2970
2971 torch._C._set_graph_executor_optimize(prev_opt)
2972
Joel Schlosser6557ea02021-03-04 09:08:46 -08002973 def test_hide_source_ranges_context_manager(self):
2974 @torch.jit.script
2975 def foo(x):
2976 return torch.add(x, x)
2977
2978 graph = foo.graph
2979 source_range_regex = "# .*\\.py"
2980 self.assertRegex(graph.__repr__(), source_range_regex)
2981 with torch.jit._hide_source_ranges():
2982 self.assertNotRegex(graph.__repr__(), source_range_regex)
2983 self.assertRegex(graph.str(print_source_ranges=True), source_range_regex)
2984 self.assertRegex(graph.__repr__(), source_range_regex)
2985
Ilia Cherniavskii235f6242020-05-19 15:46:56 -07002986
Jason Anselae57bd62023-02-14 19:06:50 +00002987@skipIfTorchDynamo()
davidriazati43c4b9f2019-08-28 17:12:28 -07002988class TestFrontend(JitTestCase):
Shen Li10224432021-08-12 11:39:31 -07002989
davidriazati43c4b9f2019-08-28 17:12:28 -07002990 def test_instancing_error(self):
2991 @torch.jit.ignore
Aaron Gokaslan8fce9a02023-02-07 21:10:52 +00002992 class MyScriptClass:
davidriazati43c4b9f2019-08-28 17:12:28 -07002993 def unscriptable(self):
2994 return "a" + 200
2995
Shen Li10224432021-08-12 11:39:31 -07002996
davidriazati43c4b9f2019-08-28 17:12:28 -07002997 class TestModule(torch.nn.Module):
davidriazati43c4b9f2019-08-28 17:12:28 -07002998 def forward(self, x):
2999 return MyScriptClass()
3000
3001 with self.assertRaises(torch.jit.frontend.FrontendError) as cm:
3002 torch.jit.script(TestModule())
3003
3004 checker = FileCheck()
3005 checker.check("Cannot instantiate class")
3006 checker.check("def forward")
3007 checker.run(str(cm.exception))
3008
tangleintel7980ed92022-10-15 05:33:07 +00003009 def test_dictionary_as_example_inputs_for_jit_trace(self):
3010 class TestModule_v1(torch.nn.Module):
tangleintel7980ed92022-10-15 05:33:07 +00003011 def forward(self, key2=None, key3=None, key4=None, key5=None, key1=None, key6=None):
3012 return key1 + key2 + key3
3013
3014 class TestModule_v2(torch.nn.Module):
tangleintel7980ed92022-10-15 05:33:07 +00003015 def forward(self, x, y):
3016 return x + y
3017
3018 def test_func(x, y):
3019 return x + y
3020 model_1 = TestModule_v1()
3021 model_2 = TestModule_v2()
3022 value1 = torch.ones(1)
3023 value2 = torch.ones(1)
3024 value3 = torch.ones(1)
3025 example_input_dict = {'key1': value1, 'key2': value2, 'key3': value3}
3026 example_input_dict_func = {'x': value1, 'y': value2}
3027 traced_model_1 = torch.jit.trace(model_1, example_kwarg_inputs=example_input_dict, strict=False)
3028 traced_model_1_m = torch.jit.trace_module(
3029 model_1, {'forward': example_input_dict}, example_inputs_is_kwarg=True, strict=False)
3030 traced_model_2 = torch.jit.trace(model_2, example_kwarg_inputs={'x': torch.rand([2]), 'y': torch.rand([2])})
3031 traced_func = torch.jit.trace(test_func, example_kwarg_inputs=example_input_dict_func, strict=False)
3032 res_1 = traced_model_1(**example_input_dict)
3033 res_1_m = traced_model_1_m(**example_input_dict)
3034 self.assertEqual(res_1, 3 * torch.ones(1))
3035 self.assertEqual(res_1_m, 3 * torch.ones(1))
3036 res_func = traced_func(**example_input_dict_func)
3037 self.assertEqual(res_func, 2 * torch.ones(1))
3038 with self.assertRaisesRegex(RuntimeError, r"forward\(\) is missing value for argument 'x'."):
3039 res_2 = traced_model_2(**{'z': torch.rand([2]), 'y': torch.rand([2])})
3040 with self.assertRaisesRegex(RuntimeError, r"forward\(\) is missing value for argument 'y'."):
3041 res_2 = traced_model_2(**{'x': torch.rand([2]), 'z': torch.rand([2])})
3042
davidriazati43c4b9f2019-08-28 17:12:28 -07003043
Jason Anselae57bd62023-02-14 19:06:50 +00003044@skipIfTorchDynamo()
Adam Paszkef45a3d52018-06-06 09:36:12 +02003045class TestScript(JitTestCase):
Yanan Cao75ee5752021-02-03 02:04:28 -08003046
3047 # Tests that calling torch.jit.script repeated on function is allowed.
3048 def test_repeated_script_on_function(self):
3049 @torch.jit.script
3050 @torch.jit.script
3051 def fn(x):
3052 return x
3053
3054 torch.jit.script(torch.jit.script(fn))
3055
Michael Suo92b90892020-04-28 21:27:59 -07003056 def test_pretty_print_function(self):
3057 @torch.jit.script
3058 def foo(x):
3059 return torch.nn.functional.interpolate(x)
3060
3061 FileCheck().check("interpolate").run(foo.code)
3062
Michael Suo416413d2020-02-19 15:39:18 -08003063 def test_inlined_graph(self):
3064 """
3065 Check that the `inlined_graph` property correctly returns an inlined
3066 graph, both through function calls and method calls.
3067 """
3068 @torch.jit.script
3069 def foo(x):
3070 return torch.add(x, x)
3071
3072 class MyNestedMod(torch.nn.Module):
Michael Suo416413d2020-02-19 15:39:18 -08003073 def forward(self, x):
3074 return torch.sub(x, x)
3075
Shen Li10224432021-08-12 11:39:31 -07003076
Michael Suo416413d2020-02-19 15:39:18 -08003077 class MyMod(torch.nn.Module):
3078 def __init__(self):
Xuehai Pan046e88a2023-02-12 22:20:50 +00003079 super().__init__()
Michael Suo416413d2020-02-19 15:39:18 -08003080 self.nested = MyNestedMod()
3081
3082 def forward(self, x):
3083 x = self.nested(x) # sub
3084 x = foo(x) # add
3085 return torch.mul(x, x)
3086
3087 m = torch.jit.script(MyMod())
Shen Li10224432021-08-12 11:39:31 -07003088 FileCheck().check("aten::sub") \
3089 .check("aten::add") \
3090 .check("aten::mul") \
3091 .run(m.inlined_graph)
Michael Suo416413d2020-02-19 15:39:18 -08003092
Michael Voznesensky960f4b52020-05-14 21:10:14 -07003093 def test_static_method_on_module(self):
3094 """
3095 Check that the `@staticmethod` annotation on a function on a module works.
3096 """
3097 class MyCell(torch.nn.Module):
Michael Voznesensky960f4b52020-05-14 21:10:14 -07003098 @staticmethod
3099 def do_it(x, h):
3100 new_h = torch.tanh(x + h)
3101 return new_h, new_h
3102
3103 def forward(self, x, h):
3104 return self.do_it(x, h)
3105
3106 my_cell = torch.jit.script(MyCell())
3107 x = torch.rand(3, 4)
3108 h = torch.rand(3, 4)
3109 jitted_cell = my_cell(x, h)
3110 non_jitted_cell = MyCell().do_it(x, h)
3111
3112 self.assertEqual(jitted_cell, non_jitted_cell)
3113
Michael Voznesensky91e74fd2020-04-30 20:41:44 -07003114 def test_code_with_constants(self):
3115 """
3116 Check that the `code_with_constants` property correctly returns graph CONSTANTS in the
3117 CONSTANTS.cN format used in the output of the `code` property.
3118 """
3119 @torch.jit.script
3120 def foo(x=torch.ones(1)):
3121 return x
3122
3123 class Moddy(torch.nn.Module):
Michael Voznesensky91e74fd2020-04-30 20:41:44 -07003124 def forward(self, x):
3125 return foo()
3126
3127 m = torch.jit.script(Moddy())
3128 src, CONSTANTS = m.code_with_constants
3129
3130 self.assertEqual(CONSTANTS.c0, torch.ones(1))
3131 self.assertEqual(src, m.code)
3132
3133 def test_code_with_constants_restore(self):
3134 """
3135 Check that the `code_with_constants` property correctly works on restoration after save() + load()
3136 """
3137 @torch.jit.script
3138 def foo(x=torch.ones(1)):
3139 return x
3140
3141 class Moddy(torch.nn.Module):
Michael Voznesensky91e74fd2020-04-30 20:41:44 -07003142 def forward(self, x):
3143 return foo()
3144
3145 m = torch.jit.script(Moddy())
3146 src, CONSTANTS = m.code_with_constants
3147 eic = self.getExportImportCopy(m)
3148
3149 src_eic, CONSTANTS_eic = eic.code_with_constants
3150
3151 self.assertEqual(src, src_eic)
3152 self.assertEqual(CONSTANTS.c0, CONSTANTS_eic.c0)
3153
Michael Suodf1d68d2020-02-05 13:05:55 -08003154
Shen Li10224432021-08-12 11:39:31 -07003155 def test_oneline_func(self):
3156 def fn(x): return x # noqa: E704
3157
3158 self.checkScript(fn, (torch.ones(2, 2), ))
Nikolay Korovaiko7d0f0b62020-01-24 11:16:49 -08003159
3160 def test_request_bailout(self):
Elias Ellison0e3a05e2020-05-06 11:27:59 -07003161 with enable_profiling_mode_for_profiling_tests():
Nikolay Korovaiko7d0f0b62020-01-24 11:16:49 -08003162
3163 def fct_loop(x):
3164 for i in range(3):
3165 x = torch.cat((x, x), 0)
3166 return x
3167
3168 x = torch.ones(2, 3, 4, dtype=torch.float32)
3169 expected = fct_loop(x)
3170 jitted = torch.jit.script(fct_loop)
3171 # profile
3172 jitted(x)
3173 # optimize
3174 jitted(x)
3175 dstate = jitted.get_debug_state()
3176 eplan = get_execution_plan(dstate)
3177 num_bailouts = eplan.code.num_bailouts()
3178
3179 for i in range(0, num_bailouts):
3180 eplan.code.request_bailout(i)
3181 self.assertEqual(jitted(x), expected)
3182
Nikolay Korovaikofe261022020-09-13 15:56:30 -07003183 @unittest.skip("bailouts are being deprecated")
Elias Ellisona55d80e2020-04-28 23:18:29 -07003184 def test_dominated_bailout(self):
Elias Ellison0e3a05e2020-05-06 11:27:59 -07003185 with enable_profiling_mode_for_profiling_tests():
Elias Ellisona55d80e2020-04-28 23:18:29 -07003186 # functional dominated guard
3187 @torch.jit.script
3188 def foo(x):
3189 dim = x.dim()
3190 if dim == 0:
3191 y = int(x)
3192 else:
3193 y = x.size()[dim - 1]
3194 return y
3195
3196 x = torch.zeros(2)
3197 self.assertEqual(foo(x), 2)
3198 self.assertEqual(foo(x), 2)
3199 g = torch.jit.last_executed_optimized_graph()
3200 g_s = str(g)
Shen Li10224432021-08-12 11:39:31 -07003201 g_s = g_s[0:g_s.find("return")]
Elias Ellisona55d80e2020-04-28 23:18:29 -07003202 FileCheck().check_count("prim::BailOut[", 1, exactly=True).run(g_s)
3203
3204 # dominated guard of non-functional value
3205 @torch.jit.script
3206 def foo(x):
3207 dim = x.dim()
3208 x.add_(3)
3209 if dim == 0:
3210 return 0
3211 else:
3212 return x.size()[dim - 1]
3213
3214 x = torch.zeros(2)
3215 self.assertEqual(foo(x), 2)
3216 self.assertEqual(foo(x), 2)
3217 g = torch.jit.last_executed_optimized_graph()
Shen Li10224432021-08-12 11:39:31 -07003218 FileCheck().check("prim::BailOut[").check("aten::add_").check_next("prim::BailOut[").check("return").run(g)
Elias Ellisona55d80e2020-04-28 23:18:29 -07003219
3220 with torch.enable_grad():
3221 @torch.jit.ignore
3222 def disable_grad():
3223 torch.set_grad_enabled(False)
3224
3225 @torch.jit.ignore
3226 def enable_grad():
3227 torch.set_grad_enabled(True)
3228
3229 @torch.jit.script
3230 def foo(x):
3231 x = x + 1
3232 dim = x.dim()
3233 disable_grad()
3234 if dim == 0:
3235 y = int(x)
3236 else:
3237 y = x.size()[dim - 1]
3238 enable_grad()
3239 return y
3240
3241 x = torch.zeros(2, requires_grad=True)
3242 self.assertEqual(foo(x), 2)
3243 self.assertEqual(foo(x), 2)
3244 g = torch.jit.last_executed_optimized_graph()
3245 # there should still be a Bailout after disable_grad call
Shen Li10224432021-08-12 11:39:31 -07003246 FileCheck().check("disable_grad").check("BailOut[").check("BailoutTemplate").run(g)
Elias Ellisona55d80e2020-04-28 23:18:29 -07003247
Yanbo Liang490c1cf2022-12-19 04:14:11 +00003248 @skipIfTorchDynamo("Torchdynamo cannot correctly handle profiler.profile calls")
Shen Li10224432021-08-12 11:39:31 -07003249 @unittest.skipIf(GRAPH_EXECUTOR != ProfilingMode.PROFILING, "skip if profiling isn't enabled")
Nikolay Korovaiko831c8f32020-05-01 12:49:18 -07003250 def test_profiling_merge(self):
3251 @torch.jit.script
3252 def test_not_const(x):
3253 if x.size(0) == 1:
3254 return 1
3255 else:
3256 return 2
3257
Elias Ellison28ac5cd2020-05-06 15:02:09 -07003258 with enable_profiling_mode_for_profiling_tests():
Nikolay Korovaiko9b95f752020-05-27 01:11:18 -07003259 with num_profiled_runs(2):
3260 test_not_const(torch.rand([1, 2]))
3261 test_not_const(torch.rand([2, 2]))
Nikolay Korovaiko831c8f32020-05-01 12:49:18 -07003262
Nikolay Korovaiko9b95f752020-05-27 01:11:18 -07003263 graph_str = torch.jit.last_executed_optimized_graph()
Shen Li10224432021-08-12 11:39:31 -07003264 FileCheck().check("profiled_type=Double(*, 2, strides=[2, 1], requires_grad=0, device=cpu").run(graph_str)
3265 FileCheck().check_not("profiled_type=Double(1, 2, strides=[2, 1], requires_grad=0, device=cpu").run(graph_str)
3266
Nikolay Korovaiko831c8f32020-05-01 12:49:18 -07003267
Nikolay Korovaikoe3334722019-11-21 09:51:18 -08003268 def test_nested_bailouts(self):
3269 @torch.jit.script
3270 def fct_loop(x):
3271 for i in range(3):
3272 x = torch.cat((x, x), 0)
3273 return x
3274
3275 x = torch.ones(2, 3, 4, dtype=torch.float32)
3276 out = fct_loop(x)
3277 jit_trace = torch.jit.trace(fct_loop, x)
3278 out_trace = jit_trace(x)
3279
Elias Ellisonef944962020-01-23 14:24:52 -08003280 def test_no_self_arg_ignore_function(self):
3281 class MyModule(nn.Module):
3282 @torch.jit.ignore # noqa: B902
3283 def call_np(): # noqa: B902
3284 # type: () -> int
Shen Li10224432021-08-12 11:39:31 -07003285 return np.random.choice(2, p=[.95, .05])
Elias Ellisonef944962020-01-23 14:24:52 -08003286
3287 def forward(self):
3288 return self.call_np()
3289
3290 with self.assertRaisesRegex(Exception, "does not have a self argument"):
3291 torch.jit.script(MyModule())
3292
Nikolay Korovaiko53708e22020-01-16 15:10:36 -08003293 def test_loop_liveness(self):
Elias Ellison0e3a05e2020-05-06 11:27:59 -07003294 with enable_profiling_mode_for_profiling_tests():
Nikolay Korovaiko53708e22020-01-16 15:10:36 -08003295 @torch.jit.script
3296 def f(i):
3297 # type: (int) -> Tensor
3298 l = []
3299 for n in [2, 1]:
3300 l.append(torch.zeros(n, i))
3301
3302 return l[0]
3303
3304 f(2)
3305 f(1)
3306
Nikolay Korovaikofc3103b2019-12-19 00:32:43 -08003307 def test_bailout_loop_carried_deps_name_clash(self):
Elias Ellison0e3a05e2020-05-06 11:27:59 -07003308 with enable_profiling_mode_for_profiling_tests():
Nikolay Korovaikofc3103b2019-12-19 00:32:43 -08003309 NUM_ITERATIONS = 10
Edward Yangda2004e2020-06-04 12:53:53 -07003310
Nikolay Korovaikofc3103b2019-12-19 00:32:43 -08003311 @torch.jit.script
3312 def fct_loop(z, size):
3313 # type: (int, int) -> Tuple[Tensor, List[int]]
3314 counters = torch.jit.annotate(List[int], [])
3315 j = 0
3316 y = torch.ones(2)
3317 for i in range(size):
3318 counters.append(i + j)
3319 y = torch.cat((y, torch.ones(z)), 0)
3320 j = j + 1
3321 return y, counters
3322
3323 inputs = [1, 2, 3, 4]
3324 expected = [x * 2 for x in range(NUM_ITERATIONS)]
3325 for inp in inputs:
3326 results = fct_loop(inp, NUM_ITERATIONS)
3327 self.assertEqual(results[1], expected)
3328
Nikolay Korovaikod4c25ad2019-12-03 14:57:21 -08003329 def test_bailout_loop_counter_transition(self):
Elias Ellison0e3a05e2020-05-06 11:27:59 -07003330 with enable_profiling_mode_for_profiling_tests():
Nikolay Korovaikod4c25ad2019-12-03 14:57:21 -08003331 NUM_ITERATIONS = 10
Edward Yangda2004e2020-06-04 12:53:53 -07003332
Nikolay Korovaikod4c25ad2019-12-03 14:57:21 -08003333 @torch.jit.script
3334 def fct_loop(z, size):
3335 # type: (int, int) -> Tuple[Tensor, List[int]]
3336 counters = torch.jit.annotate(List[int], [])
3337 y = torch.ones(2)
3338 for i in range(size):
3339 counters.append(i)
3340 y = torch.cat((y, torch.ones(z)), 0)
3341 return y, counters
3342
3343 inputs = [1, 2, 3, 4]
3344 expected = list(range(NUM_ITERATIONS))
3345 for inp in inputs:
3346 results = fct_loop(inp, NUM_ITERATIONS)
3347 self.assertEqual(results[1], expected)
3348
davidriazati8d66f882020-04-15 17:29:24 -07003349 def test_ignored_method_binding(self):
3350 class Bar(torch.nn.Module):
3351 def __init__(self):
Xuehai Pan046e88a2023-02-12 22:20:50 +00003352 super().__init__()
Shen Li10224432021-08-12 11:39:31 -07003353 self.x : int = 0
davidriazati8d66f882020-04-15 17:29:24 -07003354
3355 @torch.jit.export
Shen Li10224432021-08-12 11:39:31 -07003356 def setx(self, x : int):
davidriazati8d66f882020-04-15 17:29:24 -07003357 self.x = x
3358
3359 @torch.jit.export
3360 def getx(self):
3361 return self.x
3362
3363 @torch.jit.ignore
3364 def ignored_getx(self):
3365 return self.x
3366
3367 b = Bar()
3368 b.setx(123)
3369 sb = torch.jit.script(b)
3370 self.assertEqual(sb.getx(), 123)
3371 self.assertEqual(sb.ignored_getx(), 123)
3372
3373 sb.setx(456)
3374 self.assertEqual(sb.getx(), 456)
3375 self.assertEqual(sb.ignored_getx(), 456)
3376
Michael Suo34126272019-10-12 09:49:56 -07003377 def test_set_attribute_through_optional(self):
3378 class A(torch.nn.Module):
3379 __annotations__ = {"x": Optional[torch.Tensor]}
3380
3381 def __init__(self):
Xuehai Pan046e88a2023-02-12 22:20:50 +00003382 super().__init__()
Michael Suo34126272019-10-12 09:49:56 -07003383 self.x = None
3384
3385 @torch.jit.ignore
3386 def foo(self):
3387 if self.x is None:
3388 self.x = torch.tensor([3])
3389 return self.x
3390
3391 def forward(self, x):
3392 a = self.foo()
3393 return x + 1
3394
3395 m = torch.jit.script(A())
3396 self.assertEqual(m.x, None)
3397 m(torch.rand(1))
3398 self.assertEqual(m.x, torch.tensor([3]))
3399
3400 def test_mutate_constant(self):
3401 class M(torch.jit.ScriptModule):
3402 __constants__ = ["foo"]
3403
3404 def __init__(self, foo):
Xuehai Pan046e88a2023-02-12 22:20:50 +00003405 super().__init__()
Michael Suo34126272019-10-12 09:49:56 -07003406 self.foo = foo
3407
3408 m = M(5)
Jerry Zhangebe69232020-01-04 11:07:37 -08003409 # m has a constant attribute, but we can't
3410 # assign to it
3411 with self.assertRaises(RuntimeError):
Michael Suo34126272019-10-12 09:49:56 -07003412 m.foo = 6
3413
Michael Suo34126272019-10-12 09:49:56 -07003414 def test_class_attribute(self):
3415 class M(torch.jit.ScriptModule):
3416 FOO = 0
3417
3418 def __init__(self):
Xuehai Pan046e88a2023-02-12 22:20:50 +00003419 super().__init__()
Michael Suo34126272019-10-12 09:49:56 -07003420 self.foo = self.FOO
3421 m = M()
3422 self.assertEqual(m.foo, M.FOO)
3423
3424 def test_class_attribute_in_script(self):
3425 class M(torch.jit.ScriptModule):
3426 FOO = 0
3427
Michael Suo34126272019-10-12 09:49:56 -07003428 @torch.jit.script_method
3429 def forward(self):
3430 return self.FOO
3431 with self.assertRaises(RuntimeError):
3432 M()
3433
3434 def test_not_initialized_err(self):
3435 class M(torch.jit.ScriptModule):
3436 def __init__(self):
3437 self.foo = torch.rand(2, 3)
3438 with self.assertRaises(RuntimeError):
3439 M()
3440
3441 def test_attribute_in_init(self):
3442 class M(torch.jit.ScriptModule):
3443 def __init__(self):
Xuehai Pan046e88a2023-02-12 22:20:50 +00003444 super().__init__()
Michael Suo34126272019-10-12 09:49:56 -07003445 self.foo = torch.jit.Attribute(0.1, float)
3446 # we should be able to use self.foo as a float here
3447 assert 0.0 < self.foo
3448 M()
3449
3450 def test_scriptable_fn_as_attr(self):
3451 class M(torch.nn.Module):
3452 def __init__(self, fn):
Xuehai Pan046e88a2023-02-12 22:20:50 +00003453 super().__init__()
Michael Suo34126272019-10-12 09:49:56 -07003454 self.fn = fn
3455
3456 def forward(self, x):
3457 return self.fn(x)
3458
Michael Suobd7e9c42020-02-27 19:02:21 -08003459 m = M(torch.sigmoid)
Michael Suo34126272019-10-12 09:49:56 -07003460 inp = torch.rand(2, 3)
Shen Li10224432021-08-12 11:39:31 -07003461 self.checkModule(m, (inp, ))
Michael Suo34126272019-10-12 09:49:56 -07003462
Zachary DeVitoe41aa0e2019-05-07 19:32:15 -07003463 def test_sequence_parsing(self):
3464 tests = [
3465 ("return [x, x,]", True),
3466 ("return [x x]", "expected ]"),
3467 ("return x, x,", True),
3468 ("return bar(x, x,)", True),
Mikhail Zolotukhinfbecb462019-06-13 17:01:29 -07003469 ("return bar()", "Argument x not provided"),
Zachary DeVitoe41aa0e2019-05-07 19:32:15 -07003470 ("for a, b, in x, x,:\n pass", "List of iterables"),
Shen Li10224432021-08-12 11:39:31 -07003471 ("a, b, = x, x,\n return a + b", True)
Zachary DeVitoe41aa0e2019-05-07 19:32:15 -07003472 ]
3473 for exp, result in tests:
3474 cu = torch.jit.CompilationUnit()
3475 full = """
3476def bar(x, y):
3477 return x + y
3478def foo(x):
3479 {}
Shen Li10224432021-08-12 11:39:31 -07003480 """.format(exp)
Zachary DeVitoe41aa0e2019-05-07 19:32:15 -07003481 if isinstance(result, str):
3482 with self.assertRaisesRegex(RuntimeError, result):
3483 cu.define(full)
3484 else:
3485 cu.define(full)
3486
davidriazati4c40dbc2019-09-20 10:46:14 -07003487 def test_namedtuple_python(self):
Zachary DeVitoeb9000b2019-10-09 12:13:56 -07003488 global MyTuple, MyMod # see [local resolution in python]
Shen Li10224432021-08-12 11:39:31 -07003489 MyTuple = namedtuple('MyTuple', ['a'])
davidriazati4c40dbc2019-09-20 10:46:14 -07003490
3491 @torch.jit.unused
3492 def fn():
3493 # type: () -> MyTuple
3494 return MyTuple(1)
3495
3496 # Only check compilation
3497 @torch.jit.script
3498 def fn2():
3499 # type: () -> MyTuple
3500 return fn()
3501
3502 FileCheck().check("NamedTuple").run(fn2.graph)
3503
Elias Ellisonefaa65d2019-09-24 08:14:31 -07003504 class MyMod(torch.nn.Module):
Elias Ellisonefaa65d2019-09-24 08:14:31 -07003505 @torch.jit.unused
3506 def fn(self):
3507 # type: () -> MyTuple
3508 return MyTuple(1)
3509
3510 def forward(self, x):
Elias Ellisond1b8da72020-11-20 11:14:59 -08003511 if 1 == 1:
Michael Suo34126272019-10-12 09:49:56 -07003512 return MyTuple(torch.rand(2, 3))
3513 else:
3514 return self.fn()
Elias Ellisonefaa65d2019-09-24 08:14:31 -07003515
Michael Suo34126272019-10-12 09:49:56 -07003516 # shouldn't throw a type error
3517 torch.jit.script(MyMod())
Elias Ellisonefaa65d2019-09-24 08:14:31 -07003518
Yuxin Wu488ee372020-07-15 16:52:54 -07003519 def test_unused_decorator(self):
3520 class MyMod(torch.nn.Module):
Yuxin Wu488ee372020-07-15 16:52:54 -07003521 @torch.jit.unused
3522 @torch.no_grad()
3523 def fn(self, x):
3524 # type: (Tensor) -> int
3525 return next(x) # invalid, but should be ignored
3526
3527 def forward(self, x):
3528 return self.fn(x)
3529
3530 torch.jit.script(MyMod())
3531
davidriazati44622bb2020-03-24 13:39:12 -07003532 @_inline_everything
3533 def test_lazy_script(self):
3534 def untraceable(x):
3535 if x.ndim > 2:
3536 print("hello")
3537 else:
3538 print("goodbye")
3539 return x + 2
3540
3541 # Non-working example
3542 def fn(x):
3543 return untraceable(x)
3544
3545 with self.capture_stdout():
3546 traced_bad = torch.jit.trace(fn, [torch.ones(2, 2)])
3547
3548 FileCheck().check_not("goodbye").check_not("hello").run(traced_bad.graph)
3549
3550 # Working example
Yanan Cao6a2f40d2020-10-17 17:30:09 -07003551 untraceable = torch.jit.script_if_tracing(untraceable)
davidriazati44622bb2020-03-24 13:39:12 -07003552
3553 def fn2(x):
3554 return untraceable(x)
3555
3556 with self.capture_stdout():
3557 traced = torch.jit.trace(fn, [torch.ones(2, 2)])
3558
Elias Ellison6bc8ffe2020-04-07 09:39:56 -07003559 FileCheck().check("goodbye").run(traced.graph)
davidriazati44622bb2020-03-24 13:39:12 -07003560
Elias Ellison6468bc42020-06-23 17:10:27 -07003561 def foo(x: int):
3562 return x + 1
3563
Yanan Cao6a2f40d2020-10-17 17:30:09 -07003564 @torch.jit.script_if_tracing
Elias Ellison6468bc42020-06-23 17:10:27 -07003565 def fee(x: int = 2):
3566 return foo(1) + x
3567
3568 # test directly compiling function
3569 fee_compiled = torch.jit.script(fee)
3570 self.assertEqual(fee_compiled(), fee())
3571
3572 # test compiling it within another function
3573 @torch.jit.script
3574 def hum():
3575 return fee(x=3)
3576
3577 self.assertEqual(hum(), 5)
3578
davidriazati1f50cfc2019-12-18 13:55:08 -08003579 def test_big_int_literals(self):
3580 def ok():
3581 # signed 64 bit max
3582 a = 9223372036854775807
3583 return a
3584
3585 def toobig():
3586 a = 9223372036854775808
3587 return a
3588
3589 def waytoobig():
3590 a = 99999999999999999999
3591 return a
3592
3593 self.checkScript(ok, [])
3594
3595 with self.assertRaisesRegex(RuntimeError, "out of range"):
3596 torch.jit.script(toobig)
3597
3598 with self.assertRaisesRegex(RuntimeError, "out of range"):
3599 torch.jit.script(waytoobig)
3600
davidriazati76924942019-12-18 13:55:08 -08003601 def test_hex_literals(self):
3602 def test1():
Shen Li10224432021-08-12 11:39:31 -07003603 return 0xaaaaaa
davidriazati76924942019-12-18 13:55:08 -08003604
3605 def test2():
Shen Li10224432021-08-12 11:39:31 -07003606 return 0xaaaaaa
davidriazati76924942019-12-18 13:55:08 -08003607
3608 def test3():
Shen Li10224432021-08-12 11:39:31 -07003609 return -0xaaaaaa
davidriazati76924942019-12-18 13:55:08 -08003610
3611 self.checkScript(test1, [])
3612 self.checkScript(test2, [])
3613 self.checkScript(test3, [])
3614
3615 def ok():
3616 a = 0x7FFFFFFFFFFFFFFF
3617 return a
3618
3619 def toobig():
3620 a = 0xFFFFFFFFFFFFFFFF
3621 return a
3622
3623 def waytoobig():
3624 a = 0xFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFF
3625 return a
3626
3627 self.checkScript(ok, [])
3628
3629 with self.assertRaisesRegex(RuntimeError, "out of range"):
3630 torch.jit.script(toobig)
3631
3632 with self.assertRaisesRegex(RuntimeError, "out of range"):
3633 torch.jit.script(waytoobig)
3634
davidriazati446e9af2019-12-23 17:18:39 -08003635 def test_big_float_literals(self):
3636 def ok():
3637 # Python interprets this as inf
Shen Li10224432021-08-12 11:39:31 -07003638 a = 1.2E400
davidriazati446e9af2019-12-23 17:18:39 -08003639 return a
3640
3641 def check(fn):
3642 self.assertTrue(fn() == ok())
3643
3644 # checkScript doesn't work since assertEqual doesn't consider
3645 # `inf` == `inf`
3646 check(torch.jit.script(ok))
3647
3648 cu = torch.jit.CompilationUnit()
3649 cu.define(dedent(inspect.getsource(ok)))
3650 check(cu.ok)
3651
davidriazatif61b45f2020-02-12 18:56:55 -08003652 def _test_device_type(self, dest):
3653 def fn(x):
3654 # type: (Device) -> Tuple[str, Optional[int]]
3655 return x.type, x.index
3656
3657 device = torch.ones(2).to(dest).device
3658 self.checkScript(fn, [device])
3659
3660 def test_device_type(self):
Shen Li10224432021-08-12 11:39:31 -07003661 self._test_device_type('cpu')
davidriazatif61b45f2020-02-12 18:56:55 -08003662
3663 @unittest.skipIf(not RUN_CUDA, "Requires CUDA")
3664 def test_device_type_cuda(self):
Shen Li10224432021-08-12 11:39:31 -07003665 self._test_device_type('cuda')
davidriazatif61b45f2020-02-12 18:56:55 -08003666
Meghan Lele751c3002021-02-02 10:53:17 -08003667 def test_string_device_implicit_conversion(self):
3668 @torch.jit.script
3669 def fn(x: torch.device):
3670 return x
3671
3672 self.assertEqual(fn("cpu"), torch.device("cpu"))
3673
3674 with self.assertRaisesRegex(RuntimeError, "Expected one of"):
3675 fn("invalid_device")
3676
davidriazati00460922019-10-07 13:50:23 -07003677 def test_eval_python(self):
3678 def _test(m):
3679 self.assertTrue(m(torch.ones(2, 2)))
3680 self.assertTrue(m.training)
Shen Li10224432021-08-12 11:39:31 -07003681 self.assertTrue(m._c.getattr('training'))
davidriazati00460922019-10-07 13:50:23 -07003682
3683 m.eval()
3684
3685 self.assertFalse(m.training)
Shen Li10224432021-08-12 11:39:31 -07003686 self.assertFalse(m._c.getattr('training'))
davidriazati00460922019-10-07 13:50:23 -07003687 self.assertFalse(m(torch.ones(2, 2)))
3688
David Reisse75fb432020-04-22 09:20:13 -07003689 buffer = io.BytesIO()
3690 torch.jit.save(m, buffer)
3691 buffer.seek(0)
davidriazati00460922019-10-07 13:50:23 -07003692
David Reisse75fb432020-04-22 09:20:13 -07003693 loaded = torch.jit.load(buffer)
davidriazati00460922019-10-07 13:50:23 -07003694
David Reisse75fb432020-04-22 09:20:13 -07003695 self.assertFalse(loaded.training)
Shen Li10224432021-08-12 11:39:31 -07003696 self.assertFalse(loaded._c.getattr('training'))
davidriazati00460922019-10-07 13:50:23 -07003697
3698 class M(nn.Module):
davidriazati00460922019-10-07 13:50:23 -07003699 def forward(self, x):
3700 return self.training
3701
3702 class OldM(torch.jit.ScriptModule):
davidriazati00460922019-10-07 13:50:23 -07003703 @torch.jit.script_method
3704 def forward(self, x):
3705 return self.training
3706
3707 _test(torch.jit.script(M()))
3708 _test(OldM())
3709
Zachary DeVito5243fe02019-05-15 12:38:47 -07003710 def test_inherit_method(self):
3711 class A(torch.jit.ScriptModule):
Zachary DeVito5243fe02019-05-15 12:38:47 -07003712 @torch.jit.script_method
3713 def forward(self, x):
3714 return x + self.bar(x)
3715
3716 class B(A):
Zachary DeVito5243fe02019-05-15 12:38:47 -07003717 @torch.jit.script_method
3718 def bar(self, x):
3719 return x * x
3720
Shen Li10224432021-08-12 11:39:31 -07003721 with self.assertRaisesRegex(RuntimeError, 'attribute'):
Zachary DeVito5243fe02019-05-15 12:38:47 -07003722 A() # cannot use because bar is not defined
3723
3724 v = torch.rand(3, 4)
3725 b = B()
3726 self.assertEqual(b(v), v + v * v)
3727
3728 class C(torch.jit.ScriptModule):
Zachary DeVito5243fe02019-05-15 12:38:47 -07003729 @torch.jit.script_method
3730 def bar(self, x):
3731 return x
3732
3733 class D(C, B):
3734 def __init__(self):
Xuehai Pan046e88a2023-02-12 22:20:50 +00003735 super().__init__()
Zachary DeVito5243fe02019-05-15 12:38:47 -07003736
3737 self.assertEqual(D()(v), v + v)
3738
Tugsbayasgalan Manlaibaatar10abbb82021-04-07 12:08:58 -07003739 def test_tensor_subclasses(self):
3740 def check_subclass(x, tensor):
Shen Li10224432021-08-12 11:39:31 -07003741 template = dedent("""
Tugsbayasgalan Manlaibaatar10abbb82021-04-07 12:08:58 -07003742 def func(input: {}) -> {}:
3743 return torch.zeros((input.shape[0], 1), dtype=input.dtype)
Shen Li10224432021-08-12 11:39:31 -07003744 """)
Tugsbayasgalan Manlaibaatar10abbb82021-04-07 12:08:58 -07003745
3746 self._check_code(template.format(x, x), "func", [tensor])
3747
3748 check_subclass("torch.LongTensor", torch.LongTensor([[1, 2], [3, 4]]))
Shen Li10224432021-08-12 11:39:31 -07003749 check_subclass("torch.DoubleTensor", torch.DoubleTensor([[1.2, 2.3], [3.4, 4.5]]))
Tugsbayasgalan Manlaibaatar10abbb82021-04-07 12:08:58 -07003750 check_subclass("torch.IntTensor", torch.IntTensor([[1, 2], [3, 4]]))
Shen Li10224432021-08-12 11:39:31 -07003751 check_subclass("torch.BoolTensor", torch.BoolTensor([[False, True], [True, False]]))
Tugsbayasgalan Manlaibaatar10abbb82021-04-07 12:08:58 -07003752
3753 def check_subclass_warn(input: torch.LongTensor) -> torch.LongTensor:
3754 return torch.zeros((input.shape[0], 1), dtype=input.dtype)
3755
3756 with warnings.catch_warnings(record=True) as warns:
3757 scripted = torch.jit.script(check_subclass_warn)
Shen Li10224432021-08-12 11:39:31 -07003758 FileCheck().check("TorchScript will treat type annotations of Tensor").run(str(warns[0]))
Tugsbayasgalan Manlaibaatar10abbb82021-04-07 12:08:58 -07003759
Zachary DeVitode31f672019-06-08 20:54:17 -07003760 def test_first_class_module(self):
Zachary DeVito972ec672019-06-16 14:24:02 -07003761 class Foo(torch.jit.ScriptModule):
3762 def __init__(self):
Xuehai Pan046e88a2023-02-12 22:20:50 +00003763 super().__init__()
Zachary DeVito972ec672019-06-16 14:24:02 -07003764 self.foo = nn.Parameter(torch.rand(3, 4))
Zachary DeVitode31f672019-06-08 20:54:17 -07003765
Zachary DeVito972ec672019-06-16 14:24:02 -07003766 @torch.jit.script_method
3767 def forward(self, input):
3768 self.foo = input
3769 return self.foo
3770 foo = Foo()
3771 input = torch.rand(3, 4)
3772 foo.forward(input)
3773 self.assertEqual(input, foo.foo)
Zachary DeVito5243fe02019-05-15 12:38:47 -07003774
Michael Suo755f91b2019-08-19 18:41:08 -07003775 @_tmp_donotuse_dont_inline_everything
Zachary DeVitoea822d92019-06-08 20:54:18 -07003776 def test_first_class_calls(self):
Michael Suo755f91b2019-08-19 18:41:08 -07003777 @torch.jit.script
Aaron Gokaslan8fce9a02023-02-07 21:10:52 +00003778 class Foo:
Michael Suo755f91b2019-08-19 18:41:08 -07003779 def __init__(self, x):
3780 self.bar = x
Zachary DeVitoea822d92019-06-08 20:54:18 -07003781
Michael Suo755f91b2019-08-19 18:41:08 -07003782 def stuff(self, x):
3783 return self.bar + x
Zachary DeVitoea822d92019-06-08 20:54:18 -07003784
Michael Suo755f91b2019-08-19 18:41:08 -07003785 @torch.jit.script
3786 def foo(x):
3787 return x * x + Foo(x).stuff(2 * x)
Zachary DeVitoea822d92019-06-08 20:54:18 -07003788
Michael Suo755f91b2019-08-19 18:41:08 -07003789 @torch.jit.script
3790 def bar(x):
3791 return foo(x) * foo(x)
Zachary DeVitoea822d92019-06-08 20:54:18 -07003792
Michael Suo755f91b2019-08-19 18:41:08 -07003793 x = torch.rand(3, 4)
3794 self.assertEqual(bar(x), (x * x + 3 * x) * (x * x + 3 * x))
Zachary DeVitoea822d92019-06-08 20:54:18 -07003795
davidriazati8cdc2622019-10-16 10:34:54 -07003796 def test_static_methods(self):
3797 class M(nn.Module):
davidriazati8cdc2622019-10-16 10:34:54 -07003798 @staticmethod
3799 def my_method(x):
3800 return x + 100
3801
3802 def forward(self, x):
3803 return x + M.my_method(x)
3804
3805 class N(nn.Module):
davidriazati8cdc2622019-10-16 10:34:54 -07003806 @staticmethod
3807 def my_method(x):
3808 return x * 100
3809
3810 def forward(self, x):
3811 return x - M.my_method(x) + N.my_method(x)
3812
3813 self.checkModule(M(), (torch.ones(2, 2),))
3814
3815 self.checkModule(N(), (torch.ones(2, 2),))
3816
Nikolay Korovaikocbf2a4f2019-05-29 11:50:40 -07003817 def test_invalid_prefix_annotation(self):
3818 with self.assertRaisesRegex(RuntimeError, "annotation prefix in line"):
3819 with self.capture_stdout() as captured:
3820 @torch.jit.script
3821 def invalid_prefix_annotation1(a):
Shen Li10224432021-08-12 11:39:31 -07003822 #type: (Int) -> Int # noqa: E265
Nikolay Korovaikocbf2a4f2019-05-29 11:50:40 -07003823 return a + 2
3824
3825 with self.assertRaisesRegex(RuntimeError, "annotation prefix in line"):
3826 with self.capture_stdout() as captured:
3827 @torch.jit.script
3828 def invalid_prefix_annotation2(a):
Shen Li10224432021-08-12 11:39:31 -07003829 #type : (Int) -> Int # noqa: E265
Nikolay Korovaikocbf2a4f2019-05-29 11:50:40 -07003830 return a + 2
3831
3832 with self.assertRaisesRegex(RuntimeError, "annotation prefix in line"):
3833 with self.capture_stdout() as captured:
3834 @torch.jit.script
3835 def invalid_prefix_annotation3(a):
3836 # type: (Int) -> Int
3837 return a + 2
3838
davidriazati148bcd32019-12-18 15:22:44 -08003839 def test_builtin_function_attributes(self):
3840 class Add(nn.Module):
3841 def __init__(self):
Xuehai Pan046e88a2023-02-12 22:20:50 +00003842 super().__init__()
davidriazati148bcd32019-12-18 15:22:44 -08003843 self.add = torch.add
3844
3845 def forward(self, input):
3846 return self.add(input, input)
3847
3848 self.checkModule(Add(), [torch.randn(2, 2)])
3849
davidriazati6e13a772020-04-03 11:30:29 -07003850 def test_pybind_type_comparisons(self):
3851 @torch.jit.script
3852 def f():
3853 return None
3854
3855 node = list(f.graph.nodes())[0]
3856 t = node.outputsAt(0).type()
Sam Estepe3900d22021-04-19 13:14:27 -07003857 self.assertIsNotNone(t)
davidriazati6e13a772020-04-03 11:30:29 -07003858
Nikita Shulga5976f0b2023-01-29 18:28:46 +00003859 @unittest.skipIf(IS_WINDOWS, 'TODO: need to fix the test case')
Hiroshi Ogawa97b39a22019-10-16 10:36:06 -07003860 def test_unmatched_type_annotation(self):
Shen Li10224432021-08-12 11:39:31 -07003861 message1 = re.escape("Number of type annotations (2) did not match the number of function parameters (1):")
3862 message2 = 'def invalid2\\(a\\):\n\\s*~+\\.*\\s+<--- HERE\n\\s+# type: \\(Int, Int\\) -> Int\n\\s+return a \\+ 2'
3863 message3 = 'def invalid4\\(a\\):\n\\s*~+\\.*\\s+<--- HERE\n\\s+# type: \\(Int, Int\\) -> Int\n\\s+return a \\+ 2'
Hiroshi Ogawa97b39a22019-10-16 10:36:06 -07003864 with self.assertRaisesRegex(RuntimeError, message1):
3865 @torch.jit.script
3866 def invalid1(a):
3867 # type: (Int, Int) -> Int
3868 return a + 2
3869
3870 with self.assertRaisesRegex(RuntimeError, message2):
3871 @torch.jit.script
3872 def invalid2(a):
3873 # type: (Int, Int) -> Int
3874 return a + 2
3875
3876 with self.assertRaisesRegex(RuntimeError, message1):
3877 def invalid3(a):
3878 # type: (Int, Int) -> Int
3879 return a + 2
3880 torch.jit.script(invalid3)
3881
3882 with self.assertRaisesRegex(RuntimeError, message3):
3883 def invalid4(a):
3884 # type: (Int, Int) -> Int
3885 return a + 2
3886 torch.jit.script(invalid4)
3887
Nikita Shulga767f6aa2022-11-17 22:05:27 +00003888 def test_calls_in_type_annotations(self):
3889 with self.assertRaisesRegex(RuntimeError, "Type annotation should not contain calls"):
3890 def spooky(a):
3891 # type: print("Hello") -> Tensor # noqa: F723
3892 return a + 2
3893 print(torch.__file__)
3894 torch.jit.annotations.get_signature(spooky, None, 1, True)
3895
davidriazatid0fff0e2019-09-24 10:41:43 -07003896 def test_is_optional(self):
3897 ann = Union[List[int], List[float]]
3898 torch._jit_internal.is_optional(ann)
3899
Zachary DeVito18996a82019-06-08 20:54:17 -07003900 def test_interpreter_fuzz(self):
Nikita Shulgac6b69a42020-06-15 08:14:54 -07003901 import builtins
Zachary DeVito18996a82019-06-08 20:54:17 -07003902 # This test generates random tree-like programs to fuzz test
3903 # that the interpreter does not have a bug in its stack manipulation
3904 # code. An assert in that code ensures individual operators are
3905 # not reordered.
3906 templates = [
3907 "torch.rand(3, 4)",
3908 "({} + {})",
3909 "-{}",
3910 "({} * {})",
3911 "torch.tanh({})",
3912 "VAR {}",
3913 ]
3914
3915 def gen_code():
Shen Li10224432021-08-12 11:39:31 -07003916 src_lines = ['def f():']
Zachary DeVito18996a82019-06-08 20:54:17 -07003917 exprs = []
3918 n_variables = 0
3919
3920 def get_expr(idx):
3921 elem = exprs[idx]
3922 exprs[idx] = exprs[-1]
3923 exprs.pop()
3924 return elem
3925
3926 def select_expr_or_var():
3927 idx = random.randrange(0, len(exprs) + n_variables)
3928 if idx < len(exprs):
3929 return get_expr(idx)
3930 else:
Shen Li10224432021-08-12 11:39:31 -07003931 return 'v{}'.format(idx - len(exprs))
Zachary DeVito18996a82019-06-08 20:54:17 -07003932
3933 for i in range(50):
3934 n = None
3935 while n is None or n > len(exprs) + n_variables:
3936 template = random.choice(templates)
Shen Li10224432021-08-12 11:39:31 -07003937 n = template.count('{}')
Zachary DeVito18996a82019-06-08 20:54:17 -07003938
Shen Li10224432021-08-12 11:39:31 -07003939 if 'VAR' in template:
3940 src_lines.append(' v{} = {}'.format(n_variables, select_expr_or_var()))
Zachary DeVito18996a82019-06-08 20:54:17 -07003941 n_variables += 1
3942 else:
Shen Li10224432021-08-12 11:39:31 -07003943 exprs.append(template.format(*(select_expr_or_var() for _ in range(n))))
Zachary DeVito18996a82019-06-08 20:54:17 -07003944
Shen Li10224432021-08-12 11:39:31 -07003945 src_lines.append(' return ({})\n'.format(''.join('v{},'.format(i) for i in range(n_variables))))
3946 return '\n'.join(src_lines)
Zachary DeVito18996a82019-06-08 20:54:17 -07003947
3948 for i in range(100):
Shen Li10224432021-08-12 11:39:31 -07003949 g = {'torch': torch}
Zachary DeVito18996a82019-06-08 20:54:17 -07003950 code = gen_code()
Nikita Shulgac6b69a42020-06-15 08:14:54 -07003951 builtins.exec(code, g, None)
Zachary DeVito18996a82019-06-08 20:54:17 -07003952 cu = torch.jit.CompilationUnit(code)
3953 with freeze_rng_state():
Shen Li10224432021-08-12 11:39:31 -07003954 o1 = g['f']()
Zachary DeVito18996a82019-06-08 20:54:17 -07003955 with freeze_rng_state():
3956 o2 = cu.f()
3957 self.assertEqual(o1, o2)
3958
Animesh Jain1d90d6e2022-07-07 18:57:31 +00003959 @skipIfTorchDynamo("TorchDynamo fails with unknown reason")
Zachary DeVito79636312019-11-06 22:56:46 -08003960 def test_cpp_module_iterator(self):
3961 a = nn.Module()
Shen Li10224432021-08-12 11:39:31 -07003962 a.name = 'a'
Zachary DeVito79636312019-11-06 22:56:46 -08003963 a.p = nn.Parameter(torch.rand(3, 4))
3964 a.foo = nn.Module()
Shen Li10224432021-08-12 11:39:31 -07003965 a.foo.name = 'foo'
3966 a.foo.register_buffer('b', torch.rand(1, 1))
Zachary DeVito79636312019-11-06 22:56:46 -08003967 a.foo.bar = nn.Module()
Shen Li10224432021-08-12 11:39:31 -07003968 a.foo.bar.name = 'bar'
Zachary DeVito79636312019-11-06 22:56:46 -08003969 a.foo.bar.an_int = 4
3970 a.another = nn.Module()
Shen Li10224432021-08-12 11:39:31 -07003971 a.another.name = 'another'
Zachary DeVito79636312019-11-06 22:56:46 -08003972 sa = torch.jit.script(a)
3973 result = torch._C._jit_debug_module_iterators(sa._c)
3974
3975 def replace(e):
3976 if e is a.p:
Shen Li10224432021-08-12 11:39:31 -07003977 return 'P'
Zachary DeVito79636312019-11-06 22:56:46 -08003978 elif e is a.foo.b:
Shen Li10224432021-08-12 11:39:31 -07003979 return 'B'
Zachary DeVito79636312019-11-06 22:56:46 -08003980 elif isinstance(e, torch._C.ScriptModule):
Shen Li10224432021-08-12 11:39:31 -07003981 return e.getattr('name')
Zachary DeVito79636312019-11-06 22:56:46 -08003982
3983 return e
3984 for k, v in result.items():
3985 for i in range(len(v)):
3986 if isinstance(v[i], tuple):
3987 n, v2 = v[i]
3988 v[i] = (n, replace(v2))
3989 else:
3990 v[i] = replace(v[i])
3991 # module type creation is not deterministic, so we have to sort
3992 # the result
3993 v.sort()
Shen Li10224432021-08-12 11:39:31 -07003994 expected = {'buffers': [],
3995 'buffers_r': ['B'],
3996 'children': ['another', 'foo'],
3997 'modules': ['a', 'another', 'bar', 'foo'],
3998 'named_attributes': [('_is_full_backward_hook', None),
3999 ('another', 'another'),
4000 ('foo', 'foo'),
4001 ('name', 'a'),
4002 ('p', 'P'),
4003 ('training', True)],
4004 'named_attributes_r': [('_is_full_backward_hook', None),
4005 ('another', 'another'),
4006 ('another._is_full_backward_hook', None),
4007 ('another.name', 'another'),
4008 ('another.training', True),
4009 ('foo', 'foo'),
4010 ('foo._is_full_backward_hook', None),
4011 ('foo.b', 'B'),
4012 ('foo.bar', 'bar'),
4013 ('foo.bar._is_full_backward_hook', None),
4014 ('foo.bar.an_int', 4),
4015 ('foo.bar.name', 'bar'),
4016 ('foo.bar.training', True),
4017 ('foo.name', 'foo'),
4018 ('foo.training', True),
4019 ('name', 'a'),
4020 ('p', 'P'),
4021 ('training', True)],
4022 'named_buffers': [],
4023 'named_buffers_r': [('foo.b', 'B')],
4024 'named_children': [('another', 'another'), ('foo', 'foo')],
4025 'named_modules': [('', 'a'),
4026 ('another', 'another'),
4027 ('foo', 'foo'),
4028 ('foo.bar', 'bar')],
4029 'named_parameters': [('p', 'P')],
4030 'named_parameters_r': [('p', 'P')],
4031 'parameters': ['P'],
4032 'parameters_r': ['P']}
Zachary DeVito79636312019-11-06 22:56:46 -08004033 self.assertEqual(expected, result)
4034
Adam Paszkee3d50c42020-03-09 10:21:50 -07004035 def test_parameter_order(self):
4036 m = nn.Module()
4037 for i, name in enumerate(string.ascii_letters):
4038 setattr(m, name, nn.Parameter(torch.tensor([float(i)])))
4039 ms = torch.jit.script(m)
4040 print(torch.cat(list(m.parameters())))
4041 print(torch.cat(list(ms.parameters())))
4042 self.assertEqual(list(m.parameters()), list(ms.parameters()))
4043
davidriazatief8d1c52019-09-24 17:58:31 -07004044 def test_python_op_builtins(self):
4045 @torch.jit.unused
4046 def fn(x):
4047 # type: (List[int]) -> int
4048 return sum(x)
4049
4050 @torch.jit.script
4051 def script_fn(x):
4052 # type: (List[int]) -> int
4053 return fn(x)
4054
Nikolay Korovaiko7ddd5d02019-05-03 16:09:36 -07004055 def test_submodule_twice(self):
Zachary DeVito53458c92019-04-05 13:33:14 -07004056 @torch.jit.script
4057 def foo(x):
4058 return x * x
4059
4060 class What(torch.jit.ScriptModule):
4061 def __init__(self, x):
Xuehai Pan046e88a2023-02-12 22:20:50 +00004062 super().__init__()
Zachary DeVito53458c92019-04-05 13:33:14 -07004063 self.foo = x
4064 a = What(foo)
4065 c = What(foo)
4066
Zachary DeVitobb546b22018-12-03 20:29:51 -08004067 def test_training_param(self):
4068 class What(torch.jit.ScriptModule):
4069 @torch.jit.script_method
4070 def forward(self, x):
4071 # type: (int) -> int
4072 if self.training:
4073 r = x
4074 else:
4075 r = x + 4
4076 # check double use of training
4077 if self.training:
4078 r = r + 1
4079 return r
4080
4081 w = What()
4082 self.assertEqual(4, w(3))
4083 w.train(False)
4084 self.assertEqual(7, w(3))
davidriazati61cc03f2019-06-06 11:55:44 -07004085 self.assertFalse("training" in w.state_dict())
Zachary DeVitobb546b22018-12-03 20:29:51 -08004086
James Reedd68592a2020-01-27 20:36:13 -08004087 def test_class_as_attribute(self):
4088 @torch.jit.script
Aaron Gokaslan8fce9a02023-02-07 21:10:52 +00004089 class Foo321:
James Reedd68592a2020-01-27 20:36:13 -08004090 def __init__(self):
4091 self.x = 3
4092
4093 class FooBar1234(torch.nn.Module):
4094 def __init__(self):
Xuehai Pan046e88a2023-02-12 22:20:50 +00004095 super().__init__()
James Reedd68592a2020-01-27 20:36:13 -08004096 self.f = Foo321()
4097
4098 def forward(self, x):
4099 return x + self.f.x
4100
4101 scripted = torch.jit.script(FooBar1234())
4102 eic = self.getExportImportCopy(scripted)
4103 x = torch.rand(3, 4)
4104 self.assertEqual(scripted(x), eic(x))
4105
James Reeda5539352020-05-11 14:59:11 -07004106 def test_module_str(self):
4107 class Foo(torch.nn.Module):
4108 def forward(self, x):
4109 return torch.relu(x)
4110
4111 f = torch.jit.script(Foo())
Shen Li10224432021-08-12 11:39:31 -07004112 self.assertEqual('ScriptObject', str(f._c))
James Reeda5539352020-05-11 14:59:11 -07004113
Zachary DeVitofd31eae2018-11-29 17:51:45 -08004114 def test_jitter_bug(self):
4115 @torch.jit.script
4116 def fn2(input, kernel_size):
4117 # type: (Tensor, List[int]) -> Tensor
4118 if kernel_size[0] > 1:
4119 _stride = [2]
4120 else:
4121 _stride = kernel_size
4122 print(_stride, kernel_size)
4123 return input
4124
4125 @torch.jit.script
4126 def fn(input):
4127 # type: (Tensor) -> Tensor
4128 return fn2(input, [1])
4129
Zachary DeVito4c6da642019-02-21 15:24:23 -08004130 def test_parser_kwargonly(self):
Shen Li10224432021-08-12 11:39:31 -07004131 cu = torch.jit.CompilationUnit('''
Zachary DeVito4c6da642019-02-21 15:24:23 -08004132 def foo(x, *, y) -> Tuple[Tensor, Tensor]:
4133 return x, x
4134 def bar(x):
4135 return foo(x, y=x)
Shen Li10224432021-08-12 11:39:31 -07004136 ''')
4137 self.assertTrue('*' in str(cu.foo.schema))
Zachary DeVito4c6da642019-02-21 15:24:23 -08004138 with self.assertRaisesRegex(RuntimeError, "not provided"):
Shen Li10224432021-08-12 11:39:31 -07004139 torch.jit.CompilationUnit('''
Zachary DeVito4c6da642019-02-21 15:24:23 -08004140 def foo(x, *, y) -> Tuple[Tensor, Tensor]:
4141 return x, x
4142 def bar(x):
4143 return foo(x, x)
Shen Li10224432021-08-12 11:39:31 -07004144 ''')
Zachary DeVito4c6da642019-02-21 15:24:23 -08004145
Zachary DeVito788d2e82018-11-21 06:36:26 -08004146 def test_annoying_doubles(self):
4147 mod = types.ModuleType("temp")
4148 mod.inf = float("inf")
4149 mod.ninf = float("-inf")
4150 mod.nan = float("nan")
4151
Michael Suoca1b8eb2020-07-13 16:57:41 -07004152 with torch._jit_internal._disable_emit_hooks():
Michael Suo16aa2352019-07-15 13:05:03 -07004153 class Foo(torch.jit.ScriptModule):
Michael Suo16aa2352019-07-15 13:05:03 -07004154 @torch.jit.script_method
4155 def forward(self):
Shen Li10224432021-08-12 11:39:31 -07004156 return math.pi, 0.1, mod.inf, mod.ninf, 2.225073858507201e-308, mod.nan
Michael Suo16aa2352019-07-15 13:05:03 -07004157
4158 foo = Foo()
4159 buffer = io.BytesIO()
4160 torch.jit.save(foo, buffer)
4161
4162 buffer.seek(0)
4163 foo_loaded = torch.jit.load(buffer)
4164
Zachary DeVito788d2e82018-11-21 06:36:26 -08004165 r = foo()
Michael Suo16aa2352019-07-15 13:05:03 -07004166 r2 = foo_loaded()
Zachary DeVito788d2e82018-11-21 06:36:26 -08004167 # use precise assert, we are checking floating point details
4168 self.assertTrue(r[:-1] == r2[:-1])
4169 self.assertTrue(math.isnan(r[-1]) and math.isnan(r2[-1]))
4170
Zachary DeVito44fb23a2018-11-08 20:22:57 -08004171 def test_type_annotate(self):
Shen Li10224432021-08-12 11:39:31 -07004172
Zachary DeVito44fb23a2018-11-08 20:22:57 -08004173 def foo(a):
4174 return torch.jit.annotate(torch.Tensor, a)
4175
4176 self.checkScript(foo, (torch.rand(3),))
4177
4178 def bar():
4179 a = torch.jit.annotate(List[int], [])
James Reedde6bb3f2019-01-26 17:39:34 -08004180 for _ in range(10):
Zachary DeVito44fb23a2018-11-08 20:22:57 -08004181 a.append(4)
4182 return a
4183
4184 self.checkScript(bar, ())
4185
Zachary DeVito05731692018-11-15 15:28:56 -08004186 def baz(a):
4187 return torch.jit.annotate(float, a)
4188 self.checkScript(baz, (torch.rand(()),))
Zachary DeVito44fb23a2018-11-08 20:22:57 -08004189
Wanchao Liangac00e852019-02-07 10:32:02 -08004190 # test annotate none types
4191 def annotate_none():
4192 return torch.jit.annotate(Optional[torch.Tensor], None)
4193
Wanchao Liangac00e852019-02-07 10:32:02 -08004194 self.checkScript(annotate_none, ())
Wanchao Liangac00e852019-02-07 10:32:02 -08004195
Shen Li10224432021-08-12 11:39:31 -07004196
Zachary DeVito22c9bc32018-08-28 11:19:39 -07004197 def test_robust_op_resolution(self):
4198 neg = torch.add # misleading name to make sure we resolve by function
4199
4200 def stuff(x):
4201 return neg(x, x)
4202
4203 a = (torch.rand(3),)
4204 self.checkScript(stuff, a)
4205
davidriazatiee288312020-02-21 08:40:10 -08004206 def test_nested_aug_assign(self):
4207 @torch.jit.script
Aaron Gokaslan8fce9a02023-02-07 21:10:52 +00004208 class SomeClass:
davidriazatiee288312020-02-21 08:40:10 -08004209 def __init__(self):
4210 self.num = 99
4211
4212 def __iadd__(self, x):
4213 # type: (int)
4214 self.num += x
4215 return self
4216
4217 def __eq__(self, other):
4218 # type: (SomeClass) -> bool
4219 return self.num == other.num
4220
4221 @torch.jit.script
Aaron Gokaslan8fce9a02023-02-07 21:10:52 +00004222 class SomeOutOfPlaceClass:
davidriazatiee288312020-02-21 08:40:10 -08004223 def __init__(self):
4224 self.num = 99
4225
4226 def __add__(self, x):
4227 # type: (int)
4228 self.num = x
4229 return self
4230
4231 def __eq__(self, other):
4232 # type: (SomeClass) -> bool
4233 return self.num == other.num
4234
4235 class Child(nn.Module):
4236 def __init__(self):
4237 super().__init__()
4238 self.x = 2
4239 self.o = SomeClass()
4240 self.oop = SomeOutOfPlaceClass()
4241 self.list = [1, 2, 3]
4242
4243 class A(nn.Module):
4244 def __init__(self):
4245 super().__init__()
4246 self.child = Child()
4247
4248 def forward(self):
4249 self.child.x += 1
4250 self.child.o += 5
4251 self.child.oop += 5
4252 some_list = [1, 2]
4253 self.child.list += some_list
4254 self.child.list *= 2
4255 return self.child.x, self.child.o, self.child.list, self.child.oop
4256
4257 a = A()
4258 sa = torch.jit.script(A())
4259 eager_result = a()
4260 script_result = sa()
4261 self.assertEqual(eager_result, script_result)
4262 self.assertEqual(a.child.x, sa.child.x)
4263 self.assertEqual(a.child.o, sa.child.o)
4264 self.assertEqual(a.child.list, sa.child.list)
4265
4266 @torch.jit.script
Aaron Gokaslan8fce9a02023-02-07 21:10:52 +00004267 class SomeNonAddableClass:
davidriazatiee288312020-02-21 08:40:10 -08004268 def __init__(self):
4269 self.num = 99
4270
4271 def __eq__(self, other):
4272 # type: (SomeClass) -> bool
4273 return self.num == other.num
4274
4275 # with self.assertRaisesRegex(RuntimeError, "")
4276 class A(nn.Module):
4277 def __init__(self):
4278 super().__init__()
4279 self.x = SomeNonAddableClass()
4280
4281 def forward(self):
4282 self.x += SomeNonAddableClass()
4283 return self.x
4284
4285 with self.assertRaisesRegex(RuntimeError, "Cannot emit inplace op"):
4286 torch.jit.script(A())
4287
davidriazaticea0cc82020-02-26 18:58:28 -08004288 def test_var_aug_assign(self):
4289 @torch.jit.script
Aaron Gokaslan8fce9a02023-02-07 21:10:52 +00004290 class SomeNonAddableClass:
davidriazaticea0cc82020-02-26 18:58:28 -08004291 def __init__(self):
4292 self.num = 99
4293
4294 def __eq__(self, other):
4295 # type: (SomeNonAddableClass) -> bool
4296 return self.num == other.num
4297
4298 with self.assertRaisesRegex(RuntimeError, "Cannot emit inplace op"):
4299 @torch.jit.script
4300 def fn():
4301 a = SomeNonAddableClass()
4302 a += SomeNonAddableClass()
4303 return a
4304
4305 @torch.jit.script
Aaron Gokaslan8fce9a02023-02-07 21:10:52 +00004306 class SomeClass:
davidriazaticea0cc82020-02-26 18:58:28 -08004307 def __init__(self):
4308 self.num = 99
4309
4310 def __iadd__(self, x):
4311 # type: (int)
4312 self.num += x
4313 return self
4314
4315 def __eq__(self, other):
4316 # type: (SomeClass) -> bool
4317 return self.num == other.num
4318
4319 @torch.jit.script
Aaron Gokaslan8fce9a02023-02-07 21:10:52 +00004320 class SomeOutOfPlaceClass:
davidriazaticea0cc82020-02-26 18:58:28 -08004321 def __init__(self):
4322 self.num = 99
4323
4324 def __add__(self, x):
4325 # type: (int)
4326 self.num = x
4327 return self
4328
4329 def __eq__(self, other):
4330 # type: (SomeClass) -> bool
4331 return self.num == other.num
4332
4333 def fn2():
4334 a = SomeClass()
4335 a_copy = a
4336 a += 20
4337 assert a is a_copy
4338 b = SomeOutOfPlaceClass()
4339 b_copy = b
4340 b += 99
4341 assert b is b_copy
4342 c = [1, 2, 3]
4343 c_copy = c
4344 c *= 2
4345 assert c is c_copy
4346 c += [4, 5, 6]
4347 d = torch.ones(2, 2)
4348 d_copy = d
4349 d += torch.ones(2, 2)
4350 assert d is d_copy
4351 return a, b, c, d
4352
4353 self.checkScript(fn2, [])
4354
Zachary DeVito478803a2018-09-26 16:55:07 -07004355 def test_nested_list_construct(self):
4356 def foo():
4357 return [[4]] + [[4, 5]]
4358 self.checkScript(foo, ())
4359
James Reed76deb452019-05-30 15:34:53 -07004360 def test_file_line_error(self):
4361 def foobar(xyz):
4362 return torch.blargh(xyz)
4363
4364 _, lineno = inspect.getsourcelines(foobar)
Shen Li10224432021-08-12 11:39:31 -07004365 with self.assertRaisesRegex(RuntimeError, "test_jit.py\", line {}".format(lineno + 1)):
James Reed76deb452019-05-30 15:34:53 -07004366 scripted = torch.jit.script(foobar)
4367
4368 def test_file_line_error_class_defn(self):
Aaron Gokaslan8fce9a02023-02-07 21:10:52 +00004369 class FooBar:
James Reed76deb452019-05-30 15:34:53 -07004370 def baz(self, xyz):
4371 return torch.blargh(xyz)
4372
4373 _, lineno = inspect.getsourcelines(FooBar)
Shen Li10224432021-08-12 11:39:31 -07004374 with self.assertRaisesRegex(RuntimeError, "test_jit.py\", line {}".format(lineno + 2)):
James Reed76deb452019-05-30 15:34:53 -07004375 torch.jit.script(FooBar)
4376
James Reeddaa1e2d2019-05-31 18:10:49 -07004377 def test_file_line_graph(self):
4378 def foobar(xyz):
4379 return torch.neg(xyz)
4380
4381 scripted = torch.jit.script(foobar)
4382
4383 _, lineno = inspect.getsourcelines(foobar)
Shen Li10224432021-08-12 11:39:31 -07004384 fc = FileCheck().check('test_jit.py:{}:19'.format(lineno + 1))
James Reed45d23052019-06-04 16:06:23 -07004385 fc.run(scripted.graph)
4386 fc.run(str(scripted.graph))
James Reeddaa1e2d2019-05-31 18:10:49 -07004387
James Reed619261d2019-05-31 23:37:38 -07004388 def test_file_line_save_load(self):
4389 class Scripted(torch.jit.ScriptModule):
4390 @torch.jit.script_method
4391 def forward(self, xyz):
4392 return torch.neg(xyz)
4393
4394 scripted = Scripted()
4395
4396 # NB: not using getExportImportCopy because that takes a different
4397 # code path that calls CompilationUnit._import rather than
4398 # going through the full save/load pathway
4399 buffer = scripted.save_to_buffer()
4400 bytesio = io.BytesIO(buffer)
4401 scripted = torch.jit.load(bytesio)
4402
James Reede63bfb72019-10-08 10:25:12 -07004403 _, lineno = inspect.getsourcelines(Scripted)
Shen Li10224432021-08-12 11:39:31 -07004404 fc = FileCheck().check(':{}'.format(lineno + 3))
James Reed45d23052019-06-04 16:06:23 -07004405 fc.run(scripted.graph)
4406 fc.run(str(scripted.graph))
James Reed619261d2019-05-31 23:37:38 -07004407
4408 def test_file_line_string(self):
Shen Li10224432021-08-12 11:39:31 -07004409 scripted = torch.jit.CompilationUnit('''
James Reed619261d2019-05-31 23:37:38 -07004410def foo(xyz):
4411 return torch.neg(xyz)
Shen Li10224432021-08-12 11:39:31 -07004412 ''')
James Reed619261d2019-05-31 23:37:38 -07004413
Shen Li10224432021-08-12 11:39:31 -07004414 fc = FileCheck().check('<string>:3:11')
James Reed45d23052019-06-04 16:06:23 -07004415 fc.run(scripted.foo.graph)
4416 fc.run(str(scripted.foo.graph))
James Reed619261d2019-05-31 23:37:38 -07004417
Edward Z. Yangee955b82022-04-19 19:56:43 -07004418 @skipIfCrossRef
James Reed4ac732e2019-06-03 11:09:14 -07004419 def test_file_line_trace(self):
4420 def foobar(xyz):
4421 return torch.neg(xyz)
4422
4423 scripted = torch.jit.trace(foobar, (torch.rand(3, 4)))
4424
4425 _, lineno = inspect.getsourcelines(foobar)
Shen Li10224432021-08-12 11:39:31 -07004426 fc = FileCheck().check('test_jit.py:{}:0'.format(lineno + 1))
James Reed45d23052019-06-04 16:06:23 -07004427 fc.run(scripted.graph)
4428 fc.run(str(scripted.graph))
James Reed4ac732e2019-06-03 11:09:14 -07004429
James Reedffa15d22019-07-01 21:11:12 -07004430 def test_serialized_source_ranges(self):
Shen Li10224432021-08-12 11:39:31 -07004431
James Reedffa15d22019-07-01 21:11:12 -07004432 class FooTest(torch.jit.ScriptModule):
4433 @torch.jit.script_method
4434 def forward(self, x, w):
4435 return torch.mm(x, w.t())
4436
4437 ft = FooTest()
4438 loaded = self.getExportImportCopy(ft)
4439 _, lineno = inspect.getsourcelines(FooTest)
4440
Shen Li10224432021-08-12 11:39:31 -07004441 with self.assertRaisesRegex(RuntimeError, 'test_jit.py\", line {}'.format(lineno + 3)):
James Reedffa15d22019-07-01 21:11:12 -07004442 loaded(torch.rand(3, 4), torch.rand(30, 40))
4443
James Reede63bfb72019-10-08 10:25:12 -07004444 def test_serialized_source_ranges_graph(self):
Shen Li10224432021-08-12 11:39:31 -07004445
James Reede63bfb72019-10-08 10:25:12 -07004446 class FooTest3(torch.jit.ScriptModule):
4447 @torch.jit.script_method
4448 def forward(self, x, w):
4449 return torch.mm(x, w.t())
4450
4451 ft = FooTest3()
4452 loaded = self.getExportImportCopy(ft)
4453 _, lineno = inspect.getsourcelines(FooTest3)
4454
Shen Li10224432021-08-12 11:39:31 -07004455 fc = FileCheck().check('test_jit.py:{}'.format(lineno + 3))
James Reede63bfb72019-10-08 10:25:12 -07004456 fc.run(loaded.graph)
4457
James Reedffa15d22019-07-01 21:11:12 -07004458 def test_serialized_source_ranges2(self):
Shen Li10224432021-08-12 11:39:31 -07004459
James Reedffa15d22019-07-01 21:11:12 -07004460 class FooTest2(torch.jit.ScriptModule):
4461 @torch.jit.script_method
4462 def forward(self):
Shen Li10224432021-08-12 11:39:31 -07004463 raise RuntimeError('foo')
James Reedffa15d22019-07-01 21:11:12 -07004464
4465 _, lineno = inspect.getsourcelines(FooTest2)
4466
Shen Li10224432021-08-12 11:39:31 -07004467 with self.assertRaisesRegex(torch.jit.Error, 'test_jit.py\", line {}'.format(lineno + 3)):
James Reedffa15d22019-07-01 21:11:12 -07004468 ft = FooTest2()
4469 loaded = self.getExportImportCopy(ft)
4470 loaded()
4471
4472 def test_serialized_source_ranges_dont_jitter(self):
4473 class FooTest3(torch.jit.ScriptModule):
4474 @torch.jit.script_method
4475 def forward(self, lim):
4476 first = 1
4477 second = 1
4478 i = 1
4479 somenum = 5
4480 dontmutateme = 3
4481 third = 0
4482 while bool(i < lim):
4483 third = first + second
4484 first = second
4485 second = third
4486 j = 0
4487 while j < 10:
4488 somenum = somenum * 2
4489 j = j + 1
4490 i = i + j
4491 i = i + dontmutateme
4492
4493 st = second + third
4494 fs = first + second
4495 return third, st, fs
4496
4497 ft3 = FooTest3()
4498
Michael Suoa69a62c2019-08-14 11:11:52 -07004499 def debug_records_from_mod(self, mod):
James Reedffa15d22019-07-01 21:11:12 -07004500 buffer = io.BytesIO()
4501 torch.jit.save(ft3, buffer)
4502 buffer.seek(0)
4503 archive = zipfile.ZipFile(buffer)
Shen Li10224432021-08-12 11:39:31 -07004504 files = filter(lambda x: x.startswith('archive/code/'), archive.namelist())
4505 debug_files = list(filter(lambda f: f.endswith('.debug_pkl'), files))
Michael Suoa69a62c2019-08-14 11:11:52 -07004506 self.assertEqual(len(debug_files), 1)
4507 debug_file = archive.open(debug_files[0])
James Reedffa15d22019-07-01 21:11:12 -07004508 return pickle.load(debug_file), buffer
4509
Michael Suoa69a62c2019-08-14 11:11:52 -07004510 records1, buffer = debug_records_from_mod(self, ft3)
James Reedffa15d22019-07-01 21:11:12 -07004511
4512 buffer.seek(0)
4513 loaded = torch.jit.load(buffer)
Michael Suoa69a62c2019-08-14 11:11:52 -07004514 records2, buffer = debug_records_from_mod(self, loaded)
James Reedffa15d22019-07-01 21:11:12 -07004515
4516 buffer.seek(0)
4517 loaded2 = torch.jit.load(buffer)
Michael Suoa69a62c2019-08-14 11:11:52 -07004518 records3, _ = debug_records_from_mod(self, loaded2)
James Reedffa15d22019-07-01 21:11:12 -07004519
4520 self.assertEqual(records1, records2)
4521 self.assertEqual(records2, records3)
4522
James Reed23e526e2019-07-26 17:43:55 -07004523 def test_serialized_source_ranges_no_dups(self):
4524 class FooTest3(torch.jit.ScriptModule):
4525 @torch.jit.script_method
4526 def forward(self, lim):
4527 first = 1
4528 second = 1
4529 i = 1
4530 somenum = 5
4531 dontmutateme = 3
4532 third = 0
4533 while bool(i < lim):
4534 third = first + second
4535 first = second
4536 second = third
4537 j = 0
4538 while j < 10:
4539 somenum = somenum * 2
4540 j = j + 1
4541 i = i + j
4542 i = i + dontmutateme
4543
4544 st = second + third
4545 fs = first + second
4546 return third, st, fs
4547
4548 ft3 = FooTest3()
4549
4550 def debug_records_from_mod(mod):
4551 buffer = io.BytesIO()
4552 torch.jit.save(ft3, buffer)
4553 buffer.seek(0)
4554 archive = zipfile.ZipFile(buffer)
Shen Li10224432021-08-12 11:39:31 -07004555 files = list(filter(lambda x: x.startswith('archive/code/'), archive.namelist()))
4556 debug_files = filter(lambda f: f.endswith('.debug_pkl'), files)
Alexander Grund93719442020-10-22 09:42:34 -07004557 debug_files = (archive.open(f) for f in debug_files)
4558 debug_files = (pickle.load(f) for f in debug_files)
Han Qiaca55942022-05-03 18:42:30 +00004559 debug_files = (f[2] for f in debug_files)
Michael Suo77c08aa2019-08-11 15:43:28 -07004560 return list(debug_files)
James Reed23e526e2019-07-26 17:43:55 -07004561
Michael Suo77c08aa2019-08-11 15:43:28 -07004562 debug_files = debug_records_from_mod(ft3)
Alban Desmaison3bd15072022-02-24 07:47:48 -08004563 for debug_file in debug_files:
Michael Suo77c08aa2019-08-11 15:43:28 -07004564 for i in range(len(debug_file) - 1):
Kimish Patelf4a92162021-05-04 09:17:43 -07004565 offset, source_range_tag, source_range = debug_file[i]
4566 offset2, source_range_tag2, source_range2 = debug_file[i + 1]
Michael Suo77c08aa2019-08-11 15:43:28 -07004567 self.assertNotEqual(source_range, source_range2)
James Reed23e526e2019-07-26 17:43:55 -07004568
Zachary DeVito0e3389d2019-09-26 11:36:53 -07004569 def test_circular_dependency(self):
4570 """
4571 https://github.com/pytorch/pytorch/issues/25871
4572 """
4573 class A(torch.jit.ScriptModule):
Zachary DeVito0e3389d2019-09-26 11:36:53 -07004574 @torch.jit.script_method
4575 def forward(self, x):
4576 return x
4577
4578 class B(torch.jit.ScriptModule):
4579 def __init__(self):
Xuehai Pan046e88a2023-02-12 22:20:50 +00004580 super().__init__()
Zachary DeVito0e3389d2019-09-26 11:36:53 -07004581 self.foo = torch.nn.ModuleList([A()])
4582
4583 @torch.jit.script_method
4584 def forward(self, x):
4585 for f in self.foo:
4586 x = f(x)
4587 return x
4588
4589 class C(torch.jit.ScriptModule):
4590 def __init__(self):
Xuehai Pan046e88a2023-02-12 22:20:50 +00004591 super().__init__()
Zachary DeVito0e3389d2019-09-26 11:36:53 -07004592 self.foo = torch.nn.Sequential(B())
4593
4594 @torch.jit.script_method
4595 def forward(self, x):
4596 for f in self.foo:
4597 x = f(x)
4598 return x
4599 self.getExportImportCopy(C())
4600
Zachary DeVitob6bb6442020-04-24 15:12:12 -07004601 def test_serialize_long_lines(self):
4602 class OrderModuleLong(torch.nn.Module):
4603 def forward(self, long_arg_name: List[torch.Tensor]):
4604 return [(long_arg_name[1],), (long_arg_name[0].argmax(),)]
4605 src = str(torch.jit.script(OrderModuleLong()).code)
4606 # make long_arg_name[1] does not get reordered after the argmax
4607 FileCheck().check("long_arg_name[1]").check("argmax").run(src)
Zachary DeVito0e3389d2019-09-26 11:36:53 -07004608
Xiang Gao97eec332018-10-12 11:27:02 -07004609 def test_tensor_shape(self):
4610 x = torch.empty(34, 56, 78)
4611
4612 def f(x):
4613 return x.shape
4614
4615 self.checkScript(f, (x,))
4616
Shen Li10224432021-08-12 11:39:31 -07004617
Nikolay Korovaiko88c0d882020-07-04 13:57:10 -07004618 def test_block_input_grad_in_loop(self):
4619
4620 x = torch.randn(3, 3, requires_grad=False)
4621 y = torch.randn(3, 3, requires_grad=True)
4622
4623 def grad_in_loop(x, y):
4624 for i in range(100):
4625 x = y @ x
4626 return x
4627
4628 scripted = torch.jit.script(grad_in_loop)
4629 outer = scripted.graph_for(x, y)
4630 loop = outer.findNode("prim::Loop")
4631 loop_block = next(loop.blocks())
4632 param_node = loop_block.paramNode()
4633 x_value = list(param_node.outputs())[1]
4634 self.assertTrue(x_value.requires_grad())
4635
Elias Ellison105fa582018-11-28 18:12:22 -08004636 def test_tensor_grad(self):
Wanchao Liangca7e2a72019-08-09 16:42:34 -07004637 x = torch.randn(3, 4, requires_grad=True)
4638 y = torch.randn(3, 4, requires_grad=False)
Elias Ellison105fa582018-11-28 18:12:22 -08004639
Wanchao Liangca7e2a72019-08-09 16:42:34 -07004640 def f_requires_grad(x):
Elias Ellison105fa582018-11-28 18:12:22 -08004641 return x.requires_grad
4642
Wanchao Liangca7e2a72019-08-09 16:42:34 -07004643 self.checkScript(f_requires_grad, (x,))
4644 self.checkScript(f_requires_grad, (y,))
4645
4646 def f_grad(x):
4647 return x.grad
4648
4649 x.sum().backward()
4650 self.checkScript(f_grad, (x,))
4651 self.checkScript(f_grad, (y,))
4652
Shen Li10224432021-08-12 11:39:31 -07004653 @unittest.skipIf(GRAPH_EXECUTOR != ProfilingMode.LEGACY, "shape analysis is only enabled in Legacy")
Nikolay Korovaiko82238582020-07-04 14:08:49 -07004654 def test_prim_grad_undefined(self):
4655
4656 x = torch.ones(2)
4657
4658 def f_grad(x):
4659 return x.grad
4660
4661 scripted = self.checkScript(f_grad, (x,))
4662 g = scripted.graph_for(x)
4663
4664 prim_grad_node = g.findNode("prim::grad")
4665 self.assertTrue(next(prim_grad_node.outputs()).type().undefined() is None)
4666
Wanchao Liangca7e2a72019-08-09 16:42:34 -07004667 def test_tensor_data(self):
Wanchao Lianga7eaec62019-09-10 14:33:28 -07004668 x = torch.randn(3, 4, requires_grad=True)
4669 y = torch.randn(4, 5)
Wanchao Liangca7e2a72019-08-09 16:42:34 -07004670
4671 def f_data(x):
4672 return x.data
4673
Wanchao Lianga7eaec62019-09-10 14:33:28 -07004674 scripted_f_data = torch.jit.script(f_data)
4675
4676 scripted_x = scripted_f_data(x)
4677 self.assertEqual(scripted_x, f_data(x))
4678 self.assertEqual(scripted_x.requires_grad, False)
4679
4680 scripted_y = scripted_f_data(y)
4681 self.assertEqual(scripted_y, f_data(y))
4682 self.assertEqual(scripted_x.requires_grad, False)
4683
Xiang Gao97eec332018-10-12 11:27:02 -07004684 def test_tensor_dtype(self):
4685 x_byte = torch.empty(34, 56, 78, dtype=torch.uint8)
4686 x_long = torch.empty(34, 56, 78, dtype=torch.long)
4687 x_float32 = torch.empty(34, 56, 78, dtype=torch.float32)
4688
4689 @torch.jit.script
4690 def byte(x):
4691 return x.dtype == torch.uint8
4692
4693 @torch.jit.script
4694 def long(x):
4695 return x.dtype == torch.long
4696
4697 @torch.jit.script
4698 def float32(x):
4699 return x.dtype == torch.float32
4700
4701 self.assertTrue(byte(x_byte))
4702 self.assertFalse(byte(x_long))
4703 self.assertFalse(byte(x_float32))
4704 self.assertFalse(long(x_byte))
4705 self.assertTrue(long(x_long))
4706 self.assertFalse(long(x_float32))
4707 self.assertFalse(float32(x_byte))
4708 self.assertFalse(float32(x_long))
4709 self.assertTrue(float32(x_float32))
4710
4711 @unittest.skipIf(not RUN_CUDA, "device tests require CUDA")
4712 def test_tensor_device(self):
Shen Li10224432021-08-12 11:39:31 -07004713 cpu = torch.empty(34, 56, 78, device='cpu')
4714 gpu = torch.empty(34, 56, 78, device='cuda')
Xiang Gao97eec332018-10-12 11:27:02 -07004715
4716 @torch.jit.script
4717 def same_device(x, y):
4718 return x.device == y.device
4719
4720 self.assertTrue(same_device(cpu, cpu))
4721 self.assertTrue(same_device(gpu, gpu))
4722 self.assertFalse(same_device(cpu, gpu))
4723
Zachary DeVito78d594f2018-12-03 16:52:34 -08004724 @unittest.skipIf(not RUN_CUDA, "device tests require CUDA")
4725 def test_tensor_to_device(self):
4726 def to_device(x):
4727 return x.to(device="cuda").to(device=torch.device("cpu"))
4728
4729 self.checkScript(to_device, (torch.ones(3, 4),))
4730
Elias Ellison3d0d16d2019-01-11 10:00:37 -08004731 def test_tensor_to_cpu(self):
4732 def to_cpu(x):
4733 return x.cpu()
4734
4735 x = torch.ones(3, 4)
4736 script_fn = torch.jit.script(to_cpu)
4737 self.assertEqual(to_cpu(x).device, script_fn(x).device)
4738 self.checkScript(to_cpu, (x,))
4739
4740 @unittest.skipIf(not RUN_CUDA, "device tests require CUDA")
4741 def test_tensor_to_cuda(self):
4742 def to_cuda(x):
4743 return x.cuda()
4744
4745 x = torch.ones(3, 4)
4746 script_fn = torch.jit.script(to_cuda)
4747 self.assertEqual(to_cuda(x).device, script_fn(x).device)
4748 self.checkScript(to_cuda, (x,))
4749
Zachary DeVito478803a2018-09-26 16:55:07 -07004750 def test_generic_list_errors(self):
4751 with self.assertRaisesRegex(RuntimeError, "previously matched to type"):
Zachary DeVitoab3a2d22018-09-13 11:05:09 -07004752 @torch.jit.script
4753 def foo(x):
Zachary DeVito478803a2018-09-26 16:55:07 -07004754 return [[x]] + [[1]]
Zachary DeVitoab3a2d22018-09-13 11:05:09 -07004755
Adam Paszkea58f2d22018-03-22 16:58:36 +01004756 def test_script_cu(self):
Shen Li10224432021-08-12 11:39:31 -07004757 cu = torch.jit.CompilationUnit('''
Adam Paszkea58f2d22018-03-22 16:58:36 +01004758 def foo(a):
4759 b = a
4760 return b
Shen Li10224432021-08-12 11:39:31 -07004761 ''')
Adam Paszkea58f2d22018-03-22 16:58:36 +01004762 a = Variable(torch.rand(1))
4763 self.assertEqual(a, cu.foo(a))
James Reed60415cf2018-03-02 15:03:44 -08004764
Elias Ellison170d2972018-08-02 10:59:36 -07004765 # because the compilation unit ingests python strings
4766 # to use an escape sequence escape the backslash (\\n = \n)
4767 def test_string_cu(self):
Shen Li10224432021-08-12 11:39:31 -07004768 cu = torch.jit.CompilationUnit('''
Elias Ellison170d2972018-08-02 10:59:36 -07004769 def foo(a):
4770 print(a, """a\\n\tb\\n""", 2, "a\
4771a")
4772 return a
Shen Li10224432021-08-12 11:39:31 -07004773 ''')
Elias Ellison4d2f6f12019-03-12 11:25:37 -07004774 FileCheck().check("aa").check("a\\n\\tb\\n").run(str(cu.foo.graph))
Elias Ellison170d2972018-08-02 10:59:36 -07004775
Elias Ellisond38f9112019-12-04 12:43:38 -08004776 def test_function_compilation_caching(self):
4777 def fun():
4778 return 1 + 2
4779
4780 fun_compiled = torch.jit.script(fun)
4781 # python wrapper around the script function is a different pointer,
4782 # but the underlying script function graph is the same
4783 self.assertIs(fun_compiled.graph, torch.jit.script(fun).graph)
4784
4785 def fun():
4786 return 3 + 4
4787
4788 num_ref_counts = sys.getrefcount(fun)
4789
4790 # caching doesn't get tripped up by same qualname
4791 fun_compiled_2 = torch.jit.script(fun)
4792 self.assertIsNot(fun_compiled, fun_compiled_2)
4793 self.assertEqual(fun_compiled_2(), 7)
4794
4795 # caching doesnt increase refcounts to function (holds weak reference)
4796 self.assertTrue(sys.getrefcount(fun), num_ref_counts)
4797
Elias Ellison9e6a6952018-10-29 10:11:34 -07004798 def test_string_ops(self):
4799 def foo():
4800 a = "a" + "b"
4801 return a + a, "ab" == "b", "ab" != "b", "ab" == "ab", "ab" != "ab"
4802
4803 self.checkScript(foo, ())
4804
Yanan Cao2558e572020-09-17 11:18:24 -07004805 def test_string_sorted(self):
Yanan Cao317b9d32020-08-04 15:24:14 -07004806 def foo(strs: List[str]):
4807 return sorted(strs)
4808
Shen Li10224432021-08-12 11:39:31 -07004809 FileCheck() \
4810 .check("graph") \
4811 .check_next("str[] = aten::sorted") \
4812 .check_next("return") \
4813 .run(str(torch.jit.script(foo).graph))
Yanan Cao43613b42020-08-11 13:24:21 -07004814
Yanan Cao317b9d32020-08-04 15:24:14 -07004815 inputs = ["str3", "str2", "str1"]
Yanan Cao2558e572020-09-17 11:18:24 -07004816 self.checkScript(foo, (inputs,))
4817
4818 def test_string_sort(self):
4819 def foo(strs: List[str]):
4820 strs.sort()
4821 return strs
4822
4823 inputs = ["str3", "str2", "str1"]
4824 self.checkScript(foo, (inputs,))
4825
4826 def test_tuple_sorted(self):
4827 def foo(tups: List[Tuple[int, int]]):
4828 return sorted(tups)
4829
4830 inputs = [(1, 2), (0, 2), (1, 3)]
4831 self.checkScript(foo, (inputs,))
4832
4833 def test_tuple_sort(self):
4834 def foo(tups: List[Tuple[int, int]]):
4835 tups.sort()
4836 return tups
4837
4838 inputs = [(1, 2), (0, 2), (1, 3)]
4839 self.checkScript(foo, (inputs,))
4840
4841 def test_tuple_sort_reverse(self):
4842 def foo(tups: List[Tuple[int, int]]):
4843 tups.sort(reverse=True)
4844 return tups
4845
4846 inputs = [(1, 2), (0, 2), (1, 3)]
4847 self.checkScript(foo, (inputs,))
4848
4849 def test_tuple_unsortable_element_type(self):
4850 @torch.jit.script
4851 def foo():
4852 tups = [({1: 2}, {2: 3})]
4853 tups.sort()
4854 return tups
4855
Shen Li10224432021-08-12 11:39:31 -07004856 with self.assertRaisesRegexWithHighlight(RuntimeError, "are not sortable", "tups.sort"):
Yanan Cao2558e572020-09-17 11:18:24 -07004857 foo()
4858
4859 def test_tuple_unsortable_diff_type(self):
4860 @torch.jit.script
4861 def foo(inputs: List[Any]):
4862 inputs.sort()
4863 return inputs
4864
4865 inputs = [(1, 2), ("foo", "bar")]
Shen Li10224432021-08-12 11:39:31 -07004866 with self.assertRaisesRegexWithHighlight(RuntimeError, "Only values of same type can be compared", "inputs.sort"):
Yanan Cao2558e572020-09-17 11:18:24 -07004867 foo(inputs)
4868
4869 def test_tuple_nested_sort(self):
4870 def foo(inputs: List[Tuple[int, Tuple[int, str]]]):
4871 inputs.sort()
4872 return inputs
4873
4874 inputs = [(1, (2, "foo")), (1, (2, "bar")), (1, (0, "bar"))]
4875 self.checkScript(foo, (inputs,))
4876
4877 def test_tuple_unsortable_nested_diff_type(self):
4878 @torch.jit.script
4879 def foo(inputs: List[Any]):
4880 inputs.sort()
4881 return inputs
4882
4883 inputs = [(1, (2, 3)), (2, ("foo", "bar"))]
Shen Li10224432021-08-12 11:39:31 -07004884 with self.assertRaisesRegexWithHighlight(RuntimeError, "Only values of same type can be compared", "inputs.sort"):
Yanan Cao2558e572020-09-17 11:18:24 -07004885 foo(inputs)
Yanan Cao317b9d32020-08-04 15:24:14 -07004886
Elias Ellison170d2972018-08-02 10:59:36 -07004887 def test_string_new_line(self):
4888 with self.assertRaisesRegex(RuntimeError, "expected a valid token*"):
Shen Li10224432021-08-12 11:39:31 -07004889 torch.jit.CompilationUnit('''
Elias Ellison170d2972018-08-02 10:59:36 -07004890 def test_while(a):
4891 print("
4892 a")
4893 return a
Shen Li10224432021-08-12 11:39:31 -07004894 ''')
Elias Ellison170d2972018-08-02 10:59:36 -07004895
4896 def test_string_single_escape(self):
4897 with self.assertRaisesRegex(RuntimeError, "expected a valid token*"):
Shen Li10224432021-08-12 11:39:31 -07004898 torch.jit.CompilationUnit('''
Elias Ellison170d2972018-08-02 10:59:36 -07004899 def test_while(a):
4900 print("\\")
4901 return a
Shen Li10224432021-08-12 11:39:31 -07004902 ''')
Elias Ellison170d2972018-08-02 10:59:36 -07004903
Adam Paszkea58f2d22018-03-22 16:58:36 +01004904 def test_script_annotation(self):
4905 @torch.jit.script
4906 def foo(a):
4907 return a + a + a
4908 s = Variable(torch.rand(2))
4909 self.assertEqual(s + s + s, foo(s))
James Reed60415cf2018-03-02 15:03:44 -08004910
nikithamalgie677b712021-02-21 19:53:57 -08004911 def test_torch_pow(self):
4912 def func(a, b):
4913 return pow(a, b)
4914
4915 def func2(a, b, c, d):
4916 return pow(pow(c + a, b), d)
4917
Shen Li10224432021-08-12 11:39:31 -07004918 def func3(a : int, b : float):
nikithamalgie677b712021-02-21 19:53:57 -08004919 # type: (int, float) -> float
4920 return pow(a, b)
4921
4922 def func4():
4923 # type: () -> float
4924 return pow(2, -2)
4925
4926 def func5(x, y):
4927 return pow(x.item(), y.item())
4928
Shen Li10224432021-08-12 11:39:31 -07004929 def func6(a : int, b : int):
nikithamalgie677b712021-02-21 19:53:57 -08004930 # type: (int, int) -> float
4931 return pow(a, b)
4932
4933 a = torch.rand(1)
4934 b = torch.rand(1)
4935 c = torch.rand(1)
4936 d = torch.rand(1)
4937 self.checkScript(func, (a, b))
4938 self.checkScript(func2, (a, b, c, d))
4939 self.checkScript(func3, (4, -0.5))
4940 self.checkScript(func4, ())
4941 self.checkScript(func6, (2, 4))
4942
Shen Li10224432021-08-12 11:39:31 -07004943 inputs = [torch.tensor(2), torch.tensor(-2), torch.tensor(.5), torch.tensor(.2)]
nikithamalgie677b712021-02-21 19:53:57 -08004944 for x in inputs:
4945 for y in inputs:
4946 if x < 0:
4947 continue
4948 else:
4949 self.checkScript(func5, (x, y))
4950
Thomas Viehmannb9291f52019-04-18 17:52:33 -07004951 @unittest.skipIf(not RUN_CUDA, "device tests require CUDA")
4952 def test_pow_scalar_backward_cuda(self):
4953 # see that scalar exponent works with cuda base (#19253)
Elias Ellison0e3a05e2020-05-06 11:27:59 -07004954 with enable_profiling_mode_for_profiling_tests():
Nikolay Korovaiko5b702ab2019-11-11 13:39:03 -08004955 for dtype in [torch.float, torch.double]:
4956 @torch.jit.script
4957 def func(a, b):
4958 # type: (Tensor, float) -> Tensor
4959 return (a * 2) ** b
Thomas Viehmannb9291f52019-04-18 17:52:33 -07004960
Shen Li10224432021-08-12 11:39:31 -07004961 a = torch.rand(1, requires_grad=True, device='cuda', dtype=dtype)
Nikolay Korovaiko5b702ab2019-11-11 13:39:03 -08004962 func(a, 1, profile_and_replay=True).backward()
Thomas Viehmannb9291f52019-04-18 17:52:33 -07004963
Nikolay Korovaiko5b702ab2019-11-11 13:39:03 -08004964 @torch.jit.script
4965 def func(a, b):
4966 # type: (float, Tensor) -> Tensor
Sam Estepe3900d22021-04-19 13:14:27 -07004967 return a ** (b * 2 + 1)
Thomas Viehmannb9291f52019-04-18 17:52:33 -07004968
Shen Li10224432021-08-12 11:39:31 -07004969 a = torch.rand(1, requires_grad=True, device='cuda', dtype=dtype)
Nikolay Korovaiko5b702ab2019-11-11 13:39:03 -08004970 func(2, a, profile_and_replay=True).backward()
Thomas Viehmannb9291f52019-04-18 17:52:33 -07004971
Richard Zou4d678792018-09-06 16:33:59 -07004972 def _check_code(self, code_str, fn_name, inputs):
4973 scope = {}
4974 exec(code_str, globals(), scope)
4975 cu = torch.jit.CompilationUnit(code_str)
4976 self.assertEqual(cu.func(*inputs), scope[fn_name](*inputs))
4977
Shen Li10224432021-08-12 11:39:31 -07004978 @unittest.skipIf(not RUN_CUDA, 'no CUDA')
Richard Zou13b05c82018-09-12 11:18:42 -07004979 def test_scriptmodule_releases_tensors_cuda(self):
Elias Ellison0e3a05e2020-05-06 11:27:59 -07004980 with enable_profiling_mode_for_profiling_tests():
Nikolay Korovaiko5b702ab2019-11-11 13:39:03 -08004981 @torch.jit.script
4982 def fn(x, y):
4983 return x.sigmoid() * y.tanh()
Richard Zou13b05c82018-09-12 11:18:42 -07004984
Nikolay Korovaiko5b702ab2019-11-11 13:39:03 -08004985 def test(backward=False):
Shen Li10224432021-08-12 11:39:31 -07004986 x = torch.randn(3, 3, dtype=torch.double, device='cuda', requires_grad=True)
4987 y = torch.randn(3, 3, dtype=torch.double, device='cuda', requires_grad=True)
Nikolay Korovaiko5b702ab2019-11-11 13:39:03 -08004988 out = fn(x, y, profile_and_replay=True)
4989 if backward:
4990 out.sum().backward()
Richard Zou13b05c82018-09-12 11:18:42 -07004991
Nikolay Korovaiko5b702ab2019-11-11 13:39:03 -08004992 with self.assertLeaksNoCudaTensors():
4993 test()
4994 test()
4995 test()
Richard Zou13b05c82018-09-12 11:18:42 -07004996
Nikolay Korovaiko5b702ab2019-11-11 13:39:03 -08004997 if GRAPH_EXECUTOR != ProfilingMode.SIMPLE:
4998 with self.assertLeaksNoCudaTensors():
4999 test(backward=True)
5000 test(backward=True)
5001 test(backward=True)
Richard Zou13b05c82018-09-12 11:18:42 -07005002
Tugsbayasgalan (Tugsuu) Manlaibaatarfca931d2021-06-22 11:20:59 -07005003 def test_index(self):
5004 def consec(size, start=0):
5005 numel = torch.tensor(size).prod().item()
5006 return torch.arange(numel).view(size)
5007
5008 def consec_list(size):
5009 return list(range(size))
5010
5011 def random_string(size):
5012 letters = string.ascii_lowercase
5013 return "".join(random.choice(letters) for i in range(size))
5014
5015 def check_indexing(indexing, tensor):
Shen Li10224432021-08-12 11:39:31 -07005016 template = dedent("""
Tugsbayasgalan (Tugsuu) Manlaibaatarfca931d2021-06-22 11:20:59 -07005017 def func(x):
5018 return x{}
Shen Li10224432021-08-12 11:39:31 -07005019 """)
Tugsbayasgalan (Tugsuu) Manlaibaatarfca931d2021-06-22 11:20:59 -07005020
5021 self._check_code(template.format(indexing), "func", [tensor])
5022
5023 def check_dynamic_indexing(indexing, tensor, value1, value2):
5024 value1 = torch.tensor(value1)
5025 value2 = torch.tensor(value2)
5026
Shen Li10224432021-08-12 11:39:31 -07005027 template = dedent("""
Tugsbayasgalan (Tugsuu) Manlaibaatarfca931d2021-06-22 11:20:59 -07005028 def func(x, value1, value2):
5029 i = int(value1)
5030 j = int(value2)
5031 return x{}
Shen Li10224432021-08-12 11:39:31 -07005032 """)
Tugsbayasgalan (Tugsuu) Manlaibaatarfca931d2021-06-22 11:20:59 -07005033
Shen Li10224432021-08-12 11:39:31 -07005034 self._check_code(template.format(indexing), "func", [tensor, value1, value2])
Tugsbayasgalan (Tugsuu) Manlaibaatarfca931d2021-06-22 11:20:59 -07005035
5036 # Torchscript assumes type Tensor by default, so we need this explicit
5037 # declaration.
5038 def check_indexing_list_int(indexing, list):
Shen Li10224432021-08-12 11:39:31 -07005039 template = dedent("""
Tugsbayasgalan (Tugsuu) Manlaibaatarfca931d2021-06-22 11:20:59 -07005040 def func(x):
5041 # type: (List[int]) -> Any
5042 return x{}
Shen Li10224432021-08-12 11:39:31 -07005043 """)
Tugsbayasgalan (Tugsuu) Manlaibaatarfca931d2021-06-22 11:20:59 -07005044
5045 self._check_code(template.format(indexing), "func", [list])
5046
5047 def check_indexing_str(indexing, str):
Shen Li10224432021-08-12 11:39:31 -07005048 template = dedent("""
Tugsbayasgalan (Tugsuu) Manlaibaatarfca931d2021-06-22 11:20:59 -07005049 def func(x):
5050 # type: (str) -> Any
5051 return x{}
Shen Li10224432021-08-12 11:39:31 -07005052 """)
Tugsbayasgalan (Tugsuu) Manlaibaatarfca931d2021-06-22 11:20:59 -07005053
5054 self._check_code(template.format(indexing), "func", [str])
5055
5056 # basic slices
Shen Li10224432021-08-12 11:39:31 -07005057 check_indexing('[0]', consec((3, 3)))
5058 check_indexing('[1]', consec((3, 3), 10))
5059 check_indexing('[2]', consec((3, 3), 19))
5060 check_indexing('[2]', consec((3,)))
5061 check_indexing('[-1]', consec((3, 3), 19))
5062 check_indexing('[0:2]', consec((3, 3, 3)))
5063 check_indexing('[1:-1]', consec((3, 3, 3)))
5064 check_indexing('[-3:-1]', consec((6, 3)))
5065 check_indexing('[1:]', consec((3, 3)))
5066 check_indexing('[:1]', consec((3, 3)))
5067 check_indexing('[:]', consec((3, 2)))
Tugsbayasgalan (Tugsuu) Manlaibaatarfca931d2021-06-22 11:20:59 -07005068
5069 # multi-dim: indexes
Shen Li10224432021-08-12 11:39:31 -07005070 check_indexing('[0, 1]', consec((3, 3)))
5071 check_indexing('[0, 1]', consec((3, 3, 2)))
5072 check_indexing('[1, 0, 2]', consec((3, 3, 3)))
5073 check_indexing('[2, -1]', consec((3, 3)))
Tugsbayasgalan (Tugsuu) Manlaibaatarfca931d2021-06-22 11:20:59 -07005074
5075 # multi-dim: mixed slicing and indexing
Shen Li10224432021-08-12 11:39:31 -07005076 check_indexing('[0, 1:2]', consec((3, 3)))
5077 check_indexing('[0, :1]', consec((3, 3, 2)))
5078 check_indexing('[1, 2:]', consec((3, 3, 3)))
5079 check_indexing('[-1, 1:, 0]', consec((3, 3, 3, 3)))
5080 check_indexing('[1:, -1, 0]', consec((3, 3, 3, 3)))
5081 check_indexing('[-1, 2:, 1:2]', consec((3, 3, 3, 3)))
5082 check_indexing('[-1, 1:, 0]', consec((3, 3, 3, 3)))
5083 check_indexing('[-1, :, 0, 2]', consec((3, 3, 3, 3)))
Tugsbayasgalan (Tugsuu) Manlaibaatarfca931d2021-06-22 11:20:59 -07005084
5085 # zero-sized slices
Shen Li10224432021-08-12 11:39:31 -07005086 check_indexing('[0:0]', consec((2, 2)))
5087 check_indexing('[0:0, 1]', consec((3, 3)))
Tugsbayasgalan (Tugsuu) Manlaibaatarfca931d2021-06-22 11:20:59 -07005088
5089 # trivial expression usage
Shen Li10224432021-08-12 11:39:31 -07005090 check_indexing('[1+1]', consec((3, 3)))
5091 check_indexing('[1:(0 + 2)]', consec((3, 3, 3)))
Tugsbayasgalan (Tugsuu) Manlaibaatarfca931d2021-06-22 11:20:59 -07005092
5093 # None for new dimensions
Shen Li10224432021-08-12 11:39:31 -07005094 check_indexing('[None, 0]', consec((3, 3)))
5095 check_indexing('[1, None]', consec((3, 3), 10))
5096 check_indexing('[None, None, 2]', consec((3, 3), 19))
5097 check_indexing('[None, 2, None]', consec((3,)))
5098 check_indexing('[0:2, None]', consec((3, 3, 3)))
5099 check_indexing('[None, 1:-1]', consec((3, 3, 3)))
5100 check_indexing('[None, -3:-1, None]', consec((6, 3)))
5101 check_indexing('[-1, None, 2:, None, 1:2]', consec((3, 3, 3, 3)))
5102 check_indexing('[None, -1, None, 2:, None, 1:2, None]', consec((3, 3, 3, 3)))
Tugsbayasgalan (Tugsuu) Manlaibaatarfca931d2021-06-22 11:20:59 -07005103
5104 # dynamic expression usage
5105 check_dynamic_indexing("[i + j]", consec((3, 3)), 0, 1)
5106 check_dynamic_indexing("[i:j, i]", consec((3, 3, 2)), 0, 2)
5107
5108 # positive striding
Shen Li10224432021-08-12 11:39:31 -07005109 check_indexing_list_int('[0]', consec_list(6))
5110 check_indexing_list_int('[1]', consec_list(7))
5111 check_indexing_list_int('[2]', consec_list(8))
5112 check_indexing_list_int('[2]', consec_list(9))
5113 check_indexing_list_int('[-1]', consec_list(10))
5114 check_indexing_list_int('[0:2]', consec_list(11))
5115 check_indexing_list_int('[1:-1]', consec_list(12))
5116 check_indexing_list_int('[-3:-1]', consec_list(13))
5117 check_indexing_list_int('[1:]', consec_list(15))
5118 check_indexing_list_int('[:1]', consec_list(16))
5119 check_indexing_list_int('[:]', consec_list(17))
5120 check_indexing_list_int('[::]', consec_list(0))
5121 check_indexing_list_int('[1000::]', consec_list(0))
5122 check_indexing_list_int('[:1000:]', consec_list(0))
Tugsbayasgalan (Tugsuu) Manlaibaatarfca931d2021-06-22 11:20:59 -07005123
5124 # negative striding
Shen Li10224432021-08-12 11:39:31 -07005125 check_indexing_list_int('[::-1]', consec_list(7))
5126 check_indexing_list_int('[:3:-1]', consec_list(7))
5127 check_indexing_list_int('[3::-1]', consec_list(7))
5128 check_indexing_list_int('[1000::-1]', consec_list(7))
5129 check_indexing_list_int('[3:0:-1]', consec_list(7))
5130 check_indexing_list_int('[3:-1000:-1]', consec_list(7))
5131 check_indexing_list_int('[0:0:-1]', consec_list(7))
5132 check_indexing_list_int('[0:-1000:-1]', consec_list(7))
Tugsbayasgalan (Tugsuu) Manlaibaatarfca931d2021-06-22 11:20:59 -07005133
5134 # only step is specified
Shen Li10224432021-08-12 11:39:31 -07005135 check_indexing_list_int('[::-1]', consec_list(0))
5136 check_indexing_list_int('[::-1]', consec_list(7))
5137 check_indexing_list_int('[::-2]', consec_list(7))
5138 check_indexing_list_int('[::2]', consec_list(7))
5139 check_indexing_list_int('[::42]', consec_list(7))
5140 check_indexing_list_int('[::-42]', consec_list(7))
5141 check_indexing_list_int('[::42]', consec_list(0))
5142 check_indexing_list_int('[::-42]', consec_list(0))
5143 check_indexing_list_int('[::9223372036854775807]', consec_list(42))
5144 check_indexing_list_int('[::-9223372036854775807]', consec_list(42))
Tugsbayasgalan (Tugsuu) Manlaibaatarfca931d2021-06-22 11:20:59 -07005145 with self.assertRaisesRegex(RuntimeError, "out of bounds"):
Shen Li10224432021-08-12 11:39:31 -07005146 check_indexing_list_int('[::-9223372036854775808]', consec_list(42))
Tugsbayasgalan (Tugsuu) Manlaibaatarfca931d2021-06-22 11:20:59 -07005147 with self.assertRaisesRegex(RuntimeError, "should have non-zero step"):
Shen Li10224432021-08-12 11:39:31 -07005148 check_indexing_list_int('[::0]', consec_list(42))
Tugsbayasgalan (Tugsuu) Manlaibaatarfca931d2021-06-22 11:20:59 -07005149
5150 # striding strings
Shen Li10224432021-08-12 11:39:31 -07005151 check_indexing_str('[0]', random_string(6))
5152 check_indexing_str('[1]', random_string(7))
5153 check_indexing_str('[2]', random_string(8))
5154 check_indexing_str('[2]', random_string(9))
5155 check_indexing_str('[-1]', random_string(10))
5156 check_indexing_str('[0:2]', random_string(11))
5157 check_indexing_str('[1:-1]', random_string(12))
5158 check_indexing_str('[-3:-1]', random_string(13))
5159 check_indexing_str('[1:]', random_string(15))
5160 check_indexing_str('[:1]', random_string(16))
5161 check_indexing_str('[:]', random_string(17))
5162 check_indexing_str('[::]', random_string(0))
5163 check_indexing_str('[1000::]', random_string(0))
5164 check_indexing_str('[:1000:]', random_string(0))
Tugsbayasgalan (Tugsuu) Manlaibaatarfca931d2021-06-22 11:20:59 -07005165
Shen Li10224432021-08-12 11:39:31 -07005166 check_indexing_str('[::-1]', random_string(7))
5167 check_indexing_str('[:3:-1]', random_string(7))
5168 check_indexing_str('[3::-1]', random_string(7))
5169 check_indexing_str('[1000::-1]', random_string(7))
5170 check_indexing_str('[3:0:-1]', random_string(7))
5171 check_indexing_str('[3:-1000:-1]', random_string(7))
5172 check_indexing_str('[0:0:-1]', random_string(7))
5173 check_indexing_str('[0:-1000:-1]', random_string(7))
Tugsbayasgalan (Tugsuu) Manlaibaatarfca931d2021-06-22 11:20:59 -07005174
Shen Li10224432021-08-12 11:39:31 -07005175 check_indexing_str('[::-1]', random_string(0))
5176 check_indexing_str('[::-1]', random_string(7))
5177 check_indexing_str('[::-2]', random_string(7))
5178 check_indexing_str('[::2]', random_string(7))
5179 check_indexing_str('[::42]', random_string(7))
5180 check_indexing_str('[::-42]', random_string(7))
5181 check_indexing_str('[::42]', random_string(0))
5182 check_indexing_str('[::-42]', random_string(0))
5183 check_indexing_str('[::9223372036854775807]', random_string(42))
5184 check_indexing_str('[::-9223372036854775807]', random_string(42))
Tugsbayasgalan (Tugsuu) Manlaibaatarfca931d2021-06-22 11:20:59 -07005185 with self.assertRaisesRegex(RuntimeError, "out of bounds"):
Shen Li10224432021-08-12 11:39:31 -07005186 check_indexing_str('[::-9223372036854775808]', random_string(42))
Tugsbayasgalan (Tugsuu) Manlaibaatarfca931d2021-06-22 11:20:59 -07005187 with self.assertRaisesRegex(RuntimeError, "should have non-zero step"):
Shen Li10224432021-08-12 11:39:31 -07005188 check_indexing_str('[::0]', random_string(42))
Tugsbayasgalan (Tugsuu) Manlaibaatarfca931d2021-06-22 11:20:59 -07005189
David Riazatifea0a0b2019-04-11 15:33:51 -07005190 def test_module_copy_with_attributes(self):
5191 class Vocabulary(torch.jit.ScriptModule):
5192 def __init__(self, vocab_list):
Xuehai Pan046e88a2023-02-12 22:20:50 +00005193 super().__init__()
David Riazatifea0a0b2019-04-11 15:33:51 -07005194 self._vocab = torch.jit.Attribute(vocab_list, List[str])
5195 self.some_idx = torch.jit.Attribute(2, int)
5196 self.idx = torch.jit.Attribute(
5197 {word: i for i, word in enumerate(vocab_list)}, Dict[str, int]
5198 )
5199
5200 @torch.jit.script_method
5201 def lookup_indices_1d(self, values):
5202 # type: (List[str]) -> List[int]
5203 result = torch.jit.annotate(List[int], [])
5204 # Direct list iteration not supported
5205 for i in range(len(values)):
5206 value = values[i]
5207 result.append(self.idx.get(value, self.some_idx))
5208 return result
5209
5210 @torch.jit.script_method
5211 def forward(self, values):
5212 # type: (List[List[str]]) -> List[List[int]]
5213 result = torch.jit.annotate(List[List[int]], [])
5214 # Direct list iteration not supported
5215 for i in range(len(values)):
5216 result.append(self.lookup_indices_1d(values[i]))
5217 return result
5218
Shen Li10224432021-08-12 11:39:31 -07005219 v = Vocabulary(list('uabcdefg'))
Jerry Zhangcbd53bf2020-06-23 15:59:56 -07005220 v.__copy__()
David Riazatifea0a0b2019-04-11 15:33:51 -07005221
Elias Ellison862b8ca2018-12-03 23:59:36 -08005222 def test_tuple_to_opt_list(self):
5223 @torch.jit.script
5224 def foo(x):
5225 # type: (Optional[List[int]]) -> int
5226 return 1
5227
5228 @torch.jit.script
5229 def tuple_call():
5230 return foo((1, 2))
5231
Zachary DeVito8995ddd2018-04-12 10:32:49 -07005232 def test_keyword(self):
5233 @torch.jit.script
5234 def func(x):
Zachary DeVitoce69d312018-05-14 14:46:36 -07005235 return torch.sum(x, dim=0)
Zachary DeVito8995ddd2018-04-12 10:32:49 -07005236
5237 x = torch.rand(10, dtype=torch.float, requires_grad=True)
5238 y = func(x)
Zachary DeVitoce69d312018-05-14 14:46:36 -07005239 y2 = torch.sum(x, dim=0)
Zachary DeVito8995ddd2018-04-12 10:32:49 -07005240 self.assertEqual(y, y2)
5241
Elias Ellison2ff0e3b2019-01-07 09:58:08 -08005242 def test_constant_pooling_none(self):
5243 @torch.jit.script
5244 def typed_nones(a=None, b=None, c=None):
Sam Estepe3900d22021-04-19 13:14:27 -07005245 # type: (Optional[int], Optional[bool], Optional[Tensor]) -> Tuple[Optional[int], Optional[bool], Optional[Tensor]]
Elias Ellison2ff0e3b2019-01-07 09:58:08 -08005246 return a, b, c
5247
5248 @torch.jit.script
5249 def test(a):
5250 # type: (bool) -> None
5251 if a:
5252 print(typed_nones())
5253 else:
5254 print(typed_nones())
5255
5256 graph_str = str(test.graph)
James Reed68e07962021-04-12 17:33:20 -07005257 self.assertTrue(graph_str.count("NoneType = prim::Constant") == 1)
Elias Ellison2ff0e3b2019-01-07 09:58:08 -08005258
Elias Ellison8ecd3f72020-01-08 16:44:58 -08005259 def test_constant_pooling_same_identity(self):
5260 def foo():
5261 a = torch.tensor([4])
5262 b = (a,)
5263 index = len(a) - 1
5264 c = b[index]
5265 d = b[index]
5266 return c, d
5267
5268 foo_script = torch.jit.script(foo)
Shen Li10224432021-08-12 11:39:31 -07005269 self.run_pass('constant_propagation', foo_script.graph)
5270 self.run_pass('constant_pooling', foo_script.graph)
Elias Ellison8ecd3f72020-01-08 16:44:58 -08005271 # even though the c & d escape scope, we are still able
5272 # pool them into one constant because they are the same object
5273 FileCheck().check_count("prim::Constant", 1, exactly=True).run(foo_script.graph)
5274 self.assertEqual(foo(), foo_script())
5275
5276 def test_constant_pooling_introduce_aliasing(self):
5277 @torch.jit.script
5278 def foo():
5279 a = torch.tensor(1)
Elias Ellison38d122e2020-01-22 12:09:46 -08005280 b = torch.tensor(1)
Elias Ellison8ecd3f72020-01-08 16:44:58 -08005281 return a, b
5282
Shen Li10224432021-08-12 11:39:31 -07005283 self.run_pass('constant_propagation', foo.graph)
5284 self.run_pass('constant_pooling', foo.graph)
Elias Ellison8ecd3f72020-01-08 16:44:58 -08005285 # dont pool constants bc it would introduce observable alias relationship changing
Elias Ellison38d122e2020-01-22 12:09:46 -08005286 a, b = foo()
5287 self.assertIsNot(a, b)
Elias Ellison8ecd3f72020-01-08 16:44:58 -08005288
Zachary DeVitob8ada732018-04-23 10:58:07 -07005289 def test_literal(self):
Michael Suo13de6e82018-08-03 00:40:33 -07005290 def func1(a, b):
Zachary DeVitob8ada732018-04-23 10:58:07 -07005291 c = a, b
5292 d, e = c
5293 return d + e
5294
Michael Suo13de6e82018-08-03 00:40:33 -07005295 def func2(a, b):
Zachary DeVitob8ada732018-04-23 10:58:07 -07005296 c = a, (a, b)
5297 d, e = c
5298 f, g = e
5299 return d + f + g
5300
Richard Zou35beecf2018-08-27 12:37:20 -07005301 def func3(a, b):
5302 # type: (float, float) -> float
Shen Li10224432021-08-12 11:39:31 -07005303 c = 0., (0., 0.)
Richard Zou35beecf2018-08-27 12:37:20 -07005304 x = True
5305 while x:
5306 x = False
5307 c = a, (a, b)
5308 d, e = c
5309 f, g = e
5310 return d + f + g
5311
Zachary DeVitob8ada732018-04-23 10:58:07 -07005312 a = torch.rand(1, requires_grad=True)
5313 b = torch.rand(1, requires_grad=True)
Michael Suo13de6e82018-08-03 00:40:33 -07005314 self.checkScript(func1, (a, b), optimize=True)
Zachary DeVitob8ada732018-04-23 10:58:07 -07005315 self.checkScript(func2, (a, b), optimize=True)
Richard Zou35beecf2018-08-27 12:37:20 -07005316 self.checkScript(func3, (a.item(), b.item()), optimize=True)
Zachary DeVitob8ada732018-04-23 10:58:07 -07005317
Richard Zoufea95de2018-05-08 15:40:36 -04005318 def test_expand(self):
5319 @torch.jit.script
5320 def func(x, y):
5321 return x + y
5322
5323 x = torch.rand(2, 3, dtype=torch.float, requires_grad=True)
5324 y = torch.rand(3, dtype=torch.float, requires_grad=True)
5325 out = func(x, y)
5326 self.assertEqual(func(x, y), x + y)
5327
Sam Gross12229af2018-06-06 18:09:53 -04005328 grad = torch.randn(2, 3, dtype=torch.float)
Richard Zoufea95de2018-05-08 15:40:36 -04005329 out.backward(grad)
5330 self.assertEqual(x.grad, grad)
5331 self.assertEqual(y.grad, grad.sum(dim=0))
5332
Zachary DeVitof7f95f12018-05-17 10:00:35 -07005333 def test_sum(self):
5334 @torch.jit.script
5335 def func(x):
5336 return x.sum(dim=[4])
5337
5338 @torch.jit.script
5339 def func2(x):
5340 return x.sum(dim=4)
5341
Kurt Mohler23bdb572022-07-09 00:54:42 +00005342 # test that shape analysis is written correctly for sum with OptionalIntArrayRef[1] dim argument
Shen Li10224432021-08-12 11:39:31 -07005343 self.run_pass('constant_propagation', func.graph)
5344 self.run_pass('constant_propagation', func2.graph)
Zachary DeVitodcb5fd32019-04-13 08:28:11 -07005345 g = _propagate_shapes(func.graph, (torch.zeros(1, 1, 1, 1, 4),), False)
5346 g2 = _propagate_shapes(func2.graph, (torch.zeros(1, 1, 1, 1, 4),), False)
Zachary DeVitof7f95f12018-05-17 10:00:35 -07005347
Zachary DeVitob8ada732018-04-23 10:58:07 -07005348 def test_cat(self):
Elias Ellison0e3a05e2020-05-06 11:27:59 -07005349 with enable_profiling_mode_for_profiling_tests():
Nikolay Korovaiko5b702ab2019-11-11 13:39:03 -08005350 @torch.jit.script
5351 def func(x):
5352 return torch.cat((x, x), dim=0)
Zachary DeVitob8ada732018-04-23 10:58:07 -07005353
Nikolay Korovaiko5b702ab2019-11-11 13:39:03 -08005354 x = torch.rand(10, dtype=torch.float, requires_grad=True)
5355 self.assertEqual(func(x, profile_and_replay=True), torch.cat((x, x), dim=0))
Zachary DeVitob8ada732018-04-23 10:58:07 -07005356
Nikolay Korovaiko5b702ab2019-11-11 13:39:03 -08005357 @torch.jit.script
5358 def func2(x, y):
5359 return torch.cat((x, x), y)
Elias Ellison74e6a662018-08-24 12:59:24 -07005360
Nikolay Korovaiko5b702ab2019-11-11 13:39:03 -08005361 with disable_autodiff_subgraph_inlining():
Natalia Gimelshein09eb3e62021-09-29 09:26:26 -07005362 for sizes in ((2, 2), (0, 2)):
5363 x = torch.rand(sizes).requires_grad_()
5364 y = torch.tensor(1)
Zachary DeVito1abbee02019-04-10 18:12:38 -07005365
Natalia Gimelshein09eb3e62021-09-29 09:26:26 -07005366 output = func2(x, y, profile_and_replay=True)
5367 output_ref = torch.cat((x, x), y)
5368 self.assertEqual(output, output_ref)
Zachary DeVito1abbee02019-04-10 18:12:38 -07005369
Natalia Gimelshein09eb3e62021-09-29 09:26:26 -07005370 if GRAPH_EXECUTOR != ProfilingMode.SIMPLE:
5371 self.assertAutodiffNode(func2.graph_for(x, y), True, ['aten::cat'], [])
Zachary DeVito1abbee02019-04-10 18:12:38 -07005372
Natalia Gimelshein09eb3e62021-09-29 09:26:26 -07005373 grad = torch.autograd.grad(output.sum(), x)
5374 grad_ref = torch.autograd.grad(output_ref.sum(), x)
5375 self.assertEqual(grad, grad_ref)
Zachary DeVitob8ada732018-04-23 10:58:07 -07005376
Zachary DeVitoa466c122018-06-07 12:38:58 -07005377 def test_cat_lifts(self):
5378 @torch.jit.script
5379 def foo(x):
5380 return torch.cat([x, x], dim=1)
5381
5382 @torch.jit.script
5383 def foo2(x):
Zachary DeVito44fb23a2018-11-08 20:22:57 -08005384 return torch.cat([], dim=1)
Zachary DeVitoa466c122018-06-07 12:38:58 -07005385
5386 @torch.jit.script
5387 def foo3(x):
5388 return torch.cat([x], dim=1)
5389
eellisonbd7fcce2019-03-06 13:41:13 -08005390 for g in [foo.graph, foo2.graph, foo3.graph]:
Shen Li10224432021-08-12 11:39:31 -07005391 FileCheck().check("int =").check("ListConstruct").check("aten::cat").run(str(g))
Zachary DeVitoa466c122018-06-07 12:38:58 -07005392
Zachary DeVito1abbee02019-04-10 18:12:38 -07005393 def test_stack(self):
Elias Ellison0e3a05e2020-05-06 11:27:59 -07005394 with enable_profiling_mode_for_profiling_tests():
Nikolay Korovaiko5b702ab2019-11-11 13:39:03 -08005395 @torch.jit.script
5396 def func(x):
5397 return torch.stack((x, x), dim=1)
5398 x = torch.rand(10, 10)
Shen Li10224432021-08-12 11:39:31 -07005399 self.assertEqual(func(x, profile_and_replay=True), torch.stack((x, x), dim=1))
Zachary DeVito1abbee02019-04-10 18:12:38 -07005400
Nikolay Korovaiko5b702ab2019-11-11 13:39:03 -08005401 @torch.jit.script
5402 def func2(x, y):
5403 return torch.stack((x, y), dim=0)
Zachary DeVito1abbee02019-04-10 18:12:38 -07005404
Nikolay Korovaiko5b702ab2019-11-11 13:39:03 -08005405 with disable_autodiff_subgraph_inlining():
5406 x = torch.randn([2, 2]).requires_grad_()
5407 y = torch.randn([2, 2]).requires_grad_()
Zachary DeVito1abbee02019-04-10 18:12:38 -07005408
Nikolay Korovaiko5b702ab2019-11-11 13:39:03 -08005409 output = func2(x, y, profile_and_replay=True)
5410 output_ref = torch.stack((x, y), 0)
5411 self.assertEqual(output, output_ref)
5412 if GRAPH_EXECUTOR != ProfilingMode.SIMPLE:
Shen Li10224432021-08-12 11:39:31 -07005413 self.assertAutodiffNode(func2.graph_for(x, y), True, ['aten::stack'], [])
Zachary DeVito1abbee02019-04-10 18:12:38 -07005414
Nikolay Korovaiko5b702ab2019-11-11 13:39:03 -08005415 grads = torch.autograd.grad(output.sum(), (x, y))
5416 grads_ref = torch.autograd.grad(output_ref.sum(), (x, y))
5417 self.assertEqual(grads, grads_ref)
Zachary DeVito1abbee02019-04-10 18:12:38 -07005418
Shen Li10224432021-08-12 11:39:31 -07005419 @unittest.skipIf(GRAPH_EXECUTOR != ProfilingMode.LEGACY,
5420 "Profiling executor will be using different heuristics for constructing differentiable graphs")
Zachary DeVito1abbee02019-04-10 18:12:38 -07005421 def test_unbind(self):
Elias Ellison0e3a05e2020-05-06 11:27:59 -07005422 with enable_profiling_mode_for_profiling_tests():
Nikolay Korovaiko5b702ab2019-11-11 13:39:03 -08005423 @torch.jit.script
5424 def func(x, y):
5425 # type: (Tensor, int) -> List[Tensor]
Sam Estepe3900d22021-04-19 13:14:27 -07005426 return torch.unbind(x, y)
Zachary DeVito1abbee02019-04-10 18:12:38 -07005427
Nikolay Korovaiko5b702ab2019-11-11 13:39:03 -08005428 with disable_autodiff_subgraph_inlining():
5429 x = torch.rand([2, 2]).requires_grad_()
5430 y = 0
5431 outputs = func(x, y, profile_and_replay=True)
5432 outputs_ref = torch.unbind(x, dim=y)
5433 self.assertEqual(outputs, outputs_ref)
Nikolay Korovaikoab1d8792021-11-18 14:57:56 -08005434 self.assertAutodiffNode(func.graph_for(x, y), True, [], [])
Zachary DeVito1abbee02019-04-10 18:12:38 -07005435
Nikolay Korovaiko97a29182020-06-05 13:41:53 -07005436 grad = torch.autograd.grad(_sum_of_list(outputs), x)
5437 grad_ref = torch.autograd.grad(_sum_of_list(outputs_ref), x)
5438 self.assertEqual(grad, grad_ref)
Nikolay Korovaiko5b702ab2019-11-11 13:39:03 -08005439
Shen Li10224432021-08-12 11:39:31 -07005440
5441 @unittest.skipIf(GRAPH_EXECUTOR == ProfilingMode.PROFILING,
5442 "Profiling executor fails to recognize that tensors in a list require gradients")
Zachary DeVito1abbee02019-04-10 18:12:38 -07005443 def test_meshgrid(self):
Elias Ellison0e3a05e2020-05-06 11:27:59 -07005444 with enable_profiling_mode_for_profiling_tests():
Nikolay Korovaiko5b702ab2019-11-11 13:39:03 -08005445 @torch.jit.script
5446 def func(a):
5447 # type: (List[Tensor]) -> List[Tensor]
Sam Estepe3900d22021-04-19 13:14:27 -07005448 return torch.meshgrid(a)
Nikolay Korovaiko5b702ab2019-11-11 13:39:03 -08005449 with disable_autodiff_subgraph_inlining():
5450 a = torch.tensor([1.0, 2, 3]).requires_grad_()
5451 b = torch.tensor([1.0, 2, 3, 4]).requires_grad_()
5452 inputs = [a, b]
Zachary DeVito1abbee02019-04-10 18:12:38 -07005453
Nikolay Korovaiko5b702ab2019-11-11 13:39:03 -08005454 outputs_ref = torch.meshgrid(inputs)
5455 outputs = func(inputs, profile_and_replay=True)
5456 self.assertEqual(outputs, outputs_ref)
Zachary DeVito1abbee02019-04-10 18:12:38 -07005457
Nikolay Korovaiko5b702ab2019-11-11 13:39:03 -08005458 if GRAPH_EXECUTOR != ProfilingMode.SIMPLE:
Nikolay Korovaikoab1d8792021-11-18 14:57:56 -08005459 self.assertAutodiffNode(func.graph_for(inputs), True, [], [])
Zachary DeVito1abbee02019-04-10 18:12:38 -07005460
Nikolay Korovaiko5b702ab2019-11-11 13:39:03 -08005461 grads = torch.autograd.grad(_sum_of_list(outputs), inputs)
5462 grads_ref = torch.autograd.grad(_sum_of_list(outputs_ref), inputs)
5463 self.assertEqual(grads, grads_ref)
Zachary DeVito1abbee02019-04-10 18:12:38 -07005464
David Riazatif9c0a082018-10-31 13:11:01 -07005465 def test_tensor_len(self):
5466 def func(x):
5467 return len(x)
5468
5469 self.checkScript(func, [torch.ones(4, 5, 6)])
5470
Adam Paszkea58f2d22018-03-22 16:58:36 +01005471 def test_func_call(self):
Adam Paszkea58f2d22018-03-22 16:58:36 +01005472 def add(a, b):
5473 return a + b
5474
5475 def mul(a, x):
5476 return a * x
5477
5478 def func(alpha, beta, x, y):
5479 return add(mul(alpha, x), mul(beta, y))
David Riazatidefd23b2019-06-25 16:17:49 -07005480
Adam Paszkea58f2d22018-03-22 16:58:36 +01005481 alpha = torch.rand(1, dtype=torch.float, requires_grad=True)
5482 beta = torch.rand(1, dtype=torch.float, requires_grad=True)
5483 x = torch.rand(3, dtype=torch.float, requires_grad=True)
5484 y = torch.rand(3, dtype=torch.float, requires_grad=True)
David Riazatidefd23b2019-06-25 16:17:49 -07005485
Adam Paszkea58f2d22018-03-22 16:58:36 +01005486 # NOTE: cannot optimize yet because broadcasts are not inserted before the fuser runs
David Riazatidefd23b2019-06-25 16:17:49 -07005487 self.checkScript(func, [alpha, beta, x, y], optimize=False)
Adam Paszkea58f2d22018-03-22 16:58:36 +01005488
Nikolay Korovaikofe261022020-09-13 15:56:30 -07005489 @unittest.skip("bailouts are being deprecated")
Nikolay Korovaiko9499c7b2019-05-10 23:02:41 -07005490 def test_profiling_graph_executor(self):
5491 @torch.jit.script
Nikolay Korovaikoa85305f2019-06-14 16:51:59 -07005492 def def_in_one_branch(x, z):
5493 # type: (Tensor, bool) -> float
5494 y = x
5495 if z is False:
5496 y = x + 1
5497
5498 return y.sum()
Nikolay Korovaiko9499c7b2019-05-10 23:02:41 -07005499
5500 a = torch.rand(2, 3)
Nikolay Korovaiko9499c7b2019-05-10 23:02:41 -07005501
Elias Ellison0e3a05e2020-05-06 11:27:59 -07005502 with enable_profiling_mode_for_profiling_tests():
Nikolay Korovaiko47faee22019-10-29 11:40:04 -07005503 # check prim::profile are inserted
5504 profiled_graph_str = str(def_in_one_branch.graph_for(a, True))
Nikolay Korovaikoa85305f2019-06-14 16:51:59 -07005505 FileCheck().check_count("prim::profile", 4).run(profiled_graph_str)
Nikolay Korovaiko47faee22019-10-29 11:40:04 -07005506 # this call is optimized for
5507 # the given shape of (2, 3)
Nikolay Korovaikoa85305f2019-06-14 16:51:59 -07005508 def_in_one_branch(a, False)
Nikolay Korovaiko47faee22019-10-29 11:40:04 -07005509 # change shape to (3)
5510 # so we go down a bailout path
Nikolay Korovaikoa85305f2019-06-14 16:51:59 -07005511 a = torch.ones(3)
5512 # check prim::BailOuts are inserted
5513 bailout_graph_str = str(def_in_one_branch.graph_for(a, True))
Nikolay Korovaikoa3fc6ed2019-06-20 21:19:25 -07005514 FileCheck().check_count("prim::BailOut", 3).run(bailout_graph_str)
Nikolay Korovaikoa85305f2019-06-14 16:51:59 -07005515 # this triggers all 3 bailouts
5516 self.assertEqual(def_in_one_branch(a, False), 6.0)
5517 # this triggers 2 bailouts
5518 self.assertEqual(def_in_one_branch(a, True), 3.0)
Nikolay Korovaiko9499c7b2019-05-10 23:02:41 -07005519
Nikolay Korovaikofe261022020-09-13 15:56:30 -07005520 @unittest.skip("bailouts are being deprecated")
Owen Andersona4224882020-03-17 13:48:21 -07005521 def test_maxpool_guard_elimination(self):
5522 @torch.jit.script
5523 def my_maxpool(x):
5524 return F.max_pool1d(x, kernel_size=[1]) + torch.ones([32, 32, 32])
5525
5526 a = torch.rand(32, 32, 32)
5527
Elias Ellison0e3a05e2020-05-06 11:27:59 -07005528 with enable_profiling_mode_for_profiling_tests():
Owen Andersona4224882020-03-17 13:48:21 -07005529 my_maxpool(a)
5530 bailout_graph_str = str(my_maxpool.graph_for(a))
5531 FileCheck().check_count("prim::BailOut", 1).run(bailout_graph_str)
5532
Nikolay Korovaikofe261022020-09-13 15:56:30 -07005533 @unittest.skip("bailouts are being deprecated")
Owen Andersond35a4c22020-02-14 22:52:48 -08005534 def test_slice_guard_elimination(self):
5535 @torch.jit.script
5536 def my_slice(x):
5537 return x[0:16:2] + x[0:16:2]
5538
5539 a = torch.rand(32, 4)
5540
Elias Ellison0e3a05e2020-05-06 11:27:59 -07005541 with enable_profiling_mode_for_profiling_tests():
Owen Andersond35a4c22020-02-14 22:52:48 -08005542 my_slice(a)
5543 bailout_graph_str = str(my_slice.graph_for(a))
5544 FileCheck().check_count("prim::BailOut", 1).run(bailout_graph_str)
5545
Nikolay Korovaikofe261022020-09-13 15:56:30 -07005546 @unittest.skip("bailouts are being deprecated")
Owen Anderson1d743e32020-02-18 13:21:04 -08005547 def test_unsqueeze_guard_elimination(self):
5548 @torch.jit.script
5549 def my_unsqueeze(x):
5550 return torch.unsqueeze(x, 0) + torch.unsqueeze(x, 0)
5551
5552 a = torch.rand(32, 4)
5553
Elias Ellison0e3a05e2020-05-06 11:27:59 -07005554 with enable_profiling_mode_for_profiling_tests():
Owen Anderson1d743e32020-02-18 13:21:04 -08005555 my_unsqueeze(a)
5556 bailout_graph_str = str(my_unsqueeze.graph_for(a))
5557 FileCheck().check_count("prim::BailOut", 2).run(bailout_graph_str)
Owen Andersond35a4c22020-02-14 22:52:48 -08005558
Elias Ellison221eddd2019-02-27 18:59:19 -08005559 def test_resize_input_ops(self):
5560 # resize_ and resize_as resize the input tensor. because our shape analysis
5561 # is flow invariant, we set any Tensor that can alias a resized Tensor
5562 # to the base Tensor Type, without size information.
5563
5564 # testing that value which is an input of a graph gets handled
5565 def out_op_graph_input():
5566 @torch.jit.script
5567 def test(x, y, z):
5568 torch.mul(x, y, out=z)
5569 return z
5570
Shen Li10224432021-08-12 11:39:31 -07005571 graph = _propagate_shapes(test.graph,
5572 (torch.zeros(2, 1), torch.zeros(1, 2), torch.zeros(1, 1, 1)), False)
Zachary DeVito2d079932019-04-02 17:33:06 -07005573 self.assertTrue(next(graph.outputs()).type() == TensorType.get())
Elias Ellison221eddd2019-02-27 18:59:19 -08005574 out_op_graph_input()
5575
5576 def test_resize():
5577 @torch.jit.script
5578 def test(x):
5579 after_resize_alias = torch.zeros([2])
5580 for _i in range(5):
5581 b = x + 1
5582 f = [1]
5583 before_resize_alias = b.sub_(1)
5584 # for i in range(10):
5585 f.append(1)
5586 b.resize_(f)
5587 after_resize_alias = b.add_(1)
5588 return after_resize_alias
5589
Shen Li10224432021-08-12 11:39:31 -07005590 self.run_pass('constant_propagation', test.graph)
Zachary DeVitodcb5fd32019-04-13 08:28:11 -07005591 g = _propagate_shapes(test.graph, (torch.zeros(1, 1),), False)
Elias Ellison221eddd2019-02-27 18:59:19 -08005592 resize_node = g.findNode("aten::resize_")
5593 # first input and output of b.resize_ is b
5594 self.assertTrue(next(resize_node.inputs()).type() == TensorType.get())
5595 self.assertTrue(next(resize_node.outputs()).type() == TensorType.get())
5596
5597 # correctly propagates to b alias set
5598 before_resize = g.findNode("aten::sub_")
5599 self.assertTrue(next(before_resize.outputs()).type() == TensorType.get())
5600
5601 after_resize = g.findNode("aten::add_")
5602 self.assertTrue(next(after_resize.outputs()).type() == TensorType.get())
5603
5604 test_resize()
5605
5606 def test_resize_as():
5607 @torch.jit.script
5608 def test(x):
5609 b = torch.zeros([2, 2])
5610 b.resize_as_(x)
5611 return b
5612
5613 g = test.graph
Shen Li10224432021-08-12 11:39:31 -07005614 self.run_pass('constant_propagation', g)
Zachary DeVitodcb5fd32019-04-13 08:28:11 -07005615 g = _propagate_shapes(test.graph, (torch.zeros(1, 1),), False)
Elias Ellison221eddd2019-02-27 18:59:19 -08005616
5617 # x doesn't alias a resized op so it shouldn't be set to base Tensor type
5618 self.assertTrue(next(g.inputs()).type() != TensorType.get())
5619 # return is resized
5620 self.assertTrue(next(g.outputs()).type() == TensorType.get())
5621
5622 test_resize_as()
5623
eellison8a88d332019-06-10 14:43:19 -07005624 def test_uninitialized(self):
5625 graph_str = """graph():
5626 %1 : int = prim::Uninitialized()
5627 %2 : int = prim::Constant[value=1]()
5628 %3 : int = aten::add(%1, %2)
5629 return (%3)
5630 """
5631 g = parse_ir(graph_str)
5632 m = self.createFunctionFromGraph(g)
5633 self.getExportImportCopy(m)
5634 with self.assertRaisesRegex(RuntimeError, "isInt"):
5635 m()
5636
Shen Li10224432021-08-12 11:39:31 -07005637
5638 @unittest.skipIf(GRAPH_EXECUTOR == ProfilingMode.SIMPLE, "Simple Executor doesn't use requires_grad information")
5639 @unittest.skipIf(GRAPH_EXECUTOR == ProfilingMode.PROFILING, "Peeling is now disabled")
Elias Ellisonca962f02019-03-24 14:28:22 -07005640 def test_requires_grad_loop(self):
5641 @torch.jit.script
5642 def test(x, y, z):
5643 # type: (Tensor, Tensor, int) -> Tensor
5644 for _ in range(z):
5645 x = y
5646 return x
5647
5648 # x requires grad, y does not
5649 # testing that requires grad analysis correctly exits, with its input
5650 # to the loop (x) requiring grad and its output to the loop not requiring grad
5651 # and the output of the node conservatively setting grad to true
5652
5653 inps = (torch.tensor(1.0, requires_grad=True), torch.tensor(1), 10)
Nikolay Korovaiko47faee22019-10-29 11:40:04 -07005654 test(*inps, profile_and_replay=True)
Elias Ellisonca962f02019-03-24 14:28:22 -07005655
5656 graph = test.graph_for(*inps)
5657 loop = graph.findNode("prim::Loop")
5658 loop_body = next(loop.blocks())
5659 loop_inputs = list(loop_body.inputs())
5660 loop_outputs = list(loop_body.outputs())
5661
Nikolay Korovaiko5b702ab2019-11-11 13:39:03 -08005662 if GRAPH_EXECUTOR == ProfilingMode.PROFILING:
Nikolay Korovaiko7f551972020-06-10 13:46:11 -07005663 # TODO: simplify this test as it's very sensitive
5664 # the optimized graph will have 3 loops
5665 # the original loop is peeled
5666 # peeled loop also gets unrolled
5667 index_of_x_in_peeled_unrolled_loop = -2
Shen Li10224432021-08-12 11:39:31 -07005668 self.assertTrue(loop_inputs[index_of_x_in_peeled_unrolled_loop].requires_grad())
Nikolay Korovaiko47faee22019-10-29 11:40:04 -07005669 bailouts_in_outer_block = graph.findAllNodes("prim::BailOut", False)
Nikolay Korovaiko7f551972020-06-10 13:46:11 -07005670 last_bailout_index_on_loops_output = -1
Shen Li10224432021-08-12 11:39:31 -07005671 self.assertFalse(bailouts_in_outer_block[last_bailout_index_on_loops_output].output().requires_grad())
Nikolay Korovaiko47faee22019-10-29 11:40:04 -07005672 else:
Nikolay Korovaikoe16908c2020-03-09 17:09:42 -07005673 self.assertTrue(loop_inputs[1].requires_grad())
Nikolay Korovaiko47faee22019-10-29 11:40:04 -07005674 self.assertTrue(loop.output().requires_grad())
5675 self.assertFalse(loop_outputs[1].requires_grad())
Elias Ellisonca962f02019-03-24 14:28:22 -07005676
Adam Paszkeb45f2ff2018-05-16 20:03:04 +02005677 def test_view_shape_prop(self):
Shen Li10224432021-08-12 11:39:31 -07005678 cu = torch.jit.CompilationUnit('''
Adam Paszkeb45f2ff2018-05-16 20:03:04 +02005679 def test_view_shape_prop(a):
James Reeda4120fa2018-10-11 10:47:16 -07005680 return a.view(size=[-1])
Shen Li10224432021-08-12 11:39:31 -07005681 ''')
Adam Paszkeb45f2ff2018-05-16 20:03:04 +02005682 inputs = [torch.zeros(10, 10)]
5683 outputs = torch.zeros(100)
5684
5685 real_outs = cu.test_view_shape_prop(*inputs)
5686 self.assertEqual(real_outs, outputs)
5687
Yanbo Liang490c1cf2022-12-19 04:14:11 +00005688 @skipIfTorchDynamo("TorchDynamo fails with unknown reason")
Richard Zoub5f60af2018-09-21 14:17:07 -07005689 def test_view_listconstruct_shape_prop(self):
5690 def fn(x):
5691 B = x.size(0)
5692 C = x.size(1)
5693 T = x.size(2)
5694 return x.view(T, B, C)
5695
5696 x = torch.randn(3, 1, 5, requires_grad=True)
Zachary DeVito2d079932019-04-02 17:33:06 -07005697 fn = torch.jit.script(fn)
Zachary DeVitodcb5fd32019-04-13 08:28:11 -07005698 graph = _propagate_shapes(fn.graph, (x,), False)
Shen Li10224432021-08-12 11:39:31 -07005699 self.assertTrue(next(graph.outputs()).type().scalarType() == 'Double')
Richard Zoub5f60af2018-09-21 14:17:07 -07005700
James Reed33f47512019-04-04 12:53:44 -07005701 def test_shape_prop_promotion(self):
5702 @torch.jit.script
5703 def fn(x, y):
5704 return x + y
5705
5706 x, y = torch.rand(3, 4, dtype=torch.float), torch.rand(3, 4, dtype=torch.double)
Zachary DeVitodcb5fd32019-04-13 08:28:11 -07005707 graph = _propagate_shapes(fn.graph, (x, y), False)
Shen Li10224432021-08-12 11:39:31 -07005708 FileCheck().check('Double(*, *, device=cpu) = aten::add').run(graph)
James Reed33f47512019-04-04 12:53:44 -07005709
5710 def test_shape_prop_promote_scalar_arg(self):
5711 @torch.jit.script
5712 def fn(x):
5713 return math.pi + x
5714
5715 x = torch.zeros(3, 4, dtype=torch.long)
Zachary DeVitodcb5fd32019-04-13 08:28:11 -07005716 graph = _propagate_shapes(fn.graph, (x,), False)
Brian Vaughan88e4cee2019-09-05 18:24:09 -07005717 default = torch.get_default_dtype()
Shen Li10224432021-08-12 11:39:31 -07005718 if(default == torch.float):
5719 FileCheck().check('Float(*, *, requires_grad=0, device=cpu) = aten::add').run(graph)
Brian Vaughan88e4cee2019-09-05 18:24:09 -07005720 else:
Shen Li10224432021-08-12 11:39:31 -07005721 FileCheck().check('Double(*, *, requires_grad=0, device=cpu) = aten::add').run(graph)
James Reed33f47512019-04-04 12:53:44 -07005722
Adam Paszkeb45f2ff2018-05-16 20:03:04 +02005723 def test_integral_shape_inference(self):
Shen Li10224432021-08-12 11:39:31 -07005724 cu = torch.jit.CompilationUnit('''
Adam Paszkeb45f2ff2018-05-16 20:03:04 +02005725 def test_integral_shape_inference(a):
Mike Ruberry594a66d2021-02-10 03:11:14 -08005726 return a * a
Shen Li10224432021-08-12 11:39:31 -07005727 ''')
Ailing Zhang92323562020-05-08 08:14:06 -07005728 inputs = [torch.ones(10, 10, dtype=torch.long)]
Sergii Dymchenko58d1cf72022-08-03 22:45:39 +00005729 outputs = torch.ones(10, 10, dtype=torch.long)
Adam Paszkeb45f2ff2018-05-16 20:03:04 +02005730
Sergii Dymchenko58d1cf72022-08-03 22:45:39 +00005731 self.assertEqual(cu.test_integral_shape_inference(*inputs), outputs)
Adam Paszkeb45f2ff2018-05-16 20:03:04 +02005732
Shen Li10224432021-08-12 11:39:31 -07005733 @unittest.skipIf(RUN_CUDA, 'This tests the CPU fuser')
peter2ce8c832019-09-17 07:27:39 -07005734 @unittest.skipIf(IS_SANDCASTLE, "NYI: fuser support for Sandcastle")
James Reed60849082019-04-05 17:10:13 -07005735 @enable_cpu_fuser
5736 def test_batchnorm_fuser_cpu(self):
Shen Li10224432021-08-12 11:39:31 -07005737 code = '''
James Reed60849082019-04-05 17:10:13 -07005738 graph(%3 : Tensor,
5739 %7 : Tensor,
5740 %12 : Float(*, *),
5741 %13 : Tensor,
5742 %25 : Tensor):
5743 %23 : int = prim::Constant[value=1]()
5744 %22 : float = prim::Constant[value=1e-05]()
5745 %26 : Tensor = aten::sqrt(%25)
5746 %24 : Tensor = aten::add(%26, %22, %23)
5747 %20 : Tensor = aten::reciprocal(%24)
5748 %norm_invstd : Tensor = aten::mul(%20, %23)
5749 %15 : Tensor = aten::sub(%12, %13, %23)
5750 %11 : Tensor = aten::mul(%15, %norm_invstd)
5751 %8 : Tensor = aten::mul(%11, %7)
5752 %5 : Tensor = aten::add(%8, %3, %23)
5753 %1 : Float(*, *) = aten::relu(%5)
5754 return (%1)
Shen Li10224432021-08-12 11:39:31 -07005755 '''
James Reed60849082019-04-05 17:10:13 -07005756
5757 graph = parse_ir(code)
5758 inputs = 5 * [torch.rand(26, 2048, dtype=torch.float)]
5759 code = torch._C._jit_fuser_get_fused_kernel_code(graph, inputs)
Shen Li10224432021-08-12 11:39:31 -07005760 FileCheck().check('sqrtf').run(code)
James Reed60849082019-04-05 17:10:13 -07005761
davidriazati60642232019-12-30 11:43:04 -08005762 @slowTest
Shen Li10224432021-08-12 11:39:31 -07005763 @unittest.skipIf(RUN_CUDA, 'This tests the CPU fuser')
peter2ce8c832019-09-17 07:27:39 -07005764 @unittest.skipIf(IS_SANDCASTLE, "NYI: fuser support for Sandcastle")
James Reed34382e42019-04-06 17:44:53 -07005765 @enable_cpu_fuser
5766 def test_fuser_double_float_codegen(self):
Shen Li10224432021-08-12 11:39:31 -07005767 fns = ['log', 'log10', 'log1p', 'log2', 'lgamma', 'exp', 'expm1', 'erf',
5768 'erfc', 'cos', 'acos', 'cosh', 'sin', 'asin', 'sinh', 'tan',
5769 'atan', 'tanh', 'sqrt', 'ceil', 'floor', 'round', 'trunc',
5770 'frac']
James Reed34382e42019-04-06 17:44:53 -07005771
5772 def lookup_c_equivalent_fn(aten_fn):
David Berard15c98702022-03-23 14:07:21 -07005773 return aten_fn
James Reed34382e42019-04-06 17:44:53 -07005774
5775 def test_dispatch(op, expects, dtype, binary=False):
5776 if dtype == torch.double:
Shen Li10224432021-08-12 11:39:31 -07005777 dtype_str = 'Double'
James Reed34382e42019-04-06 17:44:53 -07005778 elif dtype == torch.float:
Shen Li10224432021-08-12 11:39:31 -07005779 dtype_str = 'Float'
James Reed34382e42019-04-06 17:44:53 -07005780 else:
Shen Li10224432021-08-12 11:39:31 -07005781 raise RuntimeError('Unknown dtype')
James Reed34382e42019-04-06 17:44:53 -07005782
5783 if binary:
Shen Li10224432021-08-12 11:39:31 -07005784 code = '''
James Reed34382e42019-04-06 17:44:53 -07005785 graph(%3 : Tensor, %4 : Tensor):
5786 %2 : {dtype}(*, *) = aten::{op}(%3, %4)
5787 %1 : {dtype}(*, *) = aten::relu(%2)
5788 return (%1)
Shen Li10224432021-08-12 11:39:31 -07005789 '''.format(op=op, dtype=dtype_str)
James Reed34382e42019-04-06 17:44:53 -07005790 else:
Shen Li10224432021-08-12 11:39:31 -07005791 code = '''
James Reed34382e42019-04-06 17:44:53 -07005792 graph(%3 : Tensor):
5793 %2 : {dtype}(*, *) = aten::{op}(%3)
5794 %1 : {dtype}(*, *) = aten::relu(%2)
5795 return (%1)
Shen Li10224432021-08-12 11:39:31 -07005796 '''.format(op=op, dtype=dtype_str)
James Reed34382e42019-04-06 17:44:53 -07005797
5798 graph = parse_ir(code)
5799 inputs = (2 if binary else 1) * [torch.rand(26, 2048, dtype=dtype)]
5800 code = torch._C._jit_fuser_get_fused_kernel_code(graph, inputs)
5801 FileCheck().check(expects).run(code)
5802
5803 for fn in fns:
Shen Li10224432021-08-12 11:39:31 -07005804 test_dispatch(fn, lookup_c_equivalent_fn(fn) + '(', torch.double)
5805 test_dispatch(fn, lookup_c_equivalent_fn(fn) + 'f(', torch.float)
James Reed34382e42019-04-06 17:44:53 -07005806
David Berard15c98702022-03-23 14:07:21 -07005807 # 'min', 'max' were previously tested but are now replaced with ternary expressions
5808 # instead of fmin() and fmax()
5809 binary_fns = ['pow']
James Reed34382e42019-04-06 17:44:53 -07005810 for fn in binary_fns:
Shen Li10224432021-08-12 11:39:31 -07005811 test_dispatch(fn, lookup_c_equivalent_fn(fn) + '(', torch.double, binary=True)
5812 test_dispatch(fn, lookup_c_equivalent_fn(fn) + 'f(', torch.float, binary=True)
James Reed34382e42019-04-06 17:44:53 -07005813
Shen Li10224432021-08-12 11:39:31 -07005814 @unittest.skipIf(RUN_CUDA, 'This tests the CPU fuser')
peter2ce8c832019-09-17 07:27:39 -07005815 @unittest.skipIf(IS_SANDCASTLE, "NYI: fuser support for Sandcastle")
James Reed9b69f212019-04-07 00:15:42 -07005816 @enable_cpu_fuser
5817 def test_fuser_double_literal_precision(self):
Shen Li10224432021-08-12 11:39:31 -07005818 code = '''
James Reed9b69f212019-04-07 00:15:42 -07005819 graph(%2 : Float(*, *)):
5820 %4 : int = prim::Constant[value=1]()
5821 %3 : float = prim::Constant[value=1.282549830161864]()
5822 %5 : Float(*, *) = aten::add(%2, %3, %4)
5823 %1 : Float(*, *) = aten::relu(%5)
5824 return (%1)
Shen Li10224432021-08-12 11:39:31 -07005825 '''
James Reed9b69f212019-04-07 00:15:42 -07005826
5827 graph = parse_ir(code)
5828 code = torch._C._jit_fuser_get_fused_kernel_code(graph, [torch.rand(3, 4)])
Shen Li10224432021-08-12 11:39:31 -07005829 FileCheck().check('1.282549830161864').run(code)
James Reed9b69f212019-04-07 00:15:42 -07005830
Adam Paszkeb45f2ff2018-05-16 20:03:04 +02005831 def test_fuser_multiple_blocks(self):
Shen Li10224432021-08-12 11:39:31 -07005832 cu = torch.jit.CompilationUnit('''
Adam Paszkeb45f2ff2018-05-16 20:03:04 +02005833 def test_fuser_multiple_blocks(this, that, theother, meme):
5834 i = 0
5835 while i < 20:
James Reeda4120fa2018-10-11 10:47:16 -07005836 this = torch.cat([this, meme], dim=0)
5837 that = torch.cat([that, meme], dim=0)
5838 theother = torch.cat([theother, meme], dim=0)
Adam Paszkeb45f2ff2018-05-16 20:03:04 +02005839 i = i + 1
5840 return this, that, theother
Shen Li10224432021-08-12 11:39:31 -07005841 ''')
Adam Paszkeb45f2ff2018-05-16 20:03:04 +02005842
5843 inputs = [torch.ones(0, 10, 10)] * 3
5844 inputs += [torch.ones(1, 10, 10)]
5845 outputs = [torch.ones(20, 10, 10)] * 3
5846
5847 self.assertEqual(cu.test_fuser_multiple_blocks(*inputs), outputs)
5848
Adam Paszkea58f2d22018-03-22 16:58:36 +01005849 @unittest.skip("RuntimeError: VariableType::ID() not implemented")
5850 def test_cast(self):
Shen Li10224432021-08-12 11:39:31 -07005851 script = '''
Adam Paszkea58f2d22018-03-22 16:58:36 +01005852 def to_int(x):
5853 return int(x)
Shen Li10224432021-08-12 11:39:31 -07005854 '''
Adam Paszkea58f2d22018-03-22 16:58:36 +01005855 x = Variable(torch.FloatTensor([1.1, 2.3]), requires_grad=True)
5856 out = Variable(torch.IntTensor([1, 2]), requires_grad=True)
Shen Li10224432021-08-12 11:39:31 -07005857 self.checkScript(script, [x], optimize=True, outputs=[out], func='to_int')
Adam Paszkea58f2d22018-03-22 16:58:36 +01005858
Zachary DeVitof0b5ad82019-05-07 18:43:04 -07005859 def test_str_cast(self):
5860 @torch.jit.script
5861 def to_str(x):
5862 # type: (int) -> str
5863 return str((x, x))
5864
5865 self.assertEqual("(1, 1)", to_str(1))
5866
Vishwak Srinivasanfd5b5cd2020-04-23 14:23:58 -07005867 def test_int_cast(self):
5868 @torch.jit.script
5869 def to_int(x):
5870 # type: (str) -> int
5871 return int(x)
5872
Shen Li10224432021-08-12 11:39:31 -07005873 self.assertEqual(5, to_int('5'))
5874 self.assertEqual(-5, to_int('-5'))
5875 self.assertEqual(2147483647, to_int('2147483647'))
5876 self.assertEqual(-2147483648, to_int('-2147483648'))
Vishwak Srinivasanfd5b5cd2020-04-23 14:23:58 -07005877
5878 with self.assertRaisesRegex(RuntimeError, "invalid literal for int()"):
Shen Li10224432021-08-12 11:39:31 -07005879 to_int('0x20')
Vishwak Srinivasanfd5b5cd2020-04-23 14:23:58 -07005880
5881 with self.assertRaisesRegex(RuntimeError, "invalid literal for int()"):
Shen Li10224432021-08-12 11:39:31 -07005882 to_int('0b0001')
Vishwak Srinivasanfd5b5cd2020-04-23 14:23:58 -07005883
Adam Paszkea58f2d22018-03-22 16:58:36 +01005884 def test_python_frontend(self):
5885 def fn(x, y, z):
Wanchao Liang47c1bad2018-07-27 22:47:29 -07005886 q = None
Adam Paszkea58f2d22018-03-22 16:58:36 +01005887 q = x + y - z.sigmoid()
5888 print(q)
5889 w = -z
5890 if not x and not y and z:
5891 m = x if not z else y
5892 while x < y > z:
5893 q = x
Elias Ellisona5b627a2018-11-01 09:58:37 -07005894 assert 1 == 1, "hello"
Adam Paszkea58f2d22018-03-22 16:58:36 +01005895 return x
5896
Michael Suo167a9782020-05-12 23:17:57 -07005897 ast = torch.jit.frontend.get_jit_def(fn, fn.__name__)
Adam Paszkea58f2d22018-03-22 16:58:36 +01005898 self.assertExpected(str(ast))
5899
aizjForevercdc3e232020-08-27 12:27:52 -07005900 def test_python_frontend_source_range(self):
5901 def fn():
5902 raise Exception("hello")
5903 ast = torch.jit.frontend.get_jit_def(fn, fn.__name__)
Shen Li10224432021-08-12 11:39:31 -07005904 FileCheck().check("SourceRange at:") \
5905 .check("def fn():") \
5906 .check("~~~~~~~~~") \
5907 .check('raise Exception("hello")') \
5908 .check('~~~~~~~~~~~~~~~~~ <--- HERE') \
5909 .run(str(ast.range()))
aizjForevercdc3e232020-08-27 12:27:52 -07005910
Elias Ellison59f8e8a2018-10-30 20:20:26 -07005911 def test_python_frontend_py3(self):
5912 def fn():
5913 raise Exception("hello")
Michael Suo167a9782020-05-12 23:17:57 -07005914 ast = torch.jit.frontend.get_jit_def(fn, fn.__name__)
Elias Ellison59f8e8a2018-10-30 20:20:26 -07005915 self.assertExpected(str(ast))
5916
Adam Paszkea58f2d22018-03-22 16:58:36 +01005917 def _make_scalar_vars(self, arr, dtype):
5918 return [torch.tensor(val, dtype=dtype) for val in arr]
5919
Shen Li10224432021-08-12 11:39:31 -07005920
Elias Ellison170d2972018-08-02 10:59:36 -07005921 def test_string_print(self):
5922 def func(a):
Shen Li10224432021-08-12 11:39:31 -07005923 print(a, "a" 'b' '''c''' """d""", 2, 1.5)
Elias Ellison170d2972018-08-02 10:59:36 -07005924 return a
5925
5926 inputs = self._make_scalar_vars([1], torch.int64)
5927 self.checkScript(func, inputs, capture_output=True)
5928
Adam Paszkea58f2d22018-03-22 16:58:36 +01005929 def test_while(self):
5930 def func(a, b, max):
David Riazati6f53b4e2018-09-13 11:10:00 -07005931 while bool(a < max):
Adam Paszkea58f2d22018-03-22 16:58:36 +01005932 a = a + 1
5933 b = b + 1
5934 c = a + b
5935 return c
5936
gchanane1f5d802018-04-18 23:37:54 -04005937 inputs = self._make_scalar_vars([1, 1, 10], torch.int64)
Adam Paszkea58f2d22018-03-22 16:58:36 +01005938 self.checkScript(func, inputs, optimize=True)
5939
5940 def test_fibb(self):
5941 def func(lim):
5942 first = 1
5943 second = 1
5944 i = 1
5945 somenum = 5
5946 dontmutateme = 3
5947 third = 0
David Riazati6f53b4e2018-09-13 11:10:00 -07005948 while bool(i < lim):
Adam Paszkea58f2d22018-03-22 16:58:36 +01005949 third = first + second
5950 first = second
5951 second = third
5952 j = 0
5953 while j < 10:
5954 somenum = somenum * 2
5955 j = j + 1
5956 i = i + j
5957 i = i + dontmutateme
5958
5959 st = second + third
5960 fs = first + second
Richard Zou67f6f932018-08-27 08:53:56 -07005961 return third, st, fs
Adam Paszkea58f2d22018-03-22 16:58:36 +01005962
gchanane1f5d802018-04-18 23:37:54 -04005963 inputs = self._make_scalar_vars([10], torch.int64)
Adam Paszkea58f2d22018-03-22 16:58:36 +01005964 self.checkScript(func, inputs, optimize=True)
5965
Zachary DeVitocf356a32019-06-03 21:31:53 -07005966 def test_fibb_totally_better(self):
5967 def fib(x):
5968 # type: (int) -> int
5969 prev = 1
5970 v = 1
5971 for i in range(0, x):
5972 save = v
5973 v = v + prev
5974 prev = save
5975 return v
5976
5977 self.checkScript(fib, (10,))
5978
Adam Paszkea58f2d22018-03-22 16:58:36 +01005979 def test_if(self):
5980 def func(a, b):
Richard Zou67f6f932018-08-27 08:53:56 -07005981 # type: (int, int) -> int
Adam Paszkea58f2d22018-03-22 16:58:36 +01005982 d = 3
David Riazati6f53b4e2018-09-13 11:10:00 -07005983 if bool(a > 10):
Richard Zou67f6f932018-08-27 08:53:56 -07005984 a = 3 + d
Adam Paszkea58f2d22018-03-22 16:58:36 +01005985 else:
Richard Zou67f6f932018-08-27 08:53:56 -07005986 b = 3 + d
Adam Paszkea58f2d22018-03-22 16:58:36 +01005987 d = 4
5988 c = a + b
5989 return c
5990
gchanane1f5d802018-04-18 23:37:54 -04005991 inputs = self._make_scalar_vars([1, -1], torch.int64)
Adam Paszkea58f2d22018-03-22 16:58:36 +01005992 self.checkScript(func, inputs, optimize=True)
5993
Wanchao73ce21a2018-06-20 12:09:24 -07005994 def test_if_for_in_range(self):
5995 def func(a, b):
Richard Zou67f6f932018-08-27 08:53:56 -07005996 # type: (int, int) -> int
5997 d = 3
Wanchao73ce21a2018-06-20 12:09:24 -07005998 for _ in range(20):
David Riazati6f53b4e2018-09-13 11:10:00 -07005999 if bool(a > 10):
Wanchao73ce21a2018-06-20 12:09:24 -07006000 a = 3 + d
6001 else:
6002 b = 3 + d
Richard Zou67f6f932018-08-27 08:53:56 -07006003 d = 4
Wanchao73ce21a2018-06-20 12:09:24 -07006004 c = a + b
6005 return d
6006 inputs = self._make_scalar_vars([1, -1], torch.int64)
6007 self.checkScript(func, inputs, optimize=True)
6008
Adam Paszkea58f2d22018-03-22 16:58:36 +01006009 def test_if_noelse(self):
6010 def func(a, b):
David Riazati6f53b4e2018-09-13 11:10:00 -07006011 if bool(a > 10):
Adam Paszkea58f2d22018-03-22 16:58:36 +01006012 a = 3 + b
6013 c = a + b
6014 return c
6015
gchanane1f5d802018-04-18 23:37:54 -04006016 inputs = self._make_scalar_vars([-1, 1], torch.int64)
Adam Paszkea58f2d22018-03-22 16:58:36 +01006017 self.checkScript(func, inputs, optimize=True)
6018
Wanchao Liang4b315722018-12-03 15:44:45 -08006019 def test_if_is_none_dispatch(self):
Shen Li10224432021-08-12 11:39:31 -07006020
Wanchao Liang5f6ecd12019-02-14 21:37:08 -08006021 @torch.jit.script
6022 def test_lhs_none_rhs_none():
6023 # LHS, RHS both alwaysNone, dispatch always_none_branch
6024 # only emit one prim::Constant
6025 if None is None:
6026 return 1
6027 elif None is not None:
6028 return 2
6029 else:
6030 return 3
Wanchao Liang4b315722018-12-03 15:44:45 -08006031
Shen Li10224432021-08-12 11:39:31 -07006032 self.assertTrue(str(test_lhs_none_rhs_none.graph).count(': int = prim::Constant') == 1)
Wanchao Liang4b315722018-12-03 15:44:45 -08006033
Wanchao Liang5f6ecd12019-02-14 21:37:08 -08006034 @torch.jit.script
6035 def test_lhs_opt_rhs_none(lhs=None):
6036 # type: (Optional[Tensor]) -> int
6037 # LHS maybeNone: emit normal if stmt that contains 3 constants
6038 if lhs is not None:
6039 return 2
6040 elif lhs is None:
6041 return 1
6042 else:
6043 return 3
Wanchao Liang4b315722018-12-03 15:44:45 -08006044
Shen Li10224432021-08-12 11:39:31 -07006045 self.assertTrue(str(test_lhs_opt_rhs_none.graph).count(': int = prim::Constant') == 3)
Wanchao Liang4b315722018-12-03 15:44:45 -08006046
Wanchao Liang5f6ecd12019-02-14 21:37:08 -08006047 @torch.jit.script
6048 def test_lhs_none_rhs_opt(rhs=None):
6049 # type: (Optional[Tensor]) -> int
6050 # RHS maybeNone, emit normal if stmt that contains 3 constants
6051 if None is rhs:
6052 return 1
6053 elif None is not rhs:
6054 return 2
6055 else:
6056 return 3
Wanchao Liang4b315722018-12-03 15:44:45 -08006057
Shen Li10224432021-08-12 11:39:31 -07006058 self.assertTrue(str(test_lhs_opt_rhs_none.graph).count(': int = prim::Constant') == 3)
Wanchao Liang4b315722018-12-03 15:44:45 -08006059
Wanchao Liang5f6ecd12019-02-14 21:37:08 -08006060 @torch.jit.script
6061 def test_lhs_never_rhs_none(lhs):
6062 # LHS neverNone, RHS alwaysNone dispatch never_none_branch
6063 # only emit one prim::Constant
6064 if lhs is None:
6065 return 1
6066 elif lhs is not None:
6067 return 2
6068 else:
6069 return 3
6070
Shen Li10224432021-08-12 11:39:31 -07006071 self.assertTrue(str(test_lhs_never_rhs_none.graph).count(': int = prim::Constant') == 1)
Wanchao Liang5f6ecd12019-02-14 21:37:08 -08006072
6073 @torch.jit.script
6074 def test_lhs_none_rhs_never(rhs):
6075 # LHS alwaysNone, RHS neverNone dispatch never_none_branch
6076 # only emit one prim::Constant
6077 if None is rhs:
6078 return 1
6079 elif None is not rhs:
6080 return 2
6081 else:
6082 return 3
6083
Shen Li10224432021-08-12 11:39:31 -07006084 self.assertTrue(str(test_lhs_none_rhs_never.graph).count(': int = prim::Constant') == 1)
Wanchao Liang4b315722018-12-03 15:44:45 -08006085
Zachary DeVitofcd13542019-09-23 14:21:57 -07006086 @torch.jit.script
6087 def test_bool_arith_and(lhs):
6088 if lhs is None and lhs is not None:
6089 return 1
6090 else:
6091 return 2
6092 self.assertEqual(test_bool_arith_and(torch.zeros(3)), 2)
Shen Li10224432021-08-12 11:39:31 -07006093 self.assertTrue(str(test_bool_arith_and.graph).count('if') == 0)
Zachary DeVitofcd13542019-09-23 14:21:57 -07006094
6095 @torch.jit.script
6096 def test_bool_arith_or(lhs):
6097 if lhs is None or lhs is not None:
6098 return 1
6099 else:
6100 return 2
Zsolt Dollensteinb0043072021-08-12 10:56:55 -07006101 self.assertEqual(test_bool_arith_or(torch.zeros(3)), 1)
Shen Li10224432021-08-12 11:39:31 -07006102 self.assertTrue(str(test_bool_arith_or.graph).count('if') == 0)
6103
Zachary DeVitofcd13542019-09-23 14:21:57 -07006104
6105 @torch.jit.script
6106 def test_bool_arith_not(lhs):
6107 if not (lhs is None):
6108 return 1
6109 else:
6110 return 2
Zsolt Dollensteinb0043072021-08-12 10:56:55 -07006111 self.assertEqual(test_bool_arith_not(torch.zeros(3)), 1)
Shen Li10224432021-08-12 11:39:31 -07006112 self.assertTrue(str(test_bool_arith_not.graph).count('if') == 0)
6113
Elias Ellisonb80a4fa2019-04-03 17:09:37 -07006114 def test_conditional_casting(self):
6115 def test_bool_cast_tensor(x):
6116 if x:
6117 return 1
6118 else:
6119 return 0
6120
6121 for make_one_dim in [True, False]:
6122 for inp_val in [0.1, 0.0, -0.0, -0.1, -1, 0, 1]:
6123 inp_val = [inp_val] if make_one_dim else inp_val
6124 self.checkScript(test_bool_cast_tensor, (torch.tensor(inp_val),))
6125
Shen Li10224432021-08-12 11:39:31 -07006126 self.checkScriptRaisesRegex(test_bool_cast_tensor, (torch.tensor([1, 1]),), Exception,
6127 "Boolean value of Tensor with more than one value")
Elias Ellisonb80a4fa2019-04-03 17:09:37 -07006128
Elias Ellison4371cb52019-04-17 16:01:41 -07006129 def test_not_cast(x):
6130 if not x:
6131 return 1
6132 else:
6133 return 0
6134
6135 self.checkScript(test_not_cast, (torch.tensor(1),))
6136 self.checkScript(test_not_cast, (torch.tensor(0),))
6137
Shen Li10224432021-08-12 11:39:31 -07006138 with self.assertRaisesRegex(RuntimeError, r"Could not cast value of type Tuple\[Tensor, Tensor\]"): # noqa: W605
Elias Ellison4371cb52019-04-17 16:01:41 -07006139 @torch.jit.script
6140 def test_mult(x, y):
Shen Li10224432021-08-12 11:39:31 -07006141 return not(x, y)
Elias Ellison4371cb52019-04-17 16:01:41 -07006142
Elias Ellisonb80a4fa2019-04-03 17:09:37 -07006143 def test_cast_int(x):
6144 # type: (int) -> int
6145 if x:
6146 return 1
6147 else:
6148 return 0
6149 self.checkScript(test_cast_int, (1,))
6150 self.checkScript(test_cast_int, (0,))
6151 self.checkScript(test_cast_int, (-1,))
6152
6153 def test_cast_float(x):
6154 # type: (float) -> int
6155 if x:
6156 return 1
6157 else:
6158 return 0
Shen Li10224432021-08-12 11:39:31 -07006159 self.checkScript(test_cast_float, (1.,))
6160 self.checkScript(test_cast_float, (0.,))
6161 self.checkScript(test_cast_float, (-1.,))
Elias Ellisonb80a4fa2019-04-03 17:09:37 -07006162
Shen Li10224432021-08-12 11:39:31 -07006163 with self.assertRaisesRegex(RuntimeError, r"Could not cast value of type Tuple\[int, int\] to bool"): # noqa: W605
Edward Yangda2004e2020-06-04 12:53:53 -07006164
David Riazati6f53b4e2018-09-13 11:10:00 -07006165 @torch.jit.script
Elias Ellisonb80a4fa2019-04-03 17:09:37 -07006166 def test_bad_conditional(x):
Sam Estepe3900d22021-04-19 13:14:27 -07006167 if (1, 2): # noqa: F634
Elias Ellisonb80a4fa2019-04-03 17:09:37 -07006168 return
6169 else:
6170 return 0
David Riazati6f53b4e2018-09-13 11:10:00 -07006171
Adam Paszkea58f2d22018-03-22 16:58:36 +01006172 def test_while_nonexistent_value(self):
6173 with self.assertRaisesRegex(RuntimeError, "undefined value x"):
Shen Li10224432021-08-12 11:39:31 -07006174 torch.jit.CompilationUnit('''
Adam Paszkea58f2d22018-03-22 16:58:36 +01006175 def test_while(a, b):
David Riazati6f53b4e2018-09-13 11:10:00 -07006176 while bool(a < 10):
Adam Paszkea58f2d22018-03-22 16:58:36 +01006177 a = a + x
6178 b = b + 1
6179 return a + b
Shen Li10224432021-08-12 11:39:31 -07006180 ''')
Adam Paszkea58f2d22018-03-22 16:58:36 +01006181
6182 def test_while_nonexistent_cond_value(self):
6183 with self.assertRaisesRegex(RuntimeError, "undefined value x"):
Shen Li10224432021-08-12 11:39:31 -07006184 torch.jit.CompilationUnit('''
Adam Paszkea58f2d22018-03-22 16:58:36 +01006185 def test_while(a, b):
6186 while a < x:
6187 a = a + 1
6188 b = b + 1
6189 return a + b
Shen Li10224432021-08-12 11:39:31 -07006190 ''')
Adam Paszkea58f2d22018-03-22 16:58:36 +01006191
Elias Ellisond4f6bef2019-01-18 11:17:34 -08006192 @torch.jit.script
6193 def test_ternary(x):
6194 # type: (Optional[int]) -> int
6195 x = x if x is not None else 2
6196 return x
6197
6198 @torch.jit.script
6199 def test_not_none(x):
6200 # type: (Optional[int]) -> None
6201 if x is not None:
6202 print(x + 1)
6203
6204 @torch.jit.script
6205 def test_and(x, y):
6206 # type: (Optional[int], Optional[int]) -> None
6207 if x is not None and y is not None:
6208 print(x + y)
6209
6210 @torch.jit.script
6211 def test_not(x, y):
6212 # type: (Optional[int], Optional[int]) -> None
6213 if not (x is not None and y is not None):
6214 pass
6215 else:
6216 print(x + y)
6217
6218 @torch.jit.script
6219 def test_bool_expression(x):
6220 # type: (Optional[int]) -> None
6221 if x is not None and x < 2:
6222 print(x + 1)
6223
6224 @torch.jit.script
6225 def test_nested_bool_expression(x, y):
6226 # type: (Optional[int], Optional[int]) -> int
6227 if x is not None and x < 2 and y is not None:
6228 x = x + y
6229 else:
6230 x = 5
6231 return x + 2
6232
6233 @torch.jit.script
6234 def test_or(x, y):
6235 # type: (Optional[int], Optional[int]) -> None
6236 if y is None or x is None:
6237 pass
6238 else:
6239 print(x + y)
6240
6241 # backwards compatibility
6242 @torch.jit.script
6243 def test_manual_unwrap_opt(x):
6244 # type: (Optional[int]) -> int
6245 if x is None:
6246 x = 1
6247 else:
6248 x = torch.jit._unwrap_optional(x)
Elias Ellison561037a2019-03-07 09:12:35 -08006249 return x # noqa: T484
Elias Ellisond4f6bef2019-01-18 11:17:34 -08006250
Mikhail Zolotukhinfbecb462019-06-13 17:01:29 -07006251 with self.assertRaisesRegex(RuntimeError, "Arguments for call are not valid"):
Elias Ellisond4f6bef2019-01-18 11:17:34 -08006252 @torch.jit.script
6253 def or_error(x, y):
Elias Ellison561037a2019-03-07 09:12:35 -08006254 # type: (Optional[int], Optional[int]) -> None
Elias Ellisond4f6bef2019-01-18 11:17:34 -08006255 if x is None or y is None:
Elias Ellison561037a2019-03-07 09:12:35 -08006256 print(x + y) # noqa: T484
Elias Ellisond4f6bef2019-01-18 11:17:34 -08006257
Mikhail Zolotukhinfbecb462019-06-13 17:01:29 -07006258 with self.assertRaisesRegex(RuntimeError, "Arguments for call are not valid"):
Elias Ellisond4f6bef2019-01-18 11:17:34 -08006259 @torch.jit.script
6260 def and_error(x, y):
Elias Ellison561037a2019-03-07 09:12:35 -08006261 # type: (Optional[int], Optional[int]) -> None
Elias Ellisond4f6bef2019-01-18 11:17:34 -08006262 if x is None and y is None:
6263 pass
6264 else:
Elias Ellison561037a2019-03-07 09:12:35 -08006265 print(x + y) # noqa: T484
Elias Ellisond4f6bef2019-01-18 11:17:34 -08006266
Mikhail Zolotukhinfbecb462019-06-13 17:01:29 -07006267 with self.assertRaisesRegex(RuntimeError, "Arguments for call are not valid"):
Elias Ellisond4f6bef2019-01-18 11:17:34 -08006268 @torch.jit.script
6269 def named_var(x):
6270 # type: (Optional[int]) -> None
6271 x_none = x is not None
6272 if x_none:
Elias Ellison561037a2019-03-07 09:12:35 -08006273 print(x + 1) # noqa: T484
Elias Ellisond4f6bef2019-01-18 11:17:34 -08006274
Mikhail Zolotukhinfbecb462019-06-13 17:01:29 -07006275 with self.assertRaisesRegex(RuntimeError, "Arguments for call are not valid"):
Elias Ellisond4f6bef2019-01-18 11:17:34 -08006276 @torch.jit.script
6277 def named_var_and(x, y):
6278 # type: (Optional[int], Optional[int]) -> None
6279 x_none = x is not None
6280 if y is not None and x_none:
Elias Ellison561037a2019-03-07 09:12:35 -08006281 print(x + y) # noqa: T484
Elias Ellisond4f6bef2019-01-18 11:17:34 -08006282
Elias Ellisone90adf52019-08-07 12:58:59 -07006283 def test_assertion_optional_refinement(self):
6284 @torch.jit.script
6285 def test(x, y):
6286 # type: (Optional[int], Optional[int]) -> int
6287 assert x is not None and y is not None
6288 return x + y
6289
6290 self.assertEqual(test(2, 2), 4)
6291 with self.assertRaisesRegex(Exception, ""):
6292 test(1, None)
6293
Shen Li10224432021-08-12 11:39:31 -07006294 @unittest.skipIf(GRAPH_EXECUTOR != ProfilingMode.LEGACY, "the current version of Profiler doesn't profile/specialize Optionals")
Thomas Viehmann5c9ab6f2019-05-06 14:54:10 -07006295 def test_optional_tensor(self):
6296 @torch.jit.script
6297 def fn(x, y):
6298 # type: (Optional[Tensor], int) -> int
6299 if x is None:
6300 return y
6301 else:
6302 return 0
6303
6304 res = fn(None, 1)
6305 self.assertEqual(res, 1)
6306 g = torch.jit.last_executed_optimized_graph()
6307 first_input = next(g.inputs())
6308 # check if input is disconnected
Shen Li10224432021-08-12 11:39:31 -07006309 self.assertEqual(first_input.type().kind(), 'OptionalType')
Thomas Viehmann5c9ab6f2019-05-06 14:54:10 -07006310 self.assertEqual(first_input.uses(), [])
6311 t = torch.ones(1)
6312 res = fn(t, 1)
6313 self.assertEqual(res, 0)
6314 g = torch.jit.last_executed_optimized_graph()
Shen Li10224432021-08-12 11:39:31 -07006315 self.assertEqual(next(g.inputs()).type().kind(), 'TensorType')
Thomas Viehmann5c9ab6f2019-05-06 14:54:10 -07006316
6317 @torch.jit.script
6318 def fn(x, y, b):
6319 # type: (Optional[Tensor], Tensor, bool) -> Tensor
6320 if b:
6321 res = y
6322 else:
6323 res = torch.jit._unwrap_optional(x)
6324 return res
6325
6326 t2 = torch.zeros(1)
6327 res = fn(t, t2, True)
6328 self.assertEqual(res, t2)
6329 with self.assertRaisesRegex(RuntimeError, "Unwrapping null optional"):
6330 res = fn(None, t2, False)
6331 res = fn(None, t2, True)
6332 g = torch.jit.last_executed_optimized_graph()
Shen Li10224432021-08-12 11:39:31 -07006333 self.assertIn(next(g.outputs()).type().str(), ("Tensor", "Tensor(requires_grad=1)"))
Thomas Viehmann5c9ab6f2019-05-06 14:54:10 -07006334
Shen Li10224432021-08-12 11:39:31 -07006335 @unittest.skipIf(GRAPH_EXECUTOR != ProfilingMode.LEGACY, "the current version of Profiler doesn't profile/specialize Optionals")
Thomas Viehmann5c9ab6f2019-05-06 14:54:10 -07006336 def test_optional_list(self):
6337 @torch.jit.script
6338 def fn(x, y):
6339 # type: (Optional[List[int]], int) -> int
6340 if x is None:
6341 return y
6342 else:
6343 res = 0
6344 for d in x:
6345 res += d
6346 return res
6347
6348 res = fn(None, 1)
6349 self.assertEqual(res, 1)
6350 g = torch.jit.last_executed_optimized_graph()
6351 first_input = next(g.inputs())
6352 # check if input is disconnected
Shen Li10224432021-08-12 11:39:31 -07006353 self.assertEqual(first_input.type().kind(), 'OptionalType')
Thomas Viehmann5c9ab6f2019-05-06 14:54:10 -07006354 self.assertEqual(first_input.uses(), [])
6355 l = [2, 3]
6356 res = fn(l, 1)
6357 self.assertEqual(res, 5)
6358 g = torch.jit.last_executed_optimized_graph()
Shen Li10224432021-08-12 11:39:31 -07006359 self.assertEqual(next(g.inputs()).type().kind(), 'ListType')
Thomas Viehmann5c9ab6f2019-05-06 14:54:10 -07006360
6361 @torch.jit.script
6362 def fn(x, y, b):
6363 # type: (Optional[List[int]], List[int], bool) -> List[int]
6364 if b:
6365 l = torch.jit._unwrap_optional(x)
6366 else:
6367 l = y
6368 return l
6369
6370 l2 = [0, 1]
6371 res = fn(l, l2, True)
6372 self.assertEqual(res, l)
6373 with self.assertRaisesRegex(RuntimeError, "Unwrapping null optional"):
6374 res = fn(None, l2, True)
6375 res = fn(None, l2, False)
6376 g = torch.jit.last_executed_optimized_graph()
6377 self.assertEqual(next(g.outputs()).type().str(), "int[]")
6378
Elias Ellison85b1c452020-02-28 19:51:49 -08006379 def test_alias_covariant_type_containers(self):
6380 @torch.jit.script
6381 def foo(x):
6382 # type: (bool)
Jerry Zhang5b9f1ad2020-03-03 10:46:42 -08006383 if x:
Elias Ellison85b1c452020-02-28 19:51:49 -08006384 a = (None,)
6385 else:
6386 a = ([],)
6387 return a
6388
6389 @torch.jit.script
6390 def foo2(x, li):
6391 # type: (bool, Tuple[Optional[List[Tensor]]])
6392 if x:
6393 li = (None,)
6394 return li
6395
Adam Paszkea58f2d22018-03-22 16:58:36 +01006396 def test_while_write_outer_then_read(self):
6397 def func(a, b):
David Riazati6f53b4e2018-09-13 11:10:00 -07006398 while bool(a < 10):
Adam Paszkea58f2d22018-03-22 16:58:36 +01006399 a = a + 1
6400 b = a + 1
6401 return a + b
6402
gchanane1f5d802018-04-18 23:37:54 -04006403 inputs = self._make_scalar_vars([42, 1337], torch.int64)
Adam Paszkea58f2d22018-03-22 16:58:36 +01006404 self.checkScript(func, inputs, optimize=True)
6405
Animesh Jain1d90d6e2022-07-07 18:57:31 +00006406 @skipIfTorchDynamo("TorchDynamo fails with unknown reason")
Adam Paszkea58f2d22018-03-22 16:58:36 +01006407 def test_while_nest_if(self):
6408 def func(a, b):
Richard Zou67f6f932018-08-27 08:53:56 -07006409 # type: (int, int) -> int
6410 c = 0
Adam Paszkea58f2d22018-03-22 16:58:36 +01006411 while a < 10:
6412 a = a + 1
6413 b = b + 1
6414 if a > b:
6415 c = -a
6416 else:
6417 c = -b
6418 return c + 1
6419
gchanane1f5d802018-04-18 23:37:54 -04006420 inputs = self._make_scalar_vars([-1234, 4321], torch.int64)
Adam Paszkea58f2d22018-03-22 16:58:36 +01006421 self.checkScript(func, inputs, optimize=True)
6422
Kartikey Pandey2378c122019-06-10 14:51:57 -07006423 def test_divmod(self):
6424 def func_int(a, b):
6425 # type: (int, int) -> Tuple[int, int]
6426 return divmod(a, b)
6427
6428 def func_float(a, b):
6429 # type: (float, float) -> Tuple[float, float]
6430 return divmod(a, b)
6431
6432 def func_int_float(a, b):
6433 # type: (int, float) -> Tuple[float, float]
6434 return divmod(a, b)
6435
6436 def func_float_int(a, b):
6437 # type: (float, int) -> Tuple[float, float]
6438 return divmod(a, b)
6439
6440 def divmod_test_iterator(func, num, den):
6441 for i in num:
6442 for j in den:
David Riazatidefd23b2019-06-25 16:17:49 -07006443 self.checkScript(func, (i, j), frames_up=2)
Kartikey Pandey2378c122019-06-10 14:51:57 -07006444
6445 num_int = [1024, -1024]
6446 den_int = [10, -10]
6447 num_float = [5.3, -5.3]
6448 den_float = [2.0, -2.0]
6449 divmod_test_iterator(func_int, num_int, den_int)
6450 divmod_test_iterator(func_float, num_float, den_float)
6451 divmod_test_iterator(func_int_float, num_int, den_float)
6452 divmod_test_iterator(func_float_int, num_float, den_int)
6453
Shen Li10224432021-08-12 11:39:31 -07006454 with self.assertRaisesRegex(RuntimeError, "ZeroDivisionError: integer division or modulo by zero"):
Kartikey Pandey2378c122019-06-10 14:51:57 -07006455 cu = torch.jit.CompilationUnit(dedent(inspect.getsource(func_int)))
6456 cu.func_int(1024, 0)
6457 with self.assertRaisesRegex(RuntimeError, "ZeroDivisionError: float divmod()"):
6458 cu = torch.jit.CompilationUnit(dedent(inspect.getsource(func_float)))
6459 cu.func_float(5.3, 0.0)
6460 with self.assertRaisesRegex(RuntimeError, "ZeroDivisionError: float divmod()"):
6461 cu = torch.jit.CompilationUnit(dedent(inspect.getsource(func_int_float)))
6462 cu.func_int_float(1024, 0.0)
6463 with self.assertRaisesRegex(RuntimeError, "ZeroDivisionError: float divmod()"):
6464 cu = torch.jit.CompilationUnit(dedent(inspect.getsource(func_float_int)))
6465 cu.func_float_int(5.3, 0)
6466
Elias Ellison862b8ca2018-12-03 23:59:36 -08006467 def test_math_ops(self):
Horace He2ed6f012019-06-04 13:12:05 -07006468 def checkMathWrap(func_name, num_args=1, is_float=True, **args):
6469 if is_float:
6470 checkMath(func_name, num_args, True, **args)
6471 checkMath(func_name, num_args, False, **args)
6472 else:
6473 checkMath(func_name, num_args, is_float, **args)
Elias Ellison862b8ca2018-12-03 23:59:36 -08006474
Horace He2ed6f012019-06-04 13:12:05 -07006475 inf = float("inf")
6476 NaN = float("nan")
Shen Li10224432021-08-12 11:39:31 -07006477 mx_int = 2**31 - 1
6478 mn_int = -2**31
6479 float_vals = ([inf, NaN, 0.0, 1.0, 2.2, -1.0, -0.0, -2.2, -inf, 1, 0, 2] +
6480 [10.0 ** i for i in range(5)] + [-(10.0 ** i) for i in range(5)])
6481 int_vals = list(range(-5, 5, 1)) + [mx_int + 5, mx_int * 2, mn_int - 5, mn_int * 2]
Elias Ellison862b8ca2018-12-03 23:59:36 -08006482
Shen Li10224432021-08-12 11:39:31 -07006483 def checkMath(func_name, num_args, is_float=True, ret_type="float", debug=False, vals=None, args_type=None):
6484 funcs_template = dedent('''
Horace He2ed6f012019-06-04 13:12:05 -07006485 def func(a, b):
6486 # type: {args_type} -> {ret_type}
6487 return math.{func}({args})
Shen Li10224432021-08-12 11:39:31 -07006488 ''')
Horace He2ed6f012019-06-04 13:12:05 -07006489 if num_args == 1:
6490 args = "a"
6491 elif num_args == 2:
6492 args = "a, b"
6493 else:
6494 raise RuntimeError("Test doesn't support more than 2 arguments")
6495 if args_type is None:
6496 args_type = "(float, float)" if is_float else "(int, int)"
Shen Li10224432021-08-12 11:39:31 -07006497 funcs_str = funcs_template.format(func=func_name, args=args, args_type=args_type, ret_type=ret_type)
Horace He2ed6f012019-06-04 13:12:05 -07006498 scope = {}
6499 execWrapper(funcs_str, globals(), scope)
6500 cu = torch.jit.CompilationUnit(funcs_str)
6501 f_script = cu.func
Shen Li10224432021-08-12 11:39:31 -07006502 f = scope['func']
Alexandr Morevda4ff172019-04-16 10:19:04 -07006503
Horace He2ed6f012019-06-04 13:12:05 -07006504 if vals is None:
6505 vals = float_vals if is_float else int_vals
6506 vals = [(i, j) for i in vals for j in vals]
Alexandr Morevda4ff172019-04-16 10:19:04 -07006507
Horace He2ed6f012019-06-04 13:12:05 -07006508 for a, b in vals:
6509 res_python = None
6510 res_script = None
6511 try:
6512 res_python = f(a, b)
6513 except Exception as e:
6514 res_python = e
6515 try:
6516 res_script = f_script(a, b)
6517 except Exception as e:
6518 res_script = e
6519 if debug:
6520 print("in: ", a, b)
6521 print("out: ", res_python, res_script)
6522 # We can't use assertEqual because of a couple of differences:
6523 # 1. nan == nan should return true
6524 # 2. When python functions throw an exception, we usually want to silently ignore them.
6525 # (ie: We want to return `nan` for math.sqrt(-5))
6526 if res_python != res_script:
6527 if isinstance(res_python, Exception):
6528 continue
Karl Ostmo72355322019-07-02 10:09:39 -07006529
Horace He2ed6f012019-06-04 13:12:05 -07006530 if type(res_python) == type(res_script):
Shen Li10224432021-08-12 11:39:31 -07006531 if isinstance(res_python, tuple) and (math.isnan(res_python[0]) == math.isnan(res_script[0])):
Horace He2ed6f012019-06-04 13:12:05 -07006532 continue
Shen Li10224432021-08-12 11:39:31 -07006533 if isinstance(res_python, float) and math.isnan(res_python) and math.isnan(res_script):
Horace He2ed6f012019-06-04 13:12:05 -07006534 continue
Shen Li10224432021-08-12 11:39:31 -07006535 msg = ("Failed on {func_name} with inputs {a} {b}. Python: {res_python}, Script: {res_script}"
6536 .format(func_name=func_name, a=a, b=b, res_python=res_python, res_script=res_script))
Nikita Shulgae1a2b0d2023-01-06 00:56:43 +00006537 # math.pow() behavior has changed in 3.11, see https://docs.python.org/3/library/math.html#math.pow
6538 if sys.version_info >= (3, 11) and func_name == "pow" and a == 0.0 and b == -math.inf:
6539 self.assertTrue(res_python == math.inf and type(res_script) is RuntimeError)
6540 else:
6541 self.assertEqual(res_python, res_script, msg=msg, atol=(1e-4) * max(abs(res_python), res_script), rtol=0)
Alexandr Morevda4ff172019-04-16 10:19:04 -07006542
Shen Li10224432021-08-12 11:39:31 -07006543 unary_float_ops = ["log", "log1p", "log10", "exp", "sqrt", "gamma", "lgamma", "erf",
6544 "erfc", "expm1", "fabs", "acos", "asin", "atan", "cos", "sin", "tan",
6545 "asinh", "atanh", "acosh", "sinh", "cosh", "tanh", "degrees", "radians"]
Karl Ostmo72355322019-07-02 10:09:39 -07006546 binary_float_ops = ["atan2", "fmod", "copysign"]
6547 for op in unary_float_ops:
6548 checkMathWrap(op, 1)
6549 for op in binary_float_ops:
6550 checkMathWrap(op, 2)
Alexandr Morevda4ff172019-04-16 10:19:04 -07006551
Karl Ostmo72355322019-07-02 10:09:39 -07006552 checkMath("modf", 1, ret_type="Tuple[float, float]")
6553 checkMath("frexp", 1, ret_type="Tuple[float, int]")
6554 checkMath("isnan", 1, ret_type="bool")
6555 checkMath("isinf", 1, ret_type="bool")
Shen Li10224432021-08-12 11:39:31 -07006556 checkMath("ldexp", 2, is_float=False, ret_type="float", args_type="(float, int)",
6557 vals=[(i, j) for i in float_vals for j in range(-10, 10)])
Elias Ellisonff8b7ef2019-09-27 17:11:43 -07006558 checkMath("pow", 2, is_float=False, ret_type="float")
Karl Ostmo72355322019-07-02 10:09:39 -07006559 checkMath("pow", 2, is_float=True, ret_type="float")
David Reisse75fb432020-04-22 09:20:13 -07006560 checkMathWrap("floor", ret_type="int")
6561 checkMathWrap("ceil", ret_type="int")
6562 checkMathWrap("gcd", 2, is_float=False, ret_type="int")
6563 checkMath("isfinite", 1, ret_type="bool")
Nikita Shulgadc5cda02022-01-19 15:59:11 -08006564 checkMathWrap("remainder", 2)
Shen Li10224432021-08-12 11:39:31 -07006565 checkMathWrap("factorial", 1, is_float=False, ret_type="int", vals=[(i, 0) for i in range(-2, 10)])
Yauheni Koran5f7ef092019-05-13 08:43:16 -07006566
Animesh Jain1d90d6e2022-07-07 18:57:31 +00006567 @skipIfTorchDynamo("TorchDynamo fails with unknown reason")
Adam Paszkea58f2d22018-03-22 16:58:36 +01006568 def test_if_nest_while(self):
6569 def func(a, b):
Richard Zou67f6f932018-08-27 08:53:56 -07006570 # type: (int, int) -> int
6571 c = 0
Adam Paszkea58f2d22018-03-22 16:58:36 +01006572 if a > b:
6573 while a > b:
6574 b = b + 1
6575 c = -b
6576 return c
6577
gchanane1f5d802018-04-18 23:37:54 -04006578 inputs = self._make_scalar_vars([4321, 1234], torch.int64)
David Riazatidefd23b2019-06-25 16:17:49 -07006579 self.checkScript(func, inputs)
Adam Paszkea58f2d22018-03-22 16:58:36 +01006580
Wanchao Liang79ceece2018-11-09 11:26:50 -08006581 def test_script_optional_none(self):
6582 def none_stmt(x):
Wanchao Liang47c1bad2018-07-27 22:47:29 -07006583 output = None
6584 output = x
6585 return output
6586
Wanchao Liang79ceece2018-11-09 11:26:50 -08006587 def none_args(x):
6588 # type: (Optional[Tensor]) -> Optional[Tensor]
6589 return None
6590
6591 self.checkScript(none_stmt, [torch.arange(0, 2)], optimize=True)
6592 self.checkScript(none_args, [None], optimize=True)
6593
6594 # test undefined tensor None as default param
6595 def test_script_optional_tensor_none(x=None):
6596 # type: (Optional[Tensor]) -> Tensor
6597 res = torch.zeros(1, dtype=torch.int8)
6598 if x is None:
6599 res = res + 1
6600 else:
Wanchao Liang5f6ecd12019-02-14 21:37:08 -08006601 res = x
Wanchao Liang79ceece2018-11-09 11:26:50 -08006602 return res
6603
6604 fn = test_script_optional_tensor_none
6605 scripted_fn = torch.jit.script(fn)
6606 self.assertEqual(fn(), scripted_fn())
6607 self.assertEqual(fn(torch.zeros(1)), scripted_fn(torch.zeros(1)))
6608
6609 # test typical None as default param
6610 def test_script_optional_other_none(x=None):
6611 # type: (Optional[float]) -> float
6612 res = 2.0
6613 if x is None:
6614 res = res + 1.0
6615 else:
Wanchao Liang5f6ecd12019-02-14 21:37:08 -08006616 res = x
Wanchao Liang79ceece2018-11-09 11:26:50 -08006617 return res
6618
6619 fn = test_script_optional_other_none
6620 scripted_fn = torch.jit.script(fn)
6621 self.assertEqual(fn(), scripted_fn())
6622 self.assertEqual(fn(1.0), scripted_fn(1.0))
Wanchao Liang47c1bad2018-07-27 22:47:29 -07006623
6624 def test_script_clamp_none(self):
Wanchao Liang47c1bad2018-07-27 22:47:29 -07006625 def test_script_clamp_max_none(x):
Wanchao Liang4e1c64c2018-10-25 16:06:00 -07006626 return torch.clamp(x, min=2, max=None)
6627
6628 def test_script_clamp_max(x):
6629 return torch.clamp(x, max=2)
Wanchao Liang47c1bad2018-07-27 22:47:29 -07006630
6631 def test_script_clamp_min_none(x):
Wanchao Liang4e1c64c2018-10-25 16:06:00 -07006632 return torch.clamp(x, min=None, max=2)
6633
6634 def test_script_clamp_min(x):
6635 return torch.clamp(x, min=2)
Wanchao Liang47c1bad2018-07-27 22:47:29 -07006636
6637 input = [torch.arange(0, 3)]
6638 self.checkScript(test_script_clamp_max_none, input, optimize=True)
Wanchao Liang4e1c64c2018-10-25 16:06:00 -07006639 self.checkScript(test_script_clamp_max, input, optimize=True)
Wanchao Liang47c1bad2018-07-27 22:47:29 -07006640 self.checkScript(test_script_clamp_min_none, input, optimize=True)
Wanchao Liang4e1c64c2018-10-25 16:06:00 -07006641 self.checkScript(test_script_clamp_min, input, optimize=True)
Wanchao Liang47c1bad2018-07-27 22:47:29 -07006642
James Reed213fa612018-03-23 08:55:32 -07006643 def test_script_bool_constant(self):
James Reed213fa612018-03-23 08:55:32 -07006644 def test_script_bool_constant():
6645 a = True
6646 return a
David Riazatidefd23b2019-06-25 16:17:49 -07006647 self.checkScript(test_script_bool_constant, [])
James Reed213fa612018-03-23 08:55:32 -07006648
Adam Paszkea58f2d22018-03-22 16:58:36 +01006649 def test_ternary(self):
6650 def func(a, b):
6651 c = 3
David Riazati6f53b4e2018-09-13 11:10:00 -07006652 c = a + b if bool(a > 3) else b
Adam Paszkea58f2d22018-03-22 16:58:36 +01006653 return c
6654
gchanane1f5d802018-04-18 23:37:54 -04006655 inputs_true = self._make_scalar_vars([5, 2], torch.int64)
6656 inputs_false = self._make_scalar_vars([1, 0], torch.int64)
Adam Paszkea58f2d22018-03-22 16:58:36 +01006657 self.checkScript(func, inputs_true, optimize=True)
6658 self.checkScript(func, inputs_false, optimize=True)
6659
Erjia Guanb80a3662020-12-21 10:09:37 -08006660 def test_ternary_module_type_hint(self):
6661 class M1(torch.nn.Module):
6662 def forward(self) -> Any:
Shen Li10224432021-08-12 11:39:31 -07006663 return 'out' if self.training else {}
Erjia Guanb80a3662020-12-21 10:09:37 -08006664
6665 class M2(torch.nn.Module):
6666 def forward(self) -> Any:
Shen Li10224432021-08-12 11:39:31 -07006667 out: Any = 'out' if self.training else {}
Erjia Guanb80a3662020-12-21 10:09:37 -08006668 return out
6669
6670 class M3(torch.nn.Module):
6671 def forward(self) -> Optional[int]:
6672 return None if self.training else 1
6673
6674 for module in [M1, M2, M3]:
6675 self.checkModule(module().train(), ())
6676 self.checkModule(module().eval(), ())
6677
nikithamalgifa701682021-02-06 10:11:30 -08006678 def test_ternary_static_if(self):
6679 # Test for True branch when condition variable
6680 # is annotated as Final
6681 class M1(torch.nn.Module):
6682 flag: torch.jit.Final[bool]
6683
6684 def __init__(self):
6685 super().__init__()
6686 self.flag = True
6687
6688 def forward(self) -> torch.Tensor:
6689 return torch.ones(3) if self.flag else {}
6690
6691 # Test for True branch when condition variable
6692 # is annotated as Final
6693 class M2(torch.nn.Module):
6694 flag: torch.jit.Final[bool]
6695
6696 def __init__(self):
6697 super().__init__()
6698 self.flag = False
6699
6700 def forward(self) -> torch.Tensor:
6701 return {} if self.flag else torch.ones(3)
6702
6703 model1 = M1()
6704 model2 = M2()
6705 script_model_1 = torch.jit.script(model1)
6706 script_model_2 = torch.jit.script(model2)
6707 self.assertEqual(model1.forward(), script_model_1.forward())
6708 self.assertEqual(model2.forward(), script_model_2.forward())
6709
bingdc81ba12022-03-01 15:22:12 -08006710 def test_ternary_right_associative(self):
6711 def plus_123(x: int):
6712 return x + 1 if x == 1 else x + 2 if x == 2 else x + 3
6713 self.checkScript(plus_123, (1,))
6714 self.checkScript(plus_123, (2,))
6715 self.checkScript(plus_123, (3,))
6716
Yanbo Liang490c1cf2022-12-19 04:14:11 +00006717 @skipIfTorchDynamo("TorchDynamo fails with unknown reason")
Adam Paszkea58f2d22018-03-22 16:58:36 +01006718 def test_print(self):
6719 def func(x, y):
6720 q = (x + y).sigmoid()
Elias Ellison0ef5cfd2018-08-24 13:31:24 -07006721 print(q, 1, 2, [1, 2], [1.0, 2.0])
Adam Paszkea58f2d22018-03-22 16:58:36 +01006722 w = -q
6723 return w * w
6724
Shen Li10224432021-08-12 11:39:31 -07006725 x = torch.arange(4., requires_grad=True)
6726 y = torch.arange(0., 8, 2, requires_grad=True)
Adam Paszkea58f2d22018-03-22 16:58:36 +01006727 self.checkScript(func, [x, y], optimize=True, capture_output=True)
6728
David Riazati404f8662018-10-31 12:48:40 -07006729 def test_format(self):
6730 def func(x):
6731 print("{}, I'm a {}".format("Hello", "test"))
Zachary DeVito86192302018-11-02 00:03:56 -07006732 print("format blank".format())
6733 print("stuff before {}".format("hi"))
6734 print("{} stuff after".format("hi"))
David Riazati404f8662018-10-31 12:48:40 -07006735 return x + 1
6736
Shen Li10224432021-08-12 11:39:31 -07006737 x = torch.arange(4., requires_grad=True)
David Riazati404f8662018-10-31 12:48:40 -07006738 self.checkScript(func, [x], optimize=True, capture_output=True)
6739
Elias Ellison539579a2018-09-04 09:17:32 -07006740 def test_logical_short_circuit(self):
6741 @torch.jit.script
6742 def testNoThrows(t):
6743 c1 = 1
David Riazati6f53b4e2018-09-13 11:10:00 -07006744 if (False and bool(t[1])) or (True or bool(t[1])):
Elias Ellison539579a2018-09-04 09:17:32 -07006745 c1 = 0
6746 return c1
6747
Elias Ellison3eefc062019-12-09 14:17:58 -08006748 FileCheck().check_not("prim::If").run(testNoThrows.graph)
eellisond8d83712019-02-22 17:54:09 -08006749 self.assertEqual(0, testNoThrows(torch.randn(0)))
Elias Ellison3eefc062019-12-09 14:17:58 -08006750 self.assertEqual(0, testNoThrows(torch.randn([2, 3])))
eellisond8d83712019-02-22 17:54:09 -08006751
Elias Ellison539579a2018-09-04 09:17:32 -07006752 @torch.jit.script
6753 def throwsOr(t):
David Riazati6f53b4e2018-09-13 11:10:00 -07006754 c0 = False or bool(t[1])
Elias Ellison539579a2018-09-04 09:17:32 -07006755 print(c0)
6756
6757 @torch.jit.script
6758 def throwsAnd(t):
David Riazati6f53b4e2018-09-13 11:10:00 -07006759 c0 = True and bool(t[1])
Elias Ellison539579a2018-09-04 09:17:32 -07006760 print(c0)
6761
6762 t = torch.randn(0)
Shen Li10224432021-08-12 11:39:31 -07006763 with self.assertRaisesRegex(RuntimeError, "index 1 out of range for tensor of size"):
Elias Ellison539579a2018-09-04 09:17:32 -07006764 throwsOr(t)
Shen Li10224432021-08-12 11:39:31 -07006765 with self.assertRaisesRegex(RuntimeError, "index 1 out of range for tensor of size"):
Elias Ellison539579a2018-09-04 09:17:32 -07006766 throwsAnd(t)
6767
Wanchao Liang50cf3262018-08-03 10:52:26 -07006768 def test_type_cast(self):
Shen Li10224432021-08-12 11:39:31 -07006769 template = dedent('''
David Riazatidefd23b2019-06-25 16:17:49 -07006770 def func(v):
David Riazati70f0c472018-12-27 15:58:32 -08006771 # type: ({from_type}) -> {to_type}
6772 return {to_type}(v)
Shen Li10224432021-08-12 11:39:31 -07006773 ''')
Wanchao Liang50cf3262018-08-03 10:52:26 -07006774
David Riazati70f0c472018-12-27 15:58:32 -08006775 def check_cast(from_type, to_type, value, raises=False):
6776 code = template.format(from_type=from_type, to_type=to_type)
David Riazatidefd23b2019-06-25 16:17:49 -07006777 self.checkScript(code, (value,))
David Riazatid1ac1eb2018-10-03 12:25:39 -07006778
Shen Li10224432021-08-12 11:39:31 -07006779 check_cast('int', 'float', 1)
6780 check_cast('int', 'bool', 1)
6781 check_cast('int', 'bool', 0)
Wanchao Liang50cf3262018-08-03 10:52:26 -07006782
Shen Li10224432021-08-12 11:39:31 -07006783 check_cast('float', 'int', 1.)
6784 check_cast('float', 'bool', 1.)
6785 check_cast('float', 'bool', 0.)
David Riazatid1ac1eb2018-10-03 12:25:39 -07006786
Shen Li10224432021-08-12 11:39:31 -07006787 check_cast('bool', 'int', True)
6788 check_cast('bool', 'float', True)
David Riazatid1ac1eb2018-10-03 12:25:39 -07006789
Adam Paszkea58f2d22018-03-22 16:58:36 +01006790 def test_multiple_assignment(self):
6791 def outer_func(x):
6792 return x * 2, x + 2
6793
6794 @torch.jit.script
6795 def func(x):
6796 y, z = outer_func(x)
6797 return y + z
6798
6799 x = torch.arange(4)
6800 self.assertEqual(func(x), x * 2 + x + 2)
6801
6802 def test_literals(self):
6803 def func(a):
6804 return a.view(size=[1, 2, 3])
6805
6806 a = torch.randn(6)
6807 self.checkScript(func, [a], optimize=True)
6808
Adam Paszkeda6c3c92018-03-31 18:35:33 +02006809 def test_return(self):
6810 def no_return(a):
6811 a + 1
6812
6813 def void_return(a):
6814 return
6815
6816 def one_return(a):
Shen Li10224432021-08-12 11:39:31 -07006817 return a + 1.
Adam Paszkeda6c3c92018-03-31 18:35:33 +02006818
6819 def multiple_returns(a):
Shen Li10224432021-08-12 11:39:31 -07006820 return a * 1., a * 2., a * 3.
Adam Paszkeda6c3c92018-03-31 18:35:33 +02006821
6822 a = torch.randn(1, dtype=torch.float)
6823 self.checkScript(no_return, [a], optimize=True)
6824 self.checkScript(void_return, [a], optimize=True)
6825 self.checkScript(one_return, [a], optimize=True)
6826 self.checkScript(multiple_returns, [a], optimize=True)
6827
Sam Estepe3900d22021-04-19 13:14:27 -07006828 with self.assertRaisesRegex(RuntimeError, "does not return along all paths"):
Shen Li10224432021-08-12 11:39:31 -07006829 torch.jit.CompilationUnit('''
David Riazatidf67d412018-11-05 16:49:56 -08006830 def no_return_bad_annotation(a):
6831 # type: (Tensor) -> Tensor
6832 a + 1
Shen Li10224432021-08-12 11:39:31 -07006833 ''')
David Riazatidf67d412018-11-05 16:49:56 -08006834
Adam Paszkea58f2d22018-03-22 16:58:36 +01006835 def test_error(self):
6836 @torch.jit.script
6837 def foo(a):
6838 return a.t()
bhushan4ca1a542019-02-26 14:11:18 -08006839 s = Variable(torch.rand(5, 5, 5))
Adam Paszkea58f2d22018-03-22 16:58:36 +01006840 # XXX: this should stay quiet in stay propagation and only fail in the interpreter
Shen Li10224432021-08-12 11:39:31 -07006841 with self.assertRaisesRegex(RuntimeError, "failed in the TorchScript interpreter"):
Adam Paszkea58f2d22018-03-22 16:58:36 +01006842 foo(s)
6843
6844 @torch.jit.script
6845 def bar(c, b):
Richard Zou68c2e012018-09-05 14:51:18 -07006846 return c + b
Adam Paszkea58f2d22018-03-22 16:58:36 +01006847
Shen Li10224432021-08-12 11:39:31 -07006848 with self.assertRaisesRegex(RuntimeError, "failed in the TorchScript interpreter"):
6849 bar(Variable(torch.rand(10), requires_grad=True), Variable(torch.rand(9), requires_grad=True))
Adam Paszkea58f2d22018-03-22 16:58:36 +01006850
Mikhail Zolotukhin2c8dce92019-11-19 17:55:42 -08006851 def test_error_stacktrace(self):
6852 @torch.jit.script
6853 def baz(c, b):
6854 return c + b
6855
6856 @torch.jit.script
6857 def foo(c, b):
6858 return baz(c, b)
6859
6860 @torch.jit.script
6861 def bar(c, b):
6862 return foo(c, b)
6863
6864 with self.assertRaises(RuntimeError) as cm:
6865 bar(torch.rand(10), torch.rand(9))
Shen Li10224432021-08-12 11:39:31 -07006866 FileCheck().check("The following operation failed in the TorchScript interpreter") \
6867 .check("Traceback") \
6868 .check("in foo").check("in baz").run(str(cm.exception))
Mikhail Zolotukhin2c8dce92019-11-19 17:55:42 -08006869
6870 def test_error_stacktrace_interface(self):
Mikhail Zolotukhin2c8dce92019-11-19 17:55:42 -08006871 @torch.jit.script
6872 def baz(c, b):
6873 return c + b
6874
6875 @torch.jit.script
6876 def foo(c, b):
6877 return baz(c, b)
6878
6879 @torch.jit.script
6880 def bar(c, b):
6881 return foo(c, b)
6882
6883 @torch.jit.script
Aaron Gokaslan8fce9a02023-02-07 21:10:52 +00006884 class Bar:
Mikhail Zolotukhin2c8dce92019-11-19 17:55:42 -08006885 def one(self, x, y):
6886 return bar(x, y)
6887
6888 @torch.jit.interface
Aaron Gokaslan8fce9a02023-02-07 21:10:52 +00006889 class IFace:
Mikhail Zolotukhin2c8dce92019-11-19 17:55:42 -08006890 def one(self, x, y):
6891 # type: (Tensor, Tensor) -> Tensor
6892 pass
6893
Nikita Shulga00adc7b2021-01-27 10:49:10 -08006894 make_global(IFace)
6895
Mikhail Zolotukhin2c8dce92019-11-19 17:55:42 -08006896 @torch.jit.script
6897 def as_interface(x):
6898 # type: (IFace) -> IFace
6899 return x
6900
6901 f = as_interface(Bar())
6902
6903 with self.assertRaises(RuntimeError) as cm:
6904 x = f.one(torch.rand(10), torch.rand(9))
6905 bar(torch.rand(10), torch.rand(9))
Shen Li10224432021-08-12 11:39:31 -07006906 FileCheck().check("The following operation failed in the TorchScript interpreter") \
6907 .check("Traceback") \
6908 .check("in foo").check("in baz").run(str(cm.exception))
Mikhail Zolotukhin2c8dce92019-11-19 17:55:42 -08006909
Tugrul Incec9023e32020-03-13 12:49:41 -07006910 def test_operator_precedence(self):
6911 def double(x):
6912 # type: (int) -> int
6913 return 2 * x
6914
6915 def complicated_arithmetic_operation():
6916 # TODO we need to test exponent operator '**' and bitwise not
6917 # operator '~' once they are properly supported.
6918 list = [0, 1, 2, 3]
Shen Li10224432021-08-12 11:39:31 -07006919 result = list[1:3][0] + double(4) + (-3 + 8) * 6 // 2 % 4 << 2 + 1 >> 1 | 23 & 16 + 3 ^ 4
Tugrul Incec9023e32020-03-13 12:49:41 -07006920 return result
6921
6922 self.checkScript(complicated_arithmetic_operation, ())
Richard Zou1807bac2018-03-28 13:41:45 -04006923
Ansley Ussery58fe6792020-12-28 10:21:47 -08006924 def test_in_operator_with_two_strings(self):
6925 def fn() -> bool:
6926 return "a" in "abcd"
6927 self.checkScript(fn, ())
6928
Elias Ellisonf649d8b2018-11-13 16:30:02 -08006929 def test_bitwise_ops(self):
Shen Li10224432021-08-12 11:39:31 -07006930
Elias Ellisonf649d8b2018-11-13 16:30:02 -08006931 def int_test():
Tugrul Incec9023e32020-03-13 12:49:41 -07006932 return 2 & 3, 2 ^ 3, 2 | 3, 2 << 3, 2 >> 3
Elias Ellisonf649d8b2018-11-13 16:30:02 -08006933
6934 self.checkScript(int_test, ())
6935
6936 def bool_test(x, y):
6937 # type: (bool, bool) -> Tuple[bool, bool, bool]
6938 return x & y, x ^ y, x | y
6939
6940 self.checkScript(bool_test, (True, False))
6941 self.checkScript(bool_test, (True, True))
6942
6943 def tensor_test(x, y):
6944 return x & y, x ^ y, x | y
6945
Tugrul Incec9023e32020-03-13 12:49:41 -07006946 def tensor_with_int_test(x, y):
6947 # type: (Tensor, int) -> Tuple[Tensor, Tensor]
6948 return x << y, x >> y
6949
Elias Ellisonf649d8b2018-11-13 16:30:02 -08006950 x = torch.tensor(2)
6951 y = torch.tensor(3)
6952
6953 self.checkScript(tensor_test, (x, y))
Tugrul Incec9023e32020-03-13 12:49:41 -07006954 self.checkScript(tensor_with_int_test, (x, 2))
James Reed6b099ed2019-05-29 21:47:15 -07006955
Horace Hef3f83cc2019-09-03 11:09:52 -07006956 def not_test(x):
6957 return ~x
6958
Shen Li10224432021-08-12 11:39:31 -07006959 self.checkScript(not_test, (torch.tensor([2, 4]), ))
Horace Hef3f83cc2019-09-03 11:09:52 -07006960
Michaelf3b8a472020-03-18 11:58:19 -07006961 def test_all(self):
6962 @torch.jit.script
6963 def test_all_tensor(x):
6964 return all(x)
6965 self.assertFalse(test_all_tensor(torch.tensor([1, 0, 3], dtype=torch.uint8)))
6966 self.assertTrue(test_all_tensor(torch.tensor([3.14, 3, 99], dtype=torch.uint8)))
6967 self.assertTrue(test_all_tensor(torch.tensor([True, True], dtype=torch.uint8)))
Shen Li10224432021-08-12 11:39:31 -07006968 self.assertFalse(test_all_tensor(torch.tensor([True, False], dtype=torch.uint8)))
James Reed6b099ed2019-05-29 21:47:15 -07006969
Michaelf3b8a472020-03-18 11:58:19 -07006970 @torch.jit.script
Elias Ellisonbcbde492020-03-18 16:41:53 -07006971 def test_all_bool_list(x):
6972 # type: (List[bool]) -> bool
Michaelf3b8a472020-03-18 11:58:19 -07006973 return all(x)
6974 self.assertTrue(test_all_bool_list([True, True]))
6975 self.assertTrue(test_all_bool_list([True, 1]))
6976 self.assertFalse(test_all_bool_list([True, False]))
6977 self.assertFalse(test_all_bool_list([True, 0]))
6978 self.assertFalse(test_all_bool_list([False, 0]))
6979 self.assertTrue(test_all_bool_list([]))
Elias Ellisonf649d8b2018-11-13 16:30:02 -08006980
Michaelf3b8a472020-03-18 11:58:19 -07006981 @torch.jit.script
Elias Ellisonbcbde492020-03-18 16:41:53 -07006982 def test_all_int_list(x):
6983 # type: (List[int]) -> bool
Michaelf3b8a472020-03-18 11:58:19 -07006984 return all(x)
6985 self.assertTrue(test_all_int_list([3, 6]))
6986 self.assertFalse(test_all_int_list([2, 0]))
6987
6988 @torch.jit.script
Elias Ellisonbcbde492020-03-18 16:41:53 -07006989 def test_all_float_list(x):
6990 # type: (List[float]) -> bool
Michaelf3b8a472020-03-18 11:58:19 -07006991 return all(x)
6992 self.assertTrue(test_all_float_list([3.14, 8.1]))
6993 self.assertFalse(test_all_float_list([3.14, 0, 8.9]))
6994
Shen Li10224432021-08-12 11:39:31 -07006995
Richard Zou8489c4c2018-06-21 15:43:38 -04006996 def test_number_math(self):
Shen Li10224432021-08-12 11:39:31 -07006997 ops_template = dedent('''
Elias Ellison686e8322018-11-12 14:02:06 -08006998 def func():
6999 return {scalar1} {op} {scalar2}
Shen Li10224432021-08-12 11:39:31 -07007000 ''')
7001 ops = ['+', '-', '*', '%', '<', '<=', '>', '>=', '==', '!=', '//']
7002 funcs_template = dedent('''
Zachary DeVitob0cf7802019-01-02 20:07:55 -08007003 def func():
7004 return {func}({scalar1}, {scalar2})
Shen Li10224432021-08-12 11:39:31 -07007005 ''')
7006 funcs = ['min', 'max']
7007 scalars = ['7', '2', '3', '-3', '3.14', '0.125', '-0.5', '2.0', '-2.0']
7008 scalar_pairs = [(scalar1, scalar2) for scalar1 in scalars for scalar2 in scalars]
Zachary DeVitob0cf7802019-01-02 20:07:55 -08007009
7010 def run_test(code):
7011 scope = {}
Mikhail Zolotukhin159c2f32019-01-09 16:57:22 -08007012 execWrapper(code, globals(), scope)
Zachary DeVitob0cf7802019-01-02 20:07:55 -08007013 cu = torch.jit.CompilationUnit(code)
7014
Shen Li10224432021-08-12 11:39:31 -07007015 self.assertEqual(cu.func(), scope['func']())
Zachary DeVitob0cf7802019-01-02 20:07:55 -08007016
Elias Ellison686e8322018-11-12 14:02:06 -08007017 for scalar1, scalar2 in scalar_pairs:
7018 for op in ops:
Zachary DeVitob0cf7802019-01-02 20:07:55 -08007019 code = ops_template.format(op=op, scalar1=scalar1, scalar2=scalar2)
7020 run_test(code)
7021 for func in funcs:
Shen Li10224432021-08-12 11:39:31 -07007022 code = funcs_template.format(func=func, scalar1=scalar1, scalar2=scalar2)
Zachary DeVitob0cf7802019-01-02 20:07:55 -08007023 run_test(code)
Richard Zou8489c4c2018-06-21 15:43:38 -04007024
Basil Hosmer167722d2019-09-25 13:46:44 -07007025 # test Scalar overloads
7026 for scalar1, scalar2 in scalar_pairs:
Shen Li10224432021-08-12 11:39:31 -07007027 item1 = 'torch.tensor(' + scalar1 + ').item()'
7028 item2 = 'torch.tensor(' + scalar2 + ').item()'
Basil Hosmer167722d2019-09-25 13:46:44 -07007029 for op in ops:
7030 code = ops_template.format(op=op, scalar1=item1, scalar2=scalar2)
7031 run_test(code)
7032 code = ops_template.format(op=op, scalar1=scalar1, scalar2=item2)
7033 run_test(code)
7034 code = ops_template.format(op=op, scalar1=item1, scalar2=item2)
7035 run_test(code)
7036 for func in funcs:
7037 code = funcs_template.format(func=func, scalar1=item1, scalar2=scalar2)
7038 run_test(code)
7039 code = funcs_template.format(func=func, scalar1=scalar1, scalar2=item2)
7040 run_test(code)
7041 code = funcs_template.format(func=func, scalar1=item1, scalar2=item2)
7042 run_test(code)
7043
Michael Kösel46a68c12019-04-03 10:11:33 -07007044 def test_number_abs(self):
7045 def func1(x):
7046 # type: (float) -> float
7047 return abs(x)
7048
7049 def func2(x):
7050 # type: (int) -> int
7051 return abs(x)
7052
7053 def func3(x):
7054 return abs(x)
7055
7056 self.checkScript(func1, (-3.14,))
7057 self.checkScript(func1, (3.14,))
7058 self.checkScript(func2, (-10,))
7059 self.checkScript(func2, (10,))
7060 self.checkScript(func3, (torch.tensor([-5, -10, -20]),))
7061 self.checkScript(func3, (torch.tensor([5, 10, 20]),))
7062 self.checkScript(func3, (torch.tensor([-5, 10, -20]),))
7063
Richard Zou68c2e012018-09-05 14:51:18 -07007064 def test_number_div(self):
David Riazatidefd23b2019-06-25 16:17:49 -07007065 self.assertEqual(div_int_future(), torch.jit.script(div_int_future)())
7066 self.checkScript(div_float_future, ())
Richard Zou68c2e012018-09-05 14:51:18 -07007067
David Reisse75fb432020-04-22 09:20:13 -07007068 self.checkScript(div_int_nofuture, ())
7069 self.checkScript(div_float_nofuture, ())
Richard Zou68c2e012018-09-05 14:51:18 -07007070
Nikitha Malgie17f0fd2020-12-18 12:05:52 -08007071 # Testing bitwise shorthand aug assignment
7072 def test_bool_augassign_bitwise_or(self):
7073 def func(a: bool, b: bool) -> bool:
7074 a |= b
7075 return a
7076
7077 self.checkScript(func, (True, False), optimize=True)
7078 self.checkScript(func, (True, True), optimize=True)
7079 self.checkScript(func, (False, False), optimize=True)
7080 self.checkScript(func, (False, True), optimize=True)
7081
7082 def test_bool_augassign_bitwise_and(self):
7083 def func(a: bool, b: bool) -> bool:
7084 a &= b
7085 return a
7086
7087 self.checkScript(func, (True, False), optimize=True)
7088 self.checkScript(func, (True, True), optimize=True)
7089 self.checkScript(func, (False, False), optimize=True)
7090 self.checkScript(func, (False, True), optimize=True)
7091
7092 def test_bool_augassign_bitwise_xor(self):
7093 def func(a: bool, b: bool) -> bool:
7094 a ^= b
7095 return a
7096
7097 self.checkScript(func, (True, False), optimize=True)
7098 self.checkScript(func, (True, True), optimize=True)
7099 self.checkScript(func, (False, False), optimize=True)
7100 self.checkScript(func, (False, True), optimize=True)
7101
7102 def test_number_augassign_bitwise_lshift(self):
7103 def func() -> int:
7104 z = 8
7105 z <<= 2
7106 return z
7107
7108 self.checkScript(func, (), optimize=True)
7109
7110 def test_number_augassign_bitwise_rshift(self):
7111 def func() -> int:
7112 z = 8
7113 z >>= 2
7114 return z
7115
7116 self.checkScript(func, (), optimize=True)
7117
7118 def test_number_augassign_bitwise_pow(self):
7119 def func() -> float:
7120 z = 8
7121 z **= 2
7122 return z
7123
7124 self.checkScript(func, (), optimize=True)
7125
Richard Zou6c84f7f2018-08-22 15:05:27 -07007126 def test_number_augassign(self):
7127 def func():
7128 z = 1
7129 z += 2
7130 return z
7131
7132 self.checkScript(func, (), optimize=True)
7133
Elias Ellison10bd21d2020-01-31 16:53:25 -08007134 def test_nested_select_assign(self):
7135 class SubSubModule(torch.nn.Module):
7136 def __init__(self):
Xuehai Pan046e88a2023-02-12 22:20:50 +00007137 super().__init__()
Elias Ellison10bd21d2020-01-31 16:53:25 -08007138 self.abc = 11
7139
7140 def forward(self, x):
7141 return self.abc
7142
7143 class SubModule(torch.nn.Module):
7144 def __init__(self):
Xuehai Pan046e88a2023-02-12 22:20:50 +00007145 super().__init__()
Elias Ellison10bd21d2020-01-31 16:53:25 -08007146 self.a = 11
7147 self.nested = SubSubModule()
7148
7149 def forward(self, x):
7150 return self.a
7151
7152 class TestModule(torch.nn.Module):
7153 def __init__(self):
Xuehai Pan046e88a2023-02-12 22:20:50 +00007154 super().__init__()
Elias Ellison10bd21d2020-01-31 16:53:25 -08007155 self.sub = SubModule()
7156 self.hi = 1
7157
7158 def forward(self):
7159 self.hi = 5
7160 self.sub.a = 1
7161 self.sub.nested.abc = 5
7162 return self.sub.a * 20 + self.sub.nested.abc * 3 + self.hi
7163
7164 self.checkModule(TestModule(), ())
7165
Richard Zou8489c4c2018-06-21 15:43:38 -04007166 def test_number_neg(self):
7167 # int -> int
7168 def func1():
Richard Zou67f6f932018-08-27 08:53:56 -07007169 return -8
Richard Zou8489c4c2018-06-21 15:43:38 -04007170
7171 # float -> float
7172 def func2():
Richard Zou67f6f932018-08-27 08:53:56 -07007173 return -3.14
Richard Zou8489c4c2018-06-21 15:43:38 -04007174
7175 self.checkScript(func1, (), optimize=True)
7176 self.checkScript(func2, (), optimize=True)
7177
nikithamalgi9c0caf02021-02-10 02:09:15 -08007178 def test_compare_two_bool_inputs(self):
7179 def compare_eq(a: bool, b: bool):
7180 return a == b
7181
7182 def compare_ne(a: bool, b: bool):
7183 return a != b
7184
7185 scripted_fn_eq = torch.jit.script(compare_eq)
7186 scripted_fn_ne = torch.jit.script(compare_ne)
7187 self.assertEqual(scripted_fn_eq(True, False), compare_eq(True, False))
7188 self.assertEqual(scripted_fn_eq(False, True), compare_eq(False, True))
7189 self.assertEqual(scripted_fn_eq(True, True), compare_eq(True, True))
7190 self.assertEqual(scripted_fn_eq(False, False), compare_eq(False, False))
7191
7192 self.assertEqual(scripted_fn_ne(True, False), compare_ne(True, False))
7193 self.assertEqual(scripted_fn_ne(False, True), compare_ne(False, True))
7194 self.assertEqual(scripted_fn_ne(True, True), compare_ne(True, True))
7195 self.assertEqual(scripted_fn_ne(False, False), compare_ne(False, False))
7196
Shen Li10224432021-08-12 11:39:31 -07007197
7198 def _test_tensor_number_math(self, device='cpu'):
7199 template = dedent('''
Richard Zou67f6f932018-08-27 08:53:56 -07007200 def func(t):
7201 return {lhs} {op} {rhs}
Shen Li10224432021-08-12 11:39:31 -07007202 ''')
Richard Zou8489c4c2018-06-21 15:43:38 -04007203
Brian Vaughan88e4cee2019-09-05 18:24:09 -07007204 def test(op, tensor, const, swap_args, template=template):
Shen Li10224432021-08-12 11:39:31 -07007205 args = ('t', const)
Richard Zou8489c4c2018-06-21 15:43:38 -04007206 if swap_args:
Shen Li10224432021-08-12 11:39:31 -07007207 args = (const, 't')
Richard Zou8489c4c2018-06-21 15:43:38 -04007208
7209 code = template.format(lhs=args[0], rhs=args[1], op=op)
7210 scope = {}
Mikhail Zolotukhin159c2f32019-01-09 16:57:22 -08007211 execWrapper(code, globals(), scope)
Richard Zou8489c4c2018-06-21 15:43:38 -04007212 cu = torch.jit.CompilationUnit(code)
Shen Li10224432021-08-12 11:39:31 -07007213 message = 'with code `{} {} {}` and t={}'.format(args[0], op, args[1], tensor)
Brian Vaughan88e4cee2019-09-05 18:24:09 -07007214 res1 = cu.func(tensor)
Shen Li10224432021-08-12 11:39:31 -07007215 res2 = scope['func'](tensor)
7216 self.assertEqual(res1, res2, msg=message + "\nres1=" + str(res1) + "\nres2=" + str(res2))
7217 self.assertEqual(res1.dtype, res2.dtype, msg=message + "\nres1=" + str(res1) + "\nres2=" + str(res2))
Richard Zou8489c4c2018-06-21 15:43:38 -04007218
Wanchao Liang739e6af2018-09-12 12:28:31 -07007219 var_int = [2, -2]
7220 var_float = [1.4321, -1.2]
Richard Zou8489c4c2018-06-21 15:43:38 -04007221
Shen Li10224432021-08-12 11:39:31 -07007222 ops = ['+', '-', '*', '%', '<', '<=', '>', '>=', '==', '!=', '/']
Richard Zou8489c4c2018-06-21 15:43:38 -04007223
7224 float_tensor = torch.randn(5, 5, device=device)
7225 double_tensor = torch.randn(5, 5, dtype=torch.double, device=device)
7226 long_tensor = torch.randint(-5, 5, (5, 5), dtype=torch.long, device=device)
7227 long_tensor[long_tensor == 0] = 2
7228
7229 tensors = [float_tensor, double_tensor, long_tensor]
Wanchao Liang739e6af2018-09-12 12:28:31 -07007230 consts = var_int + var_float
Richard Zou8489c4c2018-06-21 15:43:38 -04007231
Shen Li10224432021-08-12 11:39:31 -07007232 for op, tensor, const, swap_args in product(ops, tensors, consts, [True, False]):
Richard Zou8489c4c2018-06-21 15:43:38 -04007233 # FIXME: things like 2 / long_tensor are not implemented correctly
Philip Meierb0afe942021-03-09 11:28:03 -08007234 # Look in torch/_tensor.py to see how pytorch implements it.
Shen Li10224432021-08-12 11:39:31 -07007235 if op == '/' and tensor.data_ptr() == long_tensor.data_ptr():
Richard Zou8489c4c2018-06-21 15:43:38 -04007236 continue
7237
Wanchao Liang739e6af2018-09-12 12:28:31 -07007238 # % operator does not take: const % tensor
Shen Li10224432021-08-12 11:39:31 -07007239 if op == '%' and swap_args is True:
Wanchao Liang739e6af2018-09-12 12:28:31 -07007240 continue
7241
Brian Vaughan88e4cee2019-09-05 18:24:09 -07007242 test(op, tensor, const, swap_args)
Richard Zou8489c4c2018-06-21 15:43:38 -04007243
7244 def test_tensor_number_math(self):
7245 self._test_tensor_number_math()
7246
Elias Ellisonbebf1f72019-01-03 17:31:56 -08007247 def test_torch_tensor_bad_input(self):
Shen Li10224432021-08-12 11:39:31 -07007248 with self.assertRaisesRegex(RuntimeError, "must be of ints, floats, "
7249 "or bools, got None"):
Elias Ellisonbebf1f72019-01-03 17:31:56 -08007250 @torch.jit.script
7251 def test():
7252 return torch.tensor([None])
Zachary DeVito99349de2020-02-12 14:45:44 -08007253 test()
Elias Ellisonbebf1f72019-01-03 17:31:56 -08007254
Shen Li10224432021-08-12 11:39:31 -07007255 with self.assertRaisesRegex(RuntimeError, r"Empty lists default to List\[Tensor\]"):
Elias Ellisonbebf1f72019-01-03 17:31:56 -08007256 @torch.jit.script
7257 def tmp():
7258 return torch.tensor([])
Zachary DeVito99349de2020-02-12 14:45:44 -08007259 tmp()
Elias Ellisonbebf1f72019-01-03 17:31:56 -08007260
7261 @torch.jit.script
7262 def foo():
7263 return torch.tensor([[2, 2], [1]])
7264 with self.assertRaisesRegex(RuntimeError, "Expected sequence of length"):
7265 foo()
7266
7267 @suppress_warnings
Wanchao Liangc779eff2019-07-29 14:26:07 -07007268 def test_torch_tensor_as_tensor_empty_list(self):
Shen Li10224432021-08-12 11:39:31 -07007269 tensor_template = dedent('''
Elias Ellisonbebf1f72019-01-03 17:31:56 -08007270 def func():
Wanchao Liangc779eff2019-07-29 14:26:07 -07007271 empty_list = torch.jit.annotate(List[int], [])
7272 ten1 = torch.{tensor_op}({input})
7273 return ten1
Shen Li10224432021-08-12 11:39:31 -07007274 ''')
7275 ops = ['tensor', 'as_tensor']
7276 inputs = ['empty_list', '[empty_list, empty_list]', '[[[empty_list]]]']
Elias Ellisonbebf1f72019-01-03 17:31:56 -08007277
Wanchao Liangc779eff2019-07-29 14:26:07 -07007278 for op in ops:
7279 for inp in inputs:
7280 code = tensor_template.format(tensor_op=op, input=inp)
7281 scope = {}
7282 exec(code, globals(), scope)
7283 cu = torch.jit.CompilationUnit(code)
7284 t1 = cu.func()
Shen Li10224432021-08-12 11:39:31 -07007285 t2 = scope['func']()
7286 if inp == 'empty_list':
Wanchao Liangc779eff2019-07-29 14:26:07 -07007287 # torchscript returns int tensor, python returns float tensor
7288 self.assertNotEqual(t1.dtype, t2.dtype)
Sergii Dymchenko58d1cf72022-08-03 22:45:39 +00007289 self.assertEqual(t1, t2, exact_dtype=False)
Wanchao Liangc779eff2019-07-29 14:26:07 -07007290 self.assertEqual(t1.device, t2.device)
7291
Shen Li10224432021-08-12 11:39:31 -07007292 @unittest.skipIf(GRAPH_EXECUTOR != ProfilingMode.LEGACY, "Simple Executor doesn't have any shapes to propagate")
Wanchao Liangc779eff2019-07-29 14:26:07 -07007293 def test_tensor_as_tensor_shape_prop(self):
Shen Li10224432021-08-12 11:39:31 -07007294 tensor_template = dedent('''
Elias Ellisonbebf1f72019-01-03 17:31:56 -08007295 def func():
Wanchao Liangc779eff2019-07-29 14:26:07 -07007296 return torch.{tensor_op}({input})
Shen Li10224432021-08-12 11:39:31 -07007297 ''')
7298 ops = ['tensor', 'as_tensor']
7299 inputs = ['[1]', '[False]', '[2.5]', '0.5', '1', 'False', '[[1]]', 'torch.jit.annotate(List[List[int]], [])']
7300 expected_shape = ["Long(*, device=cpu)", "Bool(*, device=cpu)",
7301 "Double(*, device=cpu)", "Double(device=cpu)",
7302 "Long(device=cpu)", "Bool(device=cpu)", "Long(*, *, device=cpu)"]
Elias Ellison0922a642019-04-23 12:21:32 -07007303
Wanchao Liangc779eff2019-07-29 14:26:07 -07007304 for op in ops:
7305 for inp, expect in zip(inputs, expected_shape):
7306 code = tensor_template.format(tensor_op=op, input=inp)
7307 scope = {}
7308 exec(code, globals(), scope)
Nikolay Korovaiko5375cea2019-12-20 10:47:03 -08007309 cu = torch.jit.CompilationUnit(code)
7310 torch._C._jit_pass_complete_shape_analysis(cu.func.graph, (), False)
Shen Li10224432021-08-12 11:39:31 -07007311 FileCheck().check(expect).check("aten::{tensor_op}".format(tensor_op=op)).run(cu.func.graph)
Elias Ellison0922a642019-04-23 12:21:32 -07007312
7313 @torch.jit.script
Elias Ellison428bc902020-06-09 14:56:19 -07007314 def test_dtype(inp_dtype: torch.dtype):
Elias Ellison0922a642019-04-23 12:21:32 -07007315 a = torch.tensor(1.0, dtype=torch.float, requires_grad=True)
Sam Estepe3900d22021-04-19 13:14:27 -07007316 return a, torch.tensor(1.0, dtype=inp_dtype)
Elias Ellison0922a642019-04-23 12:21:32 -07007317
Nikolay Korovaiko5b702ab2019-11-11 13:39:03 -08007318 if GRAPH_EXECUTOR == ProfilingMode.PROFILING:
Nikolay Korovaiko47faee22019-10-29 11:40:04 -07007319 g = test_dtype.graph_for(5, profile_and_replay=True)
7320 # both should have completed shapes
Shen Li10224432021-08-12 11:39:31 -07007321 FileCheck().check("Tensor = aten::tensor").check("Float(device=cpu) = prim::BailOut") \
7322 .check("Tensor = aten::tensor").check("Half(device=cpu) = prim::BailOut").run(g)
Nikolay Korovaiko47faee22019-10-29 11:40:04 -07007323 else:
7324 g = test_dtype.graph_for(5)
7325 # first should have type set second should not
Shen Li10224432021-08-12 11:39:31 -07007326 FileCheck().check("Float(requires_grad=1, device=cpu) = aten::tensor") \
7327 .check("Tensor(requires_grad=0) = aten::tensor").run(g)
Elias Ellison0922a642019-04-23 12:21:32 -07007328
Wanchao Liangc779eff2019-07-29 14:26:07 -07007329 @torch.jit.script
7330 def test_as_tensor_tensor_input(input):
7331 a = torch.as_tensor(input, dtype=input.dtype)
7332 return a, torch.as_tensor(input, dtype=torch.float)
7333
Nikolay Korovaiko5b702ab2019-11-11 13:39:03 -08007334 if GRAPH_EXECUTOR == ProfilingMode.PROFILING:
Shen Li10224432021-08-12 11:39:31 -07007335 g = test_as_tensor_tensor_input.graph_for(torch.ones(3, 4), profile_and_replay=True)
7336 FileCheck().check("Tensor = aten::as_tensor").check("Float(3, 4) = prim::BailOut") \
7337 .check("Tensor = aten::as_tensor").check("Float(3, 4) = prim::BailOut").run(g)
Nikolay Korovaiko47faee22019-10-29 11:40:04 -07007338 else:
7339 g = test_as_tensor_tensor_input.graph_for(torch.ones(3, 4))
Shen Li10224432021-08-12 11:39:31 -07007340 FileCheck().check("Tensor = aten::as_tensor").check("Float(*, *, requires_grad=0, device=cpu) = aten::as_tensor").run(g)
7341
Elias Ellison6694fda2022-03-29 11:32:31 -07007342 @unittest.skipIf(GRAPH_EXECUTOR != ProfilingMode.LEGACY, "testing legacy behavior")
Elias Ellison0922a642019-04-23 12:21:32 -07007343 def test_tensor_requires_grad(self):
7344 @torch.jit.script
7345 def test(b):
7346 # type: (bool) -> Tuple[Tensor, Tensor, Tensor]
Shen Li10224432021-08-12 11:39:31 -07007347 a = torch.tensor(1., requires_grad=b)
7348 b = torch.tensor(1., requires_grad=True)
7349 c = torch.tensor(1., requires_grad=False)
Sam Estepe3900d22021-04-19 13:14:27 -07007350 return a, b, c
Elias Ellison0922a642019-04-23 12:21:32 -07007351
7352 g = test.graph_for(True)
7353 out = next(g.outputs())
7354 out_inp = list(out.node().inputs())
7355
7356 self.assertTrue(out_inp[0].requires_grad())
7357 self.assertTrue(out_inp[1].requires_grad())
7358 self.assertFalse(out_inp[2].requires_grad())
7359
7360 def test_grad_from_script(self):
7361 def test():
7362 a = torch.tensor(2.5, requires_grad=True)
7363 b = a * 2
7364 return a, b
7365
7366 a, b = test()
7367 b.backward()
7368
7369 a_script, b_script = torch.jit.script(test)()
7370 b_script.backward()
7371 self.assertEqual(a.grad, a_script.grad)
7372
Wanchao Liangc779eff2019-07-29 14:26:07 -07007373 def test_torch_tensor_as_tensor(self):
Shen Li10224432021-08-12 11:39:31 -07007374 tensor_template = dedent('''
Elias Ellisonbebf1f72019-01-03 17:31:56 -08007375 def func():
7376 li = {list_create}
Wanchao Liangc779eff2019-07-29 14:26:07 -07007377 ten1 = torch.{tensor_op}(li {options})
7378 return ten1
Shen Li10224432021-08-12 11:39:31 -07007379 ''')
Elias Ellisonbebf1f72019-01-03 17:31:56 -08007380
Shen Li10224432021-08-12 11:39:31 -07007381 lists = ["2.5", "4", "True", "False", "[2]", "[-.5]", "[False, True, False]", "[2, 2]", "(1, 1)",
7382 "torch.jit.annotate(List[List[int]], [])",
7383 "torch.jit.annotate(List[int], [])", "[2.5, 2.5]", "[[2], [2]]", "[[-.5], [2.2]]", "[[False], [True]]"]
Elias Ellisonbebf1f72019-01-03 17:31:56 -08007384
Shen Li10224432021-08-12 11:39:31 -07007385 dtypes = ["", ", dtype=torch.float", ", dtype=torch.double", ", dtype=torch.half",
7386 ", dtype=torch.uint8", ", dtype=torch.int8", ", dtype=torch.short",
7387 ", dtype=torch.int", ", dtype=torch.long", ", dtype=torch.cfloat",
7388 ", dtype=torch.cdouble"]
Elias Ellisonbebf1f72019-01-03 17:31:56 -08007389
Shen Li10224432021-08-12 11:39:31 -07007390 ops = ['tensor', 'as_tensor']
7391 devices = ['', ", device='cpu'"]
Elias Ellisonbebf1f72019-01-03 17:31:56 -08007392 if RUN_CUDA:
7393 devices.append(", device='cuda'")
7394
7395 option_pairs = [dtype + device for dtype in dtypes for device in devices]
Wanchao Liangc779eff2019-07-29 14:26:07 -07007396 for op in ops:
7397 for li in lists:
7398 for option in option_pairs:
7399 # tensor from empty list is type float in python and annotated type in torchscript
7400 if "annotate" in li and "dtype" not in option:
7401 continue
Nikita Shulgad80fe492022-07-27 20:22:47 +00007402 # Skip unsigned tensor initializaton for signed values on 3.10
7403 if sys.version_info[:2] >= (3, 10) and "torch.uint8" in option and "-" in li:
7404 continue
Shen Li10224432021-08-12 11:39:31 -07007405 code = tensor_template.format(list_create=li, tensor_op=op, options=option)
Wanchao Liangc779eff2019-07-29 14:26:07 -07007406 scope = {}
7407 exec(code, globals(), scope)
7408 cu = torch.jit.CompilationUnit(code)
7409 t1 = cu.func()
Shen Li10224432021-08-12 11:39:31 -07007410 t2 = scope['func']()
Wanchao Liangc779eff2019-07-29 14:26:07 -07007411 if t1.dtype == torch.float16: # equality NYI for half tensor
7412 self.assertTrue(str(t1) == str(t2))
7413 else:
7414 self.assertEqual(t1, t2)
7415 self.assertEqual(t1.dtype, t2.dtype)
7416 self.assertEqual(t1.device, t2.device)
7417
7418 def test_as_tensor_tensor_input(input):
anjali411f9ca0d82021-03-24 08:10:58 -07007419 # type: (Tensor) -> Tuple[Tensor, Tensor, Tensor]
Shen Li10224432021-08-12 11:39:31 -07007420 return torch.as_tensor(input, dtype=torch.cfloat), torch.as_tensor(input, dtype=torch.float), \
7421 torch.as_tensor(input, dtype=torch.int32)
Wanchao Liangc779eff2019-07-29 14:26:07 -07007422
anjali411f9ca0d82021-03-24 08:10:58 -07007423 inp = torch.randn(3, 4, dtype=torch.cfloat)
Wanchao Liangc779eff2019-07-29 14:26:07 -07007424 self.checkScript(test_as_tensor_tensor_input, (inp,))
Elias Ellisonbebf1f72019-01-03 17:31:56 -08007425
Elias Ellison54a575c2020-04-16 10:53:34 -07007426 def test_torch_tensor_dtype(self):
7427 def foo(s: float):
7428 return torch.tensor(s), torch.tensor([s, s])
7429
Elias Ellison4af84242020-07-06 17:07:44 -07007430 # need to clear function cache so we re run shape analysis
Elias Ellison54a575c2020-04-16 10:53:34 -07007431 with set_default_dtype(torch.double):
Shen Li10224432021-08-12 11:39:31 -07007432 self.assertEqual(torch.jit.script(foo)(1.), foo(1.), exact_dtype=True)
Elias Ellison4af84242020-07-06 17:07:44 -07007433 if GRAPH_EXECUTOR == ProfilingMode.LEGACY:
Shen Li10224432021-08-12 11:39:31 -07007434 FileCheck().check("Double").check_same("aten::tensor").run(torch.jit.last_executed_optimized_graph())
Elias Ellison4af84242020-07-06 17:07:44 -07007435 with set_default_dtype(torch.float):
Michael Suoc93e96f2020-07-08 11:35:52 -07007436 del torch.jit._state._jit_caching_layer[foo]
Shen Li10224432021-08-12 11:39:31 -07007437 self.assertEqual(torch.jit.script(foo)(1.), foo(1.), exact_dtype=True)
Elias Ellison4af84242020-07-06 17:07:44 -07007438 if GRAPH_EXECUTOR == ProfilingMode.LEGACY:
Shen Li10224432021-08-12 11:39:31 -07007439 FileCheck().check("Float").check_same("aten::tensor").run(torch.jit.last_executed_optimized_graph())
Elias Ellison54a575c2020-04-16 10:53:34 -07007440 with set_default_dtype(torch.half):
Michael Suoc93e96f2020-07-08 11:35:52 -07007441 del torch.jit._state._jit_caching_layer[foo]
Shen Li10224432021-08-12 11:39:31 -07007442 self.assertEqual(torch.jit.script(foo)(1.), foo(1.), exact_dtype=True)
Elias Ellison4af84242020-07-06 17:07:44 -07007443 if GRAPH_EXECUTOR == ProfilingMode.LEGACY:
Shen Li10224432021-08-12 11:39:31 -07007444 FileCheck().check("Half").check_same("aten::tensor").run(torch.jit.last_executed_optimized_graph())
Elias Ellison54a575c2020-04-16 10:53:34 -07007445
Elias Ellison37a572f2020-07-06 17:07:44 -07007446 def test_shape_analysis_grad_property(self):
7447 @torch.jit.script
7448 def foo(x):
7449 return torch.sub(x, torch.tanh(x))
7450
Shen Li10224432021-08-12 11:39:31 -07007451 torch._C._jit_pass_complete_shape_analysis(foo.graph, (torch.tensor([0.39]),), False)
Elias Ellison37a572f2020-07-06 17:07:44 -07007452
7453 # requires_grad property shouldn't be accidentally set by shape analysis
7454 self.assertTrue(foo.graph.findNode("aten::sub").output().requiresGrad() is None)
7455
Basil Hosmerad769d72020-03-01 19:37:25 -08007456 def test_empty_like_memory_format_bc(self):
7457 def f(x):
7458 # type: (Tensor) -> Tensor
7459 return torch.zeros_like(x, memory_format=None)
7460
7461 scripted_f = torch.jit.script(f)
7462 x = torch.rand(3, 4)
7463 self.assertEqual(scripted_f(x), f(x))
7464
Malgi Nikitha Vivekananda85a70ce2020-09-30 16:05:55 -07007465 def test_multiline_string_dedents(self):
7466 def foo() -> None:
7467 multiline_string_dedent_1 = """
7468This is a string dedent """
7469 multiline_string_dedent_2 = """ This is a
7470 string dedent """
7471 multiline_string_dedent_3 = """
7472 This is a string
7473dedent """
7474 multiline_string_dedent_4 = """ This is a string dedent """
7475
7476 scripted_foo = torch.jit.script(foo)
7477 self.assertEqual(scripted_foo(), foo())
7478
Ansley Ussery475b4e32020-10-21 13:47:29 -07007479 def test_class_with_comment_at_lower_indentation(self):
7480 class Foo(torch.nn.Module):
7481 def forward(self, x):
7482 x = torch.neg(x)
Shen Li10224432021-08-12 11:39:31 -07007483 # This comment is at the wrong indent
Ansley Ussery475b4e32020-10-21 13:47:29 -07007484 return x
7485
7486 torch.jit.script(Foo())
7487
Elias Ellison4fb39312019-01-14 15:44:50 -08007488 # adapted from test in test_torch
7489 def test_tensor_to(self):
Shen Li10224432021-08-12 11:39:31 -07007490 template = dedent('''
Elias Ellison4fb39312019-01-14 15:44:50 -08007491 def func(t):
7492 cuda = "{cuda}"
7493 device = "{device}"
7494 non_blocking = {non_blocking}
7495 return {to_str}
Shen Li10224432021-08-12 11:39:31 -07007496 ''')
Elias Ellison4fb39312019-01-14 15:44:50 -08007497
7498 def s(t, to_str, non_blocking=None, device=None, cuda=None):
7499 device = device if device is not None else str(t.device)
7500 non_blocking = non_blocking if non_blocking is not None else False
7501 cuda = "cuda" if cuda is None else cuda
Shen Li10224432021-08-12 11:39:31 -07007502 code = template.format(to_str=to_str, device=device, non_blocking=non_blocking, cuda=cuda)
Elias Ellison4fb39312019-01-14 15:44:50 -08007503 scope = {}
7504 cu = torch.jit.CompilationUnit(code)
Nikolay Korovaiko47faee22019-10-29 11:40:04 -07007505 return cu.func(t, profile_and_replay=True)
Elias Ellison4fb39312019-01-14 15:44:50 -08007506
7507 def test_copy_behavior(t, non_blocking=False):
Shen Li10224432021-08-12 11:39:31 -07007508 self.assertIs(t, s(t, 't.to(t, non_blocking=non_blocking)', non_blocking))
7509 self.assertIs(t, s(t, 't.to(t.dtype, non_blocking=non_blocking)', non_blocking))
7510 self.assertIs(t, s(t, 't.to(torch.empty_like(t), non_blocking=non_blocking)', non_blocking))
7511 self.assertIsNot(t, s(t, 't.to(t, non_blocking=non_blocking, copy=True)', non_blocking))
7512 self.assertIsNot(t, s(t, 't.to(t.dtype, non_blocking=non_blocking, copy=True)', non_blocking))
7513 self.assertIsNot(t, s(t, 't.to(torch.empty_like(t), non_blocking=non_blocking, copy=True)', non_blocking))
Elias Ellison4fb39312019-01-14 15:44:50 -08007514
7515 devices = [t.device]
Shen Li10224432021-08-12 11:39:31 -07007516 if t.device.type == 'cuda':
Elias Ellison4fb39312019-01-14 15:44:50 -08007517 if t.device.index == -1:
Shen Li10224432021-08-12 11:39:31 -07007518 devices.append('cuda:{}'.format(torch.cuda.current_device()))
Elias Ellison4fb39312019-01-14 15:44:50 -08007519 elif t.device.index == torch.cuda.current_device():
Shen Li10224432021-08-12 11:39:31 -07007520 devices.append('cuda')
Elias Ellison4fb39312019-01-14 15:44:50 -08007521 for device in devices:
Shen Li10224432021-08-12 11:39:31 -07007522 self.assertIs(t, s(t, 't.to(device, non_blocking=non_blocking)', non_blocking, device))
7523 self.assertIs(t, s(t, 't.to(device, t.dtype, non_blocking=non_blocking)', non_blocking, device))
7524 self.assertIsNot(t, s(t, 't.to(device, non_blocking=non_blocking, copy=True)', non_blocking, device))
7525 self.assertIsNot(t, s(t, 't.to(device, t.dtype, non_blocking=non_blocking, copy=True)',
7526 non_blocking, device))
Elias Ellison4fb39312019-01-14 15:44:50 -08007527
7528 t = torch.tensor(5)
7529 test_copy_behavior(t)
7530
7531 self.assertEqual(t.device, s(t, "t.to('cpu')").device)
7532 self.assertEqual(t.device, s(t, "t.to('cpu', dtype=torch.float32)").device)
7533 self.assertIs(torch.float32, s(t, "t.to('cpu', dtype=torch.float32)").dtype)
7534 self.assertEqual(t.device, s(t, "t.to(torch.float32)").device)
7535 self.assertIs(torch.float32, s(t, "t.to(dtype=torch.float32)").dtype)
7536 self.assertEqual(t.data_ptr(), s(t, "t.to('cpu')").data_ptr())
Shen Li10224432021-08-12 11:39:31 -07007537 self.assertEqual(t.data_ptr(), s(t, "t.to(dtype=t.dtype, device=t.device, copy=False)").data_ptr())
Elias Ellison4fb39312019-01-14 15:44:50 -08007538 self.assertEqual(t.data_ptr(), s(t, "t.to('cpu', copy=False)").data_ptr())
7539 self.assertNotEqual(t.data_ptr(), s(t, "t.to('cpu', copy=True)").data_ptr())
7540
7541 a = torch.tensor(5)
7542 if torch.cuda.is_available():
7543 for non_blocking in [True, False]:
Shen Li10224432021-08-12 11:39:31 -07007544 for cuda in ['cuda', 'cuda:0' if torch.cuda.device_count() == 1 else 'cuda:1']:
7545 b = torch.tensor(5., device=cuda)
Elias Ellison4fb39312019-01-14 15:44:50 -08007546 test_copy_behavior(b, non_blocking)
Shen Li10224432021-08-12 11:39:31 -07007547 self.assertEqual(b.device, s(b, "t.to(cuda, non_blocking=non_blocking).device", cuda=cuda))
7548 self.assertEqual(a.device, s(b, "t.to('cpu', non_blocking=non_blocking).device"))
7549 self.assertEqual(b.device, s(b, "t.to(cuda, non_blocking=non_blocking).device", cuda=cuda))
7550 self.assertIs(torch.int32, s(b, "t.to('cpu', dtype=torch.int32, non_blocking=non_blocking)").dtype)
7551 self.assertEqual(a.device, s(b, "t.to('cpu', dtype=torch.int32, non_blocking=non_blocking)").device)
Elias Ellison4fb39312019-01-14 15:44:50 -08007552 self.assertIs(torch.int32, s(b, "t.to(dtype=torch.int32)").dtype)
7553 self.assertEqual(b.device, s(b, "t.to(dtype=torch.int32)").device)
7554
Ailing Zhangb0545aa2019-02-14 14:55:44 -08007555 # Test AD: aten::to(Tensor self, int dtype, bool non_blocking, bool copy) -> Tensor
7556 t = torch.tensor(5).float().requires_grad_()
7557 out_ref = t.to(torch.float32)
7558 out = s(t, "t.to(torch.float32)")
7559 self.assertEqual(out_ref, out)
7560
7561 grad_ref = torch.autograd.grad(out_ref.sum(), t)
7562 grad = torch.autograd.grad(out.sum(), t)
7563 self.assertEqual(grad_ref, grad)
7564
7565 # Test AD: aten::to(Tensor self, Device? device, int? dtype, bool non_blocking, bool copy) -> Tensor
Shen Li10224432021-08-12 11:39:31 -07007566 out_ref = t.to('cpu')
Ailing Zhangb0545aa2019-02-14 14:55:44 -08007567 out = s(t, "t.to('cpu')")
7568 self.assertEqual(out_ref, out)
7569
7570 grad_ref = torch.autograd.grad(out_ref.sum(), t)
7571 grad = torch.autograd.grad(out.sum(), t)
7572 self.assertEqual(grad_ref, grad)
7573
7574 # Test AD: aten::to(Tensor self, Tensor other, bool non_blocking, bool copy) -> Tensor
7575 @torch.jit.script
7576 def func2(t, t_ref):
7577 return t.to(t_ref)
7578
Zachary DeVito1827ca42019-04-13 08:28:11 -07007579 with disable_autodiff_subgraph_inlining():
7580 t_ref = torch.tensor(4).double()
7581 out_ref = t.to(t_ref)
7582 out = func2(t, t_ref)
7583 grad_ref = torch.autograd.grad(out_ref.sum(), t)
7584 grad = torch.autograd.grad(out.sum(), t)
7585 self.assertEqual(grad_ref, grad)
Ailing Zhangb0545aa2019-02-14 14:55:44 -08007586
Richard Zou8489c4c2018-06-21 15:43:38 -04007587 @unittest.skipIf(not RUN_CUDA, "No CUDA")
7588 def test_tensor_number_math_cuda(self):
Shen Li10224432021-08-12 11:39:31 -07007589 self._test_tensor_number_math(device='cuda')
Richard Zou8489c4c2018-06-21 15:43:38 -04007590
Wanchao Liang0fd176f2018-11-01 16:52:19 -07007591 def test_not(self):
7592 # test not operator in python
7593 # TODO: add more tests when bool conversions ready
7594 def test_not_op(a):
7595 return not bool(a > 1)
7596
Shen Li10224432021-08-12 11:39:31 -07007597 self.checkScript(test_not_op, (torch.tensor(2), ), optimize=True)
Wanchao Liang0fd176f2018-11-01 16:52:19 -07007598
7599 def test_is_isnot(self):
7600 # test is and is not operator in python
Shen Li10224432021-08-12 11:39:31 -07007601 template = dedent('''
Wanchao Liang0fd176f2018-11-01 16:52:19 -07007602 def func():
7603 # type: () -> bool
7604 return {lhs} {op} {rhs}
Shen Li10224432021-08-12 11:39:31 -07007605 ''')
Wanchao Liang0fd176f2018-11-01 16:52:19 -07007606
7607 def test(op, args):
7608 code = template.format(lhs=args[0], rhs=args[1], op=op)
7609 scope = {}
Mikhail Zolotukhin159c2f32019-01-09 16:57:22 -08007610 execWrapper(code, globals(), scope)
Wanchao Liang0fd176f2018-11-01 16:52:19 -07007611 cu = torch.jit.CompilationUnit(code)
7612 self.assertEqual(
7613 cu.func(),
Shen Li10224432021-08-12 11:39:31 -07007614 scope['func'](),
7615 msg="Failed with op: {}, lhs: {}, rhs: {}"
7616 .format(op, args[0], args[1])
Wanchao Liang0fd176f2018-11-01 16:52:19 -07007617 )
7618
Shen Li10224432021-08-12 11:39:31 -07007619 ops = ['is', 'is not']
7620 type_literals = [True, False, None, [1, 1], 1, 2, .5, 1.5]
Wanchao Liang0fd176f2018-11-01 16:52:19 -07007621
7622 # do literals product to try any types combinations
7623 for op, lhs, rhs in product(ops, type_literals, type_literals):
7624 test(op, [lhs, rhs])
7625
Zachary DeVitocf43aa32019-10-15 15:58:05 -07007626 def test_isinstance_refinement(self):
7627 @torch.jit.script
7628 def foo(a):
7629 # type: (Optional[int]) -> int
7630 if isinstance(a, int):
7631 return a + 3
7632 else:
7633 return 4
7634 self.assertEqual(foo(4), 7)
7635 self.assertEqual(foo(None), 4)
Edward Yangda2004e2020-06-04 12:53:53 -07007636
Zachary DeVitocf43aa32019-10-15 15:58:05 -07007637 @torch.jit.script
7638 def foo2(a, b):
7639 # type: (Optional[int], Optional[int]) -> int
7640 if not isinstance(a, int) or not isinstance(b, int):
7641 return 0
7642 else:
7643 return a + b
7644 self.assertEqual(foo2(3, 4), 7)
7645 self.assertEqual(foo2(None, 4), 0)
7646 self.assertEqual(foo2(4, None), 0)
7647
Zachary DeVitofb451712019-10-16 11:05:32 -07007648 @torch.jit.script
7649 def any_refinement(a, b):
7650 # type: (Any, Any) -> int
7651 if isinstance(a, int) and isinstance(b, int):
7652 return a + b
7653 return 0
7654
7655 self.assertEqual(any_refinement(3, 4), 7)
7656 self.assertEqual(any_refinement(3, "hi"), 0)
7657
Elias Ellison8cb19502020-04-09 18:23:03 -07007658 @torch.jit.script
7659 def any_refinement2(a):
7660 # type: (Any) -> Tensor
7661 if isinstance(a, Tensor):
7662 return a
7663 return torch.tensor(3)
7664
7665 self.assertEqual(any_refinement2(3), torch.tensor(3))
7666 self.assertEqual(any_refinement2(torch.tensor(5)), torch.tensor(5))
7667
Shen Li10224432021-08-12 11:39:31 -07007668 @unittest.skipIf(GRAPH_EXECUTOR == ProfilingMode.LEGACY, "bug persists in deprecated executor")
Elias Ellison43809342020-11-13 18:31:08 -08007669 def test_unspecialized_any_binding(self):
7670 # any binding will infer the type, if it infers
7671 # a specialized tensor type `x` Dict type will fail isinstance check
7672
7673 @torch.jit.script
7674 def foo(x: Any):
7675 assert isinstance(x, Dict[str, torch.Tensor])
7676
7677 foo({"1": torch.tensor(3)})
7678 with self.assertRaises(Exception):
7679 foo(2)
7680
Wanchao Liangc5dd91c2018-12-17 15:18:51 -08007681 def test_isinstance(self):
7682 # test isinstance operator for static type checking
Shen Li10224432021-08-12 11:39:31 -07007683 template = dedent('''
Wanchao Liangc5dd91c2018-12-17 15:18:51 -08007684 def func(x):
7685 # type: ({type_hint}) -> bool
7686 return isinstance(x, {typ})
Shen Li10224432021-08-12 11:39:31 -07007687 ''')
Wanchao Liangc5dd91c2018-12-17 15:18:51 -08007688
7689 def test(inp, typ, type_hint):
7690 code = template.format(typ=typ, type_hint=type_hint)
7691 scope = {}
Mikhail Zolotukhin159c2f32019-01-09 16:57:22 -08007692 execWrapper(code, globals(), scope)
Wanchao Liangc5dd91c2018-12-17 15:18:51 -08007693 cu = torch.jit.CompilationUnit(code)
7694 self.assertEqual(
Shen Li10224432021-08-12 11:39:31 -07007695 cu.func(inp),
7696 scope['func'](inp),
7697 msg="Failed with typ: {}"
7698 .format(typ)
Wanchao Liangc5dd91c2018-12-17 15:18:51 -08007699 )
7700
7701 inputs = [True, 1, 1.0, torch.tensor(1), [1, 2], (1.0,), [1, 2], 1]
Shen Li10224432021-08-12 11:39:31 -07007702 type_literals = ['bool', 'int', 'float', 'torch.Tensor', 'list', 'tuple',
7703 '(list, tuple)', '(int, float, bool)']
7704 type_annotations = ['bool', 'int', 'float', 'Tensor', 'List[int]', 'Tuple[float]',
7705 'List[int]', 'int']
Wanchao Liangc5dd91c2018-12-17 15:18:51 -08007706
7707 # do zipping to try different types
7708 for inp, typ, type_hint in zip(inputs, type_literals, type_annotations):
7709 test(inp, typ, type_hint)
7710
Horace He33d35f52019-05-29 17:36:04 -07007711 # test optional isinstance check
Zachary DeVitobecf0802019-10-01 16:37:34 -07007712 @torch.jit.script
7713 def opt_func(x):
7714 # type: (Optional[int]) -> bool
7715 return isinstance(x, int)
7716 self.assertTrue(opt_func(3))
7717 self.assertFalse(opt_func(None))
Wanchao Liangc5dd91c2018-12-17 15:18:51 -08007718
Michael Suo62af37a2019-05-23 18:06:19 -07007719 def test_dropout_eval(self):
7720 class ScriptedConv2d(torch.jit.ScriptModule):
7721 def __init__(self, in_channels, out_channels, **kwargs):
Xuehai Pan046e88a2023-02-12 22:20:50 +00007722 super().__init__()
Michael Suo62af37a2019-05-23 18:06:19 -07007723 self.conv = nn.Conv2d(in_channels, out_channels, bias=False, **kwargs)
7724 self.bn = nn.BatchNorm2d(out_channels, eps=0.001)
7725
7726 @torch.jit.script_method
7727 def forward(self, x):
7728 x = self.conv(x)
Michael Suo62af37a2019-05-23 18:06:19 -07007729 x = self.bn(x)
7730 return F.relu(x, inplace=True)
7731
7732 class ScriptMod(torch.jit.ScriptModule):
7733 def __init__(self):
Xuehai Pan046e88a2023-02-12 22:20:50 +00007734 super().__init__()
Michael Suo62af37a2019-05-23 18:06:19 -07007735 self.Conv2d_1a_3x3 = ScriptedConv2d(3, 32, kernel_size=3, stride=2)
7736
7737 @torch.jit.script_method
7738 def forward(self, x):
7739 x = self.Conv2d_1a_3x3(x)
Michael Suo62af37a2019-05-23 18:06:19 -07007740 return F.dropout(x, training=self.training)
7741
7742 class EagerConv2d(torch.nn.Module):
7743 def __init__(self, in_channels, out_channels, **kwargs):
Xuehai Pan046e88a2023-02-12 22:20:50 +00007744 super().__init__()
Michael Suo62af37a2019-05-23 18:06:19 -07007745 self.conv = nn.Conv2d(in_channels, out_channels, bias=False, **kwargs)
7746 self.bn = nn.BatchNorm2d(out_channels, eps=0.001)
7747
7748 def forward(self, x):
7749 x = self.conv(x)
Michael Suo62af37a2019-05-23 18:06:19 -07007750 x = self.bn(x)
7751 return F.relu(x, inplace=True)
7752
7753 class EagerMod(torch.nn.Module):
7754 def __init__(self):
Xuehai Pan046e88a2023-02-12 22:20:50 +00007755 super().__init__()
Michael Suo62af37a2019-05-23 18:06:19 -07007756 self.Conv2d_1a_3x3 = EagerConv2d(3, 32, kernel_size=3, stride=2)
7757
7758 def forward(self, x):
7759 x = self.Conv2d_1a_3x3(x)
Michael Suo62af37a2019-05-23 18:06:19 -07007760 return F.dropout(x, training=self.training)
7761
7762 script_input = torch.rand(4, 3, 299, 299)
7763 eager_input = script_input.clone()
7764
7765 with freeze_rng_state():
7766 script_mod = ScriptMod()
7767 script_mod.eval()
7768 script_output = script_mod(script_input)
7769
7770 with freeze_rng_state():
7771 eager_mod = EagerMod()
7772 eager_mod.eval()
7773 eager_output = eager_mod(eager_input)
7774
7775 self.assertEqual(script_output, eager_output)
7776
7777 with freeze_rng_state():
7778 script_mod = ScriptMod()
7779 script_mod.train()
7780 script_output = script_mod(script_input)
7781
7782 with freeze_rng_state():
7783 eager_mod = EagerMod()
7784 eager_mod.train()
7785 eager_output = eager_mod(eager_input)
7786
7787 self.assertEqual(script_output, eager_output)
7788
Elias Ellisoncf2889a2019-07-12 14:58:49 -07007789 def test_nested_breaks(self):
7790 def no_bool_loop_outputs(g):
7791 # testing that the "did exit" transform values are not loop block
7792 # outputs (and thus not affecting one loop from another)
7793 loops = g.findAllNodes("prim::Loop")
7794 for loop in loops:
7795 for out in loop.outputs():
7796 self.assertTrue(out.type() != BoolType.get())
7797
7798 def test(y):
7799 # type: (int)
7800 ret = 0
7801 tensor = torch.tensor(0)
7802 while int(tensor.add_(1)) < 4:
7803 if y == 1:
7804 continue
7805 for i in range(y):
7806 continue
7807 ret += 1
7808 ret += 1
7809 return ret, int(tensor)
7810
Elias Ellisond38f9112019-12-04 12:43:38 -08007811 self.assertEqual(torch.jit.script(test)(1), test(1))
7812 self.assertEqual(torch.jit.script(test)(2), test(2))
Elias Ellisoncf2889a2019-07-12 14:58:49 -07007813 no_bool_loop_outputs(torch.jit.script(test).graph)
7814
7815 def foo():
7816 y = torch.tensor(0)
7817 z = 0
7818 while int(y.add_(1)) < 20:
7819 if int(y) < 10:
7820 for i in range(6):
7821 if i == 3:
7822 continue
7823 else:
7824 if i > 3:
7825 break
7826 z += 2
7827 if int(y) == 18:
7828 break
7829 if int(y) == 15:
7830 continue
7831 z += 1
7832 return int(y), z
7833
7834 no_bool_loop_outputs(torch.jit.script(foo).graph)
7835 self.checkScript(foo, ())
7836
7837 def test_nested_two():
7838 i = 0
7839 k = 0
7840 while i < 5:
7841 for j in range(5):
7842 k += 1
7843 if j == 3:
7844 continue
7845 i += 1
7846 k += 1
7847 if i == 4:
7848 break
7849 return i, k
7850
7851 self.checkScript(test_nested_two, ())
7852 no_bool_loop_outputs(torch.jit.script(test_nested_two).graph)
7853
7854 def test_breaks_continues(self):
7855 def foo_continue(cond):
7856 # type: (int)
7857 j = 1
7858 for i in range(5):
7859 if i == cond:
7860 continue
7861 j += 1
7862 return j
7863
7864 def foo_break(cond):
7865 # type: (int)
7866 j = 1
7867 for i in range(5):
7868 if i == cond:
7869 break
7870 j += 1
7871 return j
7872
7873 for i in range(1, 4):
7874 self.checkScript(foo_continue, (i,))
7875 self.checkScript(foo_break, (i,))
7876
7877 def test_refine_outside_loop():
Elias Ellisond1b8da72020-11-20 11:14:59 -08007878 if 1 == 1:
Elias Ellisoncf2889a2019-07-12 14:58:49 -07007879 x = None
7880 else:
7881 x = 1
7882 i = 0
7883 j = 0
Shen Li10224432021-08-12 11:39:31 -07007884 while (x is None or torch.jit._unwrap_optional(x) > 3):
Elias Ellisoncf2889a2019-07-12 14:58:49 -07007885 if i < 3:
7886 if i < 3:
7887 x = torch.jit.annotate(Optional[int], None)
7888 i += 1
7889 continue
7890 x = 1
7891 else:
7892 x = 1 if x is None else x
7893 x = x + 1
7894 j = x + x
7895
7896 return x, j
7897
7898 self.checkScript(test_refine_outside_loop, ())
7899
7900 def assign_after_break(y):
7901 # type: (int)
7902 x = 0
7903 for i in range(y):
7904 x = y * 2 + i
7905 break
7906 x = 4
7907 return x
7908
7909 self.checkScript(assign_after_break, (1,))
7910 self.checkScript(assign_after_break, (2,))
7911 self.checkScript(assign_after_break, (3,))
7912
7913 def assign_after_break_nested(y):
7914 # type: (int)
7915 x = 0
7916 for i in range(y):
7917 if y == 1:
7918 x = 5
7919 break
7920 assert 1 == 2
7921 else:
7922 x = x + 1
7923 break
7924 assert 1 == 2
7925 x = -30
7926 assert 1 == 2
7927 return x
7928
7929 self.checkScript(assign_after_break_nested, (1,))
7930 self.checkScript(assign_after_break_nested, (2,))
7931 self.checkScript(assign_after_break_nested, (3,))
7932
7933 def may_break(y):
7934 # type: (int)
7935 x = 0
7936 for i in range(y):
7937 if y == 1:
7938 x = 5
7939 else:
7940 x = x + 1
7941 break
7942 x = -30
7943 return x
7944
7945 self.checkScript(may_break, (1,))
7946 self.checkScript(may_break, (2,))
7947 self.checkScript(may_break, (3,))
7948
7949 def test(x, y):
7950 # type: (int, int)
7951 a = 1
Shen Li10224432021-08-12 11:39:31 -07007952 while (x > 0):
Elias Ellisoncf2889a2019-07-12 14:58:49 -07007953 if y == 3:
7954 for i in range(y):
Shen Li10224432021-08-12 11:39:31 -07007955 a += (1 % (i + 1))
Elias Ellisoncf2889a2019-07-12 14:58:49 -07007956 x -= 1
7957 if x == 3:
7958 a = x * 3
7959 break
7960 if x < 3:
7961 if x == 1:
7962 a -= 2
7963 x -= 1
7964 break
7965 a -= 1
7966 x -= 3
7967 return a, x
7968
7969 self.checkScript(test, (10, 3))
7970 self.checkScript(test, (10, 2))
7971 self.checkScript(test, (3, 2))
7972 self.checkScript(test, (5, 3))
7973 self.checkScript(test, (2, 3))
7974
7975 def test_delete_after_break(x):
7976 # type: (int)
7977 a = 1
7978 b = 1
7979 for i in range(x):
7980 a = i * 3
7981 break
7982 b = i * 5
7983 return a, b
7984
7985 self.checkScript(test_delete_after_break, (0,))
7986 self.checkScript(test_delete_after_break, (1,))
7987
7988 def test_will_break_after_guard(x):
7989 # type: (int)
7990 a = 1
7991 for i in range(x):
7992 if i == 4:
7993 a = 3
7994 break
7995 a -= 1
7996 break
7997 assert 1 == 2
7998 a -= -100
7999 return a
8000
8001 self.checkScript(test_will_break_after_guard, (0,))
8002 self.checkScript(test_will_break_after_guard, (2,))
8003 self.checkScript(test_will_break_after_guard, (4,))
8004
8005 def test_varexit(cond):
8006 # type: (int)
8007 m = 0
8008 for i in range(3):
8009 if cond == 2:
8010 if cond == 2:
8011 m = 2
8012 break
8013 k = 1
8014 else:
8015 k = 2
8016 m += k
8017 return m
8018
8019 # use of k tests the pathway where we have to insert unitialized
8020 self.checkScript(test_varexit, (3,))
8021 self.checkScript(test_varexit, (2,))
8022
8023 def test_break_true():
8024 i = 0
8025 while True:
8026 i += 1
8027 if i == 3:
8028 break
8029 while False:
8030 i += 1
8031 return i
8032
8033 self.checkScript(test_break_true, ())
8034
8035 def test_break_continue_error(self):
8036 with self.assertRaisesRegex(RuntimeError, "Syntax"):
Shen Li10224432021-08-12 11:39:31 -07008037 cu = torch.jit.CompilationUnit('''
Elias Ellisoncf2889a2019-07-12 14:58:49 -07008038 def other_func(a):
8039 break
Shen Li10224432021-08-12 11:39:31 -07008040 ''')
Elias Ellisoncf2889a2019-07-12 14:58:49 -07008041
8042 with self.assertRaisesRegex(RuntimeError, "Syntax"):
Shen Li10224432021-08-12 11:39:31 -07008043 cu = torch.jit.CompilationUnit('''
Elias Ellisoncf2889a2019-07-12 14:58:49 -07008044 def other_func(a):
8045 for i in range(5):
8046 def foo():
8047 break
Shen Li10224432021-08-12 11:39:31 -07008048 ''')
Elias Ellisoncf2889a2019-07-12 14:58:49 -07008049
Shen Li10224432021-08-12 11:39:31 -07008050 with self.assertRaisesRegex(RuntimeError, "do not support break or continue inside"):
Elias Ellison91e1f072019-11-08 15:49:21 -08008051 @torch.jit.script
8052 def foo(x):
8053 i = 0
8054 for a in (1, "2", 1.5):
8055 b = a
8056 if x:
8057 break
8058 return b
8059
Adam Paszkea58f2d22018-03-22 16:58:36 +01008060 def test_python_call(self):
8061 def pyfunc(a):
8062 return a * 3.0
8063
Shen Li10224432021-08-12 11:39:31 -07008064 cu = torch.jit.CompilationUnit('''
Adam Paszkea58f2d22018-03-22 16:58:36 +01008065 def other_func(a):
8066 return a + a
8067
8068 def test_call_python(a):
8069 b = pyfunc(a)
8070 b = other_func(b)
8071 i = 0
8072 step = 1
8073 while i < 10:
8074 b = pyfunc(b)
David Riazati6f53b4e2018-09-13 11:10:00 -07008075 if bool(b > 3.0):
Adam Paszkea58f2d22018-03-22 16:58:36 +01008076 b = pyfunc(b)
8077 i = 11
8078 return b
Shen Li10224432021-08-12 11:39:31 -07008079 ''')
Adam Paszkea58f2d22018-03-22 16:58:36 +01008080 inputs = self._make_scalar_vars([1], torch.float)
8081 outputs = self._make_scalar_vars([54], torch.float)
8082
8083 self.assertEqual(cu.test_call_python(*inputs), outputs[0])
8084
8085 def test_python_call_failure(self):
8086 with self.assertRaisesRegex(RuntimeError, "undefined value pyfunc2"):
8087 def pyfunc(a):
8088 return a * 3.0
8089
Shen Li10224432021-08-12 11:39:31 -07008090 cu = torch.jit.CompilationUnit('''
Adam Paszkea58f2d22018-03-22 16:58:36 +01008091 def other_func(a):
8092 return a + a
8093
8094 def test_call_python(a):
8095 b = pyfunc(a)
8096 b = other_func(b)
8097 i = 0
8098 step = 1
8099 while i < 10:
8100 b = pyfunc2(b)
8101 if b > 3.0:
8102 b = pyfunc(b)
8103 i = 11
8104 return b
Shen Li10224432021-08-12 11:39:31 -07008105 ''')
Adam Paszkea58f2d22018-03-22 16:58:36 +01008106 inputs = self._make_scalar_vars([1], torch.float)
8107 outputs = self._make_scalar_vars([54], torch.float)
8108
8109 self.assertEqual(cu.test_call_python(*inputs), outputs)
8110
Elias Ellisona5aeb372019-11-05 15:11:02 -08008111 def test_type_call_in_script(self):
8112 @torch.jit.script
8113 def fn(x):
8114 return type(x)
8115
Edward Yang4d725382021-04-28 09:23:07 -07008116 with self.assertRaisesRegex(RuntimeError, "value of type _TensorMeta"):
Shen Li10224432021-08-12 11:39:31 -07008117 fn(torch.tensor(.5))
Elias Ellisona5aeb372019-11-05 15:11:02 -08008118
Adam Paszkea58f2d22018-03-22 16:58:36 +01008119 def test_python_call_annotation(self):
8120 def pyfunc(a):
8121 return a * 3.0
8122
8123 @torch.jit.script
8124 def foo(a):
8125 return pyfunc(a) + pyfunc(a)
8126
8127 inputs = self._make_scalar_vars([1], torch.float)
8128 outputs = self._make_scalar_vars([6], torch.float)
8129 self.assertEqual(foo(*inputs), outputs[0])
8130
8131 def test_python_call_annoytation_failure(self):
8132 with self.assertRaisesRegex(RuntimeError, "undefined value pyfunc2"):
8133 def pyfunc(a):
8134 return a * 3.0
8135
8136 @torch.jit.script
8137 def foo(a):
8138 return pyfunc2(a) + pyfunc(a)
8139
8140 inputs = self._make_scalar_vars([1], torch.float)
8141 outputs = self._make_scalar_vars([6], torch.float)
8142
8143 self.assertEqual(foo(*inputs), outputs[0])
8144
8145 def test_desugar_module(self):
8146 import torch.nn.functional as F
8147
8148 def fn(x, slope):
8149 a = torch.abs(x)
8150 b = torch.nn.functional.prelu(x, slope)
8151 c = F.prelu(x, slope)
8152 return a, b, c
8153
Shen Li10224432021-08-12 11:39:31 -07008154 x = torch.arange(-3., 4)
Adam Paszkea58f2d22018-03-22 16:58:36 +01008155 slope = torch.tensor([0.5])
8156 self.checkScript(fn, [x, slope], optimize=True)
James Reed60415cf2018-03-02 15:03:44 -08008157
Gao, Xiangfe805792018-06-05 22:36:08 +08008158 def test_script_docstring(self):
8159 @torch.jit.script
8160 def with_docstring(x):
8161 """test str"""
8162 y = x
8163 """y is the same as x"""
8164 return y
Shen Li10224432021-08-12 11:39:31 -07008165 self.assertEqual(with_docstring.__doc__, 'test str')
Gao, Xiangfe805792018-06-05 22:36:08 +08008166
8167 def test_script_method_docstring(self):
8168 class A(torch.jit.ScriptModule):
8169 @torch.jit.script_method
8170 def with_docstring(self, x):
8171 """test str"""
8172 y = x
8173 """y is the same as x"""
8174 return y
8175 a = A()
Shen Li10224432021-08-12 11:39:31 -07008176 self.assertEqual(a.with_docstring.__doc__, 'test str')
Gao, Xiangfe805792018-06-05 22:36:08 +08008177
Zachary DeVito41285ed2018-03-12 06:52:40 -07008178 def test_script_module(self):
8179 class M1(torch.jit.ScriptModule):
8180 def __init__(self):
Xuehai Pan046e88a2023-02-12 22:20:50 +00008181 super().__init__()
Zachary DeVito41285ed2018-03-12 06:52:40 -07008182 self.weight = nn.Parameter(torch.randn(2))
8183
8184 @torch.jit.script_method
8185 def forward(self, thing):
8186 return self.weight + thing
8187
Zachary DeVitoc8d1ec02018-03-22 00:17:49 -04008188 class PModule(nn.Module):
8189 def __init__(self):
Xuehai Pan046e88a2023-02-12 22:20:50 +00008190 super().__init__()
Zachary DeVitoc8d1ec02018-03-22 00:17:49 -04008191 self.a = nn.Parameter(torch.randn(2, 3))
8192
8193 def forward(self, a):
8194 return self.a.mm(a)
8195
Zachary DeVito41285ed2018-03-12 06:52:40 -07008196 class M2(torch.jit.ScriptModule):
8197 def __init__(self):
Xuehai Pan046e88a2023-02-12 22:20:50 +00008198 super().__init__()
Zachary DeVito41285ed2018-03-12 06:52:40 -07008199 # test submodule
8200 self.sub = M1()
Zachary DeVitoc8d1ec02018-03-22 00:17:49 -04008201 self.sub2 = PModule()
Zachary DeVito41285ed2018-03-12 06:52:40 -07008202 # test parameters
8203 self.weight = nn.Parameter(torch.randn(2, 3))
8204 self.bias = nn.Parameter(torch.randn(2))
8205 # test defining a method from a string
Shen Li10224432021-08-12 11:39:31 -07008206 self.define("""
Zachary DeVito41285ed2018-03-12 06:52:40 -07008207 def hi(self, a):
8208 return self.weight.mm(a)
Shen Li10224432021-08-12 11:39:31 -07008209 """)
Zachary DeVito41285ed2018-03-12 06:52:40 -07008210 # test script methods
8211
8212 @torch.jit.script_method
8213 def doit(self, input):
8214 # test use of parameter
8215 return self.weight.mm(input)
8216
8217 @torch.jit.script_method
8218 def doit2(self, input):
8219 return self.weight.mm(input)
8220
8221 @torch.jit.script_method
8222 def forward(self, input):
8223 a = self.doit(input)
8224 b = self.doit2(input)
8225 c = self.hi(input)
Zachary DeVitoc8d1ec02018-03-22 00:17:49 -04008226 d = self.sub2(input)
8227 return a + b + self.bias + self.sub(a) + c + d
Michael Suo711be822019-07-24 23:05:48 -07008228 with torch.jit.optimized_execution(False):
8229 m2 = M2()
8230 input = torch.randn(3, 2)
8231 a = m2.weight.mm(input)
8232 b = m2.weight.mm(input)
8233 c = m2.weight.mm(input)
8234 d = m2.sub2.a.mm(input)
8235 ref = a + b + m2.bias + m2.sub.weight + a + c + d
8236 self.assertEqual(ref, m2.forward(input))
8237 m2.weight = nn.Parameter(torch.zeros_like(m2.weight))
8238 m2.bias = nn.Parameter(torch.zeros_like(m2.bias))
8239 m2.sub.weight = nn.Parameter(torch.zeros_like(m2.sub.weight))
8240 m2.sub2.a.data.zero_()
8241 self.assertEqual(torch.zeros(2, 2), m2.forward(torch.randn(3, 2)))
Adam Paszke4afd62d2018-03-01 07:50:50 +01008242
eellisonaf933542019-04-01 11:48:39 -07008243 def test_irparser(self):
8244 graph_str = """graph(%0 : Double(5, 5)):
8245 # CHECK: aten::relu
8246 %1 : Double(5, 5) = aten::relu(%0)
8247 return (%1)
8248 """
8249 FileCheck().run(graph_str, parse_ir(graph_str))
8250
Elias Ellison43b56b32022-04-06 10:42:57 -07008251 def test_parse_tensor_constants(self):
8252 def foo():
8253 return torch.zeros([4, 4])
8254
8255 foo_s = torch.jit.script(foo)
8256 torch._C._jit_pass_constant_propagation(foo_s.graph)
8257
8258 g = str(foo_s.graph)
8259 g_parsed = parse_ir(g, parse_tensor_constants=True)
8260 self.assertEqual(str(canonical(g_parsed)), str(canonical(foo_s.graph)))
8261 func = torch._C._create_function_from_graph("forward", g_parsed)
8262
8263 out_parsed = func()
8264 out_func = foo()
8265 # not checking data, just dtype, size etc
8266 out_parsed[:] = 0
8267 out_func[:] = 0
8268 self.assertEqual(out_func, out_parsed)
8269
8270 with self.assertRaises(RuntimeError):
8271 parse_ir(g, parse_tensor_constants=False)
8272
Elias Ellisonb72b5b22022-04-06 10:42:57 -07008273 def test_parse_nested_names(self):
8274 g_str = """
8275 graph(%x.1 : Tensor):
8276 %3 : int = prim::Constant[value=1]()
8277 %2 : int = prim::Constant[value=2]()
8278 %hi.submod.value.5 : Tensor = aten::add(%x.1, %2, %3)
8279 return (%hi.submod.value.5)
8280 """
8281 g = parse_ir(g_str)
8282 round_trip_g = parse_ir(str(g))
8283 self.assertEqual(canonical(g), canonical(round_trip_g))
8284
8285 func1 = torch._C._create_function_from_graph("forward", g)
8286 func2 = torch._C._create_function_from_graph("forward", round_trip_g)
8287 self.assertEqual(func1(torch.ones([2])), func2(torch.ones([2])))
8288
Elias Ellison2285a2f2020-07-31 15:09:46 -07008289 def test_is_after_use(self):
8290 def sorted_input_use(g):
8291 uses = list(next(g.inputs()).uses())
8292 return sorted(uses, key=functools.cmp_to_key(type(uses[0]).isAfter))
8293
8294 @torch.jit.script
8295 def foo(x):
8296 a = x + 1
8297 return (x, x, a)
8298
8299 uses_sorted = sorted_input_use(foo.graph)
8300 # sorts last use to the end
8301 self.assertFalse(uses_sorted[0].isAfter(uses_sorted[1]))
8302 self.assertTrue(uses_sorted[0].user.kind() == "aten::add")
8303 self.assertEqual(uses_sorted[1].offset, 0)
8304
8305 @torch.jit.script
8306 def foo(x, cond: bool):
8307 if cond:
8308 return x + 3
8309 else:
8310 return x - 3
8311
8312 uses_sorted = sorted_input_use(foo.graph)
8313 self.assertTrue(uses_sorted[0].user.kind() == "aten::add")
8314 self.assertTrue(uses_sorted[1].user.kind() == "aten::sub")
8315
8316 @torch.jit.script
8317 def foo(x, cond: bool, cond2: bool):
8318 if cond:
8319 return x + 3
Shen Li10224432021-08-12 11:39:31 -07008320 elif cond2 :
Elias Ellison2285a2f2020-07-31 15:09:46 -07008321 return x - 3
8322
8323 return x / 3
8324
8325 graph1 = foo.graph
8326
8327 @torch.jit.script
8328 def foo(x, cond: bool, cond2: bool):
8329 if cond:
8330 return x + 3
8331 else:
Shen Li10224432021-08-12 11:39:31 -07008332 if cond2 :
Elias Ellison2285a2f2020-07-31 15:09:46 -07008333 return x - 3
8334 return x / 3
8335
8336 graph2 = foo.graph
8337
8338 for graph in [graph1, graph2]:
8339 uses_sorted = sorted_input_use(graph)
8340 self.assertTrue(uses_sorted[0].user.kind() == "aten::add")
8341 self.assertTrue(uses_sorted[1].user.kind() == "aten::sub")
8342 self.assertTrue(uses_sorted[2].user.kind() == "aten::div")
8343
Elias Ellison35de90e2019-05-08 14:45:53 -07008344 def test_canonicalize_control_outputs(self):
8345 def test_all_outputs(g):
8346 ifs = g.findAllNodes("prim::If")
8347 loops = g.findAllNodes("prim::Loop")
8348
8349 def contained_blocks(node):
Shen Li10224432021-08-12 11:39:31 -07008350 return len(node.findAllNodes("prim::If")) * 2 + len(node.findAllNodes("prim::Loop"))
Elias Ellison35de90e2019-05-08 14:45:53 -07008351 for node in ifs + loops:
8352 outs = list(node.outputs())
Alexander Grund5b0f4002020-10-19 18:40:28 -07008353 out_name = [x.debugName() for x in outs]
Elias Ellison35de90e2019-05-08 14:45:53 -07008354 if len(out_name) == 0:
8355 continue
8356 fc = FileCheck()
8357 # find the last output, then all subsequent uses
8358 fc.check(out_name[-1] + " : ")
8359 # skip past node body
8360 for i in range(contained_blocks(node)):
8361 fc.check("->")
Shen Li10224432021-08-12 11:39:31 -07008362 if (node.kind() == "prim::If"):
Elias Ellison35de90e2019-05-08 14:45:53 -07008363 fc.check("->").check("->").check("\n")
8364 else:
8365 fc.check("->").check("\n")
8366 # the canonical order is the same order as the first use
8367 # appears in text
8368 for name in out_name:
8369 fc.check(name)
8370 fc.run(g)
8371
8372 @torch.jit.script
8373 def test(x):
8374 # type: (bool) -> Tuple[int, int]
8375 b = 2
8376 a = 1
8377 if x:
8378 a = 1
8379 b = 2
8380 x = False
8381 if x:
8382 b = a
8383 else:
8384 a = b
8385
8386 return a, b
8387 test_all_outputs(test.graph)
8388
8389 @torch.jit.script
8390 def test2(x):
8391 # type: (bool) -> Tuple[int, int]
8392 b = 2
8393 a = 1
8394 if x:
8395 a = 1
8396 b = 2
8397 x = False
8398 if x:
8399 print(a)
8400 else:
8401 if x:
8402 print(b)
8403
8404 return a, b
8405 test_all_outputs(test2.graph)
8406
8407 @torch.jit.script
8408 def test_loop(x, iter):
8409 # type: (bool, int) -> (None)
8410 a = 1
8411 b = 2
8412 c = 3
8413 for i in range(iter):
8414 a = 4
8415 b = 5
8416 c = 6
8417 x = True
8418 print(c)
8419 if x:
8420 print(a, b)
8421 test_all_outputs(test_loop.graph)
8422
8423 @torch.jit.script
8424 def loop_unused(iter):
8425 # type: (int) -> (None)
8426 a = 1
8427 b = 2
8428 c = 3
8429 for i in range(iter):
8430 c = c + 1
8431 b = b + 1
8432 a = a + 1
8433 print(a, b)
8434 print(c)
8435
8436 # c is used, then unused should be ordered by alphabetical
8437 FileCheck().check(r"%c : int, %a : int, %b : int").run(loop_unused.graph)
8438
Elias Ellison89df22e2019-02-19 12:25:30 -08008439 def test_filecheck(self):
Elias Ellison89df22e2019-02-19 12:25:30 -08008440 def test_check():
8441 file = "232"
8442 FileCheck().check("2").check("3").check("2").run(file)
8443 FileCheck().check("232").run(file)
8444
8445 with self.assertRaisesRegex(RuntimeError, 'Expected to find "22"'):
8446 FileCheck().check("22").run(file)
8447 with self.assertRaisesRegex(RuntimeError, "CHECK: 3"):
8448 FileCheck().check("3").check("3").run(file)
8449
8450 test_check()
8451
8452 def test_check_count():
8453 file = "22222"
8454 FileCheck().check_count("2", 5).run(file)
8455 FileCheck().check_count("22", 2).run(file)
8456 FileCheck().check_count("222", 1).run(file)
8457
Shen Li10224432021-08-12 11:39:31 -07008458 with self.assertRaisesRegex(RuntimeError, 'Expected to not find'):
eellisond8d83712019-02-22 17:54:09 -08008459 FileCheck().check_count("2", 4, exactly=True).run(file)
8460
Elias Ellison89df22e2019-02-19 12:25:30 -08008461 with self.assertRaisesRegex(RuntimeError, 'Expected to find "22"'):
8462 FileCheck().check_count("22", 3).run(file)
8463
8464 with self.assertRaisesRegex(RuntimeError, "CHECK-COUNT-6: 2"):
8465 FileCheck().check_count("2", 6).run(file)
8466
8467 test_check_count()
8468
8469 def test_check_same():
8470 file = "22\n33"
eellisond8d83712019-02-22 17:54:09 -08008471 FileCheck().check_same("22").run(file)
Elias Ellison89df22e2019-02-19 12:25:30 -08008472
8473 with self.assertRaisesRegex(RuntimeError, "Expected to not find"):
8474 FileCheck().check_same("33").run(file)
8475
8476 file = "22 1 3"
8477
8478 FileCheck().check("2").check_same("3").run(file)
8479 FileCheck().check_count("2", 2).check_same("3").run(file)
8480
8481 test_check_same()
8482
8483 def test_check_next():
8484 file = "\n1\n2\n3"
8485 FileCheck().check("1").check_next("2").check_next("3").run(file)
8486 FileCheck().check_next("1").check_next("2").check_next("3").run(file)
8487
8488 with self.assertRaisesRegex(RuntimeError, "Expected to find"):
8489 FileCheck().check("1").check_next("2").run("12")
8490
8491 with self.assertRaisesRegex(RuntimeError, "Expected to not find"):
8492 FileCheck().check("1").check_next("2").run("1\n\n2")
8493
8494 test_check_next()
8495
8496 def test_check_dag():
8497 fc = FileCheck().check_dag("1").check_dag("2").check_not("2")
8498 fc.run("12")
8499 fc.run("21")
8500
8501 fc = FileCheck()
8502 fc.check_not("3").check_dag("1").check_dag("2").check_not("3")
8503 fc.run("1 3 2")
8504 fc.run("2 3 1")
8505
8506 fc = FileCheck().check_dag("1").check_dag("2").check("3")
Shen Li10224432021-08-12 11:39:31 -07008507 with self.assertRaisesRegex(RuntimeError, 'Expected to find "3" but did not find it'):
Elias Ellison89df22e2019-02-19 12:25:30 -08008508 fc.run("1 3 2")
8509
8510 test_check_dag()
8511
8512 def test_check_not():
8513 FileCheck().check_not("2").check("1").run("12")
8514 FileCheck().check("2").check_not("2").run("12")
8515
8516 with self.assertRaisesRegex(RuntimeError, 'Expected to not find "2"'):
8517 FileCheck().check_not("2").check("1").run("21")
8518
8519 with self.assertRaisesRegex(RuntimeError, 'Expected to not find "1"'):
8520 FileCheck().check("2").check_not("1").run("21")
8521
8522 # checks with distinct range matchings
8523 fb = FileCheck().check_count("2", 2).check_count("2", 2).check_not("2")
8524 with self.assertRaisesRegex(RuntimeError, 'Expected to not find "2"'):
8525 fb.run("22 2 22")
8526
8527 fb = FileCheck().check_count("2", 2).check_not("1").check_count("2", 2)
8528 with self.assertRaisesRegex(RuntimeError, 'Expected to not find "1"'):
8529 fb.run("22 1 22")
8530
Brian Vaughan88e4cee2019-09-05 18:24:09 -07008531 def _dtype_to_jit_name(self, dtype):
Shen Li10224432021-08-12 11:39:31 -07008532 if(dtype == torch.float32):
Brian Vaughan88e4cee2019-09-05 18:24:09 -07008533 return "Float"
Shen Li10224432021-08-12 11:39:31 -07008534 if(dtype == torch.float64):
Brian Vaughan88e4cee2019-09-05 18:24:09 -07008535 return "Double"
Shen Li10224432021-08-12 11:39:31 -07008536 if(dtype == torch.int64):
Brian Vaughan88e4cee2019-09-05 18:24:09 -07008537 return "Long"
Shen Li10224432021-08-12 11:39:31 -07008538 if(dtype == torch.int32):
Brian Vaughan88e4cee2019-09-05 18:24:09 -07008539 return "Int"
Shen Li10224432021-08-12 11:39:31 -07008540 if(dtype == torch.bool):
Brian Vaughan88e4cee2019-09-05 18:24:09 -07008541 return "Bool"
Shen Li10224432021-08-12 11:39:31 -07008542 raise RuntimeError('dtype not handled')
Brian Vaughan88e4cee2019-09-05 18:24:09 -07008543
Brian Vaughan97a604e2019-07-03 19:29:08 -07008544 def _dtype_to_expect(self, dtype, dim=0):
Shen Li10224432021-08-12 11:39:31 -07008545 param = ', '.join(['*'] * dim + ['device=cpu'])
8546 param = '(' + param + ')'
Brian Vaughan88e4cee2019-09-05 18:24:09 -07008547 jit_type = self._dtype_to_jit_name(dtype)
8548 if dim >= 0:
8549 return jit_type + param
8550 # special case representing wrapped number
8551 else:
8552 return jit_type.lower()
8553
Shen Li10224432021-08-12 11:39:31 -07008554
Brian Vaughan97a604e2019-07-03 19:29:08 -07008555 def _test_dtype_op_shape(self, ops, args, input_dims=1):
8556 if input_dims < 1:
Ilia Cherniavskiif7a8bf22020-11-25 04:30:15 -08008557 raise RuntimeError("input dims must be at least 1")
Brian Vaughan97a604e2019-07-03 19:29:08 -07008558 dtypes = [torch.float32, torch.float64, torch.int64, torch.int32]
Shen Li10224432021-08-12 11:39:31 -07008559 str_args = ', '.join([str(arg) for arg in args]) + (', ' if len(args) else '')
8560 tensor_data = ('[' * input_dims) + '1, 2, 3' + (input_dims * ']')
8561 template = dedent('''
Brian Vaughan97a604e2019-07-03 19:29:08 -07008562 def func():
8563 return {return_line}
Shen Li10224432021-08-12 11:39:31 -07008564 ''')
Brian Vaughan97a604e2019-07-03 19:29:08 -07008565
8566 for op in ops:
Shen Li10224432021-08-12 11:39:31 -07008567 for dtype in (dtypes + [None]):
Brian Vaughan97a604e2019-07-03 19:29:08 -07008568 for tensor_type in dtypes:
8569 # a couple of ops aren't implemented for non-floating types
Shen Li10224432021-08-12 11:39:31 -07008570 if(not tensor_type.is_floating_point or (dtype is not None and not dtype.is_floating_point)):
8571 if op in ['mean', 'softmax', 'log_softmax']:
Brian Vaughan97a604e2019-07-03 19:29:08 -07008572 continue
Shen Li10224432021-08-12 11:39:31 -07008573 return_line = "torch.tensor({}, dtype={}).{}({}dtype={})".format(tensor_data, tensor_type, op, str_args, dtype)
Brian Vaughan97a604e2019-07-03 19:29:08 -07008574 # uncomment for debugging a failed test:
8575 # print("testing {}".format(return_line))
8576 code = template.format(return_line=return_line)
8577 scope = {}
8578 exec(code, globals(), scope)
8579 cu = torch.jit.CompilationUnit(code)
8580 graph = cu.func.graph
8581 torch._C._jit_pass_complete_shape_analysis(graph, (), False)
8582 input_array = [1, 2, 3]
8583 for _ in range(1, input_dims):
8584 input_array = [input_array]
8585 t = torch.tensor(input_array, dtype=tensor_type)
8586 attr = getattr(t, op)
Shen Li10224432021-08-12 11:39:31 -07008587 kwargs = {'dtype': dtype}
Brian Vaughan97a604e2019-07-03 19:29:08 -07008588 result = attr(*args, **kwargs)
8589 expect = self._dtype_to_expect(result.dtype, result.dim())
8590 FileCheck().check("aten::tensor").check(expect).run(graph)
8591
8592 def test_dtype_op_shape(self):
Shen Li10224432021-08-12 11:39:31 -07008593 ops = ['prod']
Brian Vaughan97a604e2019-07-03 19:29:08 -07008594 self._test_dtype_op_shape(ops, args=[])
8595 self._test_dtype_op_shape(ops, args=[0, False])
Brian Vaughan97a604e2019-07-03 19:29:08 -07008596 self._test_dtype_op_shape(ops, args=[0, False])
8597 self._test_dtype_op_shape(ops, args=[0, True])
8598
8599 def test_dtype_op_shape2(self):
Shen Li10224432021-08-12 11:39:31 -07008600 ops = ['cumprod', 'cumsum', 'softmax', 'log_softmax']
Brian Vaughan97a604e2019-07-03 19:29:08 -07008601 self._test_dtype_op_shape(ops, args=[0])
8602
8603 self._test_dtype_op_shape(ops, args=[1], input_dims=4)
8604
Shen Li10224432021-08-12 11:39:31 -07008605
Brian Vaughan88e4cee2019-09-05 18:24:09 -07008606 def _test_binary_op_shape(self, ops, input_dims=1):
8607
8608 dtypes = [torch.float32, torch.float64, torch.int64, torch.int32, torch.bool]
8609
8610 if input_dims == 0:
Shen Li10224432021-08-12 11:39:31 -07008611 shape = '1'
Brian Vaughan88e4cee2019-09-05 18:24:09 -07008612 else:
Shen Li10224432021-08-12 11:39:31 -07008613 shape = '[' + ('1,' * 4) + ']'
Brian Vaughan88e4cee2019-09-05 18:24:09 -07008614 for _ in range(1, input_dims):
Shen Li10224432021-08-12 11:39:31 -07008615 shape = '[' + ",".join([shape] * 4) + ']'
Brian Vaughan88e4cee2019-09-05 18:24:09 -07008616
Shen Li10224432021-08-12 11:39:31 -07008617 template = dedent('''
Brian Vaughan88e4cee2019-09-05 18:24:09 -07008618 def func():
8619 arg1 = {}
8620 arg2 = {}
8621 return torch.{}(arg1, arg2)
Shen Li10224432021-08-12 11:39:31 -07008622 ''')
Brian Vaughan88e4cee2019-09-05 18:24:09 -07008623
8624 args = []
8625 for dtype in dtypes:
8626 args = args + ["torch.tensor({}, dtype={})".format(shape, dtype)]
8627 args = args + [1, 1.5]
8628
8629 def isBool(arg):
8630 return type(arg) == bool or (type(arg) == str and "torch.bool" in arg)
8631
8632 for op in ops:
8633 for first_arg in args:
8634 for second_arg in args:
8635 # subtract not supported for bool
Shen Li10224432021-08-12 11:39:31 -07008636 if (op == 'sub' or op == 'div') and (isBool(first_arg) or isBool(second_arg)):
Brian Vaughan88e4cee2019-09-05 18:24:09 -07008637 continue
Mike Ruberry64584572020-05-19 19:25:35 -07008638 # div is not implemented correctly for mixed-type or int params
Shen Li10224432021-08-12 11:39:31 -07008639 if (op == 'div' and (type(first_arg) != type(second_arg) or
8640 isinstance(first_arg, int) or
8641 (isinstance(first_arg, str) and 'int' in first_arg))):
Brian Vaughan88e4cee2019-09-05 18:24:09 -07008642 continue
8643 return_line = "torch.{}({}, {})".format(op, first_arg, second_arg)
8644 # uncomment for debugging a failed test:
8645 # print("testing {}".format(return_line))
8646 code = template.format(first_arg, second_arg, op)
8647 scope = {}
8648 exec(code, globals(), scope)
Shen Li10224432021-08-12 11:39:31 -07008649 non_jit_result = scope['func']()
Brian Vaughan88e4cee2019-09-05 18:24:09 -07008650
8651 cu = torch.jit.CompilationUnit(code)
8652 graph = cu.func.graph
8653 torch._C._jit_pass_complete_shape_analysis(graph, (), False)
8654 # use dim=-1 to represent a python/jit scalar.
Shen Li10224432021-08-12 11:39:31 -07008655 dim = -1 if type(first_arg) != str and type(second_arg) != str else non_jit_result.dim()
Brian Vaughan88e4cee2019-09-05 18:24:09 -07008656 dtype = non_jit_result.dtype
8657 # jit only supports int/float scalars.
8658 if dim < 0:
8659 if dtype == torch.int64:
8660 dtype = torch.int32
8661 if dtype == torch.float64:
8662 dtype = torch.float32
8663 expect = self._dtype_to_expect(dtype, dim)
8664 jit_output = next(graph.outputs())
8665
8666 check = FileCheck()
8667 check.check(expect).run(str(jit_output))
8668
8669 def test_binary_op_shape(self):
Shen Li10224432021-08-12 11:39:31 -07008670 self._test_binary_op_shape(['mul', 'div', 'add', 'sub'], 0)
8671 self._test_binary_op_shape(['mul', 'div', 'add', 'sub'], 3)
Brian Vaughan88e4cee2019-09-05 18:24:09 -07008672
Brian Vaughan88e4cee2019-09-05 18:24:09 -07008673 def test_no_dtype_shape(self):
Shen Li10224432021-08-12 11:39:31 -07008674
Brian Vaughan88e4cee2019-09-05 18:24:09 -07008675 @torch.jit.script
8676 def foo(x):
8677 scalar_number = x.item()
8678 return x.add(scalar_number)
8679
8680 @torch.jit.script
8681 def foo2(x):
8682 scalar_number = x.item()
8683 return torch.tensor(1).add(scalar_number)
8684
8685 t = torch.tensor(5)
8686 g = foo.graph_for(t)
8687 type = next(g.outputs())
8688 self.assertTrue(type.type() == torch._C.TensorType.get())
8689 g2 = foo2.graph_for(t)
8690 type = next(g.outputs())
8691 self.assertTrue(type.type() == torch._C.TensorType.get())
8692
Shen Li10224432021-08-12 11:39:31 -07008693
eellisonaf933542019-04-01 11:48:39 -07008694 def test_filecheck_parse(self):
8695 def test_check():
8696 file = """
8697 # CHECK: 2
8698 # CHECK: 3
8699 # CHECK: 2
8700 232
8701 """
8702 FileCheck().run(checks_file=file, test_file=file)
8703 file = """
8704 # CHECK: 232
8705 232
8706 """
8707 FileCheck().run(file, "232")
8708 with self.assertRaisesRegex(RuntimeError, 'Expected to find "232"'):
8709 FileCheck().run(file, "22")
8710 with self.assertRaisesRegex(RuntimeError, 'Expected to find "22"'):
8711 FileCheck().run("# CHECK: 22", "23")
8712 test_check()
8713
8714 def test_check_count():
8715 file = "22222"
8716 FileCheck().run("# CHECK-COUNT-5: 2", file)
8717 FileCheck().run("# CHECK-COUNT-EXACTLY-5: 2", file)
8718 FileCheck().run("# CHECK-COUNT-2: 22", file)
8719 FileCheck().run("# CHECK-COUNT-1: 222", file)
8720
Shen Li10224432021-08-12 11:39:31 -07008721 with self.assertRaisesRegex(RuntimeError, 'Expected to not find'):
eellisonaf933542019-04-01 11:48:39 -07008722 FileCheck().run("# CHECK-COUNT-EXACTLY-2: 2", file)
8723 test_check_count()
8724
8725 def test_check_same():
8726 file = "22\n33"
8727 FileCheck().run("# CHECK-SAME: 22", file)
8728
8729 with self.assertRaisesRegex(RuntimeError, "Expected to not find"):
8730 FileCheck().run("# CHECK-SAME: 33", file)
8731
8732 file = "22 1 3"
8733
8734 FileCheck().run("# CHECK: 2\n # CHECK-SAME: 3", file)
8735 FileCheck().run("# CHECK-COUNT-2: 2\n # CHECK-SAME: 3", file)
8736 test_check_same()
8737
8738 def test_bad_input():
8739 with self.assertRaisesRegex(RuntimeError, "Check for bad input"):
8740 FileCheck().run("", "1")
8741
8742 with self.assertRaisesRegex(RuntimeError, "Could not parse check"):
8743 FileCheck().run("# CHECK1", "")
8744
8745 test_bad_input()
8746
Zachary DeVito0f198fa2018-03-27 23:37:56 -07008747 def test_script_module_call_noscript(self):
8748 class M(torch.jit.ScriptModule):
8749 def __init__(self):
Xuehai Pan046e88a2023-02-12 22:20:50 +00008750 super().__init__()
Zachary DeVito0f198fa2018-03-27 23:37:56 -07008751 self.value = 1
8752
davidriazati7a370db2019-07-16 12:50:02 -07008753 @torch.jit.ignore
Zachary DeVito0f198fa2018-03-27 23:37:56 -07008754 def foo(self):
8755 return torch.ones(2, 2) + self.value
8756
8757 @torch.jit.script_method
8758 def forward(self, input):
8759 return input + self.foo()
8760
Michael Suo711be822019-07-24 23:05:48 -07008761 with torch.jit.optimized_execution(False):
8762 m = M()
8763 input = torch.randn(2, 2)
8764 o = m(input)
8765 self.assertEqual(o, input + torch.ones(2, 2) + 1)
8766 # check that we can change python attributes
8767 # and that those changes are picked up in script methods
8768 m.value = 2
8769 o = m(input)
8770 self.assertEqual(o, input + torch.ones(2, 2) + 2)
Zachary DeVito0f198fa2018-03-27 23:37:56 -07008771
8772 def test_script_module_nochange_submodule(self):
8773 class M(torch.jit.ScriptModule):
8774 def __init__(self):
Xuehai Pan046e88a2023-02-12 22:20:50 +00008775 super().__init__()
Zachary DeVito0f198fa2018-03-27 23:37:56 -07008776 self.sub = nn.Linear(5, 5)
8777
8778 @torch.jit.script_method
8779 def forward(self, input):
8780 return self.sub(input)
Michael Suo711be822019-07-24 23:05:48 -07008781 with torch.jit.optimized_execution(False):
8782 m = M()
8783 input = torch.randn(1, 5, 5)
8784 o = m(input)
8785 self.assertEqual(o, m.sub(input))
8786 with self.assertRaisesRegex(RuntimeError, "Cannot re-assign"):
8787 m.sub = nn.Linear(5, 5)
Zachary DeVito0f198fa2018-03-27 23:37:56 -07008788
Elias Ellison78aebbc2020-03-05 16:01:18 -08008789 def test_module_apis(self):
8790 class Sub(torch.nn.Module):
Elias Ellison78aebbc2020-03-05 16:01:18 -08008791 def forward(self, thing):
8792 return thing - 2
8793
8794 class Double(torch.nn.Module):
Elias Ellison78aebbc2020-03-05 16:01:18 -08008795 def forward(self, thing):
8796 return thing * 2
8797
Elias Ellison057fd5e2020-02-26 18:04:42 -08008798 class MyMod(torch.nn.Module):
8799 def __init__(self):
Xuehai Pan046e88a2023-02-12 22:20:50 +00008800 super().__init__()
Shen Li10224432021-08-12 11:39:31 -07008801 self.mod = (Sub())
8802 self.mod2 = (Sub())
Elias Ellison78aebbc2020-03-05 16:01:18 -08008803 self.mod3 = nn.Sequential(nn.Sequential(Sub()))
8804 self.mod4 = nn.Sequential(Sub(), Double())
Elias Ellison057fd5e2020-02-26 18:04:42 -08008805
8806 @torch.jit.export
Elias Ellison78aebbc2020-03-05 16:01:18 -08008807 def method(self, x, x1, y, y1):
Elias Ellison057fd5e2020-02-26 18:04:42 -08008808 mod_names = ""
8809 for name, mod in self.named_modules():
8810 mod_names = mod_names + " " + name
8811 x = mod(x)
Elias Ellison78aebbc2020-03-05 16:01:18 -08008812
8813 children_names = ""
8814 for name, mod in self.named_children():
8815 children_names = children_names + " " + name
8816 x1 = mod(x1)
8817
8818 for mod in self.modules():
8819 y = mod(y)
8820
8821 for mod in self.children():
8822 y1 = mod(y1)
8823
8824 return mod_names, children_names, x, x1, y, y1
Elias Ellison057fd5e2020-02-26 18:04:42 -08008825
8826 def forward(self, x):
8827 return x + 2
8828
8829 mod = torch.jit.script(MyMod())
Elias Ellison78aebbc2020-03-05 16:01:18 -08008830 inps = tuple([torch.tensor(i) for i in range(1, 5)])
8831 self.assertEqual(mod.method(*inps), MyMod().method(*inps))
Elias Ellison057fd5e2020-02-26 18:04:42 -08008832
Zachary DeVito5ab30ee2018-04-05 11:31:43 -07008833 def test_script_module_const(self):
8834 class M(torch.jit.ScriptModule):
8835
Shen Li10224432021-08-12 11:39:31 -07008836 __constants__ = ['b', 'i', 'c', 's']
Zachary DeVito5ab30ee2018-04-05 11:31:43 -07008837
8838 def __init__(self):
Xuehai Pan046e88a2023-02-12 22:20:50 +00008839 super().__init__()
Zachary DeVito5ab30ee2018-04-05 11:31:43 -07008840 self.b = False
8841 self.i = 1
8842 self.c = 3.5
Michael Suo90c65b82020-01-15 17:33:25 -08008843 self.s = ["hello"]
Zachary DeVito5ab30ee2018-04-05 11:31:43 -07008844
8845 @torch.jit.script_method
8846 def forward(self):
8847 return self.b, self.i, self.c
8848
Michael Suo711be822019-07-24 23:05:48 -07008849 with torch.jit.optimized_execution(False):
8850 m = M()
8851 o0, o1, o2 = m()
Zachary DeVito5ab30ee2018-04-05 11:31:43 -07008852 self.assertEqual(o0, 0)
8853 self.assertEqual(o1, 1)
8854 self.assertEqual(o2, 3.5)
8855
Horace He52d27892019-05-28 12:21:55 -07008856 def test_script_module_fail_exist(self):
8857 class M(torch.jit.ScriptModule):
Horace He52d27892019-05-28 12:21:55 -07008858 @torch.jit.script_method
8859 def forward(self, x):
8860 return x + self.whatisgoingon
Michael Suo34126272019-10-12 09:49:56 -07008861 with self.assertRaisesRegex(RuntimeError, "Module 'M' has no attribute"):
Horace He52d27892019-05-28 12:21:55 -07008862 M()
8863
Shen Li10224432021-08-12 11:39:31 -07008864 @unittest.skip("[module dedupe] currently NoneType refinement on optional attributes doesn't work.")
Horace He52d27892019-05-28 12:21:55 -07008865 def test_script_module_none_exist_fail(self):
8866 class M(torch.jit.ScriptModule):
8867 def __init__(self, my_optional):
Xuehai Pan046e88a2023-02-12 22:20:50 +00008868 super().__init__()
Horace He52d27892019-05-28 12:21:55 -07008869 self.my_optional = my_optional
8870
8871 @torch.jit.script_method
8872 def forward(self, x):
8873 if self.my_optional is not None:
8874 return torch.neg(x) + self.my_optional
8875 return torch.neg(x)
Michael Suo34126272019-10-12 09:49:56 -07008876 with self.assertRaisesRegex(RuntimeError, "has no attribute 'my_optional'"):
Horace He52d27892019-05-28 12:21:55 -07008877 x = torch.rand(3, 4)
8878 fb = M(None)
8879 fb(x)
8880
Michael Suo34126272019-10-12 09:49:56 -07008881 def test_script_module_invalid_consts(self):
Zachary DeVito5ab30ee2018-04-05 11:31:43 -07008882 class Foo(torch.jit.ScriptModule):
Shen Li10224432021-08-12 11:39:31 -07008883 __constants__ = ['invalid']
Zachary DeVito5ab30ee2018-04-05 11:31:43 -07008884
8885 def __init__(self):
Xuehai Pan046e88a2023-02-12 22:20:50 +00008886 super().__init__()
Michael Suo34126272019-10-12 09:49:56 -07008887 self.invalid = [nn.Linear(3, 4)]
Zachary DeVito5ab30ee2018-04-05 11:31:43 -07008888
Michael Suo34126272019-10-12 09:49:56 -07008889 with self.assertRaisesRegex(
Shen Li10224432021-08-12 11:39:31 -07008890 TypeError,
8891 "Linear' object in attribute 'Foo.invalid' is not a valid constant"):
Michael Suo34126272019-10-12 09:49:56 -07008892 Foo()
8893
8894 class Foo2(torch.jit.ScriptModule):
Shen Li10224432021-08-12 11:39:31 -07008895 __constants__ = ['invalid']
Michael Suo34126272019-10-12 09:49:56 -07008896
8897 def __init__(self):
Xuehai Pan046e88a2023-02-12 22:20:50 +00008898 super().__init__()
Aaron Gokaslan9171f7d2023-02-10 18:02:44 +00008899 self.invalid = int
Michael Suo34126272019-10-12 09:49:56 -07008900
8901 with self.assertRaisesRegex(TypeError, "not a valid constant"):
8902 Foo2()
8903
8904 class Foo3(torch.jit.ScriptModule):
Shen Li10224432021-08-12 11:39:31 -07008905 __constants__ = ['invalid']
Michael Suo34126272019-10-12 09:49:56 -07008906
8907 def __init__(self):
Xuehai Pan046e88a2023-02-12 22:20:50 +00008908 super().__init__()
Michael Suo34126272019-10-12 09:49:56 -07008909 self.invalid = (3, 4, {})
8910
8911 with self.assertRaisesRegex(TypeError, "not a valid constant"):
8912 Foo3()
David Riazatieac3e7a2018-10-25 10:41:26 -07008913
Dmytro Dzhulgakov8e284172020-07-08 21:47:49 -07008914 class Foo4(torch.jit.ScriptModule):
Shen Li10224432021-08-12 11:39:31 -07008915 __constants__ = ['invalid']
Dmytro Dzhulgakov8e284172020-07-08 21:47:49 -07008916
8917 def __init__(self):
Xuehai Pan046e88a2023-02-12 22:20:50 +00008918 super().__init__()
Dmytro Dzhulgakov8e284172020-07-08 21:47:49 -07008919 self.invalid = np.int64(5)
8920
8921 # verify that we capture human understandable class name
8922 with self.assertRaisesRegex(TypeError, "numpy.int64"):
8923 Foo4()
8924
Wanchao Liangd6bfc532018-11-20 14:09:27 -08008925 def test_script_module_param_buffer_mutation(self):
8926 # TODO: add param mutation test case after JIT support it
8927 class ModuleBufferMutate(torch.jit.ScriptModule):
Wanchao Liangd6bfc532018-11-20 14:09:27 -08008928 def __init__(self):
Xuehai Pan046e88a2023-02-12 22:20:50 +00008929 super().__init__()
Shen Li10224432021-08-12 11:39:31 -07008930 self.register_buffer('running_var', torch.tensor(0, dtype=torch.long))
Wanchao Liangd6bfc532018-11-20 14:09:27 -08008931
8932 @torch.jit.script_method
8933 def forward(self):
8934 if self.training:
8935 self.running_var += 1
8936 return self.running_var
8937
Michael Suo711be822019-07-24 23:05:48 -07008938 with torch.jit.optimized_execution(False):
8939 m = ModuleBufferMutate()
8940 self.assertEqual(m(), 1)
8941 m.eval()
8942 self.assertEqual(m(), 1)
Wanchao Liangd6bfc532018-11-20 14:09:27 -08008943
Zachary DeVito5ab30ee2018-04-05 11:31:43 -07008944 def test_script_module_for(self):
8945 class M(torch.jit.ScriptModule):
Shen Li10224432021-08-12 11:39:31 -07008946 __constants__ = ['b']
Zachary DeVito5ab30ee2018-04-05 11:31:43 -07008947
8948 def __init__(self):
Xuehai Pan046e88a2023-02-12 22:20:50 +00008949 super().__init__()
Zachary DeVito5ab30ee2018-04-05 11:31:43 -07008950 self.b = [1, 2, 3, 4]
8951
8952 @torch.jit.script_method
8953 def forward(self):
Richard Zou67f6f932018-08-27 08:53:56 -07008954 sum = 0
Zachary DeVito5ab30ee2018-04-05 11:31:43 -07008955 for i in self.b:
8956 sum += i
8957 return sum
8958
Michael Suo711be822019-07-24 23:05:48 -07008959 with torch.jit.optimized_execution(False):
8960 m = M()
8961 self.assertEqual(m(), 10)
Zachary DeVito5ab30ee2018-04-05 11:31:43 -07008962
Elias Ellisonfbe90b62019-11-12 14:08:07 -08008963 def test_override_magic(self):
8964 class OverrideMagic(nn.Module):
Elias Ellisonfbe90b62019-11-12 14:08:07 -08008965 @torch.jit.export
8966 def __len__(self):
8967 return 10
8968
8969 mod = OverrideMagic()
8970 self.assertEqual(len(mod), len(torch.jit.script(mod)))
8971
8972 class OverrideMagicSeq(nn.Sequential):
Elias Ellisonfbe90b62019-11-12 14:08:07 -08008973 @torch.jit.export
8974 def __len__(self):
8975 return 10
8976
8977 mod = OverrideMagicSeq()
8978 self.assertEqual(len(mod), len(torch.jit.script(mod)))
8979 self.assertTrue(torch.jit.script(mod))
8980
Zachary DeVito5ab30ee2018-04-05 11:31:43 -07008981 def test_script_module_for2(self):
8982 class Sub(torch.jit.ScriptModule):
8983 def __init__(self):
Xuehai Pan046e88a2023-02-12 22:20:50 +00008984 super().__init__()
Zachary DeVito5ab30ee2018-04-05 11:31:43 -07008985 self.weight = nn.Parameter(torch.randn(2))
8986
8987 @torch.jit.script_method
8988 def forward(self, thing):
8989 return self.weight + thing
8990
8991 class M(torch.jit.ScriptModule):
Zachary DeVito5ab30ee2018-04-05 11:31:43 -07008992 def __init__(self):
Xuehai Pan046e88a2023-02-12 22:20:50 +00008993 super().__init__()
Zachary DeVito5ab30ee2018-04-05 11:31:43 -07008994 self.mods = nn.ModuleList([Sub() for i in range(10)])
8995
8996 @torch.jit.script_method
8997 def forward(self, v):
8998 for m in self.mods:
8999 v = m(v)
9000 return v
9001
Michael Suo711be822019-07-24 23:05:48 -07009002 with torch.jit.optimized_execution(False):
Yukio Siraichi93bf0ae2021-04-11 15:43:54 -07009003 i = torch.empty(2)
Michael Suo711be822019-07-24 23:05:48 -07009004 m = M()
9005 o = m(i)
9006 v = i
9007 for sub in m.mods:
9008 v = sub(v)
9009 self.assertEqual(o, v)
Elias Ellisonfbe90b62019-11-12 14:08:07 -08009010 with self.assertRaisesRegex(Exception, "object is not iterable"):
Hong Xua6a72ac2020-02-21 08:29:32 -08009011 print(list(m))
Elias Ellisonfbe90b62019-11-12 14:08:07 -08009012
James Reed896e4b62019-08-20 13:07:45 -07009013 def test_attr_qscheme_script(self):
9014 class Foo(torch.nn.Module):
9015 def __init__(self):
Xuehai Pan046e88a2023-02-12 22:20:50 +00009016 super().__init__()
James Reed896e4b62019-08-20 13:07:45 -07009017 self.qscheme = torch.per_tensor_affine
9018
9019 def forward(self):
9020 if self.qscheme == torch.per_tensor_symmetric:
9021 return 3
9022 else:
9023 return 4
9024
9025 f = Foo()
9026 scripted = torch.jit.script(f)
9027 self.assertEqual(f(), scripted())
9028
Zachary DeVito5ab30ee2018-04-05 11:31:43 -07009029 def test_script_module_const_submodule_fail(self):
9030 class Sub(torch.jit.ScriptModule):
9031 def __init__(self):
Xuehai Pan046e88a2023-02-12 22:20:50 +00009032 super().__init__()
Zachary DeVito5ab30ee2018-04-05 11:31:43 -07009033 self.weight = nn.Parameter(torch.randn(2))
9034
9035 @torch.jit.script_method
9036 def forward(self, thing):
9037 return self.weight + thing
9038
9039 class M(torch.jit.ScriptModule):
9040 def __init__(self):
Xuehai Pan046e88a2023-02-12 22:20:50 +00009041 super().__init__()
Zachary DeVito5ab30ee2018-04-05 11:31:43 -07009042 self.mods = [Sub() for _ in range(10)]
9043
9044 @torch.jit.script_method
9045 def forward(self):
9046 for _ in self.mods:
9047 print(1)
9048 return 4
9049
Michael Suo34126272019-10-12 09:49:56 -07009050 with self.assertRaisesRegex(RuntimeError, "has no attribute 'mods'"):
Zachary DeVito5ab30ee2018-04-05 11:31:43 -07009051 M()
9052
James Reed0a36fe52018-12-10 15:35:11 -08009053 class DerivedStateModule(torch.jit.ScriptModule):
9054 def __init__(self):
9055 super(TestScript.DerivedStateModule, self).__init__()
9056 self.param = torch.nn.Parameter(torch.ones(3, 4, dtype=torch.float))
Shen Li10224432021-08-12 11:39:31 -07009057 self.register_buffer('derived', torch.neg(self.param).detach().clone())
James Reed0a36fe52018-12-10 15:35:11 -08009058
9059 # This is a flag so we can test that the pack method was called
Shen Li10224432021-08-12 11:39:31 -07009060 self.register_buffer('pack_called', torch.zeros(1, dtype=torch.long))
James Reed0a36fe52018-12-10 15:35:11 -08009061 # This is a flag so we can test that the unpack method was called
Shen Li10224432021-08-12 11:39:31 -07009062 self.register_buffer('unpack_called', torch.zeros(1, dtype=torch.long))
James Reed0a36fe52018-12-10 15:35:11 -08009063
9064 @torch.jit.script_method
9065 def _pack(self):
9066 self.pack_called.set_(torch.ones(1, dtype=torch.long))
9067 self.derived.set_(torch.rand(1, dtype=torch.float).detach())
9068
9069 @torch.jit.script_method
9070 def _unpack(self):
9071 self.unpack_called.set_(torch.ones(1, dtype=torch.long))
9072 self.derived.set_(torch.neg(self.param).detach())
9073
9074 @torch.jit.script_method
9075 def forward(self, x):
9076 return x + self.derived
9077
9078 def test_pack_unpack_state(self):
9079 sm = TestScript.DerivedStateModule()
9080 x = torch.rand(3, 4, dtype=torch.float)
Philip Meier99203582021-08-19 12:45:32 -07009081 torch.testing.assert_close(sm(x), x + torch.neg(torch.ones(3, 4, dtype=torch.float)))
James Reed0a36fe52018-12-10 15:35:11 -08009082
9083 # Test save path
9084 self.assertFalse(sm.pack_called.item())
9085 self.assertFalse(sm.unpack_called.item())
James Reed7f1397a2019-01-15 10:07:18 -08009086 imported = self.getExportImportCopyWithPacking(sm)
James Reed0a36fe52018-12-10 15:35:11 -08009087 # ensure pack was called before serialization
9088 self.assertTrue(sm.pack_called.item())
9089 # ensure unpack was called after serialization so as to leave the module in an initialized state
9090 self.assertTrue(sm.unpack_called.item())
9091
Philip Meier99203582021-08-19 12:45:32 -07009092 torch.testing.assert_close(sm.derived, torch.neg(sm.param))
James Reed0a36fe52018-12-10 15:35:11 -08009093
9094 # Test load paths
9095 self.assertTrue(imported.unpack_called.item())
Philip Meier99203582021-08-19 12:45:32 -07009096 torch.testing.assert_close(imported(x), x + torch.neg(torch.ones(3, 4, dtype=torch.float)))
James Reed0a36fe52018-12-10 15:35:11 -08009097
Elias Ellisonfddf7322020-02-26 18:28:47 -08009098 @unittest.skipIf(not TEST_MKL, "PyTorch is built without MKL support")
Ansley Ussery6831d8e2021-09-03 06:10:37 -07009099 @unittest.skipIf(True, "Skipping while landing PR stack")
Elias Ellisonfddf7322020-02-26 18:28:47 -08009100 def test_torch_functional(self):
moto5a27ec02020-04-24 12:12:33 -07009101 def stft(input, n_fft):
Elias Ellisonfddf7322020-02-26 18:28:47 -08009102 # type: (Tensor, int) -> Tensor
Peter Bell5c25f8f2020-12-20 14:40:16 -08009103 return torch.stft(input, n_fft, return_complex=True)
Elias Ellisonfddf7322020-02-26 18:28:47 -08009104
9105 inps = (torch.randn(10), 7)
moto5a27ec02020-04-24 12:12:33 -07009106 self.assertEqual(stft(*inps), torch.jit.script(stft)(*inps))
9107
9108 def istft(input, n_fft):
9109 # type: (Tensor, int) -> Tensor
9110 return torch.istft(input, n_fft)
9111
Peter Bell5c25f8f2020-12-20 14:40:16 -08009112 inps2 = (stft(*inps), inps[1])
9113 self.assertEqual(istft(*inps2), torch.jit.script(istft)(*inps2))
Elias Ellisonfddf7322020-02-26 18:28:47 -08009114
Elias Ellisonf31b1d32020-02-26 18:28:47 -08009115 def lu_unpack(x):
lezcanof7b9a462022-06-07 15:51:05 +00009116 A_LU, pivots = torch.linalg.lu_factor(x)
Elias Ellisonf31b1d32020-02-26 18:28:47 -08009117 return torch.lu_unpack(A_LU, pivots)
9118
9119 for shape in ((3, 3), (5, 3, 3), (7, 3, 5, 5), (7, 5, 3, 3, 3)):
9120 a = torch.randn(*shape)
9121 self.checkScript(lu_unpack, (a,))
9122
Elias Ellison857eb412020-02-26 18:28:47 -08009123 def cdist_fn():
9124 a = torch.tensor([[0.9041, 0.0196], [-0.3108, -2.4423], [-0.4821, 1.059]])
9125 b = torch.tensor([[-2.1763, -0.4713], [-0.6986, 1.3702]])
9126 return torch.cdist(a, b, compute_mode="use_mm_for_euclid_dist")
9127
9128 self.checkScript(cdist_fn, ())
9129
Elias Ellison479c3b02020-03-05 14:42:56 -08009130 def norm():
9131 c = torch.tensor([[1, 2, 3], [-1, 1, 4]], dtype=torch.float)
Shen Li10224432021-08-12 11:39:31 -07009132 return torch.norm(c, p="fro"), torch.norm(c, p="nuc"), torch.norm(c), torch.norm(c, p=.5)
Elias Ellison479c3b02020-03-05 14:42:56 -08009133
9134 self.checkScript(norm, ())
9135
Elias Ellisoneb3e9872020-05-12 18:29:18 -07009136 def torch_unique(dim: Optional[int]):
9137 ten = torch.unique(torch.tensor([[1, 3], [2, 3]], dtype=torch.long))
9138 a = torch.unique(ten, dim=dim)
9139 b = torch.unique(ten, return_counts=True, dim=dim)
9140 c = torch.unique(ten, return_inverse=True, dim=dim)
9141 d = torch.unique(ten, return_counts=True, return_inverse=True, dim=dim)
9142 return a, b, c, d
9143
9144 self.checkScript(torch_unique, (None,))
9145 self.checkScript(torch_unique, (0,))
9146
Xiang Gaoebd41252020-06-02 14:52:29 -07009147 def torch_unique_consecutive(dim: Optional[int]):
Shen Li10224432021-08-12 11:39:31 -07009148 ten = torch.unique(torch.tensor([[1, 3], [3, 2], [3, 2], [2, 3]], dtype=torch.long))
Xiang Gaoebd41252020-06-02 14:52:29 -07009149 a = torch.unique_consecutive(ten, dim=dim)
9150 b = torch.unique_consecutive(ten, return_counts=True, dim=dim)
9151 c = torch.unique_consecutive(ten, return_inverse=True, dim=dim)
Shen Li10224432021-08-12 11:39:31 -07009152 d = torch.unique_consecutive(ten, return_counts=True, return_inverse=True, dim=dim)
Xiang Gaoebd41252020-06-02 14:52:29 -07009153 return a, b, c, d
9154
9155 self.checkScript(torch_unique_consecutive, (None,))
9156 self.checkScript(torch_unique_consecutive, (0,))
9157
Yanan Caof48a9712021-03-20 23:01:40 -07009158 def test_torch_functional_tensordot_int(self):
9159 def tensordot_dims_int(a: torch.Tensor, b: torch.Tensor, dims: int):
9160 return torch.tensordot(a, b, dims=dims)
9161
Shen Li10224432021-08-12 11:39:31 -07009162 a = torch.arange(120.).reshape(2, 3, 4, 5)
9163 b = torch.arange(840.).reshape(4, 5, 6, 7)
Yanan Caof48a9712021-03-20 23:01:40 -07009164 dims = 2
9165 self.checkScript(tensordot_dims_int, (a, b, dims))
9166
9167 def test_torch_functional_tensordot_tensor(self):
9168 def tensordot_dims_tensor(a: torch.Tensor, b: torch.Tensor, dims: torch.Tensor):
9169 return torch.tensordot(a, b, dims=dims)
9170
Shen Li10224432021-08-12 11:39:31 -07009171 a = torch.arange(120.).reshape(2, 3, 4, 5)
9172 b = torch.arange(840.).reshape(4, 5, 6, 7)
Yukio Siraichi93bf0ae2021-04-11 15:43:54 -07009173 dims = torch.tensor([2])
Yanan Caof48a9712021-03-20 23:01:40 -07009174 self.checkScript(tensordot_dims_tensor, (a, b, dims))
9175
Shen Li10224432021-08-12 11:39:31 -07009176 a = torch.arange(60.).reshape(3, 4, 5)
9177 b = torch.arange(24.).reshape(4, 3, 2)
Yanan Caof48a9712021-03-20 23:01:40 -07009178 dims = torch.tensor([[1, 0], [0, 1]], dtype=torch.long)
9179 self.checkScript(tensordot_dims_tensor, (a, b, dims))
9180
9181 def test_torch_functional_tensordot_list(self):
Shen Li10224432021-08-12 11:39:31 -07009182 def tensordot_dims_list(a: torch.Tensor, b: torch.Tensor, dims: List[List[int]]):
Yanan Caof48a9712021-03-20 23:01:40 -07009183 return torch.tensordot(a, b, dims=dims)
9184
Shen Li10224432021-08-12 11:39:31 -07009185 a = torch.arange(60.).reshape(3, 4, 5)
9186 b = torch.arange(24.).reshape(4, 3, 2)
Yanan Caof48a9712021-03-20 23:01:40 -07009187 dims = [[1, 0], [0, 1]]
9188 self.checkScript(tensordot_dims_list, (a, b, dims))
9189
9190 def test_torch_functional_tensordot_tuple(self):
Shen Li10224432021-08-12 11:39:31 -07009191 def tensordot_dims_tuple(a: torch.Tensor, b: torch.Tensor, dims: Tuple[List[int], List[int]]):
Yanan Caof48a9712021-03-20 23:01:40 -07009192 return torch.tensordot(a, b, dims=dims)
9193
Shen Li10224432021-08-12 11:39:31 -07009194 a = torch.arange(60.).reshape(3, 4, 5)
9195 b = torch.arange(24.).reshape(4, 3, 2)
Yanan Caof48a9712021-03-20 23:01:40 -07009196 dims = ([1, 0], [0, 1])
9197 self.checkScript(tensordot_dims_tuple, (a, b, dims))
9198
Michael Suo63170432020-01-28 01:23:37 -08009199 def test_missing_getstate(self):
9200 class Foo(torch.nn.Module):
9201 def __init__(self):
Xuehai Pan046e88a2023-02-12 22:20:50 +00009202 super().__init__()
Michael Suo63170432020-01-28 01:23:37 -08009203 self.x = 1
9204
9205 def forward(self, x):
9206 return x * self.x
9207
9208 @torch.jit.export
9209 def __setstate__(self, state):
9210 self.x = state[0]
9211 self.training = state[1]
9212
9213 with self.assertRaisesRegex(RuntimeError, "getstate"):
9214 scripted = torch.jit.script(Foo())
9215
Elias Ellison6bc8ffe2020-04-07 09:39:56 -07009216 def test_inlining_cleanup(self):
9217 def foo(x):
9218 return F.linear(x, x)
9219
9220 @torch.jit.script
9221 def fee(x):
9222 return foo(x)
9223
9224 # inlining optimizations should have cleaned up linear if statement
9225 self.run_pass("inline", fee.graph)
9226 FileCheck().check_not("prim::If").run(fee.graph)
9227
Yanbo Liang490c1cf2022-12-19 04:14:11 +00009228 @skipIfTorchDynamo("TorchDynamo fails with unknown reason")
James Reed0a36fe52018-12-10 15:35:11 -08009229 def test_pack_unpack_nested(self):
9230 class SubSubMod(torch.jit.ScriptModule):
9231 def __init__(self):
Xuehai Pan046e88a2023-02-12 22:20:50 +00009232 super().__init__()
Shen Li10224432021-08-12 11:39:31 -07009233 self.register_buffer('buf', torch.ones(3, 4) * 3)
James Reed0a36fe52018-12-10 15:35:11 -08009234
9235 @torch.jit.script_method
9236 def _pack(self):
9237 self.buf.set_(torch.zeros(1, dtype=torch.double))
9238
9239 @torch.jit.script_method
9240 def _unpack(self):
9241 self.buf.set_(torch.ones(3, 4, dtype=torch.double) * 3)
9242
9243 @torch.jit.script_method
9244 def forward(self, x):
9245 return x + self.buf
9246
9247 class SubMod(torch.jit.ScriptModule):
9248 def __init__(self):
Xuehai Pan046e88a2023-02-12 22:20:50 +00009249 super().__init__()
Shen Li10224432021-08-12 11:39:31 -07009250 self.register_buffer('buf', torch.ones(3, 4) * 2)
James Reed0a36fe52018-12-10 15:35:11 -08009251 self.ssm = SubSubMod()
9252
9253 @torch.jit.script_method
9254 def _pack(self):
9255 self.buf.set_(torch.zeros(1, dtype=torch.double))
9256
9257 @torch.jit.script_method
9258 def _unpack(self):
9259 self.buf.set_(torch.ones(3, 4, dtype=torch.double) * 2)
9260
9261 @torch.jit.script_method
9262 def forward(self, x):
9263 return self.ssm(x + self.buf)
9264
9265 class Mod(torch.jit.ScriptModule):
9266 def __init__(self):
Xuehai Pan046e88a2023-02-12 22:20:50 +00009267 super().__init__()
James Reed0a36fe52018-12-10 15:35:11 -08009268 self.submod = SubMod()
Shen Li10224432021-08-12 11:39:31 -07009269 self.register_buffer('buf', torch.ones(3, 4) * 1)
James Reed0a36fe52018-12-10 15:35:11 -08009270
9271 @torch.jit.script_method
9272 def _pack(self):
9273 self.buf.set_(torch.zeros(1, dtype=torch.double))
9274
9275 @torch.jit.script_method
9276 def _unpack(self):
9277 self.buf.set_(torch.ones(3, 4, dtype=torch.double))
9278
9279 @torch.jit.script_method
9280 def forward(self, x):
9281 return self.submod(x + self.buf)
9282
9283 m = Mod()
Philip Meier99203582021-08-19 12:45:32 -07009284 torch.testing.assert_close(m(torch.zeros(3, 4)), torch.ones(3, 4) * 6)
James Reed0a36fe52018-12-10 15:35:11 -08009285 m.apply(lambda s: s._pack())
Philip Meier99203582021-08-19 12:45:32 -07009286 torch.testing.assert_close(m(torch.zeros(3, 4)), torch.zeros(3, 4))
James Reed0a36fe52018-12-10 15:35:11 -08009287 m.apply(lambda s: s._unpack())
Philip Meier99203582021-08-19 12:45:32 -07009288 torch.testing.assert_close(m(torch.zeros(3, 4)), torch.ones(3, 4) * 6)
James Reed0a36fe52018-12-10 15:35:11 -08009289
nikithamalgid819a212021-02-21 15:46:57 -08009290 def test_torch_any(self):
9291 def fn(x):
9292 return torch.any(x)
9293
9294 def fn1(x, dim: int):
9295 return torch.any(x, dim)
9296
Shen Li10224432021-08-12 11:39:31 -07009297 self.checkScript(fn, (torch.randn(3, 4), ))
9298 self.checkScript(fn, (torch.empty(3), ))
9299 self.checkScript(fn, (torch.empty(1), ))
nikithamalgid819a212021-02-21 15:46:57 -08009300 self.checkScript(fn, (torch.ones(3, 4),))
9301 self.checkScript(fn, (torch.zeros(5, 7, 1),))
9302 self.checkScript(fn1, (torch.empty(3, 4), -2))
9303 self.checkScript(fn1, (torch.randn(3, 8), 1))
9304 self.checkScript(fn1, (torch.zeros(3, 6, 9), -3))
9305 self.checkScript(fn1, (torch.empty(5), 0))
9306
9307 def test_any(self):
9308 def fn(x: List[int]):
9309 return any(x)
9310
9311 def fn1(x: List[float]):
9312 return any(x)
9313
9314 def fn2(x: List[bool]):
9315 return any(x)
9316
9317 def fn3(x: List[str]):
9318 return any(x)
9319
Shen Li10224432021-08-12 11:39:31 -07009320 self.checkScript(fn, ([0, 0, 0, 0], ))
9321 self.checkScript(fn, ([0, 3, 0], ))
9322 self.checkScript(fn, ([], ))
9323 self.checkScript(fn1, ([1.0, 2.0, 3.0], ))
9324 self.checkScript(fn1, ([0.0, 0.0, 0.0], ))
9325 self.checkScript(fn1, ([0, 0, 0], ))
9326 self.checkScript(fn1, ([], ))
9327 self.checkScript(fn2, ([True, False, False], ))
9328 self.checkScript(fn2, ([False, False, False], ))
9329 self.checkScript(fn2, ([True, True, True, True], ))
9330 self.checkScript(fn2, ([], ))
9331 self.checkScript(fn3, (["", "", ""], ))
9332 self.checkScript(fn3, (["", "", "", "-1"], ))
9333 self.checkScript(fn3, ([], ))
nikithamalgid819a212021-02-21 15:46:57 -08009334
Zachary DeVito5ab30ee2018-04-05 11:31:43 -07009335 def test_script_module_not_tuple(self):
9336 class M(torch.jit.ScriptModule):
Shen Li10224432021-08-12 11:39:31 -07009337 __constants__ = ['mods']
Zachary DeVito5ab30ee2018-04-05 11:31:43 -07009338
9339 def __init__(self):
Xuehai Pan046e88a2023-02-12 22:20:50 +00009340 super().__init__()
Zachary DeVito5ab30ee2018-04-05 11:31:43 -07009341 self.mods = 1
9342
9343 @torch.jit.script_method
9344 def forward(self, v):
9345 for m in self.mods:
9346 print(m)
9347 return v
Wanchao Liange0f5ab22019-06-22 00:57:24 -07009348 with self.assertRaisesRegex(RuntimeError, "'int' object is not iterable"):
Zachary DeVito5ab30ee2018-04-05 11:31:43 -07009349 M()
9350
David Riazati10c4b982019-07-03 17:22:22 -07009351 def test_attr_module_constants(self):
Elias Ellison7c2290e2019-03-22 20:13:02 -07009352 class M2(torch.jit.ScriptModule):
9353 def __init__(self, mod_list):
Xuehai Pan046e88a2023-02-12 22:20:50 +00009354 super().__init__()
Elias Ellison7c2290e2019-03-22 20:13:02 -07009355 self.mods = mod_list
9356
9357 @torch.jit.script_method
David Riazati10c4b982019-07-03 17:22:22 -07009358 def forward(self, x):
Elias Ellison7c2290e2019-03-22 20:13:02 -07009359 return self.mods.forward(x)
9360
Michael Suo711be822019-07-24 23:05:48 -07009361 with torch.jit.optimized_execution(False):
9362 m = M2(nn.Sequential(nn.ReLU()))
9363 self.assertExportImportModule(m, (torch.randn(2, 2),))
Elias Ellison7c2290e2019-03-22 20:13:02 -07009364
Chunli14f8cd72018-05-25 13:38:24 -07009365 def test_script_sequential_for(self):
9366 class Sub(torch.jit.ScriptModule):
9367 def __init__(self):
Xuehai Pan046e88a2023-02-12 22:20:50 +00009368 super().__init__()
Chunli14f8cd72018-05-25 13:38:24 -07009369 self.weight = nn.Parameter(torch.randn(2))
9370
9371 @torch.jit.script_method
9372 def forward(self, thing):
9373 return self.weight + thing
9374
9375 class M(torch.jit.ScriptModule):
Chunli14f8cd72018-05-25 13:38:24 -07009376 def __init__(self):
Xuehai Pan046e88a2023-02-12 22:20:50 +00009377 super().__init__()
Chunli14f8cd72018-05-25 13:38:24 -07009378 self.mods = nn.Sequential(Sub(), Sub(), Sub())
9379
9380 @torch.jit.script_method
9381 def forward(self, v):
9382 for m in self.mods:
9383 v = m(v)
9384 return v
9385
9386 @torch.jit.script_method
9387 def forward2(self, v):
9388 return self.mods(v)
9389
Michael Suo711be822019-07-24 23:05:48 -07009390 with torch.jit.optimized_execution(False):
Yukio Siraichi93bf0ae2021-04-11 15:43:54 -07009391 i = torch.empty(2)
Michael Suo711be822019-07-24 23:05:48 -07009392 m = M()
9393 o = m(i)
9394 v = i
Michael Suo34126272019-10-12 09:49:56 -07009395 for sub in m.mods._modules.values():
Michael Suo711be822019-07-24 23:05:48 -07009396 v = sub(v)
9397 self.assertEqual(o, v)
Chunli14f8cd72018-05-25 13:38:24 -07009398
Michael Suo711be822019-07-24 23:05:48 -07009399 o2 = m.forward2(i)
9400 self.assertEqual(o2, v)
Chunli14f8cd72018-05-25 13:38:24 -07009401
Will Constabled8555282020-06-24 09:00:42 -07009402 def test_script_sequential_sliced_iteration(self):
9403 class seq_mod(nn.Module):
9404 def __init__(self):
Xuehai Pan046e88a2023-02-12 22:20:50 +00009405 super().__init__()
Will Constabled8555282020-06-24 09:00:42 -07009406 self.layers = [nn.ReLU(), nn.ReLU(), nn.ReLU()]
9407 self.layers = nn.Sequential(*self.layers)
9408
9409 def forward(self, input):
9410 x = self.layers[0].forward(input)
9411 for layer in self.layers[1:3]:
9412 x = layer.forward(x)
9413 for layer in self.layers[2:]:
9414 x = layer.forward(x)
9415 return x
9416
9417 seq = seq_mod()
9418 self.checkModule(seq, [torch.tensor([-2, 1, -1, 2])])
9419
Wanchao Liangc7e07222019-05-24 20:25:49 -07009420 def test_script_sequential_orderdict(self):
Wanchao Liangc7e07222019-05-24 20:25:49 -07009421 class M(torch.jit.ScriptModule):
Wanchao Liangc7e07222019-05-24 20:25:49 -07009422 def __init__(self):
Xuehai Pan046e88a2023-02-12 22:20:50 +00009423 super().__init__()
Shen Li10224432021-08-12 11:39:31 -07009424 self.mods = nn.Sequential(OrderedDict([
9425 ("conv", nn.Conv2d(1, 20, 5)),
9426 ("relu", nn.ReLU())
9427 ]))
Wanchao Liangc7e07222019-05-24 20:25:49 -07009428
9429 @torch.jit.script_method
9430 def forward(self, input):
9431 return self.mods(input)
9432
9433 m = M()
Shen Li10224432021-08-12 11:39:31 -07009434 self.assertTrue('mods.conv.weight' in m.state_dict().keys())
Wanchao Liangc7e07222019-05-24 20:25:49 -07009435
Chunli14f8cd72018-05-25 13:38:24 -07009436 def test_script_sequential_multi_output_fail(self):
9437 class Sub(torch.jit.ScriptModule):
9438 def __init__(self):
Xuehai Pan046e88a2023-02-12 22:20:50 +00009439 super().__init__()
Chunli14f8cd72018-05-25 13:38:24 -07009440 self.weight = nn.Parameter(torch.randn(2))
9441
9442 @torch.jit.script_method
9443 def forward(self, thing):
9444 return self.weight + thing
9445
9446 class ReturnMulti(torch.jit.ScriptModule):
Chunli14f8cd72018-05-25 13:38:24 -07009447 @torch.jit.script_method
9448 def forward(self, x):
9449 return x, x, x
9450
9451 class HaveSequential(torch.jit.ScriptModule):
Chunli14f8cd72018-05-25 13:38:24 -07009452 def __init__(self):
Xuehai Pan046e88a2023-02-12 22:20:50 +00009453 super().__init__()
Shen Li10224432021-08-12 11:39:31 -07009454 self.someseq = nn.Sequential(
9455 Sub(),
9456 ReturnMulti(),
9457 Sub()
9458 )
Chunli14f8cd72018-05-25 13:38:24 -07009459
9460 @torch.jit.script_method
9461 def forward(self, x):
9462 return self.someseq(x)
9463
9464 with self.assertRaisesRegex(RuntimeError, "(Tensor, Tensor, Tensor)"):
Michael Suo711be822019-07-24 23:05:48 -07009465 with torch.jit.optimized_execution(False):
9466 hs = HaveSequential()
Yukio Siraichi93bf0ae2021-04-11 15:43:54 -07009467 i = torch.empty(2)
Michael Suo711be822019-07-24 23:05:48 -07009468 hs(i)
Chunli14f8cd72018-05-25 13:38:24 -07009469
Michael Suo755f91b2019-08-19 18:41:08 -07009470 @_tmp_donotuse_dont_inline_everything
Elias Ellisoncd2dca32019-02-08 11:34:40 -08009471 def test_script_sequential_in_mod_list(self):
9472 class Sub(torch.jit.ScriptModule):
9473 def __init__(self):
Xuehai Pan046e88a2023-02-12 22:20:50 +00009474 super().__init__()
Elias Ellisoncd2dca32019-02-08 11:34:40 -08009475 self.weight = nn.Parameter(torch.randn(2))
9476
9477 @torch.jit.script_method
9478 def forward(self, thing):
9479 return self.weight + thing
9480
9481 class M(torch.jit.ScriptModule):
Elias Ellisoncd2dca32019-02-08 11:34:40 -08009482 def __init__(self):
Xuehai Pan046e88a2023-02-12 22:20:50 +00009483 super().__init__()
Shen Li10224432021-08-12 11:39:31 -07009484 self.mods = nn.ModuleList([Sub(), nn.Sequential(Sub(), nn.Sequential(Sub(), Sub()), Sub())])
Elias Ellisoncd2dca32019-02-08 11:34:40 -08009485
9486 @torch.jit.script_method
9487 def forward(self, v):
9488 for mod in self.mods:
9489 v = mod(v)
9490 return v
9491
9492 m = M()
9493 graph = str(m.graph)
Michael Suo755f91b2019-08-19 18:41:08 -07009494 self.assertTrue(graph.count("prim::CallMethod") == 2)
Elias Ellisoncd2dca32019-02-08 11:34:40 -08009495 self.assertTrue("python" not in graph)
9496
Michael Suo755f91b2019-08-19 18:41:08 -07009497 @_tmp_donotuse_dont_inline_everything
Elias Ellisoncd2dca32019-02-08 11:34:40 -08009498 def test_script_nested_mod_list(self):
9499 class Sub(torch.jit.ScriptModule):
9500 def __init__(self):
Xuehai Pan046e88a2023-02-12 22:20:50 +00009501 super().__init__()
Elias Ellisoncd2dca32019-02-08 11:34:40 -08009502 self.weight = nn.Parameter(torch.randn(2))
9503
9504 @torch.jit.script_method
9505 def forward(self, thing):
9506 return self.weight + thing
9507
9508 class M(torch.jit.ScriptModule):
Elias Ellisoncd2dca32019-02-08 11:34:40 -08009509 def __init__(self):
Xuehai Pan046e88a2023-02-12 22:20:50 +00009510 super().__init__()
Shen Li10224432021-08-12 11:39:31 -07009511 self.mods = nn.ModuleList([nn.ModuleList([Sub()]), nn.Sequential(Sub()), nn.ModuleList([Sub(), Sub()])])
Elias Ellisoncd2dca32019-02-08 11:34:40 -08009512
9513 @torch.jit.script_method
9514 def forward(self, v):
9515 for mod in self.mods:
9516 for m in mod:
9517 v = m(v)
9518 return v
9519
9520 m = M()
9521 graph = str(m.graph)
Michael Suo755f91b2019-08-19 18:41:08 -07009522 self.assertTrue(graph.count("prim::CallMethod") == 4)
Elias Ellisoncd2dca32019-02-08 11:34:40 -08009523 self.assertTrue("python" not in graph)
9524
Zachary DeVito733e2962018-04-27 17:44:17 -07009525 def test_constant_as_attr(self):
9526 class M(torch.jit.ScriptModule):
Shen Li10224432021-08-12 11:39:31 -07009527 __constants__ = ['dim']
Zachary DeVito733e2962018-04-27 17:44:17 -07009528
9529 def __init__(self):
Xuehai Pan046e88a2023-02-12 22:20:50 +00009530 super().__init__()
Zachary DeVito733e2962018-04-27 17:44:17 -07009531 self.dim = 1
9532
9533 @torch.jit.script_method
9534 def forward(self, v):
9535 return torch.cat([v, v, v], dim=self.dim)
9536 v = torch.zeros(1, 1)
Michael Suo711be822019-07-24 23:05:48 -07009537 with torch.jit.optimized_execution(False):
9538 self.assertEqual(torch.cat([v, v, v], dim=1), M()(v))
Zachary DeVito733e2962018-04-27 17:44:17 -07009539
Sam Estepe3900d22021-04-19 13:14:27 -07009540 class StarTestSumStarred(torch.nn.Module):
James Reed15331552018-04-09 19:34:51 -07009541 def __init__(self):
9542 super(TestScript.StarTestSumStarred, self).__init__()
9543
9544 def forward(self, *inputs):
9545 output = inputs[0]
9546 for i in range(1, len(inputs)):
9547 output += inputs[i]
9548 return output
9549
Sam Estepe3900d22021-04-19 13:14:27 -07009550 class StarTestReturnThree(torch.nn.Module):
James Reed15331552018-04-09 19:34:51 -07009551 def __init__(self):
9552 super(TestScript.StarTestReturnThree, self).__init__()
9553
9554 def forward(self, rep):
9555 return rep, rep, rep
9556
9557 def test_script_star_expr(self):
Shen Li10224432021-08-12 11:39:31 -07009558
James Reed15331552018-04-09 19:34:51 -07009559 class M2(torch.jit.ScriptModule):
9560 def __init__(self):
Xuehai Pan046e88a2023-02-12 22:20:50 +00009561 super().__init__()
Shen Li10224432021-08-12 11:39:31 -07009562 self.m = torch.jit.trace(TestScript.StarTestSumStarred(),
9563 (torch.ones(4, 3), torch.ones(4, 3), torch.ones(4, 3)))
9564 self.g = torch.jit.trace(TestScript.StarTestReturnThree(), torch.ones(4, 3))
James Reed15331552018-04-09 19:34:51 -07009565
9566 @torch.jit.script_method
9567 def forward(self, rep):
9568 tup = self.g(rep)
9569 return self.m(*tup)
9570
9571 m = M2()
9572 self.assertEqual(m(torch.zeros(4, 3)), 3 * torch.zeros(4, 3))
9573
9574 def test_script_star_expr_string(self):
9575 class M2(torch.jit.ScriptModule):
9576 def __init__(self):
Xuehai Pan046e88a2023-02-12 22:20:50 +00009577 super().__init__()
Shen Li10224432021-08-12 11:39:31 -07009578 self.m = torch.jit.trace(TestScript.StarTestSumStarred(),
9579 (torch.ones(4, 3), torch.ones(4, 3), torch.ones(4, 3)))
9580 self.g = torch.jit.trace(TestScript.StarTestReturnThree(), torch.ones(4, 3))
James Reed15331552018-04-09 19:34:51 -07009581
Shen Li10224432021-08-12 11:39:31 -07009582 self.define('''
James Reed15331552018-04-09 19:34:51 -07009583 def forward(self, rep):
9584 tup = self.g(rep)
9585 return self.m(*tup)
Shen Li10224432021-08-12 11:39:31 -07009586 ''')
James Reed15331552018-04-09 19:34:51 -07009587
9588 m = M2()
9589 self.assertEqual(m(torch.zeros(4, 3)), 3 * torch.zeros(4, 3))
9590
Sam Estepe3900d22021-04-19 13:14:27 -07009591 class StarTestSumAndReturnThree(torch.nn.Module):
James Reed15331552018-04-09 19:34:51 -07009592 def __init__(self):
9593 super(TestScript.StarTestSumAndReturnThree, self).__init__()
9594
9595 def forward(self, *inputs):
9596 output = inputs[0]
9597 for i in range(1, len(inputs)):
9598 output += inputs[i]
9599 return output, output, output
9600
9601 def test_script_star_assign(self):
9602 class M2(torch.jit.ScriptModule):
9603 def __init__(self):
Xuehai Pan046e88a2023-02-12 22:20:50 +00009604 super().__init__()
Shen Li10224432021-08-12 11:39:31 -07009605 self.g = torch.jit.trace(TestScript.StarTestSumAndReturnThree(), torch.ones(4, 3))
9606 self.define('''
James Reed15331552018-04-09 19:34:51 -07009607 def forward(self, rep):
9608 head, *tail = self.g(rep)
9609 return head
Shen Li10224432021-08-12 11:39:31 -07009610 ''')
James Reed15331552018-04-09 19:34:51 -07009611
9612 m = M2()
9613 self.assertEqual(m(torch.zeros(4, 3)), 3 * torch.zeros(4, 3))
9614
9615 def test_script_module_star_assign2(self):
9616 class M2(torch.jit.ScriptModule):
9617 def __init__(self):
Xuehai Pan046e88a2023-02-12 22:20:50 +00009618 super().__init__()
James Reed15331552018-04-09 19:34:51 -07009619 self.g = torch.jit.trace(
Zachary DeVito93bd2912018-08-30 13:51:45 -07009620 TestScript.StarTestSumAndReturnThree(),
Michael Suo3fca4bd2018-11-27 12:38:28 -08009621 (torch.ones(4, 3), torch.ones(4, 3), torch.ones(4, 3)),
Shen Li10224432021-08-12 11:39:31 -07009622 _force_outplace=True)
9623 self.define('''
James Reed15331552018-04-09 19:34:51 -07009624 def forward(self, rep):
9625 *head, tail = self.g(rep, rep, rep)
9626 return tail
Shen Li10224432021-08-12 11:39:31 -07009627 ''')
James Reed15331552018-04-09 19:34:51 -07009628
9629 m = M2()
9630 self.assertEqual(m(torch.ones(4, 3)), 3 * torch.ones(4, 3))
9631
Michael Suo3fca4bd2018-11-27 12:38:28 -08009632 def test_script_module_star_assign2_inplace(self):
9633 class M2(torch.jit.ScriptModule):
9634 def __init__(self):
Xuehai Pan046e88a2023-02-12 22:20:50 +00009635 super().__init__()
Michael Suo3fca4bd2018-11-27 12:38:28 -08009636 self.g = torch.jit.trace(
9637 TestScript.StarTestSumAndReturnThree(),
9638 (torch.ones(4, 3), torch.ones(4, 3), torch.ones(4, 3)),
Shen Li10224432021-08-12 11:39:31 -07009639 _force_outplace=False)
9640 self.define('''
Michael Suo3fca4bd2018-11-27 12:38:28 -08009641 def forward(self, rep):
9642 *head, tail = self.g(rep, rep, rep)
9643 return tail
Shen Li10224432021-08-12 11:39:31 -07009644 ''')
Michael Suo3fca4bd2018-11-27 12:38:28 -08009645
9646 m = M2()
9647 # since forward() makes three aliases to the input `rep` before passing
9648 # it to StarTestSumAndReturnThree(), in-place behavior will be different
9649 # than the above out of place.
9650 self.assertEqual(m(torch.ones(4, 3)), 4 * torch.ones(4, 3))
9651
James Reed15331552018-04-09 19:34:51 -07009652 def test_script_module_star_assign_fail_pythonop(self):
9653
Zachary DeVitoce69d312018-05-14 14:46:36 -07009654 with self.assertRaisesRegex(RuntimeError, "cannot be used as a tuple"):
James Reed15331552018-04-09 19:34:51 -07009655 class M2(torch.jit.ScriptModule):
9656 def __init__(self):
Xuehai Pan046e88a2023-02-12 22:20:50 +00009657 super().__init__()
James Reed15331552018-04-09 19:34:51 -07009658
davidriazati7a370db2019-07-16 12:50:02 -07009659 @torch.jit.ignore
James Reed15331552018-04-09 19:34:51 -07009660 def myfunc():
9661 return torch.zeros(1, 2, 3), torch.zeros(1, 2, 3)
9662
Shen Li10224432021-08-12 11:39:31 -07009663 self.define('''
James Reed15331552018-04-09 19:34:51 -07009664 def forward(self, rep):
9665 a, *b = myfunc()
9666 return a
Shen Li10224432021-08-12 11:39:31 -07009667 ''')
James Reed15331552018-04-09 19:34:51 -07009668
9669 m = M2()
9670 m(torch.zeros(4, 3))
9671
9672 def test_script_module_star_assign_fail_builtin(self):
Zachary DeVitoce69d312018-05-14 14:46:36 -07009673 with self.assertRaisesRegex(RuntimeError, "cannot be used as a tuple"):
James Reed15331552018-04-09 19:34:51 -07009674 class M2(torch.jit.ScriptModule):
9675 def __init__(self):
Xuehai Pan046e88a2023-02-12 22:20:50 +00009676 super().__init__()
James Reed15331552018-04-09 19:34:51 -07009677
Shen Li10224432021-08-12 11:39:31 -07009678 self.define('''
James Reed15331552018-04-09 19:34:51 -07009679 def forward(self, rep):
9680 a, *b = torch.neg(rep)
9681 return a
Shen Li10224432021-08-12 11:39:31 -07009682 ''')
James Reed15331552018-04-09 19:34:51 -07009683
9684 m = M2()
9685 m(torch.zeros(4, 3))
9686
Wanchao Liangf4eb93f2019-07-29 15:31:03 -07009687 def test_script_pack_padded_sequence(self):
9688 from torch.nn.utils.rnn import pack_padded_sequence, pad_packed_sequence
9689
9690 def pack_padded_pad_packed_script(x, seq_lens):
9691 x = pack_padded_sequence(x, seq_lens)
9692 x, lengths = pad_packed_sequence(x)
9693 return x, lengths
9694
9695 T, B, C = 3, 5, 7
9696 x = torch.ones((T, B, C))
9697 seq_lens = torch.tensor([3, 3, 2, 2, 1])
9698 # set padding value so we can test equivalence
9699 for b in range(B):
9700 if seq_lens[b] < T:
Shen Li10224432021-08-12 11:39:31 -07009701 x[seq_lens[b]:, b, :] = 0
Wanchao Liangf4eb93f2019-07-29 15:31:03 -07009702
9703 eager_seq, eager_lengths = pack_padded_pad_packed_script(x, seq_lens)
Michael Suoca1b8eb2020-07-13 16:57:41 -07009704 with torch._jit_internal._disable_emit_hooks():
davidriazati11843042020-03-02 13:48:06 -08009705 scripted_pack_padded_seq = torch.jit.script(pack_padded_pad_packed_script)
Wanchao Liangf4eb93f2019-07-29 15:31:03 -07009706 script_seq, script_lengths = scripted_pack_padded_seq(x, seq_lens)
9707 self.assertEqual(eager_seq, script_seq)
9708 self.assertEqual(eager_lengths, script_lengths)
9709
Wanchao Liang4b028a82020-06-19 19:01:26 -07009710 class ExperimentalLSTM(torch.nn.Module):
9711 def __init__(self, input_dim, hidden_dim):
9712 super().__init__()
9713
9714 def forward(self, input):
9715 # type: (Tensor)
9716 packed = pack_padded_sequence(
9717 input=input, lengths=torch.tensor([1, 2]), enforce_sorted=False
9718 )
Shen Li10224432021-08-12 11:39:31 -07009719 output, lengths = pad_packed_sequence(
9720 sequence=packed, total_length=2
9721 )
Wanchao Liang4b028a82020-06-19 19:01:26 -07009722 # lengths is flipped, so is output
9723 return output[0]
9724
9725 lstm = ExperimentalLSTM(input_dim=2, hidden_dim=2)
9726
Michael Suoca1b8eb2020-07-13 16:57:41 -07009727 with torch._jit_internal._disable_emit_hooks():
Wanchao Liang4b028a82020-06-19 19:01:26 -07009728 self.checkModule(lstm, [torch.ones(2, 2)])
9729
9730 def test_script_pad_sequence_pack_sequence(self):
9731 from torch.nn.utils.rnn import pad_sequence, pack_sequence, pad_packed_sequence
9732
9733 def pad_sequence_func(tensor_list, batch_first=False, padding_value=0.0):
9734 # type: (List[Tensor], bool, float) -> Tensor
9735 return pad_sequence(tensor_list, batch_first, padding_value)
9736
9737 def pack_sequence_func(tensor_list, enforce_sorted=True):
9738 # type: (List[Tensor], bool) -> Tensor
9739 return pad_packed_sequence(pack_sequence(tensor_list, enforce_sorted))[0]
9740
9741 ones3 = torch.ones(3, 5)
9742 ones4 = torch.ones(4, 5)
9743 ones5 = torch.ones(5, 5)
9744 tensor1 = torch.tensor([1, 2, 3])
9745 tensor2 = torch.tensor([4, 5])
9746 tensor3 = torch.tensor([6])
Michael Suoca1b8eb2020-07-13 16:57:41 -07009747 with torch._jit_internal._disable_emit_hooks():
Shen Li10224432021-08-12 11:39:31 -07009748 self.checkScript(pad_sequence_func,
9749 ([ones3, ones4, ones5],))
9750 self.checkScript(pad_sequence_func,
9751 ([ones3, ones4, ones5], True))
9752 self.checkScript(pad_sequence_func,
9753 ([ones3, ones4, ones5], True, 2.5))
9754 self.checkScript(pack_sequence_func,
9755 ([tensor1, tensor2, tensor3],))
9756 self.checkScript(pack_sequence_func,
9757 ([tensor1, tensor2, tensor3], False))
Wanchao Liang4b028a82020-06-19 19:01:26 -07009758
Wanchao Liangc384fbf2019-07-29 15:01:56 -07009759 def test_script_get_tracing_state(self):
9760 def test_if_tracing(x):
9761 if torch._C._get_tracing_state():
9762 return x + 1
9763 else:
9764 return x - 1
9765
9766 inp = torch.randn(3, 3)
Wanchao Liangc384fbf2019-07-29 15:01:56 -07009767 self.checkScript(test_if_tracing, (inp,))
9768
Yuxin Wua62b0de2021-02-20 02:51:18 -08009769 def test_script_is_tracing(self):
9770 def test_is_tracing(x):
9771 if torch.jit.is_tracing():
9772 return x + 1
9773 else:
9774 return x - 1
9775
9776 inp = torch.randn(3, 3)
9777 self.checkScript(test_is_tracing, (inp,))
9778
Elias Ellison18974402019-09-10 17:26:38 -07009779 def test_is_scripting(self):
9780 def foo():
9781 return torch.jit.is_scripting()
9782
9783 self.assertFalse(foo())
9784 scripted = torch.jit.script(foo)
Elias Ellison18974402019-09-10 17:26:38 -07009785 self.assertTrue(scripted())
9786
Tugsbayasgalan (Tugsuu) Manlaibaatar70b18b92022-01-06 19:13:00 -08009787 def test_comment_ignore_indent(self):
9788 class Model(torch.nn.Module):
9789 def __init__(self):
9790 # useless comment that is not indented correctly # noqa: E115
9791 super().__init__()
9792
9793 def forward(self):
9794 return 5
9795
9796 # should compile without an error
9797 self.checkModule(Model(), ())
9798
Zachary DeVito8995ddd2018-04-12 10:32:49 -07009799 def test_script_outputs(self):
Zachary DeVitoce69d312018-05-14 14:46:36 -07009800 with self.assertRaisesRegex(RuntimeError, "cannot be used as a tuple"):
Zachary DeVito8995ddd2018-04-12 10:32:49 -07009801 @torch.jit.script
9802 def foo(a):
9803 c, d = a + a
9804 return c + d
9805
9806 @torch.jit.script
9807 def return3():
Richard Zou67f6f932018-08-27 08:53:56 -07009808 return 1, 2, 3
Zachary DeVito8995ddd2018-04-12 10:32:49 -07009809
9810 with self.assertRaisesRegex(RuntimeError, "too many values to unpack"):
9811 @torch.jit.script
9812 def bind2():
9813 a, b = return3()
9814 print(a)
9815 print(b)
9816
Richard Zouefab8e82018-10-25 19:55:56 -07009817 @unittest.skipIf(not RUN_CUDA, "requires CUDA")
9818 def test_script_get_device_cuda(self):
9819 @torch.jit.script
9820 def foo(a):
9821 return a.get_device()
9822
Shen Li10224432021-08-12 11:39:31 -07009823 v = torch.randn(1, device='cuda')
Richard Zouefab8e82018-10-25 19:55:56 -07009824 self.assertEqual(foo(v), 0)
9825
Zachary DeVito8995ddd2018-04-12 10:32:49 -07009826 def test_script_chunk(self):
9827 @torch.jit.script
9828 def foo(a):
9829 b, c = torch.chunk(a, dim=0, chunks=2)
9830 return b
9831 v = torch.rand(10, 3)
9832 self.assertEqual(torch.chunk(v, dim=0, chunks=2)[0], foo(v))
9833
Jerry Zhangde9a5442019-11-07 15:23:37 -08009834 def test_script_copy(self):
9835 class M(torch.nn.Module):
Shen Li10224432021-08-12 11:39:31 -07009836 __annotations__ = {
9837 "val": Optional[torch.Tensor]
9838 }
Jerry Zhangde9a5442019-11-07 15:23:37 -08009839
9840 def __init__(self):
Xuehai Pan046e88a2023-02-12 22:20:50 +00009841 super().__init__()
Jerry Zhangde9a5442019-11-07 15:23:37 -08009842 self.val = None
9843
9844 def some_method(self):
9845 return 3
9846
9847 def forward(self, x):
9848 # type: (Tensor) -> Tensor
9849 self.val = x + self.some_method()
9850 return x
9851
9852 m = torch.jit.script(M())
9853 # test copy
Jerry Zhangf652abc2020-06-23 09:15:35 -07009854 copy.copy(m)
9855 copy.deepcopy(m)
Jerry Zhangde9a5442019-11-07 15:23:37 -08009856
Will Feng (FAIAR)38d141e2020-05-14 12:10:59 -07009857 def test_script_forward_method_replacement(self):
Shihao Xu00651b82020-06-15 19:05:42 -07009858 # We want to support the use case of attaching a different `forward` method
Will Feng (FAIAR)38d141e2020-05-14 12:10:59 -07009859 class LowLevelModule(torch.nn.Module):
Shihao Xu00651b82020-06-15 19:05:42 -07009860 def forward(self, input: torch.Tensor):
9861 # Generic forward dispatch
9862 return self.forward_pytorch(input) * 2
Will Feng (FAIAR)38d141e2020-05-14 12:10:59 -07009863
9864 class TestModule(LowLevelModule):
9865 def __init__(self):
Xuehai Pan046e88a2023-02-12 22:20:50 +00009866 super().__init__()
Shihao Xu00651b82020-06-15 19:05:42 -07009867 # Replace the forward method
9868 self.forward = types.MethodType(LowLevelModule.forward, self)
Will Feng (FAIAR)38d141e2020-05-14 12:10:59 -07009869
9870 def forward_pytorch(self, input: torch.Tensor):
9871 return torch.tensor(123)
9872
9873 def forward(self, input: torch.Tensor):
Shihao Xu00651b82020-06-15 19:05:42 -07009874 # Should not use this forward method
9875 raise AssertionError("This method should not be used")
Will Feng (FAIAR)38d141e2020-05-14 12:10:59 -07009876 return self.forward_pytorch(input)
9877
9878 m = TestModule()
9879 self.assertEqual(m(torch.tensor(1)), torch.tensor(246))
9880
9881 m_scripted = torch.jit.script(m)
Shihao Xu00651b82020-06-15 19:05:42 -07009882 self.assertEqual(m_scripted(torch.tensor(1)), torch.tensor(246))
Will Feng (FAIAR)38d141e2020-05-14 12:10:59 -07009883
Zachary DeVito289a8c92018-09-11 15:01:48 -07009884 def test_python_call_non_tensor(self):
9885 def foo(a, b, c):
9886 # type: (Tensor, int, Tuple[Tensor, int]) -> Tuple[int, Tensor]
9887 d, e = c
9888 return b + e, a + d
9889
9890 @torch.jit.script
9891 def bar():
9892 x = torch.ones(3, 4)
9893 a, b = foo(x, 3, (x, 3))
9894 return a, b
9895
9896 self.assertEqual((6, torch.ones(3, 4) + 1), bar())
9897
9898 def test_python_call_non_tensor_wrong(self):
Shen Li10224432021-08-12 11:39:31 -07009899 with self.assertRaisesRegex(RuntimeError, r"but instead got value of type tuple"):
davidriazati7a370db2019-07-16 12:50:02 -07009900 @torch.jit.ignore
Zachary DeVito289a8c92018-09-11 15:01:48 -07009901 def foo():
9902 # type: () -> Tensor
Elias Ellison561037a2019-03-07 09:12:35 -08009903 return ((3, 4),) # noqa: T484
Zachary DeVito289a8c92018-09-11 15:01:48 -07009904
9905 @torch.jit.script
9906 def bar():
9907 return foo()
9908
9909 bar()
9910
Elias Ellison5c0eece2018-08-22 19:41:56 -07009911 def test_if_different_type(self):
Ansley Ussery6831d8e2021-09-03 06:10:37 -07009912 with self.assertRaisesRegex(RuntimeError, "c0 is set to type "
9913 "int in the true branch and type "
9914 "float in the false branch"):
Elias Ellison5c0eece2018-08-22 19:41:56 -07009915 @torch.jit.script
9916 def diff_type_used():
Elias Ellisond1b8da72020-11-20 11:14:59 -08009917 if 1 == 2:
Elias Ellison5c0eece2018-08-22 19:41:56 -07009918 c0 = 1
9919 else:
9920 c0 = 1.0
9921 return c0
9922
Ansley Ussery6831d8e2021-09-03 06:10:37 -07009923 with self.assertRaisesRegex(RuntimeError, "Variable 'c0' previously had type float"):
Elias Ellison5c0eece2018-08-22 19:41:56 -07009924 @torch.jit.script
9925 def diff_existing_type(x):
9926 c0 = 1.0
Elias Ellisond1b8da72020-11-20 11:14:59 -08009927 if 1 == 2:
Elias Ellison5c0eece2018-08-22 19:41:56 -07009928 c0 = 1
9929 print(x)
9930 return x
9931
9932 @torch.jit.script
9933 def diff_type_unused():
Elias Ellisond1b8da72020-11-20 11:14:59 -08009934 if 1 == 1:
Elias Ellison5c0eece2018-08-22 19:41:56 -07009935 c0 = 1
9936 print(c0)
9937 else:
9938 c0 = 1.0
9939 print(c0)
9940 return 1
9941
Elias Ellison0aeb9712019-05-31 14:26:46 -07009942 def test_if_not_defined_error(self):
Shen Li10224432021-08-12 11:39:31 -07009943 with self.assertRaisesRegex(RuntimeError, "c0 is not defined in the false branch"):
Elias Ellison0aeb9712019-05-31 14:26:46 -07009944 @torch.jit.script
9945 def test():
Elias Ellisond1b8da72020-11-20 11:14:59 -08009946 if 1 == 1:
Elias Ellison0aeb9712019-05-31 14:26:46 -07009947 c0 = 1
9948 return c0
Shen Li10224432021-08-12 11:39:31 -07009949 with self.assertRaisesRegex(RuntimeError, "c0 is not defined in the true branch"):
Elias Ellison0aeb9712019-05-31 14:26:46 -07009950 @torch.jit.script
9951 def test2():
Elias Ellisond1b8da72020-11-20 11:14:59 -08009952 if 1 == 1:
Elias Ellison0aeb9712019-05-31 14:26:46 -07009953 pass
9954 else:
9955 c0 = 1
9956 return c0
9957
Elias Ellison411cf432019-02-25 16:11:47 -08009958 def test_if_list_cat(self):
9959 # testing that different length lists don't throw error on cat in shape prop
Elias Ellison5c0eece2018-08-22 19:41:56 -07009960 @torch.jit.script
9961 def test_list(x):
Elias Ellison411cf432019-02-25 16:11:47 -08009962 if bool(x.sum() < 1):
Elias Ellison5c0eece2018-08-22 19:41:56 -07009963 c = [x, x]
9964 else:
9965 c = [x, x, x]
9966 return torch.cat(c)
9967
9968 b = torch.zeros(2, 4)
Zachary DeVitodcb5fd32019-04-13 08:28:11 -07009969 _propagate_shapes(test_list.graph, (b,), False)
Elias Ellison5c0eece2018-08-22 19:41:56 -07009970
9971 def test_if_supertype(self):
9972 @torch.jit.script
9973 def tensor_unifying(x, y, z):
Elias Ellison5c0eece2018-08-22 19:41:56 -07009974 # testing dynamic is appropriately set for y and z
Elias Ellison3eefc062019-12-09 14:17:58 -08009975 if bool(x):
9976 x, y, z = x + 1, y, z
Elias Ellison5c0eece2018-08-22 19:41:56 -07009977 else:
Elias Ellison3eefc062019-12-09 14:17:58 -08009978 x, y, z = x + 1, x, y
Elias Ellison5c0eece2018-08-22 19:41:56 -07009979
9980 return x, y, z
9981
9982 a = torch.zeros(2, 2, dtype=torch.float)
9983 b = torch.zeros(2, 4, dtype=torch.long)
9984 c = torch.zeros(2, 4, dtype=torch.float)
9985
Zachary DeVitodcb5fd32019-04-13 08:28:11 -07009986 graph = _propagate_shapes(tensor_unifying.graph, (a, b, c), False)
Zachary DeVito2d079932019-04-02 17:33:06 -07009987 if_outputs = list(graph.findNode("prim::If").outputs())
Shen Li10224432021-08-12 11:39:31 -07009988 self.assertTrue(if_outputs[0].type().str() == "Float(*, *, requires_grad=0, device=cpu)")
9989 self.assertTrue(if_outputs[1].type().str() == "Tensor(*, *, requires_grad=0, device=cpu)")
9990 self.assertTrue(if_outputs[2].type().str() == "Tensor(*, *, requires_grad=0, device=cpu)")
Elias Ellison5c0eece2018-08-22 19:41:56 -07009991
Elias Ellisonb0b75412019-02-25 13:26:49 -08009992 def test_list_unify(self):
9993 # allowing a unififed int?[] would cause a runtime error b/c
9994 # the index operation expects int?[] to be a generic list,
9995 # but in the true branch the IValue will be a int list
Shen Li10224432021-08-12 11:39:31 -07009996 with self.assertRaisesRegex(RuntimeError, "int[] in the true branch and type None[]"):
Elias Ellisonb0b75412019-02-25 13:26:49 -08009997 @torch.jit.script
9998 def list_optional_fails(x):
9999 # type: (bool) -> Optional[int]
10000 if x:
10001 y = [1]
10002 else:
Elias Ellison561037a2019-03-07 09:12:35 -080010003 y = [None] # noqa: T484
Elias Ellisonb0b75412019-02-25 13:26:49 -080010004 return y[0]
10005
10006 @torch.jit.script
10007 def list_tensors(x):
10008 # type: (bool) -> Tuple[Tensor, List[Tensor]]
10009 if x:
10010 a = torch.zeros([1, 1])
10011 y = [a]
10012 else:
10013 a = torch.zeros([1, 2])
10014 y = [a]
10015 return a, y
10016
Shen Li10224432021-08-12 11:39:31 -070010017 self.run_pass('constant_propagation', list_tensors.graph)
Zachary DeVito6cb1b992019-04-25 15:43:53 -070010018 m = self.createFunctionFromGraph(list_tensors.graph)
Elias Ellisonb0b75412019-02-25 13:26:49 -080010019 # testing that tensor type of lists is unified
10020 self.getExportImportCopy(m)
10021
Yanbo Liang490c1cf2022-12-19 04:14:11 +000010022 @skipIfTorchDynamo("Not a TorchDynamo suitable test")
Michael Suo755f91b2019-08-19 18:41:08 -070010023 @_inline_everything
Elias Ellisonf2f3e8a2019-07-19 14:08:11 -070010024 def test_import_constants_not_specialized(self):
10025 class Mod(torch.nn.Module):
10026 def forward(self, x):
10027 return torch.cat(2 * [x], dim=0)
10028
10029 class ScriptMod(torch.jit.ScriptModule):
10030 def __init__(self, mod):
Xuehai Pan046e88a2023-02-12 22:20:50 +000010031 super().__init__()
Elias Ellisonf2f3e8a2019-07-19 14:08:11 -070010032 x = torch.zeros(1, 3)
Shen Li10224432021-08-12 11:39:31 -070010033 mod_fn = lambda : mod(x) # noqa: E731
Elias Ellisonf2f3e8a2019-07-19 14:08:11 -070010034 self.mod = torch.jit.trace(mod_fn, tuple())
10035
10036 @torch.jit.script_method
10037 def forward(self):
10038 return self.mod()
10039
10040 cm = ScriptMod(Mod())
10041 # specialized tensor in graph
Shen Li10224432021-08-12 11:39:31 -070010042 FileCheck().check("Double(1, 3, strides=[3, 1], requires_grad=0, device=cpu)").run(cm.forward.graph)
Elias Ellisonf2f3e8a2019-07-19 14:08:11 -070010043 buffer = io.BytesIO()
10044 torch.jit.save(cm, buffer)
10045 buffer.seek(0)
10046 # when tensor is loaded as constant it isnt specialized
10047 cm_load = torch.jit.load(buffer)
10048 FileCheck().check_not("Double(1, 3)").run(cm_load.forward.graph)
10049
Animesh Jain1d90d6e2022-07-07 18:57:31 +000010050 @skipIfTorchDynamo("TorchDynamo fails with unknown reason")
Elias Ellison70db5362018-11-01 10:30:33 -070010051 def test_type_annotations_repeated_list(self):
10052 @torch.jit.script
10053 def float_fn(x, y):
David Riazati556ff8e2018-11-08 11:24:36 -080010054 # type: (float, BroadcastingList3[float]) -> List[float]
Elias Ellison70db5362018-11-01 10:30:33 -070010055 return y
10056 self.assertEqual(float_fn(2.0, 1.0), float_fn(2.0, [1.0, 1.0, 1.0]))
10057 self.assertEqual(float_fn(2.0, 1.0), float_fn(2.0, (1.0, 1.0, 1.0)))
10058
10059 @torch.jit.script
10060 def float_fn_call():
10061 print(float_fn(1.0, 1.0))
10062 print(float_fn(1.0, (1.0, 1.0, 1.0)))
10063
10064 @torch.jit.script
10065 def int_fn(x):
David Riazati556ff8e2018-11-08 11:24:36 -080010066 # type: (BroadcastingList3[int]) -> List[int]
Elias Ellison70db5362018-11-01 10:30:33 -070010067 return x
10068 self.assertEqual(int_fn(1), int_fn([1, 1, 1]))
10069 self.assertEqual(int_fn(1), int_fn((1, 1, 1)))
10070
10071 @torch.jit.script
10072 def int_fn_call():
10073 print(int_fn(1))
10074 print(int_fn((1, 1, 1)))
10075
Zachary DeVitof1185682018-12-14 19:29:19 -080010076 with self.assertRaisesRegex(RuntimeError, "must be a positive integer:"):
Elias Ellison561037a2019-03-07 09:12:35 -080010077 @torch.jit.script # noqa: T484
Elias Ellison70db5362018-11-01 10:30:33 -070010078 def fn(x):
Elias Ellison561037a2019-03-07 09:12:35 -080010079 # type: (BroadcastingListx[int]) -> List[int] # noqa: T484
Elias Ellison70db5362018-11-01 10:30:33 -070010080 return x
10081
Elias Ellison561037a2019-03-07 09:12:35 -080010082 # using CU so that flake8 error on int[2] is not raised (noqa not working)
Elias Ellison70db5362018-11-01 10:30:33 -070010083 with self.assertRaisesRegex(RuntimeError, "Unknown type constructor"):
Shen Li10224432021-08-12 11:39:31 -070010084 cu = torch.jit.CompilationUnit('''
Elias Ellison561037a2019-03-07 09:12:35 -080010085 def nested(x, y):
10086 # type: (int, Tuple[int, int[2]]) -> List[int]
10087 return x # noqa: T484
Shen Li10224432021-08-12 11:39:31 -070010088 ''')
Elias Ellison70db5362018-11-01 10:30:33 -070010089
Elias Ellison49b69b22020-06-04 12:13:33 -070010090 @torch.jit.script
10091 def f(x: BroadcastingList2[int]):
10092 return x
10093
10094 out = f(1)
10095 self.assertTrue(isinstance(out[0], int))
10096 self.assertEqual(out, [1, 1])
10097
Elias Ellison421f3f32018-11-01 15:40:40 -070010098 def test_ntuple_builtins(self):
10099 from torch.nn.modules.utils import _single, _pair, _triple, _quadruple
10100
10101 def test_ints():
10102 return _single(1), _pair(2), _triple(3), _quadruple(4)
10103
10104 def test_floats():
10105 return _single(1), _pair(2.1), _triple(3.1), _quadruple(4.1)
10106
10107 self.checkScript(test_ints, ())
10108 self.checkScript(test_floats, ())
10109
Elias Ellison404ad932018-11-30 16:53:55 -080010110 def test_embedding_renorm_grad_error(self):
10111 # Testing that the builtin call to embedding_renorm_ correctly throws
10112 # Error when .backward() is called on its input
10113
10114 def embedding_norm(input, embedding_matrix, max_norm):
10115 F.embedding(input, embedding_matrix, max_norm=0.01)
10116
10117 @torch.jit.script
10118 def embedding_norm_script(input, embedding_matrix, max_norm):
James Reedde6bb3f2019-01-26 17:39:34 -080010119 # type: (Tensor, Tensor, float) -> None
Elias Ellison404ad932018-11-30 16:53:55 -080010120 F.embedding(input, embedding_matrix, max_norm=0.01)
10121
James Reedde6bb3f2019-01-26 17:39:34 -080010122 for _ in [embedding_norm, embedding_norm_script]:
Elias Ellison404ad932018-11-30 16:53:55 -080010123 input = torch.tensor([[1, 2, 4, 5], [4, 3, 2, 9]])
10124 embedding_matrix = torch.randn(10, 3)
10125
10126 var1 = torch.randn(10, 3, requires_grad=True)
10127 var2 = var1.detach().requires_grad_()
10128 output1 = var1 * embedding_matrix
10129 output2 = var2 * embedding_matrix
10130
10131 output1.sum().backward()
10132
10133 ignore = F.embedding(input, embedding_matrix, max_norm=0.01)
10134 with self.assertRaisesRegex(RuntimeError, "modified"):
10135 output2.sum().backward()
10136
Adam Paszkeda654332018-05-04 10:54:19 +020010137 def test_type_annotations(self):
10138 def fn(x, y):
10139 # type: (Tensor, Tensor) -> Tuple[Tensor, Tensor, Tensor]
10140 return x, x * 2, x * 3
10141
10142 with self.assertRaisesRegex(RuntimeError, r"need 4 values .* found only 3"):
10143 @torch.jit.script
10144 def script_fn(x):
10145 x, y, z, w = fn(x, x)
10146
Shen Li10224432021-08-12 11:39:31 -070010147 with self.assertRaisesRegex(RuntimeError, r"too many values .* need 2 but found 3"):
Adam Paszkeda654332018-05-04 10:54:19 +020010148 @torch.jit.script
10149 def script_fn2(x):
10150 x, y = fn(x, x)
10151
10152 def fn_unpack(x):
10153 y, z, w = fn(x, x)
10154 return y
10155
10156 def fn_index(x):
10157 q = fn(x, x)
10158 return x
10159
Elias Ellisonf3e1fe52018-10-19 11:11:32 -070010160 def fn_string(str, strpair):
10161 # type: (str, Tuple[str, str]) -> Tuple[str, int, str, str]
10162 str1, str2 = strpair
10163 return str, 2, str1, str2
10164
Adam Paszkeda654332018-05-04 10:54:19 +020010165 x = torch.ones(2, 2)
10166 self.checkScript(fn_unpack, (x,), optimize=True)
10167 self.checkScript(fn_index, (x,), optimize=True)
Elias Ellisonf3e1fe52018-10-19 11:11:32 -070010168 self.checkScript(fn_string, ("1", ("3", "4")), optimize=True)
Adam Paszkeda654332018-05-04 10:54:19 +020010169
10170 def test_type_annotations_varargs(self):
davidriazati7a370db2019-07-16 12:50:02 -070010171 @torch.jit.ignore
Adam Paszkeda654332018-05-04 10:54:19 +020010172 def fn_varargs(x, *args):
10173 return args[0] if args else x
10174
10175 def fn1(x, y, z):
10176 return fn_varargs(x)
10177
10178 def fn2(x, y, z):
10179 return fn_varargs(x, y)
10180
10181 def fn3(x, y, z):
10182 return fn_varargs(x, y, z)
10183
10184 x, y, z = [torch.randn(2, 2) for _ in range(3)]
10185 self.checkScript(fn1, (x, y, z), optimize=True)
10186 self.checkScript(fn2, (x, y, z), optimize=True)
10187 self.checkScript(fn3, (x, y, z), optimize=True)
10188
Adam Paszkeda654332018-05-04 10:54:19 +020010189 def test_type_annotation_py3(self):
Shen Li10224432021-08-12 11:39:31 -070010190 code = dedent("""
Adam Paszkeda654332018-05-04 10:54:19 +020010191 import torch
10192 from torch import Tensor
10193 from typing import Tuple
10194
10195 def fn(x : torch.Tensor, y : Tensor, z) -> Tuple[Tensor, Tensor, Tensor]:
10196 return (x, y + z, z)
Shen Li10224432021-08-12 11:39:31 -070010197 """)
Adam Paszkeda654332018-05-04 10:54:19 +020010198
10199 with tempfile.TemporaryDirectory() as tmp_dir:
Shen Li10224432021-08-12 11:39:31 -070010200 script_path = os.path.join(tmp_dir, 'script.py')
10201 with open(script_path, 'w') as f:
Adam Paszkeda654332018-05-04 10:54:19 +020010202 f.write(code)
Shen Li10224432021-08-12 11:39:31 -070010203 fn = get_fn('test_type_annotation_py3', script_path)
davidriazati7a370db2019-07-16 12:50:02 -070010204 fn = torch.jit.ignore(fn)
Adam Paszkeda654332018-05-04 10:54:19 +020010205
Shen Li10224432021-08-12 11:39:31 -070010206 with self.assertRaisesRegex(RuntimeError, r"Expected a value of type 'Tensor' for argument"
10207 r" 'x' but instead found type 'Tuple\[Tensor,"):
Adam Paszkeda654332018-05-04 10:54:19 +020010208 @torch.jit.script
10209 def bad_fn(x):
10210 x, y = fn((x, x), x, x)
10211 return y
10212
Shen Li10224432021-08-12 11:39:31 -070010213 with self.assertRaisesRegex(RuntimeError, r"too many values .* need 2 but found 3"):
Adam Paszkeda654332018-05-04 10:54:19 +020010214 @torch.jit.script
10215 def bad_fn2(x):
10216 x, y = fn(x, x, x)
10217 return y
10218
10219 with self.assertRaisesRegex(RuntimeError, r"need 4 values .* found only 3"):
10220 @torch.jit.script
10221 def bad_fn3(x):
10222 x, y, z, w = fn(x, x, x)
10223 return y
10224
10225 def good_fn(x):
10226 y, z, w = fn(x, x, x)
10227 return y, z, w
10228
10229 self.checkScript(good_fn, (torch.ones(2, 2),), optimize=True)
10230
10231 def test_type_annotation_module(self):
10232 class BaseModule(torch.jit.ScriptModule):
davidriazati7a370db2019-07-16 12:50:02 -070010233 @torch.jit.ignore
Adam Paszkeda654332018-05-04 10:54:19 +020010234 def foo(self, x):
10235 # type: (Tensor) -> Tensor
10236 return x + 1
10237
davidriazati7a370db2019-07-16 12:50:02 -070010238 @torch.jit.ignore
Adam Paszkeda654332018-05-04 10:54:19 +020010239 def bar(self, x, y):
10240 # type: (Tensor, Tensor) -> Tuple[Tensor, Tensor]
10241 return x + y, y
10242
davidriazati7a370db2019-07-16 12:50:02 -070010243 @torch.jit.ignore
Adam Paszkeda654332018-05-04 10:54:19 +020010244 def baz(self, x, y):
10245 return x
10246
10247 class ModuleTooMany(BaseModule):
10248 @torch.jit.script_method
10249 def method(self, x):
10250 return self.foo(x, x)
10251
10252 class ModuleTooFew(BaseModule):
10253 @torch.jit.script_method
10254 def method(self, x):
10255 return self.bar(x)
10256
10257 class ModuleTooManyAssign(BaseModule):
10258 @torch.jit.script_method
10259 def method(self, x):
10260 y, z, w = self.bar(x, x)
10261 return x
10262
10263 class ModuleDefault(BaseModule):
10264 @torch.jit.script_method
10265 def method(self, x):
10266 y = self.baz(x)
10267 return x
10268
Shen Li10224432021-08-12 11:39:31 -070010269 with self.assertRaisesRegex(RuntimeError, "Expected at most 2 arguments but found 3"):
Adam Paszkeda654332018-05-04 10:54:19 +020010270 ModuleTooMany()
davidriazati883fb542020-01-08 15:39:32 -080010271 with self.assertRaisesRegex(RuntimeError, "Argument y not provided"):
Adam Paszkeda654332018-05-04 10:54:19 +020010272 ModuleTooFew()
10273 with self.assertRaisesRegex(RuntimeError, "need 3 values .* found only 2"):
10274 ModuleTooManyAssign()
davidriazati883fb542020-01-08 15:39:32 -080010275 with self.assertRaisesRegex(RuntimeError, "Argument y not provided."):
Adam Paszkeda654332018-05-04 10:54:19 +020010276 ModuleDefault()
10277
Ansley Usseryf18cc9c2020-10-05 15:07:11 -070010278 def test_type_inferred_from_empty_annotation(self):
10279 """
10280 Test that the type inferred from an empty or missing annotation is Torch.Tensor wtih `inferred=true`
10281 """
10282 @torch.jit.script
10283 def fn(x):
10284 return x
10285
10286 graph = fn.graph
10287 n = next(graph.inputs())
10288 self.assertTrue(n.type() == torch._C.TensorType.getInferred())
10289
Shen Li10224432021-08-12 11:39:31 -070010290 with self.assertRaisesRegex(RuntimeError, "Inferred \'x\' to be of type \'Tensor"):
Edward Z. Yangf2eed942022-05-06 09:24:42 -070010291 fn("1")
Ansley Usseryf18cc9c2020-10-05 15:07:11 -070010292
Zachary DeVitoee240aa2018-04-16 15:19:05 -070010293 def test_script_define_order(self):
10294 class M(torch.jit.ScriptModule):
Shen Li10224432021-08-12 11:39:31 -070010295
Zachary DeVitoee240aa2018-04-16 15:19:05 -070010296 @torch.jit.script_method
10297 def call_foo(self, input):
10298 return self.foo(input)
10299
10300 @torch.jit.script_method
10301 def foo(self, input):
10302 return input + 1
10303 m = M()
gchanane1f5d802018-04-18 23:37:54 -040010304 self.assertEqual(2, m.call_foo(torch.ones((), dtype=torch.int64)))
Zachary DeVitoee240aa2018-04-16 15:19:05 -070010305
10306 def test_script_define_order_recursive_fail(self):
10307 class M(torch.jit.ScriptModule):
Shen Li10224432021-08-12 11:39:31 -070010308
Zachary DeVitoee240aa2018-04-16 15:19:05 -070010309 @torch.jit.script_method
10310 def call_foo(self, input):
10311 return self.foo(input)
10312
10313 @torch.jit.script_method
10314 def foo(self, input):
10315 self.call_foo(input)
10316
Shen Li10224432021-08-12 11:39:31 -070010317 with self.assertRaisesRegex(RuntimeError, 'called recursively'):
Zachary DeVitoee240aa2018-04-16 15:19:05 -070010318 M()
10319
Zachary DeVitoce69d312018-05-14 14:46:36 -070010320 def test_script_kwargs_fn_call(self):
10321 class M(torch.jit.ScriptModule):
Shen Li10224432021-08-12 11:39:31 -070010322
Zachary DeVitoce69d312018-05-14 14:46:36 -070010323 @torch.jit.script_method
10324 def call_foo(self, input):
10325 return self.foo(input=input, bar=1)
10326
10327 @torch.jit.script_method
10328 def foo(self, bar, input):
Richard Zou35beecf2018-08-27 12:37:20 -070010329 # type: (int, Tensor) -> Tensor
Zachary DeVitoce69d312018-05-14 14:46:36 -070010330 return input + bar
10331 m = M()
10332 self.assertEqual(2, m.call_foo(torch.ones((), dtype=torch.int64)))
10333
Zachary DeVitoa3f38172018-04-19 11:32:31 -070010334 def test_if_define(self):
10335 @torch.jit.script
10336 def foo(a):
David Riazati6f53b4e2018-09-13 11:10:00 -070010337 if bool(a == 0):
Zachary DeVitoa3f38172018-04-19 11:32:31 -070010338 b = 1
10339 else:
10340 b = 0
Richard Zou67f6f932018-08-27 08:53:56 -070010341 return b + 1
Zachary DeVitoa3f38172018-04-19 11:32:31 -070010342
10343 @torch.jit.script
10344 def foo2(a):
10345 b = 0
David Riazati6f53b4e2018-09-13 11:10:00 -070010346 if bool(a == 0):
Zachary DeVitoa3f38172018-04-19 11:32:31 -070010347 b = 1
Richard Zou67f6f932018-08-27 08:53:56 -070010348 return b + 1
Zachary DeVitoa3f38172018-04-19 11:32:31 -070010349
10350 @torch.jit.script
10351 def foo3(a):
10352 b = 1
David Riazati6f53b4e2018-09-13 11:10:00 -070010353 if bool(a == 0):
Zachary DeVitoa3f38172018-04-19 11:32:31 -070010354 c = 4
10355 else:
10356 b = 0
Richard Zou67f6f932018-08-27 08:53:56 -070010357 return b + 1
Zachary DeVitoa3f38172018-04-19 11:32:31 -070010358
10359 a = torch.ones(1, dtype=torch.long)
10360 b = torch.zeros(1, dtype=torch.long)
10361 self.assertEqual(1, foo(a))
10362 self.assertEqual(2, foo(b))
10363 self.assertEqual(1, foo2(a))
10364 self.assertEqual(2, foo2(b))
10365 self.assertEqual(1, foo3(a))
10366 self.assertEqual(2, foo3(b))
10367
Roy Li0e9c6892018-08-02 15:42:44 -070010368 def test_script_module_export_submodule(self):
10369 class M1(torch.jit.ScriptModule):
10370 def __init__(self):
Xuehai Pan046e88a2023-02-12 22:20:50 +000010371 super().__init__()
Roy Li0e9c6892018-08-02 15:42:44 -070010372 self.weight = nn.Parameter(torch.randn(2))
10373
10374 @torch.jit.script_method
10375 def forward(self, thing):
10376 return self.weight + thing
10377
10378 class M2(torch.jit.ScriptModule):
10379 def __init__(self):
Xuehai Pan046e88a2023-02-12 22:20:50 +000010380 super().__init__()
Roy Li0e9c6892018-08-02 15:42:44 -070010381 # test submodule
10382 self.sub = M1()
10383 self.weight = nn.Parameter(torch.randn(2, 3))
10384 self.bias = nn.Parameter(torch.randn(2))
Shen Li10224432021-08-12 11:39:31 -070010385 self.define("""
Roy Li0e9c6892018-08-02 15:42:44 -070010386 def hi(self, a):
10387 return self.weight.mm(a)
Shen Li10224432021-08-12 11:39:31 -070010388 """)
Roy Li0e9c6892018-08-02 15:42:44 -070010389
10390 @torch.jit.script_method
10391 def doit(self, input):
10392 return self.weight.mm(input)
10393
10394 @torch.jit.script_method
10395 def doit2(self, input):
10396 return self.weight.mm(input)
10397
10398 @torch.jit.script_method
10399 def doit3(self, input):
10400 return input + torch.ones([1], dtype=torch.double)
10401
10402 @torch.jit.script_method
10403 def forward(self, input):
10404 a = self.doit(input)
10405 b = self.doit2(input)
10406 c = self.hi(input)
10407 return a + b + self.bias + c
10408
Michael Suo711be822019-07-24 23:05:48 -070010409 with torch.jit.optimized_execution(False):
10410 m_orig = M2()
10411 m_import = self.getExportImportCopy(m_orig)
Roy Li0e9c6892018-08-02 15:42:44 -070010412
Michael Suo711be822019-07-24 23:05:48 -070010413 input = torch.randn(3, 2)
10414 self.assertEqual(m_orig.doit(input), m_import.doit(input))
10415 self.assertEqual(m_orig.hi(input), m_import.hi(input))
10416 self.assertEqual(m_orig.doit3(input), m_import.doit3(input))
10417 self.assertEqual(m_orig.forward(input), m_import.forward(input))
Roy Li0e9c6892018-08-02 15:42:44 -070010418
Elias Ellison8faf0112019-03-27 19:21:32 -070010419 @slowTest
Elias Ellison2e446302019-08-16 13:27:37 -070010420 def test_compile_module_with_constant(self):
10421 class Double(nn.Module):
10422 def __init__(self, downsample=None):
Xuehai Pan046e88a2023-02-12 22:20:50 +000010423 super().__init__()
Elias Ellison2e446302019-08-16 13:27:37 -070010424
10425 def forward(self, input):
10426 return input * 2
10427
10428 class Mod(nn.Module):
Shen Li10224432021-08-12 11:39:31 -070010429 __constants__ = ['downsample']
Elias Ellison2e446302019-08-16 13:27:37 -070010430
10431 def __init__(self, downsample=None):
Xuehai Pan046e88a2023-02-12 22:20:50 +000010432 super().__init__()
Elias Ellison2e446302019-08-16 13:27:37 -070010433 self.downsample = downsample
10434
10435 def forward(self, input):
10436 if self.downsample is not None:
10437 return self.downsample(input)
10438 return input
10439
10440 none_mod = torch.jit.script(Mod(None))
10441 double_mod = torch.jit.script(Mod(Double()))
10442 self.assertEqual(none_mod(torch.tensor(1)), torch.tensor(1))
10443 self.assertEqual(double_mod(torch.tensor(1)), torch.tensor(1) * 2)
10444
Elias Ellison97e8dcb2022-01-11 22:09:58 -080010445 def test_device_kwarg(self):
10446 from torch import device
10447
10448 def f():
10449 return device(type='cuda'), torch.device(type='cpu')
10450 self.checkScript(f, ())
10451
Roy Li0e9c6892018-08-02 15:42:44 -070010452 def test_script_module_export_tensor_type(self):
10453 class M(torch.jit.ScriptModule):
Roy Li0e9c6892018-08-02 15:42:44 -070010454 def __init__(self, type):
Xuehai Pan046e88a2023-02-12 22:20:50 +000010455 super().__init__()
Shen Li10224432021-08-12 11:39:31 -070010456 self.param = torch.nn.Parameter(torch.zeros((5, 5), dtype=type).random_())
Roy Li0e9c6892018-08-02 15:42:44 -070010457
10458 @torch.jit.script_method
10459 def foo(self):
10460 return self.param
10461
Michael Suo711be822019-07-24 23:05:48 -070010462 with torch.jit.optimized_execution(False):
10463 for type in [torch.float, torch.double]:
10464 m_orig = M(type)
10465 m_import = self.getExportImportCopy(m_orig)
10466 # check to make sure the storage wasn't resized
10467 self.assertTrue(m_orig.param.storage().size() == 25)
10468 self.assertEqual(m_orig.foo(), m_import.foo())
10469 self.assertTrue(m_orig.foo().dtype == m_import.foo().dtype)
Roy Li0e9c6892018-08-02 15:42:44 -070010470
10471 @unittest.skipIf(not RUN_CUDA, "testing cuda tensors require CUDA")
10472 def test_script_module_export_tensor_cuda(self):
10473 class M(torch.jit.ScriptModule):
Shen Li10224432021-08-12 11:39:31 -070010474
Roy Li0e9c6892018-08-02 15:42:44 -070010475 def __init__(self):
Xuehai Pan046e88a2023-02-12 22:20:50 +000010476 super().__init__()
Shen Li10224432021-08-12 11:39:31 -070010477 self.param = torch.nn.Parameter(torch.zeros((5, 5), device='cuda:0').random_())
Roy Li0e9c6892018-08-02 15:42:44 -070010478
10479 @torch.jit.script_method
10480 def foo(self):
10481 return self.param
10482
10483 m_orig = M()
Roy Lie9ad7432018-08-10 00:03:50 -070010484 m_import = self.getExportImportCopy(m_orig)
Zachary DeVitob937cbb2018-10-05 16:23:42 -070010485 # check to make sure the storage wasn't resized
10486 self.assertTrue(m_orig.param.storage().size() == 25)
Shen Li10224432021-08-12 11:39:31 -070010487 self.assertTrue(m_import.foo().device == torch.device('cuda:0'))
Roy Li0e9c6892018-08-02 15:42:44 -070010488 self.assertEqual(m_orig.foo(), m_import.foo())
10489 self.assertTrue(m_orig.foo().dtype == m_import.foo().dtype)
10490
Roy Li1a0d82e2018-10-08 22:12:54 -070010491 def test_script_module_export_blocks(self):
10492 class M(torch.jit.ScriptModule):
10493 def __init__(self, n, m):
Xuehai Pan046e88a2023-02-12 22:20:50 +000010494 super().__init__()
Roy Li1a0d82e2018-10-08 22:12:54 -070010495 self.weight = torch.nn.Parameter(torch.rand(n, m))
10496
10497 @torch.jit.script_method
10498 def forward(self, input):
10499 if bool(input.sum() > 0):
10500 output = self.weight.mv(input)
10501 else:
10502 output = self.weight + input
10503 return output
10504
10505 m_orig = M(200, 200)
10506 m_import = self.getExportImportCopy(m_orig)
10507
10508 t = torch.rand(200)
10509 self.assertEqual(m_orig(t), m_import(t))
10510
Roy Li0e9c6892018-08-02 15:42:44 -070010511 def test_script_module_export_shared_storage(self):
10512 class M(torch.jit.ScriptModule):
Shen Li10224432021-08-12 11:39:31 -070010513
Roy Li0e9c6892018-08-02 15:42:44 -070010514 def __init__(self):
Xuehai Pan046e88a2023-02-12 22:20:50 +000010515 super().__init__()
Roy Li0e9c6892018-08-02 15:42:44 -070010516 self.param1 = torch.nn.Parameter(torch.rand(5, 5))
10517 self.param2 = torch.nn.Parameter(self.param1[3])
10518 self.param3 = torch.nn.Parameter(torch.rand(5, 5))
Roy Lia8615732018-09-13 10:59:11 -070010519 self.param4 = torch.nn.Parameter(torch.rand(11, 5)[1:6])
Roy Li0e9c6892018-08-02 15:42:44 -070010520
10521 @torch.jit.script_method
10522 def foo(self):
Roy Lia8615732018-09-13 10:59:11 -070010523 return self.param1 + self.param2 + self.param3 + self.param4
Roy Li0e9c6892018-08-02 15:42:44 -070010524
Michael Suo711be822019-07-24 23:05:48 -070010525 with torch.jit.optimized_execution(False):
10526 m_orig = M()
10527 m_import = self.getExportImportCopy(m_orig)
Roy Lie9ad7432018-08-10 00:03:50 -070010528
Michael Suo711be822019-07-24 23:05:48 -070010529 self.assertEqual(m_orig.foo(), m_import.foo())
Zachary DeVitoe2cccce2019-08-22 11:44:53 -070010530
Shen Li10224432021-08-12 11:39:31 -070010531 self.assertTrue(m_import.param1.storage().data_ptr() == m_import.param2.storage().data_ptr())
10532 self.assertTrue(m_import.param1.storage().data_ptr() != m_import.param3.storage().data_ptr())
Roy Li0e9c6892018-08-02 15:42:44 -070010533
Elias Ellisonfdeef452019-11-04 09:18:09 -080010534 def test_sequential_intermediary_types(self):
10535 class A(torch.nn.Module):
Elias Ellisonfdeef452019-11-04 09:18:09 -080010536 def forward(self, x):
10537 return x + 3
10538
10539 class B(torch.nn.Module):
Elias Ellisonfdeef452019-11-04 09:18:09 -080010540 def forward(self, x):
10541 return {"1": x}
10542
10543 class C(torch.nn.Module):
Elias Ellisonfdeef452019-11-04 09:18:09 -080010544 def __init__(self):
Xuehai Pan046e88a2023-02-12 22:20:50 +000010545 super().__init__()
Elias Ellisonfdeef452019-11-04 09:18:09 -080010546 self.foo = torch.nn.Sequential(A(), B())
10547
10548 def forward(self, x):
10549 return self.foo(x)
10550
10551 self.checkModule(C(), (torch.tensor(1),))
10552
Yanan Caobf88a4d2021-03-09 00:00:14 -080010553 def test_ellipsis_const_mid(self):
10554 def ellipsize(x):
10555 # type: (Tensor) -> List[int]
Sam Estepe3900d22021-04-19 13:14:27 -070010556 return x[2, Ellipsis, 0:4, 4:8].size()
Yanan Caobf88a4d2021-03-09 00:00:14 -080010557
10558 dummy = torch.zeros(8, 8, 8, 8, 8)
10559 self.checkScript(ellipsize, (dummy,), optimize=True)
10560
10561 def test_ellipsis_const_mid_select(self):
10562 def ellipsize(x):
10563 # type: (Tensor) -> List[int]
Sam Estepe3900d22021-04-19 13:14:27 -070010564 return x[2, Ellipsis, 4, 4, 4:8, 2].size()
Yanan Caobf88a4d2021-03-09 00:00:14 -080010565
10566 dummy = torch.zeros(8, 8, 8, 8, 8, 8, 8)
10567 self.checkScript(ellipsize, (dummy,), optimize=True)
10568
10569 def test_ellipsis_const_start(self):
10570 def ellipsize(x):
10571 # type: (Tensor) -> List[int]
Sam Estepe3900d22021-04-19 13:14:27 -070010572 return x[Ellipsis, 0:4, 4:8].size()
Yanan Caobf88a4d2021-03-09 00:00:14 -080010573 dummy = torch.zeros(8, 8, 8, 8, 8)
10574 self.checkScript(ellipsize, (dummy,), optimize=True)
10575
10576 def test_ellipsis_const_end(self):
10577 def ellipsize(x):
10578 # type: (Tensor) -> List[int]
Sam Estepe3900d22021-04-19 13:14:27 -070010579 return x[0:4, 2, Ellipsis].size()
Yanan Caobf88a4d2021-03-09 00:00:14 -080010580 dummy = torch.zeros(8, 8, 8, 8, 8)
10581 self.checkScript(ellipsize, (dummy,), optimize=True)
10582
Nikolay Korovaikoada10ad2019-04-15 22:05:20 -070010583 def test_ellipsis_mid(self):
10584 def ellipsize(x):
10585 # type: (Tensor) -> List[int]
Sam Estepe3900d22021-04-19 13:14:27 -070010586 return x[2, ..., 0:4, 4:8].size()
Nikolay Korovaikoada10ad2019-04-15 22:05:20 -070010587
10588 dummy = torch.zeros(8, 8, 8, 8, 8)
10589 self.checkScript(ellipsize, (dummy,), optimize=True)
10590
10591 def test_ellipsis_mid_select(self):
10592 def ellipsize(x):
10593 # type: (Tensor) -> List[int]
Sam Estepe3900d22021-04-19 13:14:27 -070010594 return x[2, ..., 4, 4, 4:8, 2].size()
Nikolay Korovaikoada10ad2019-04-15 22:05:20 -070010595
10596 dummy = torch.zeros(8, 8, 8, 8, 8, 8, 8)
10597 self.checkScript(ellipsize, (dummy,), optimize=True)
10598
10599 def test_ellipsis_start(self):
10600 def ellipsize(x):
10601 # type: (Tensor) -> List[int]
Sam Estepe3900d22021-04-19 13:14:27 -070010602 return x[..., 0:4, 4:8].size()
Nikolay Korovaikoada10ad2019-04-15 22:05:20 -070010603 dummy = torch.zeros(8, 8, 8, 8, 8)
10604 self.checkScript(ellipsize, (dummy,), optimize=True)
10605
10606 def test_ellipsis_end(self):
10607 def ellipsize(x):
10608 # type: (Tensor) -> List[int]
Sam Estepe3900d22021-04-19 13:14:27 -070010609 return x[0:4, 2, ...].size()
Nikolay Korovaikoada10ad2019-04-15 22:05:20 -070010610 dummy = torch.zeros(8, 8, 8, 8, 8)
10611 self.checkScript(ellipsize, (dummy,), optimize=True)
10612
Elias Ellison19f73182019-04-22 10:52:28 -070010613 def test_torch_manual_seed(self):
10614 with freeze_rng_state():
10615 def test():
10616 torch.manual_seed(2)
10617 return torch.rand(1)
10618
10619 script = torch.jit.script(test)
10620 self.assertEqual(test(), script())
10621 graph = script.graph_for()
10622 FileCheck().check("aten::manual_seed").run(graph)
10623
Yanbo Liang490c1cf2022-12-19 04:14:11 +000010624 @skipIfTorchDynamo("Not a TorchDynamo suitable test")
Zachary DeVito0b5910f2018-04-24 14:04:18 -070010625 def test_index_select_shape_prop(self):
Shen Li10224432021-08-12 11:39:31 -070010626
Zachary DeVito0b5910f2018-04-24 14:04:18 -070010627 @torch.jit.script
10628 def foo(x, y):
Zachary DeVitoce69d312018-05-14 14:46:36 -070010629 return torch.index_select(x, index=y, dim=1)
Zachary DeVito0b5910f2018-04-24 14:04:18 -070010630
10631 a = torch.zeros(2, 2)
10632 b = torch.zeros(4, dtype=torch.long)
Adam Paszkec8b246a2018-08-26 09:40:58 -070010633 torch._C._jit_pass_complete_shape_analysis(foo.graph, (a, b), False)
Shen Li10224432021-08-12 11:39:31 -070010634 FileCheck().check("Double(2, 4, strides=[4, 1], requires_grad=0, device=cpu)").run(str(foo.graph))
Zachary DeVito0b5910f2018-04-24 14:04:18 -070010635
Zachary DeVito733e2962018-04-27 17:44:17 -070010636 def test_shape_analysis_loop(self):
10637 def foo(a, b, x):
10638 c = a
10639 # on the first iteration of the loop it appears that
10640 # c should have a expand to the size of b
10641 # but on the second+ iterations, there is no broadcast and the
10642 # sizes are different.
10643 # previously this would cause the compiler to (1) enter an infinite
10644 # loop trying to compute the shape, and (2) insert invalid
10645 # broadcasts.
10646 # this test ensure we don't regress on these issues
10647 for _ in range(2):
10648 a = c + b
10649 c = x
10650 b = x
10651 return a
10652
Shen Li10224432021-08-12 11:39:31 -070010653 self.checkScript(foo, (torch.zeros(1), torch.zeros(4), torch.zeros(5)), optimize=False)
Zachary DeVito733e2962018-04-27 17:44:17 -070010654
Marcin Elantkowskibc626452018-04-30 05:09:04 +020010655 def test_intlist_args(self):
10656 def func_1(x):
10657 return torch.nn.functional.adaptive_avg_pool1d(x, 1)
10658
10659 def func_2(x):
10660 return torch.nn.functional.adaptive_avg_pool1d(x, output_size=1)
10661
10662 def func_3(x):
10663 return torch.nn.functional.adaptive_avg_pool1d(x, output_size=[1])
10664
10665 x = torch.randn(8, 8, 8)
10666 self.checkScript(func_1, [x], optimize=True)
10667 self.checkScript(func_2, [x], optimize=True)
10668 self.checkScript(func_3, [x], optimize=True)
10669
Zachary DeVito93eb50c2018-05-10 10:47:43 -070010670 def test_wrong_implicit_expand(self):
Shen Li10224432021-08-12 11:39:31 -070010671
Zachary DeVito93bd2912018-08-30 13:51:45 -070010672 @_trace(torch.zeros(3), torch.zeros(1))
Zachary DeVito93eb50c2018-05-10 10:47:43 -070010673 def foo(a, b):
10674 return a + b
10675
10676 a = torch.rand(4)
10677 b = torch.rand(4)
10678 self.assertEqual(a + b, foo(a, b))
10679
Zachary DeVitoce69d312018-05-14 14:46:36 -070010680 def test_builtin_args_fails(self):
10681
Shen Li10224432021-08-12 11:39:31 -070010682 with self.assertRaisesRegex(RuntimeError, 'Argument self not provided'):
Zachary DeVitoce69d312018-05-14 14:46:36 -070010683 @torch.jit.script
10684 def f1(a):
10685 torch.sum(foo=4)
10686
Shen Li10224432021-08-12 11:39:31 -070010687 with self.assertRaisesRegex(RuntimeError, 'specified twice'):
Zachary DeVitoce69d312018-05-14 14:46:36 -070010688 @torch.jit.script
10689 def f2(a):
10690 torch.sum(a, self=a)
10691
Shen Li10224432021-08-12 11:39:31 -070010692 with self.assertRaisesRegex(RuntimeError, 'not provided'):
Zachary DeVitoce69d312018-05-14 14:46:36 -070010693 @torch.jit.script
10694 def f3(a):
10695 torch.sum(dim=4)
10696
Shen Li10224432021-08-12 11:39:31 -070010697 with self.assertRaisesRegex(RuntimeError, 'for argument \'tensors\' but instead found type \'Tensor'):
Zachary DeVitoce69d312018-05-14 14:46:36 -070010698 @torch.jit.script
10699 def f4(a):
10700 torch.cat(a)
10701
Shen Li10224432021-08-12 11:39:31 -070010702 with self.assertRaisesRegex(RuntimeError, r'argument \'tensors\' but instead found type \'List\[int\]'):
Zachary DeVitoce69d312018-05-14 14:46:36 -070010703 @torch.jit.script
10704 def f5(a):
Zachary DeVitoab3a2d22018-09-13 11:05:09 -070010705 torch.cat([3])
Zachary DeVitoce69d312018-05-14 14:46:36 -070010706
Shen Li10224432021-08-12 11:39:31 -070010707 with self.assertRaisesRegex(RuntimeError, r'Expected a value of'
10708 r' type \'List\[int\]\' for argument'
10709 r' \'size\' but instead found type '
Ansley Ussery6831d8e2021-09-03 06:10:37 -070010710 r'\'List\[Union\[List\[int\], int\]\]'):
Zachary DeVitoce69d312018-05-14 14:46:36 -070010711 @torch.jit.script
10712 def f6(a):
10713 a.expand(size=[3, [4]])
10714
Zachary DeVitoce69d312018-05-14 14:46:36 -070010715 def test_builtin_args(self):
Shen Li10224432021-08-12 11:39:31 -070010716
Zachary DeVitoce69d312018-05-14 14:46:36 -070010717 def t0(a):
10718 # default arg dim
10719 return torch.cat([a, a])
10720
James Reedc0d50e12018-05-30 11:43:22 -070010721 self.checkScript(t0, (torch.zeros(1, 1),))
Zachary DeVitoce69d312018-05-14 14:46:36 -070010722
10723 def t1(a):
10724 # keywords out of order
10725 return torch.cat(dim=1, tensors=[a, a])
10726
James Reedc0d50e12018-05-30 11:43:22 -070010727 self.checkScript(t1, (torch.zeros(1, 1, 2),))
Zachary DeVitoce69d312018-05-14 14:46:36 -070010728
10729 def t2(a):
10730 # mix const/non-const attributes
Elias Ellisond1b8da72020-11-20 11:14:59 -080010731 if 1 == 1:
Zachary DeVitoce69d312018-05-14 14:46:36 -070010732 b = 1
10733 else:
10734 b = 0
10735 return torch.sum(a, dim=b, keepdim=False)
10736
James Reedc0d50e12018-05-30 11:43:22 -070010737 self.checkScript(t2, (torch.zeros(1, 1, 2),))
10738
James Reed32bb4042018-08-14 17:56:52 -070010739 def test_parser_type_annotations(self):
Shen Li10224432021-08-12 11:39:31 -070010740 cu = torch.jit.CompilationUnit('''
James Reed32bb4042018-08-14 17:56:52 -070010741 def foo(x : Tensor, y : Tuple[Tuple[Tensor, Tensor], Tensor]) -> Tuple[Tensor, Tensor]:
10742 return x, x
Shen Li10224432021-08-12 11:39:31 -070010743 ''')
James Reed32bb4042018-08-14 17:56:52 -070010744
Zachary DeVito31524bd2019-04-25 15:43:52 -070010745 self.assertExpected(str(cu.foo.schema))
James Reed32bb4042018-08-14 17:56:52 -070010746
10747 def test_parser_type_annotations_comment(self):
Shen Li10224432021-08-12 11:39:31 -070010748 cu = torch.jit.CompilationUnit('''
James Reed32bb4042018-08-14 17:56:52 -070010749 def foo(x, y):
10750 # type: (Tensor, Tuple[Tuple[Tensor, Tensor], Tensor]) -> Tuple[Tensor, Tensor]
10751 return x, x
Shen Li10224432021-08-12 11:39:31 -070010752 ''')
James Reed32bb4042018-08-14 17:56:52 -070010753
Zachary DeVito31524bd2019-04-25 15:43:52 -070010754 self.assertExpected(str(cu.foo.schema))
James Reed32bb4042018-08-14 17:56:52 -070010755
10756 def test_parser_type_annotations_unknown_type(self):
davidriazatic267d0c2019-05-17 15:23:58 -070010757 with self.assertRaisesRegex(RuntimeError, "Unknown type name 'Foo'"):
Shen Li10224432021-08-12 11:39:31 -070010758 cu = torch.jit.CompilationUnit('''
James Reed32bb4042018-08-14 17:56:52 -070010759 def foo(x : Tensor, y : Tuple[Tuple[Foo, Tensor], Tensor]) -> Tuple[Tensor, Tensor]:
10760 return x, x
Shen Li10224432021-08-12 11:39:31 -070010761 ''')
James Reed32bb4042018-08-14 17:56:52 -070010762
10763 def test_parser_type_annotations_subscript_non_ident(self):
Shen Li10224432021-08-12 11:39:31 -070010764 with self.assertRaisesRegex(RuntimeError, r'Subscripted type must be a type identifier'):
10765 cu = torch.jit.CompilationUnit('''
James Reed32bb4042018-08-14 17:56:52 -070010766 def foo(x : Tensor, y : Tuple[Tensor, Tensor][Tensor]) -> Tuple[Tensor, Tensor]:
10767 return x, x
Shen Li10224432021-08-12 11:39:31 -070010768 ''')
James Reed32bb4042018-08-14 17:56:52 -070010769
10770 def test_parser_type_annotations_subscript_tensor(self):
Shen Li10224432021-08-12 11:39:31 -070010771 with self.assertRaisesRegex(RuntimeError, r'Unknown type constructor Tensor'):
10772 cu = torch.jit.CompilationUnit('''
James Reed32bb4042018-08-14 17:56:52 -070010773 def foo(x : Tensor, y : Tensor[Tensor, Tensor]) -> Tuple[Tensor, Tensor]:
10774 return x, x
Shen Li10224432021-08-12 11:39:31 -070010775 ''')
James Reed32bb4042018-08-14 17:56:52 -070010776
10777 def test_parser_type_annotations_incompatible_expression(self):
Shen Li10224432021-08-12 11:39:31 -070010778 with self.assertRaisesRegex(RuntimeError, r'Expression of type \+ cannot be used in a type expression'):
10779 cu = torch.jit.CompilationUnit('''
James Reed32bb4042018-08-14 17:56:52 -070010780 def foo(x : Tensor, y : Tuple[3 + 4, Tensor]) -> Tuple[Tensor, Tensor]:
10781 return x, x
Shen Li10224432021-08-12 11:39:31 -070010782 ''')
James Reed32bb4042018-08-14 17:56:52 -070010783
James Reedc0d50e12018-05-30 11:43:22 -070010784 def test_gather_dynamic_index(self):
10785 def t(x):
10786 gather1 = x[0]
10787 idx = 0 + 1
10788 gather2 = x[idx]
10789 return gather1 + gather2
10790
10791 self.checkScript(t, (torch.zeros(3, 2, 3),))
10792
Tugsbayasgalan Manlaibaatarb405e2c2021-04-16 00:04:21 -070010793 def test_torch_ignore_conversion_to_none(self):
10794 class A(torch.nn.Module):
Tugsbayasgalan Manlaibaatarb405e2c2021-04-16 00:04:21 -070010795 @torch.jit.ignore
10796 def ignored(self, a: int) -> None:
10797 l: int = len([2 for i in range(a) if i > 2])
10798 return
10799
10800 def forward(self) -> int:
10801 a: int = 4
10802 b: int = 5
10803 self.ignored(a)
10804 return a + b
10805
10806 class B(torch.nn.Module):
Tugsbayasgalan Manlaibaatarb405e2c2021-04-16 00:04:21 -070010807 @torch.jit.ignore
10808 def ignored(self, a: int):
10809 l: int = len([2 for i in range(a) if i > 2])
10810 return
10811
10812 def forward(self) -> int:
10813 a: int = 4
10814 b: int = 5
10815 self.ignored(a)
10816 return a + b
10817
10818 modelA = torch.jit.script(A())
10819 self.assertEqual(modelA(), 9)
10820
Alban Desmaison28c51992021-12-02 07:45:35 -080010821 modelB = torch.jit.script(B())
10822 self.assertEqual(modelB(), 9)
Tugsbayasgalan Manlaibaatarb405e2c2021-04-16 00:04:21 -070010823
James Reed1f94a6e2018-05-30 15:06:58 -070010824 def test_addmm_grad(self):
Shen Li10224432021-08-12 11:39:31 -070010825 """ This test checks several things:
10826 1. An expand node was inserted before the addmm operating on the
10827 bias term.
10828 2. The fused form of addmm appears in the ultimate graph that's
10829 executed.
10830 3. A sum op was emitted for accumulating gradients along the 0th
10831 (expanded) dimension of the bias term.
10832 4. The correct symbolic representation for the backward pass of the
10833 mm operator was emitted (x.t() -> mm)
James Reed1f94a6e2018-05-30 15:06:58 -070010834
Shen Li10224432021-08-12 11:39:31 -070010835 TODO: we should actually check these conditions once we have a way
10836 to dump the GraphExecutor state. Namely the processed forward graph
10837 and the backward graph.
James Reed1f94a6e2018-05-30 15:06:58 -070010838 """
10839 @torch.jit.script
10840 def addmm_grad_test(b, x, w):
10841 return torch.addmm(b, x, w)
10842
10843 # Initialize param and input values
10844 w_init = torch.rand(2, 5)
10845 b_init = torch.rand(5)
10846 x = torch.rand(3, 2)
10847
10848 # Clone trainable params
10849 b = b_init.clone()
10850 b.requires_grad_()
10851 w = w_init.clone()
10852 w.requires_grad_()
10853
10854 # Test symbolic differentiation
10855 y = addmm_grad_test(b, x, w)
10856 y.sum().backward()
10857
10858 # clone params for autograd reference
10859 b_ref = b_init.clone()
10860 b_ref.requires_grad_()
10861 w_ref = w_init.clone()
10862 w_ref.requires_grad_()
10863 y_ref = torch.addmm(b_ref, x, w_ref)
10864 y_ref.sum().backward()
10865
10866 self.assertEqual(w.grad, w_ref.grad)
10867 self.assertEqual(b.grad, b_ref.grad)
10868
jieje926f752021-08-21 09:05:04 -070010869 @unittest.skipIf(not RUN_CUDA, "running tests on cuda to verify cudnn fix")
10870 def test_batch_norm_inference_backward_cuda(self):
10871 with enable_profiling_mode_for_profiling_tests():
10872 class MyBatchNorm(torch.nn.Module):
10873 def __init__(self, num_features, affine, track_running_stats):
Xuehai Pan046e88a2023-02-12 22:20:50 +000010874 super().__init__()
jieje926f752021-08-21 09:05:04 -070010875 self.bn = torch.nn.BatchNorm2d(
10876 num_features, 1e-5, affine=affine, track_running_stats=track_running_stats).float()
10877
10878 def forward(self, x: torch.Tensor):
10879 o = self.bn(x)
10880 o = torch.nn.functional.relu(o)
10881 return o
10882
10883 batch = 4
10884 c = 2
10885 hw = 3
10886 # Initialize param and input values
10887 x_init = torch.randn(batch, c, hw, hw, dtype=torch.float).cuda()
10888 grad = torch.randn(batch, c, hw, hw, dtype=torch.float).cuda()
10889
10890 training = False
10891 affine = True
10892 track_running_stats = True
10893
10894 module = torch.jit.script(MyBatchNorm(c, affine, track_running_stats)).cuda()
10895 ref_module = MyBatchNorm(c, affine, track_running_stats).cuda()
10896 module.eval()
10897 ref_module.eval()
10898
10899 jit_module = torch.jit.script(module)
10900 ref_module.load_state_dict(module.state_dict())
10901
10902 x = x_init.detach().clone()
10903 x.requires_grad_()
10904 x_ref = x_init.detach().clone()
10905 x_ref.requires_grad_()
10906
10907 # Test symbolic differentiation
10908 # Run Forward and Backward thrice to trigger autodiff graph
10909 for i in range(0, 3):
10910 y = jit_module(x)
10911 y.backward(grad)
10912 x.grad.zero_()
10913
10914 module.bn.running_mean.zero_()
10915 module.bn.running_var.fill_(1.0)
10916 ref_module.bn.running_mean.zero_()
10917 ref_module.bn.running_var.fill_(1.0)
10918
10919 # run jitted module
10920 y = jit_module(x)
10921 y.backward(grad)
10922 # reference computation
10923 y_ref = ref_module(x_ref)
10924 y_ref.backward(grad)
10925
10926 self.assertEqual(y_ref, y)
10927 self.assertEqual(x.grad, x_ref.grad)
10928 self.assertEqual(module.bn.running_mean, ref_module.bn.running_mean)
10929 self.assertEqual(module.bn.running_var, ref_module.bn.running_var)
10930
Zachary DeVitoef1c15f2018-06-01 14:24:18 -070010931 def test_zeros(self):
10932 class M(torch.jit.ScriptModule):
Shen Li10224432021-08-12 11:39:31 -070010933 __constants__ = ['d']
Zachary DeVitoef1c15f2018-06-01 14:24:18 -070010934
10935 def __init__(self):
Xuehai Pan046e88a2023-02-12 22:20:50 +000010936 super().__init__()
Shen Li10224432021-08-12 11:39:31 -070010937 self.d = torch.device('cpu')
Zachary DeVitoef1c15f2018-06-01 14:24:18 -070010938
10939 @torch.jit.script_method
10940 def create(self):
Shen Li10224432021-08-12 11:39:31 -070010941 return torch.zeros([1, 1, 2], dtype=torch.float, device=self.d, layout=torch.strided)
Zachary DeVitoef1c15f2018-06-01 14:24:18 -070010942
10943 r = M().create()
10944 self.assertEqual(r.dtype, torch.float)
10945 self.assertEqual(torch.zeros([1, 1, 2], dtype=torch.float), r)
10946
Elias Ellison0f428812019-09-19 15:44:44 -070010947 def fn():
10948 return torch.zeros((1, 2, 3))
10949
10950 self.checkScript(fn, ())
10951
Zachary DeVito3d43a822018-08-23 15:06:16 -070010952 def test_vararg_zeros(self):
10953 def foo():
10954 return torch.zeros(3, 4, 5, dtype=torch.int)
10955
10956 self.checkScript(foo, ())
10957
Shen Li10224432021-08-12 11:39:31 -070010958 @unittest.skipIf(GRAPH_EXECUTOR != ProfilingMode.LEGACY, "the original version of test_rand")
Zachary DeVitoef1c15f2018-06-01 14:24:18 -070010959 def test_rand(self):
Zachary DeVitoef1c15f2018-06-01 14:24:18 -070010960 def test_rand():
10961 a = torch.rand([3, 4])
10962 return a + 1.0 - a
10963
10964 self.checkScript(test_rand, ())
Elias Ellison930fb2f2019-04-08 14:44:45 -070010965 fn = torch.jit.script(test_rand)
10966 out = fn()
10967 self.assertEqual(out.dtype, torch.double)
10968 g = fn.graph_for()
10969 # Testing shape analysis correctly setting type
Nikolay Korovaiko5b702ab2019-11-11 13:39:03 -080010970 if GRAPH_EXECUTOR != ProfilingMode.SIMPLE:
Shen Li10224432021-08-12 11:39:31 -070010971 FileCheck().check("Double(*, *, requires_grad=0, device=cpu)") \
10972 .check_not("Float(*, *, requires_grad=0, device=cpu)").run(g)
Elias Ellison930fb2f2019-04-08 14:44:45 -070010973
10974 @torch.jit.script
10975 def randint():
10976 return torch.randint(0, 5, [1, 2])
10977 out = randint()
Peter Bell8d0cbce2022-07-20 19:05:55 +010010978 self.assertEqual(out.dtype, torch.int64)
Nikolay Korovaiko5b702ab2019-11-11 13:39:03 -080010979 if GRAPH_EXECUTOR != ProfilingMode.SIMPLE:
Peter Bell8d0cbce2022-07-20 19:05:55 +010010980 FileCheck().check("Long(*, *, requires_grad=0, device=cpu)") \
10981 .check_not("Float(*, *, requires_grad=0, device=cpu)") \
10982 .check_not("Double(*, *, requires_grad=0, device=cpu)") \
10983 .run(randint.graph_for())
Zachary DeVitoef1c15f2018-06-01 14:24:18 -070010984
jjsjann123fde282f2022-03-21 14:22:31 -070010985 @unittest.skipIf(not RUN_CUDA, "no CUDA")
10986 @unittest.skipIf(GRAPH_EXECUTOR != ProfilingMode.PROFILING, "skip if profiling isn't enabled")
10987 def test_autodiff_complex(self):
10988 def foo(x: torch.Tensor, y: torch.Tensor, W: torch.Tensor):
10989 return torch.exp(torch.mm(torch.complex(x, y), W.cfloat()))
10990
10991 @torch.jit.script
10992 def jitted_foo(x: torch.Tensor, y: torch.Tensor, W: torch.Tensor):
10993 return torch.exp(torch.mm(torch.complex(x, y), W.cfloat()))
10994
10995 x = torch.randn(128, 16, dtype=torch.float32, device='cuda:0')
10996 y = torch.randn(128, 16, dtype=torch.float32, device='cuda:0')
10997 W = torch.randn(16, 1, dtype=torch.float32, device='cuda:0', requires_grad=True)
10998 W.data /= 4
10999
11000 with enable_profiling_mode_for_profiling_tests():
11001 for i in range(4):
11002 self.assertTrue((foo(x, y, W).grad_fn is None) == (jitted_foo(x, y, W).grad_fn is None))
11003
11004
jiej4d703d02021-02-04 16:30:35 -080011005 def test_linear_grad(self):
11006 with enable_profiling_mode_for_profiling_tests():
11007 def t(x: torch.Tensor, w: torch.Tensor, b: Optional[torch.Tensor]):
11008 return torch.nn.functional.linear(x, w, b)
11009
11010 x_init = torch.randn(4, 2)
11011 w_init = torch.randn(3, 2)
11012 b_init = torch.randn(3)
11013 grad = torch.randn(4, 3)
11014
11015 with disable_autodiff_subgraph_inlining():
11016 # script module
11017 jit_t = torch.jit.script(t)
11018
11019 x = x_init.detach().requires_grad_()
11020 w = w_init.detach().requires_grad_()
11021 b = b_init.detach().requires_grad_()
11022 x_ref = x_init.detach().requires_grad_()
11023 w_ref = w_init.detach().requires_grad_()
11024 b_ref = b_init.detach().requires_grad_()
11025
11026 # profiling/optimization runs
11027 jit_o = jit_t(x, w, b)
11028 jit_o.backward(grad)
11029 jit_o = jit_t(x, w, b)
11030 jit_o.backward(grad)
11031
11032 x.grad.zero_()
11033 w.grad.zero_()
11034 b.grad.zero_()
11035 jit_o = jit_t(x, w, b)
11036 jit_o.backward(grad)
11037 o = t(x_ref, w_ref, b_ref)
11038 o.backward(grad)
11039
11040 self.assertEqual(jit_o, o)
11041 self.assertEqual(x.grad, x_ref.grad)
11042 self.assertEqual(w.grad, w_ref.grad)
11043 self.assertEqual(b.grad, b_ref.grad)
11044
11045 x.grad.zero_()
11046 w.grad.zero_()
11047 x_ref.grad.zero_()
11048 w_ref.grad.zero_()
11049 jit_o = jit_t(x, w, None)
11050 jit_o.backward(grad)
11051 o = t(x_ref, w_ref, None)
11052 o.backward(grad)
11053
11054 self.assertEqual(jit_o, o)
11055 self.assertEqual(x.grad, x_ref.grad)
11056 self.assertEqual(w.grad, w_ref.grad)
11057
Yanbo Liang490c1cf2022-12-19 04:14:11 +000011058 @skipIfTorchDynamo("TorchDynamo doesn't support profile")
Shen Li10224432021-08-12 11:39:31 -070011059 @unittest.skipIf(GRAPH_EXECUTOR != ProfilingMode.PROFILING, "the profiling version of test_rand")
Nikolay Korovaiko47faee22019-10-29 11:40:04 -070011060 def test_rand_profiling(self):
11061 def test_rand():
11062 a = torch.rand([3, 4])
11063 return a + 1.0 - a
11064
Nikolay Korovaiko47faee22019-10-29 11:40:04 -070011065 # Testing shape analysis correctly setting type
Nikolay Korovaikofe261022020-09-13 15:56:30 -070011066 with enable_profiling_mode_for_profiling_tests():
11067 with num_profiled_runs(1):
11068 fn = torch.jit.script(test_rand)
11069 out = fn()
11070 graph_str = torch.jit.last_executed_optimized_graph()
11071 self.assertEqual(out.dtype, torch.double)
Shen Li10224432021-08-12 11:39:31 -070011072 FileCheck().check("Double(3, 4, strides=[4, 1], requires_grad=0, device=cpu)") \
11073 .check_not("Float(3, 4, strides=[4, 1], requires_grad=0, device=cpu)").run(graph_str)
Nikolay Korovaikofe261022020-09-13 15:56:30 -070011074
11075 # fn = self.checkScript(test_rand, ())
11076 # out = fn()
11077 # self.assertEqual(out.dtype, torch.double)
Nikolay Korovaiko47faee22019-10-29 11:40:04 -070011078
11079 @torch.jit.script
11080 def randint():
11081 return torch.randint(0, 5, [1, 2])
11082
Nikolay Korovaikofe261022020-09-13 15:56:30 -070011083 with enable_profiling_mode_for_profiling_tests():
11084 with num_profiled_runs(1):
11085 out = randint()
11086 graph_str = torch.jit.last_executed_optimized_graph()
Peter Bell8d0cbce2022-07-20 19:05:55 +010011087 self.assertEqual(out.dtype, torch.int64)
11088 FileCheck().check("profiled_type=Long(1, 2, strides=[2, 1], requires_grad=0, device=cpu)").run(graph_str)
Shen Li10224432021-08-12 11:39:31 -070011089
Nikolay Korovaiko47faee22019-10-29 11:40:04 -070011090
Richard Zou8489c4c2018-06-21 15:43:38 -040011091 def test_erase_number_types(self):
11092 def func(a):
11093 b = 7 + 1 + 3
11094 c = a + b
11095 c += b
11096 return c
11097
James Reed0b16b032018-07-25 16:55:09 -070011098 graph = torch.jit.script(func).graph
eellisond8d83712019-02-22 17:54:09 -080011099 FileCheck().check("int = prim::Constant").check("aten::add_").run(str(graph))
Gary Migueldec5aa22021-07-09 16:13:27 -070011100 self.run_pass("erase_number_types", graph)
11101 FileCheck().check_not("int = prim::Constant").run(str(graph))
Richard Zou8489c4c2018-06-21 15:43:38 -040011102
Henry Tuf6eb8112022-05-12 00:48:39 +000011103 def test_refine_tuple_types(self):
11104 # TupleConstruct output type is not correct here.
11105 graph_str = """
11106 graph(%a : Float(123), %b : Float(4, 5, 6)):
11107 %c : (Tensor, Tensor) = prim::TupleConstruct(%a, %b)
11108 return (%c)
11109 """
11110 graph = parse_ir(graph_str)
11111 torch._C._jit_pass_refine_tuple_types(graph)
11112
11113 # After the pass, the output type should've been updated.
11114 self.assertTrue('(Float(123), Float(4, 5, 6))' in str(graph.findNode('prim::TupleConstruct').output()))
11115
11116 # TODO(henrytu): Add test for RefineTypes for NamedTuple when it's supported by IR parser.
11117
Kimish Patelf954dd72020-05-12 14:36:25 -070011118 def test_remove_dropout(self):
11119 weight_0_shape = (20, 5)
11120 weight_1_shape = (20, 20)
11121 input_shape = (10, 5)
11122
11123 class M(torch.nn.Module):
11124 def __init__(self):
Xuehai Pan046e88a2023-02-12 22:20:50 +000011125 super().__init__()
Yukio Siraichi93bf0ae2021-04-11 15:43:54 -070011126 self.weight_0 = torch.nn.Parameter(torch.rand(weight_0_shape))
11127 self.weight_1 = torch.nn.Parameter(torch.rand(weight_1_shape))
Kimish Patelf954dd72020-05-12 14:36:25 -070011128
11129 def forward(self, x):
11130 o = F.linear(x, self.weight_0)
11131 o = F.dropout(o, training=self.training)
11132 o = F.linear(o, self.weight_1)
11133 return o
11134
11135 data = torch.rand(input_shape)
11136 m = M()
11137 m = torch.jit.script(m)
Shen Li10224432021-08-12 11:39:31 -070011138 with self.assertRaisesRegex(RuntimeError, r'Dropout removal module in training mode is not yet supported'):
Kimish Patelf954dd72020-05-12 14:36:25 -070011139 torch._C._jit_pass_remove_dropout(m._c)
11140 m.eval()
11141 ref_res = m(data)
11142 # Need to inline otherwise we see instances of Function.
11143 # We would have to use torch.linear/dropout to get around it otherwise.
11144 from torch.jit._recursive import wrap_cpp_module
11145 m = wrap_cpp_module(torch._C._freeze_module(m._c))
11146 torch._C._jit_pass_remove_dropout(m._c)
11147 res = m(data)
11148 FileCheck().check_not("aten::dropout").run(str(m.graph))
Philip Meier99203582021-08-19 12:45:32 -070011149 torch.testing.assert_close(ref_res, res, rtol=1e-2, atol=1e-3)
Kimish Patelf954dd72020-05-12 14:36:25 -070011150
Elias Ellisone1428cf2020-07-02 13:30:48 -070011151 def test_unfold_zero_dim(self):
11152 def fn(x):
11153 return x.unfold(0, 1, 1)
11154
11155 graph = torch.jit.script(fn).graph
11156 torch._C._jit_pass_complete_shape_analysis(graph, (torch.tensor(0.39),), False)
11157 out_dims = fn(torch.tensor(0.3923)).ndim
11158 self.assertEqual(graph.findNode("aten::unfold").output().type().dim(), out_dims)
11159
Adam Paszkea6036892018-11-26 09:18:43 -080011160 def test_mm_batching(self):
Adam Paszkea6036892018-11-26 09:18:43 -080011161
Elias Ellison0e3a05e2020-05-06 11:27:59 -070011162 with enable_profiling_mode_for_profiling_tests():
Nikolay Korovaiko5b702ab2019-11-11 13:39:03 -080011163 lstm_cell = torch.jit.script(LSTMCellS)
Adam Paszkea6036892018-11-26 09:18:43 -080011164
Nikolay Korovaiko5b702ab2019-11-11 13:39:03 -080011165 def lstm(x, hx, cx, w_ih, w_hh, b_ih, b_hh):
11166 for i in range(x.size(0)):
11167 hx, cx = lstm_cell(x[i], hx, cx, w_ih, w_hh, b_ih, b_hh)
11168 return hx
Adam Paszkea6036892018-11-26 09:18:43 -080011169
Nikolay Korovaiko5b702ab2019-11-11 13:39:03 -080011170 slstm = torch.jit.script(lstm)
Adam Paszkea6036892018-11-26 09:18:43 -080011171
Shen Li10224432021-08-12 11:39:31 -070011172 inputs = get_lstm_inputs('cpu', training=True, seq_length=10)
Nikolay Korovaiko5b702ab2019-11-11 13:39:03 -080011173 slstm(*inputs, profile_and_replay=True).sum().backward(retain_graph=True)
11174 if GRAPH_EXECUTOR == ProfilingMode.PROFILING:
11175 slstm(*inputs, profile_and_replay=True).sum().backward()
Adam Paszkea6036892018-11-26 09:18:43 -080011176
Nikolay Korovaiko5b702ab2019-11-11 13:39:03 -080011177 fw_graph = slstm.graph_for(*inputs)
Nikolay Korovaiko5b702ab2019-11-11 13:39:03 -080011178 if GRAPH_EXECUTOR == ProfilingMode.LEGACY:
Nikolay Korovaiko78bd0062019-11-13 14:20:40 -080011179 bw_graph = backward_graph(slstm, diff_graph_idx=0)
Shen Li10224432021-08-12 11:39:31 -070011180 self.assertTrue('prim::MMBatchSide' in str(fw_graph))
11181 self.assertTrue('prim::MMTreeReduce' in str(bw_graph))
Nikolay Korovaiko5b702ab2019-11-11 13:39:03 -080011182
11183 sout = slstm(*inputs)
11184 out = lstm(*inputs)
Hongyu Caide27f422020-02-07 10:39:39 -080011185 self.assertEqual(sout, out)
Shen Li10224432021-08-12 11:39:31 -070011186 self.assertEqual(torch.autograd.grad(sout.sum(), inputs),
11187 torch.autograd.grad(out.sum(), inputs))
Adam Paszkea6036892018-11-26 09:18:43 -080011188
Adam Paszkef45a3d52018-06-06 09:36:12 +020011189 def test_loop_unrolling(self):
11190 def fn(x):
Richard Zou67f6f932018-08-27 08:53:56 -070011191 y = 0
Zachary DeVitoa9492452018-07-23 13:58:32 -070011192 for i in range(int(x)):
Elias Ellison7fa996f2019-03-06 11:42:19 -080011193 y -= i
Adam Paszkef45a3d52018-06-06 09:36:12 +020011194 return y
11195
James Reed0b16b032018-07-25 16:55:09 -070011196 graph = torch.jit.script(fn).graph
Shen Li10224432021-08-12 11:39:31 -070011197 self.run_pass('loop_unrolling', graph)
Elias Ellison7fa996f2019-03-06 11:42:19 -080011198 unroll_factor = 8
Shen Li10224432021-08-12 11:39:31 -070011199 FileCheck().check("prim::Loop").check_count("aten::sub", unroll_factor) \
11200 .check("prim::Loop").check("aten::sub").run(str(graph))
Adam Paszkef45a3d52018-06-06 09:36:12 +020011201 self.checkScript(fn, (torch.tensor(10),))
11202
11203 def test_loop_unrolling_const(self):
11204 def fn():
Richard Zou67f6f932018-08-27 08:53:56 -070011205 y = 0
James Reedde6bb3f2019-01-26 17:39:34 -080011206 for _ in range(10):
Elias Ellison7fa996f2019-03-06 11:42:19 -080011207 y -= 1
Adam Paszkef45a3d52018-06-06 09:36:12 +020011208 return y
11209
11210 def fn2():
Richard Zou67f6f932018-08-27 08:53:56 -070011211 y = 0
Adam Paszkef45a3d52018-06-06 09:36:12 +020011212 for i in range(10):
Elias Ellison7fa996f2019-03-06 11:42:19 -080011213 y -= i
Adam Paszkef45a3d52018-06-06 09:36:12 +020011214 return y
11215
11216 def check(fn, name):
James Reed0b16b032018-07-25 16:55:09 -070011217 graph = torch.jit.script(fn).graph
Shen Li10224432021-08-12 11:39:31 -070011218 self.run_pass('loop_unrolling', graph)
Elias Ellison7fa996f2019-03-06 11:42:19 -080011219 # entirely unrolled
11220 FileCheck().check_not("prim::Loop'").run(str(graph))
Adam Paszkef45a3d52018-06-06 09:36:12 +020011221 self.checkScript(fn, ())
11222
Shen Li10224432021-08-12 11:39:31 -070011223 check(fn, 'add_const')
11224 check(fn2, 'add_iter')
Adam Paszkef45a3d52018-06-06 09:36:12 +020011225
11226 def test_loop_unrolling_nested(self):
11227 def fn(x):
Richard Zou67f6f932018-08-27 08:53:56 -070011228 y = 0
James Reedde6bb3f2019-01-26 17:39:34 -080011229 for _ in range(10):
Zachary DeVitoa9492452018-07-23 13:58:32 -070011230 for j in range(int(x)):
Elias Ellison7fa996f2019-03-06 11:42:19 -080011231 y -= j
Adam Paszkef45a3d52018-06-06 09:36:12 +020011232 return y
11233
James Reed0b16b032018-07-25 16:55:09 -070011234 graph = torch.jit.script(fn).graph
Shen Li10224432021-08-12 11:39:31 -070011235 self.run_pass('loop_unrolling', graph)
Elias Ellison7fa996f2019-03-06 11:42:19 -080011236 # inner loop with 8 subs followed by loop epilogue
11237 unroll_factor = 8
Shen Li10224432021-08-12 11:39:31 -070011238 FileCheck().check("prim::Loop").check("prim::Loop").check_count('aten::sub', unroll_factor) \
11239 .check("prim::Loop").check("aten::sub").run(str(graph))
Adam Paszkef45a3d52018-06-06 09:36:12 +020011240 self.checkScript(fn, (torch.tensor(10),))
11241
11242 def test_loop_unroll_unused_counter(self):
11243 def fn(x):
Richard Zou67f6f932018-08-27 08:53:56 -070011244 y = 0
James Reedde6bb3f2019-01-26 17:39:34 -080011245 for _ in range(int(x)):
Elias Ellison7fa996f2019-03-06 11:42:19 -080011246 y -= 1
Adam Paszkef45a3d52018-06-06 09:36:12 +020011247 return y
11248
James Reed0b16b032018-07-25 16:55:09 -070011249 graph = torch.jit.script(fn).graph
Shen Li10224432021-08-12 11:39:31 -070011250 self.run_pass('loop_unrolling', graph)
11251 FileCheck().check("prim::Loop").check_not("aten::add").check("return") \
11252 .run(str(graph))
Adam Paszkef45a3d52018-06-06 09:36:12 +020011253
11254 def test_loop_unroll_negative(self):
11255 def fn(x):
Richard Zou67f6f932018-08-27 08:53:56 -070011256 y = 0
James Reedde6bb3f2019-01-26 17:39:34 -080011257 for _ in range(int(x)):
Adam Paszkef45a3d52018-06-06 09:36:12 +020011258 y += 1
11259 return y
11260
11261 self.checkScript(fn, (torch.tensor(-20),))
11262 self.checkScript(fn, (torch.tensor(-2),))
11263 self.checkScript(fn, (torch.tensor(-1),))
11264 self.checkScript(fn, (torch.tensor(0),))
11265 self.checkScript(fn, (torch.tensor(1),))
11266 self.checkScript(fn, (torch.tensor(2),))
11267
James Reed544605d2018-06-14 16:46:08 -070011268 def test_where(self):
11269 def fn(x, y):
11270 return torch.where(x > 0.0, x, y)
11271
Shen Li10224432021-08-12 11:39:31 -070011272 self.checkScript(fn, (torch.randn(3, 2, dtype=torch.float), torch.ones(3, 2, dtype=torch.float)))
James Reed544605d2018-06-14 16:46:08 -070011273
Zachary DeVitoc8ac8782018-10-16 21:03:18 -070011274 def test_where_method(self):
11275 def fn(x, y):
11276 return x.where(x > 0.0, y)
11277
Shen Li10224432021-08-12 11:39:31 -070011278 self.checkScript(fn, (torch.randn(3, 2, dtype=torch.float), torch.ones(3, 2, dtype=torch.float)))
Zachary DeVitoc8ac8782018-10-16 21:03:18 -070011279
Tugsbayasgalan (Tugsuu) Manlaibaatar2ea70a62021-12-02 10:50:25 -080011280 def test_union_to_number(self):
11281 @torch.jit.script
11282 def fn(x: Union[int, complex, float], y: Union[int, complex, float]):
11283 return x + y
11284 FileCheck().check(": Scalar):").run(fn.graph)
11285
James Reed80292962018-06-20 14:51:53 -070011286 def test_reassign_module_lhs(self):
Shen Li10224432021-08-12 11:39:31 -070011287 with self.assertRaisesRegex(RuntimeError, 'Cannot re-assign \'self\''):
James Reed80292962018-06-20 14:51:53 -070011288 class ReassignSelfLHS(torch.jit.ScriptModule):
11289 @torch.jit.script_method
11290 def forward(self, x):
James Reedde6bb3f2019-01-26 17:39:34 -080011291 for _ in range(20):
James Reed80292962018-06-20 14:51:53 -070011292 self = x
11293 return self
11294
11295 ReassignSelfLHS()
11296
11297 def test_reassign_module_rhs(self):
Shen Li10224432021-08-12 11:39:31 -070011298 with self.assertRaisesRegex(RuntimeError, 'Cannot re-assign \'x\' to a value of type module'):
James Reed80292962018-06-20 14:51:53 -070011299 class ReassignSelfRHS(torch.jit.ScriptModule):
11300 @torch.jit.script_method
11301 def forward(self, x):
James Reedde6bb3f2019-01-26 17:39:34 -080011302 for _ in range(20):
James Reed80292962018-06-20 14:51:53 -070011303 x = self
11304 return self
11305
11306 ReassignSelfRHS()
11307
James Reed80292962018-06-20 14:51:53 -070011308 def test_unknown_builtin(self):
Shen Li10224432021-08-12 11:39:31 -070011309 with self.assertRaisesRegex(RuntimeError, 'object has no attribute or method'):
James Reed80292962018-06-20 14:51:53 -070011310 @torch.jit.script
11311 def unknown_builtin(x):
11312 return x.splork(3)
11313
Adam Paszke1f134532018-07-31 14:23:33 -070011314 def test_return_tuple(self):
Zachary DeVito6ce799e2018-08-27 14:30:25 -070011315 def return_tuple(x):
11316 a = (x, x)
11317 return a, x
11318 self.checkScript(return_tuple, (torch.rand(4),))
James Reed80292962018-06-20 14:51:53 -070011319
Tugsbayasgalan Manlaibaatar076961e2021-04-09 10:22:14 -070011320 def test_add_tuple_optional(self):
Shen Li10224432021-08-12 11:39:31 -070011321 def foo(input: Tuple[torch.Tensor, Optional[torch.Tensor], Optional[torch.Tensor]]) -> Optional[torch.Tensor]:
Tugsbayasgalan Manlaibaatar076961e2021-04-09 10:22:14 -070011322 changed_input = input[0] + 1
Shen Li10224432021-08-12 11:39:31 -070011323 value: Tuple[torch.Tensor, Optional[torch.Tensor], Optional[torch.Tensor]] = (changed_input,) + input[1:]
Tugsbayasgalan Manlaibaatar076961e2021-04-09 10:22:14 -070011324 return value[2]
Shen Li10224432021-08-12 11:39:31 -070011325 inp: Tuple[torch.Tensor, Optional[torch.Tensor], Optional[torch.Tensor]] = (torch.rand(4), None, None)
Tugsbayasgalan Manlaibaatar076961e2021-04-09 10:22:14 -070011326 self.checkScript(foo, (inp,))
11327
11328 def test_add_tuple_non_optional(self):
11329 def foo(input: Tuple[torch.Tensor, torch.Tensor, torch.Tensor]) -> torch.Tensor:
11330 changed_input = input[0] + 1
Shen Li10224432021-08-12 11:39:31 -070011331 value: Tuple[torch.Tensor, torch.Tensor, torch.Tensor] = (changed_input,) + input[1:]
Tugsbayasgalan Manlaibaatar076961e2021-04-09 10:22:14 -070011332 return torch.sum(value[2]) + 4
Shen Li10224432021-08-12 11:39:31 -070011333 inp: Tuple[torch.Tensor, torch.Tensor, torch.Tensor] = (torch.rand(4), torch.rand(4), torch.rand(4))
Tugsbayasgalan Manlaibaatar076961e2021-04-09 10:22:14 -070011334 self.checkScript(foo, (inp,))
11335
11336 def test_add_tuple_different_types(self):
11337 def foo(a: Tuple[int, float], b: Tuple[int]) -> int:
11338 c: Tuple[int, float, int] = a + b
11339 d: Tuple[int, float, int, int] = c + b
11340 return d[3] + 1
11341 a = (1, 2.0)
11342 b = (3,)
11343 self.checkScript(foo, (a, b))
11344
11345 def test_add_tuple_same_types(self):
11346 def foo(a: Tuple[int, int], b: Tuple[int, int, int]) -> int:
11347 c: Tuple[int, int, int, int, int] = a + b
11348 d: Tuple[int, int, int, int, int, int, int, int] = c + b
11349 return d[6] - 2
11350 a = (1, 2)
11351 b = (3, 4, 5)
11352 self.checkScript(foo, (a, b))
11353
James Reed80292962018-06-20 14:51:53 -070011354 def test_method_no_self(self):
Shen Li10224432021-08-12 11:39:31 -070011355 with self.assertRaisesRegex(RuntimeError, 'methods must have a self argument'):
James Reed80292962018-06-20 14:51:53 -070011356 class MethodNoSelf(torch.jit.ScriptModule):
Edward Yangba810742019-03-21 09:06:30 -070011357 @torch.jit.script_method # noqa: B902
Michael Suo62b10722019-12-06 17:48:20 -080011358 def forward(): # noqa: B902
James Reed80292962018-06-20 14:51:53 -070011359 return torch.zeros(3, 4)
11360
11361 MethodNoSelf()
11362
11363 def test_return_stmt_not_at_end(self):
Zachary DeVito6bf05bf2018-12-21 13:46:12 -080011364 def return_stmt(x):
11365 if bool(x > 3):
11366 return x + 3
11367 else:
11368 return x
11369 self.checkScript(return_stmt, (torch.rand(1),))
James Reed80292962018-06-20 14:51:53 -070011370
Wanchao Liang45b91bd2019-06-22 01:34:56 -070011371 def test_for_in_range(self):
11372 def fn():
11373 c = 0
11374 for i in range(100):
11375 c += i
11376 return c
David Riazatidefd23b2019-06-25 16:17:49 -070011377 self.checkScript(fn, ())
Wanchao Liang45b91bd2019-06-22 01:34:56 -070011378
11379 def test_for_in_range_dynamic(self):
11380 def fn():
11381 c = 0
11382 for i in range(100):
11383 acc = 0
11384 for j in range(i):
11385 acc += j
11386 c += acc
11387 return c
11388 self.checkScript(fn, (), optimize=False)
11389
11390 def test_for_in_range_ast(self):
Wanchao Liang45b91bd2019-06-22 01:34:56 -070011391 def test_script_for_in_range_ast():
11392 c = 0
11393 for i in range(100):
11394 acc = 0
11395 for j in range(i):
11396 acc += j
11397 c += acc
11398 return c
11399
David Riazatidefd23b2019-06-25 16:17:49 -070011400 self.checkScript(test_script_for_in_range_ast, ())
Wanchao Liang45b91bd2019-06-22 01:34:56 -070011401
11402 def test_for_in_range_if_ast(self):
11403 @torch.jit.script
11404 def test_script_for_in_range_if_ast(x):
11405 output = x
11406 for i in range(20):
11407 if i == 0:
11408 output = x.unsqueeze(0)
11409 else:
11410 output = torch.cat((output, x.unsqueeze(0)), dim=0)
11411 return output
11412 inputs = self._make_scalar_vars([0], torch.int64)
11413
11414 self.assertEqual(test_script_for_in_range_if_ast(*inputs).shape[0], 20)
11415
11416 def test_for_in_range_start_end(self):
11417 def fn():
11418 x = 0
11419 for i in range(7, 100):
11420 x += i
11421 return x
David Riazatidefd23b2019-06-25 16:17:49 -070011422 self.checkScript(fn, ())
Wanchao Liang45b91bd2019-06-22 01:34:56 -070011423
11424 def test_for_in_range_start_end_step(self):
11425 def fn(start, end, step):
11426 # type: (int, int, int) -> int
11427 x = 0
11428 for i in range(start, end, step):
11429 x += i
11430 return x
11431
David Riazatidefd23b2019-06-25 16:17:49 -070011432 self.checkScript(fn, (7, 100, 7))
11433 self.checkScript(fn, (7, 100, -7))
11434 self.checkScript(fn, (2, -11, -3))
11435 self.checkScript(fn, (2, -11, 3))
11436 self.checkScript(fn, (2, 10, 3))
11437 self.checkScript(fn, (-2, -10, -10))
Wanchao Liang45b91bd2019-06-22 01:34:56 -070011438
11439 def test_for_in_range_zero_step(self):
11440 @torch.jit.script
11441 def fn():
11442 x = 0
11443 for i in range(2, -11, 0):
11444 x += i
11445 return x
David Riazatidefd23b2019-06-25 16:17:49 -070011446
Wanchao Liang45b91bd2019-06-22 01:34:56 -070011447 with self.assertRaisesRegex(RuntimeError, "must not be zero"):
11448 fn()
11449
Elias Ellisonff8b7ef2019-09-27 17:11:43 -070011450 def test_range_args(self):
Shen Li10224432021-08-12 11:39:31 -070011451 with self.assertRaisesRegex(RuntimeError, r'range expected at least 1 arguments, got 0'):
James Reed80292962018-06-20 14:51:53 -070011452 @torch.jit.script
11453 def range_no_arg(x):
James Reedde6bb3f2019-01-26 17:39:34 -080011454 for _ in range():
James Reed80292962018-06-20 14:51:53 -070011455 x += 1
11456 return x
Shen Li10224432021-08-12 11:39:31 -070011457 with self.assertRaisesRegex(RuntimeError, r'found float'):
Elias Ellisonff8b7ef2019-09-27 17:11:43 -070011458 @torch.jit.script
11459 def range_non_float():
Shen Li10224432021-08-12 11:39:31 -070011460 for i in range(.5):
Elias Ellisonff8b7ef2019-09-27 17:11:43 -070011461 print(i)
11462
James Reed1e9ad6e2021-03-26 11:28:42 -070011463 def test_parse_empty_tuple_annotation(self):
Shen Li10224432021-08-12 11:39:31 -070011464 cu = torch.jit.CompilationUnit('''
James Reed1e9ad6e2021-03-26 11:28:42 -070011465 def foo(x : Tuple[()]) -> Tuple[()]:
11466 return x
Shen Li10224432021-08-12 11:39:31 -070011467 ''')
James Reed1e9ad6e2021-03-26 11:28:42 -070011468
Shen Li10224432021-08-12 11:39:31 -070011469 foo_code = cu.find_function('foo').code
James Reeda3c06e62021-04-12 17:21:52 -070011470 FileCheck().check("Tuple[()]").check("Tuple[()]").run(foo_code)
James Reed1e9ad6e2021-03-26 11:28:42 -070011471
11472 def test_parse_empty_tuple_annotation_element_error(self):
11473 with self.assertRaisesRegex(
Shen Li10224432021-08-12 11:39:31 -070011474 RuntimeError, 'Tuple literal in Tuple type annotation must not have any elements'):
11475 cu = torch.jit.CompilationUnit('''
James Reed1e9ad6e2021-03-26 11:28:42 -070011476 def foo(x : Tuple[(int,)]) -> Tuple[(int,)]:
11477 return x
Shen Li10224432021-08-12 11:39:31 -070011478 ''')
James Reed1e9ad6e2021-03-26 11:28:42 -070011479
James Reed3db23332021-03-26 11:28:42 -070011480 def test_parse_none_type_annotation(self):
Shen Li10224432021-08-12 11:39:31 -070011481 cu = torch.jit.CompilationUnit('''
James Reed3db23332021-03-26 11:28:42 -070011482 def foo(x : NoneType) -> NoneType:
11483 return x
Shen Li10224432021-08-12 11:39:31 -070011484 ''')
James Reed3db23332021-03-26 11:28:42 -070011485
Shen Li10224432021-08-12 11:39:31 -070011486 foo_code = cu.find_function('foo').code
James Reed68e07962021-04-12 17:33:20 -070011487 FileCheck().check(": NoneType").check("-> NoneType").run(foo_code)
James Reed3db23332021-03-26 11:28:42 -070011488
James Reeda3c06e62021-04-12 17:21:52 -070011489 def test_empty_tuple_str(self):
11490 empty_tuple_type = torch._C.TupleType([])
Shen Li10224432021-08-12 11:39:31 -070011491 g = {'Tuple' : typing.Tuple}
James Reeda3c06e62021-04-12 17:21:52 -070011492 python_type = eval(empty_tuple_type.annotation_str, g)
11493 assert python_type is typing.Tuple[()]
11494
James Reed68e07962021-04-12 17:33:20 -070011495 def test_none_type_str(self):
11496 none_type = torch._C.NoneType.get()
Shen Li10224432021-08-12 11:39:31 -070011497 g = {'NoneType' : type(None)}
James Reed68e07962021-04-12 17:33:20 -070011498 python_type = eval(none_type.annotation_str, g)
11499 assert python_type is type(None)
11500
Animesh Jain1d90d6e2022-07-07 18:57:31 +000011501 @skipIfTorchDynamo("TorchDynamo fails with unknown reason")
Elias Ellisonfdeef452019-11-04 09:18:09 -080011502 def test_zip_enumerate_modulelist(self):
11503 class Sub(torch.nn.Module):
Elias Ellisonfdeef452019-11-04 09:18:09 -080011504 def forward(self, thing):
11505 return thing - 2
11506
11507 class Double(torch.nn.Module):
Elias Ellisonfdeef452019-11-04 09:18:09 -080011508 def forward(self, thing):
11509 return thing * 2
11510
11511 # zipping over two
11512 class ZipModLists(torch.nn.Module):
Elias Ellisonfdeef452019-11-04 09:18:09 -080011513 def __init__(self, mods, mods2):
Xuehai Pan046e88a2023-02-12 22:20:50 +000011514 super().__init__()
Elias Ellisonfdeef452019-11-04 09:18:09 -080011515 self.mods = mods
11516 self.mods2 = mods2
11517
11518 def forward(self, x):
11519 iter = 0
11520 for mod1, mod2 in zip(self.mods, self.mods2):
11521 x = mod2(mod1(x))
11522 iter += 1
11523 return x, iter
11524
11525 class ZipWithValues(torch.nn.Module):
Shen Li10224432021-08-12 11:39:31 -070011526 __constants__ = ['tup_larger', 'tup_smaller']
Elias Ellisonfdeef452019-11-04 09:18:09 -080011527
11528 def __init__(self, mods, mods2):
Xuehai Pan046e88a2023-02-12 22:20:50 +000011529 super().__init__()
Elias Ellisonfdeef452019-11-04 09:18:09 -080011530 self.mods = mods
11531 self.mods2 = mods2
11532 self.tup_larger = list(range(len(mods2) + 1))
11533 self.tup_smaller = list(range(max(len(mods2) + 1, 1)))
11534
11535 def forward(self, x):
11536 iter = 0
11537 x2 = x
11538 for val, mod1, mod2 in zip(self.tup_larger, self.mods, self.mods2):
11539 x = mod2(mod1(x)) + val
11540 iter += 1
11541 for val, mod1, mod2 in zip(self.tup_smaller, self.mods, self.mods2):
11542 x2 = mod2(mod1(x2)) + val
11543 iter += 1
11544 return x, iter
11545
Shen Li10224432021-08-12 11:39:31 -070011546 mods = nn.ModuleList([Double()]), nn.ModuleList([Double(), Sub(), Sub()]), nn.ModuleList([Sub(), Double()])
Elias Ellisonfdeef452019-11-04 09:18:09 -080011547 for i in range(len(mods)):
11548 for j in range(len(mods)):
11549 mod = ZipModLists(mods[i], mods[j])
Shen Li10224432021-08-12 11:39:31 -070011550 self.checkModule(mod, (torch.tensor(.5),))
Elias Ellisonfdeef452019-11-04 09:18:09 -080011551 mod2 = ZipWithValues(mods[i], mods[j])
Shen Li10224432021-08-12 11:39:31 -070011552 self.checkModule(mod2, (torch.tensor(.5),))
11553
Elias Ellisonfdeef452019-11-04 09:18:09 -080011554
11555 def test_enumerate_modlist_range(self):
11556 class Double(torch.nn.Module):
11557 def forward(self, thing):
11558 return thing * 2
11559
11560 class Mod(torch.nn.Module):
Elias Ellisonfdeef452019-11-04 09:18:09 -080011561 def __init__(self):
Xuehai Pan046e88a2023-02-12 22:20:50 +000011562 super().__init__()
Elias Ellisonfdeef452019-11-04 09:18:09 -080011563 self.mods = nn.ModuleList([Double(), Double()])
11564
11565 def forward(self, x):
11566 x2 = x
11567 iter = 0
11568 for val, mod in enumerate(self.mods):
11569 x2 = mod(x2) * val
11570 iter += 1
11571 return iter, x, x2
11572
Shen Li10224432021-08-12 11:39:31 -070011573 self.checkModule(Mod(), (torch.tensor(.5),))
Elias Ellisonfdeef452019-11-04 09:18:09 -080011574
ettiee2cf576e2020-03-11 10:58:25 -070011575 # variable length, modulelist
Elias Ellisonfdeef452019-11-04 09:18:09 -080011576 class Mod2(Mod):
11577 def forward(self, x):
11578 for val, mod in zip(range(int(x)), self.mods):
11579 x = mod(x) * val
11580 return x
11581
Shen Li10224432021-08-12 11:39:31 -070011582 with self.assertRaisesRegex(Exception, "that does not have a statically determinable length"):
Elias Ellisonfdeef452019-11-04 09:18:09 -080011583 torch.jit.script(Mod2())
11584
11585 # modulelist, variable length
11586 class Mod3(Mod):
11587 def forward(self, x):
11588 for val, mod in zip(self.mods, range(int(x))):
11589 x = mod(x) * val
11590 return x
11591
Shen Li10224432021-08-12 11:39:31 -070011592 with self.assertRaisesRegex(Exception, "that does not have a statically determinable length"):
Elias Ellisonfdeef452019-11-04 09:18:09 -080011593 torch.jit.script(Mod3())
James Reed80292962018-06-20 14:51:53 -070011594
Wanchao Liange0f5ab22019-06-22 00:57:24 -070011595 def test_for_in_enumerate(self):
11596 def fn(x):
11597 # type: (List[int]) -> int
11598 sum = 0
11599 for (i, v) in enumerate(x):
11600 sum += i * v
11601
11602 return sum
11603
11604 self.checkScript(fn, ([1, 2, 3, 4, 5],))
11605
Yu Guo4c04f6d2022-06-29 19:37:51 -070011606 def fn_enumerate_start_arg(x):
11607 # type: (List[int]) -> int
11608 sum = 0
11609 for (i, v) in enumerate(x, 1):
11610 sum += i * v
11611
11612 return sum
11613
11614 self.checkScript(fn_enumerate_start_arg, ([1, 2, 3, 4, 5],))
11615
11616 def fn_enumerate_start_kwarg(x):
Wanchao Liange0f5ab22019-06-22 00:57:24 -070011617 # type: (List[int]) -> int
11618 sum = 0
11619 for (i, v) in enumerate(x, start=1):
11620 sum += i * v
11621
11622 return sum
11623
Yu Guo4c04f6d2022-06-29 19:37:51 -070011624 self.checkScript(fn_enumerate_start_kwarg, ([1, 2, 3, 4, 5],))
Wanchao Liange0f5ab22019-06-22 00:57:24 -070011625
11626 def fn_nested_enumerate(x):
11627 # type: (List[int]) -> int
11628 sum = 0
11629 for (i, (j, v)) in enumerate(enumerate(x)):
11630 sum += i * j * v
11631
11632 return sum
11633
Yu Guo4c04f6d2022-06-29 19:37:51 -070011634 self.checkScript(fn_nested_enumerate, ([1, 2, 3, 4, 5],))
Wanchao Liange0f5ab22019-06-22 00:57:24 -070011635
Shen Li10224432021-08-12 11:39:31 -070011636 with self.assertRaisesRegex(RuntimeError, r'enumerate expected at least 1 arguments, got 0'):
Wanchao Liange0f5ab22019-06-22 00:57:24 -070011637 @torch.jit.script
11638 def enumerate_no_arg(x):
11639 # type: (List[int]) -> int
11640 sum = 0
11641 for _ in enumerate():
11642 sum += 1
11643
11644 return sum
11645
Shen Li10224432021-08-12 11:39:31 -070011646 with self.assertRaisesRegex(RuntimeError, r'enumerate expected at most 2 arguments, got 3'):
Wanchao Liange0f5ab22019-06-22 00:57:24 -070011647 @torch.jit.script
11648 def enumerate_too_many_args(x):
11649 # type: (List[int]) -> int
11650 sum = 0
11651 for _ in enumerate(x, x, x):
11652 sum += 1
11653
11654 return sum
11655
Elias Ellison60cb56d2019-11-04 14:37:23 -080011656 def test_list_comprehension_modulelist(self):
11657 class Inner(torch.nn.Module):
11658 def forward(self, x):
11659 return x + 10
11660
11661 class M(torch.nn.Module):
Elias Ellison60cb56d2019-11-04 14:37:23 -080011662 def __init__(self, mod_list):
Xuehai Pan046e88a2023-02-12 22:20:50 +000011663 super().__init__()
Elias Ellison60cb56d2019-11-04 14:37:23 -080011664 self.module_list = mod_list
11665
11666 def forward(self, x):
Shen Li10224432021-08-12 11:39:31 -070011667 out = torch.jit.annotate(List[Tensor], [mod(x) for mod in self.module_list])
Elias Ellison60cb56d2019-11-04 14:37:23 -080011668 return out
11669
11670 mod = M(nn.ModuleList([Inner(), Inner()]))
11671 self.checkModule(mod, (torch.tensor(3),))
11672
11673 mod = M(nn.ModuleList([]))
11674 torch.jit.script(mod)
11675
11676 class M2(M):
11677 def __init__(self, mod_list):
Xuehai Pan046e88a2023-02-12 22:20:50 +000011678 super().__init__(mod_list)
Elias Ellison60cb56d2019-11-04 14:37:23 -080011679
11680 def forward(self, x):
11681 out = [mod(x) for mod in self.module_list]
11682 return out
11683
11684 mod = M2(nn.ModuleList([Inner(), Inner()]))
11685 self.checkModule(mod, (torch.tensor(3),))
11686
11687 mod = M2(nn.ModuleList([]))
11688 # defaults to List of Tensor for empty modulelist
Shen Li10224432021-08-12 11:39:31 -070011689 self.assertEqual(torch.jit.script(mod)(torch.tensor(.5)), [])
Elias Ellison60cb56d2019-11-04 14:37:23 -080011690
11691 def bad_type_annotation():
Hong Xua6a72ac2020-02-21 08:29:32 -080011692 out = torch.jit.annotate(int, [x for x in [1, 2, 3]]) # noqa: C416
Elias Ellison60cb56d2019-11-04 14:37:23 -080011693 return out
11694
Ansley Usseryc60075d2021-09-10 16:18:33 -070011695 with self.assertRaisesRegex(Exception, "Expected an annotation"
11696 " of type List"):
Elias Ellison60cb56d2019-11-04 14:37:23 -080011697 torch.jit.script(bad_type_annotation)
11698
Elias Ellison9ada7ab2020-04-08 09:57:38 -070011699 def test_list_comprehension_variable_write(self):
11700 # i in comprehension doesn't write to function scope
11701 def foo():
11702 i = 1
11703 x = [i if i != 5 else 3 for i in range(7)] # noqa: C416
11704 return i, x
11705
11706 self.assertEqual(foo(), torch.jit.script(foo)())
11707
Wanchao Liange0f5ab22019-06-22 00:57:24 -070011708 def test_for_in_zip(self):
11709 def fn(x, y):
11710 # type: (List[int], List[int]) -> int
11711 sum = 0
11712 for (i, j) in zip(x, y):
11713 sum += i * j
11714
11715 return sum
11716
11717 self.checkScript(fn, ([1, 2, 3, 4, 5], [2, 3, 4, 5, 6]))
11718
11719 def fn_multi_inputs(x, y, z):
11720 # type: (List[int], List[int], List[int]) -> int
11721 sum = 0
11722 for (i, j, k) in zip(x, y, z):
11723 sum += i * j * k
11724
11725 return sum
11726
11727 self.checkScript(fn_multi_inputs, ([1, 2, 3, 4], [2, 3, 4, 5], [3, 4, 5, 6]))
11728
11729 def fn_nested_zip(x, y, z):
11730 # type: (List[int], List[int], List[int]) -> int
11731 sum = 0
11732 for (i, (j, k)) in zip(x, zip(y, z)):
David Riazatidefd23b2019-06-25 16:17:49 -070011733 sum += i * j * k
Wanchao Liange0f5ab22019-06-22 00:57:24 -070011734
11735 return sum
11736
11737 self.checkScript(fn_multi_inputs, ([1, 2, 3, 4], [2, 3, 4, 5], [3, 4, 5, 6]))
11738
Shen Li10224432021-08-12 11:39:31 -070011739 with self.assertRaisesRegex(RuntimeError, r'zip expected at least 1 arguments, got 0'):
Wanchao Liange0f5ab22019-06-22 00:57:24 -070011740 @torch.jit.script
11741 def zip_no_arg(x):
11742 # type: (List[int]) -> int
11743 sum = 0
11744 for _ in zip():
11745 sum += 1
11746
11747 return sum
11748
Shen Li10224432021-08-12 11:39:31 -070011749 with self.assertRaisesRegex(RuntimeError, r'too many values to unpack: need 2 but found 3'):
Wanchao Liange0f5ab22019-06-22 00:57:24 -070011750 @torch.jit.script
11751 def fn_nested_zip_wrong_target_assign(x, y, z):
11752 # type: (List[int], List[int], List[int]) -> int
11753 sum = 0
11754 for (i, (j, k)) in zip(x, y, z):
David Riazatidefd23b2019-06-25 16:17:49 -070011755 sum += i * j * k
Wanchao Liange0f5ab22019-06-22 00:57:24 -070011756
11757 return sum
11758
11759 def test_for_in_zip_enumerate(self):
11760 def fn_zip_enumerate(x, y):
11761 # type: (List[int], List[int]) -> int
11762 sum = 0
11763 for (i, (j, v), k) in zip(x, enumerate(y), range(0, 100)):
11764 sum += i * j * v * k
11765
11766 return sum
11767
11768 self.checkScript(fn_zip_enumerate, ([1, 2, 3, 4], [2, 3, 4, 5]))
11769
11770 def fn_enumerate_zip(x, y):
11771 # type: (List[int], List[int]) -> int
11772 sum = 0
11773 for (i, (j, v)) in enumerate(zip(x, y)):
11774 sum += i * j * v
11775
11776 return sum
11777
11778 self.checkScript(fn_enumerate_zip, ([1, 2, 3, 4], [2, 3, 4, 5]))
11779
Wanchao Liang45b91bd2019-06-22 01:34:56 -070011780 def test_for_in_tensors(self):
11781 def test_sizes(x):
11782 sumz = 0
11783 for s in x:
11784 sumz += 1
11785 return sumz
11786 self.checkScript(test_sizes, (torch.rand(5, 4, 3, 2, 1),))
11787 self.checkScript(test_sizes, (torch.rand(777),))
11788 self.checkScript(test_sizes, (torch.rand(0),))
11789
11790 def test_for_in_tensors_rank0(self):
11791 with self.assertRaisesRegex(RuntimeError, "of a 0-d tensor"):
11792 @torch.jit.script
11793 def test_sizes(x):
11794 sumz = 0
11795 for s in x:
11796 sumz += 1
11797 return sumz
11798
11799 test_sizes(torch.tensor(1))
11800
11801 def test_for_in_tensors_fail_scalar(self):
11802 with self.assertRaisesRegex(RuntimeError, "'float' object is not iterable"):
11803 @torch.jit.script
11804 def test_sizes(x):
11805 # type: (float) -> int
11806 sumz = 0
Sam Estepe3900d22021-04-19 13:14:27 -070011807 for s in x:
Wanchao Liang45b91bd2019-06-22 01:34:56 -070011808 sumz += 1
11809 return sumz
11810
11811 test_sizes(0.0)
11812
11813 def test_for_in_tensors_nested(self):
11814 def test_sizes(x):
11815 sumz = 0
11816 for n in x:
11817 for t in n:
11818 sumz += 1
11819 return sumz
11820
11821 self.checkScript(test_sizes, (torch.rand(5, 4, 3, 2, 1),))
11822
Yanan Caofa29a642021-04-11 14:20:09 -070011823 # to avoid defining sum_list in multiple tests
11824 def get_sum_list_fn(self):
11825 def sum_list(a):
11826 # type: (List[int]) -> int
11827 sum = 0
11828 for i in a:
11829 sum += i
11830
11831 return sum
11832
11833 return sum_list
11834
11835 def test_sum_list_diff_elms(self):
11836 self.checkScript(self.get_sum_list_fn(), ([1, 2, 3, 4, 5],))
11837
11838 def test_sum_list_empty(self):
11839 self.checkScript(self.get_sum_list_fn(), ([],))
11840
11841 def test_sum_list_one(self):
11842 self.checkScript(self.get_sum_list_fn(), ([1],))
11843
11844 def test_sum_list_literal(self):
Shen Li10224432021-08-12 11:39:31 -070011845
Yanan Caofa29a642021-04-11 14:20:09 -070011846 def sum_list():
11847 # type: () -> int
11848 sum = 0
11849 for i in [1, 2, 3, 4, 5]:
11850 sum += i
11851
11852 return sum
11853
11854 self.checkScript(sum_list, ())
11855
11856 def test_sum_list_wrong_type(self):
11857
11858 with self.assertRaisesRegex(RuntimeError, "'int' object is not iterable"):
11859 @torch.jit.script
11860 def sum_list(a):
11861 # type: (int) -> int
11862 sum = 0
11863 for i in a: # noqa: T484
11864 sum += i
11865
11866 return sum
11867
11868 sum_list(1)
11869
11870 def test_list_iterables(self):
Shen Li10224432021-08-12 11:39:31 -070011871 with self.assertRaisesRegex(RuntimeError, 'List of iterables is not supported currently'):
11872 cu = torch.jit.CompilationUnit('''
Yanan Caofa29a642021-04-11 14:20:09 -070011873 def list_iterables(x):
11874 for i, j in [2, 3, 4], [5, 6, 7]:
11875 x += i
11876 x += j
11877 return x
Shen Li10224432021-08-12 11:39:31 -070011878 ''')
Yanan Caofa29a642021-04-11 14:20:09 -070011879
11880 def test_for_in_string(self):
11881 def test_strings(x):
11882 # type: (str) -> str
11883 reverse = ""
11884 for c in x:
11885 reverse = c + reverse
11886 return reverse
11887
11888 self.checkScript(test_strings, ("hello",))
11889 self.checkScript(test_strings, ("",))
11890
11891 def test_list_strings(x):
11892 # type: (List[str]) -> str
11893 result = ""
11894 for sub_str in x:
11895 result += sub_str
11896 return result
11897
11898 self.checkScript(test_list_strings, (["hello", "world"],))
11899 self.checkScript(test_list_strings, (["hello", " ", "world", ""],))
11900
11901 def test_for_in_dict(self):
11902 def test_dicts(x):
11903 # type: (Dict[str, int]) -> int
11904 sum = 0
11905 for key in x:
11906 sum += x[key]
11907 return sum
11908
11909 self.checkScript(test_dicts, ({"a": 1, "b": 2, "c": 3},))
11910
11911 def test_dict_keys_values(x):
11912 # type: (Dict[str, int]) -> Tuple[str, int]
11913 key_str = ""
11914 sum = 0
11915 for key in x.keys():
11916 key_str += key
11917 for val in x.values():
11918 sum += val
11919 return key_str, sum
11920
11921 self.checkScript(test_dicts, ({"a": 1, "b": 2, "c": 3},))
11922
11923 def test_for_tuple_unpack(self):
11924 def for_tuple_unpack(x, y):
11925 for i, j in [[3, 4], [5, 6], [7, 8]]:
11926 x += i
11927 y += j
11928 return x, y
11929
11930 self.checkScript(for_tuple_unpack, (torch.tensor(3), torch.tensor(5)))
11931
11932 def nested_tuple_unpack(x, y):
11933 # type: (List[int], List[int]) -> int
11934 sum = 0
11935 for i, (j, k), v in zip(x, enumerate(x), y):
11936 sum += i + j + k + v
11937 return sum
11938
11939 self.checkScript(nested_tuple_unpack, ([1, 3, 5], [2, 4, 6]))
11940
11941 def test_for_tuple_assign(self):
11942 def test_simple_assign(x):
11943 # type: (Tuple[int, float]) -> float
11944 sum = 0.0
11945 for a in x:
11946 sum += float(a)
11947 return sum
11948
11949 self.checkScript(test_simple_assign, ((1, 2.5),))
11950
11951 def test_tuple_assign(x):
11952 # type: (Tuple[Tuple[int, int], Tuple[int, int]]) -> int
11953 sum = 0
11954 for a in x:
11955 sum += a[0]
11956 sum += a[1]
11957 return sum
11958
Shen Li10224432021-08-12 11:39:31 -070011959 self.checkScript(test_tuple_assign, (((1, 2), (4, 7)), ))
Yanan Caofa29a642021-04-11 14:20:09 -070011960
11961 def test_single_starred_lhs(self):
Shen Li10224432021-08-12 11:39:31 -070011962 with self.assertRaisesRegex(RuntimeError, 'A Starred expression may only appear on the lhs within the presence'
11963 ' of another non-starred expression'):
11964 cu = torch.jit.CompilationUnit('''
Yanan Caofa29a642021-04-11 14:20:09 -070011965 def single_starred_lhs(x):
11966 a = (x, x, x)
11967 *b, = a
11968 return b
Shen Li10224432021-08-12 11:39:31 -070011969 ''')
Yanan Caofa29a642021-04-11 14:20:09 -070011970
11971 def test_singleton_tuple_unpack(self):
11972 def foo(a):
Shen Li10224432021-08-12 11:39:31 -070011973 b, = (a,)
Yanan Caofa29a642021-04-11 14:20:09 -070011974 return b + 1
11975 self.checkScript(foo, (torch.rand(3),))
11976
11977 def test_tuple_assignments(self):
11978 def var_tuple_assign(x, y):
11979 # type: (Tuple[Tensor, Tensor], Tensor) -> Tensor
11980 (a, b), c = x, y
11981 return a + b + c
11982
11983 tuple_inputs = (torch.randn(1, 4), torch.randn(3, 4))
11984 self.checkScript(var_tuple_assign, (tuple_inputs, torch.randn(3, 4)))
11985
11986 def nested_tuple_assign(x, y, z):
11987 # type: (int, Tuple[int, Tuple[int, int]], Tuple[int, int]) -> int
11988 a, (b, (c, d)), (e, f) = x, y, z
11989 return a + b + c + d + e + f
11990
11991 self.checkScript(nested_tuple_assign, ((1, (2, (3, 4)), (5, 6))))
11992
11993 def subscript_tuple_assign(a, x, i):
11994 # type: (List[int], Tensor, int) -> Tuple[int, Tensor, int]
11995 a[i], (x[i], b) = 1, (2, 3)
11996 return a[i] + 1, x + 5, b
11997
Shen Li10224432021-08-12 11:39:31 -070011998 self.checkScript(subscript_tuple_assign, ([12, 7, 9, 11], torch.tensor((3, 13, 17)), 0))
Yanan Caofa29a642021-04-11 14:20:09 -070011999
12000 def star_tuple_assign():
12001 # type: () -> Tuple[int, int, Tuple[int, int], Tuple[int, int]]
12002 a, (b, *c), *d = 1, (2, 3, 4), 5, 6
12003 return a, b, c, d
12004
12005 self.checkScript(star_tuple_assign, ())
12006
12007 def subscript_tuple_augmented_assign(a):
12008 # type: (Tuple[int, int]) -> Tuple[int, int]
12009 a[0] += 1
12010 return a
12011
Shen Li10224432021-08-12 11:39:31 -070012012 with self.assertRaisesRegex(RuntimeError, 'does not support augmented assign'):
Yanan Caofa29a642021-04-11 14:20:09 -070012013 scripted_aug_assign = torch.jit.script(subscript_tuple_augmented_assign)
12014
12015 class AttrTupleAssignmentTestClass:
12016 def __init__(self, a: int, b: int):
12017 self.a = a
12018 self.b = b
12019
12020 def set_ab(self, a: int, b: int):
12021 self.a, self.b = (a, b)
12022
12023 def get(self) -> Tuple[int, int]:
12024 return (self.a, self.b)
12025
12026 make_global(AttrTupleAssignmentTestClass)
12027
12028 @torch.jit.script
12029 def attr_tuple_assignment(o: AttrTupleAssignmentTestClass, a: int, b: int):
12030 o.set_ab(a, b)
12031 return o
12032
12033 o = AttrTupleAssignmentTestClass(1, 2)
12034 self.assertEqual(attr_tuple_assignment(o, 3, 4).get(), (3, 4))
12035
12036 def test_multiple_assign(self):
12037 def test():
12038 a = b, c = d, f = (1, 1)
12039
12040 # side effect
12041 ten = torch.tensor(1)
12042 ten1 = ten2 = ten.add_(1)
12043
12044 # ordering
12045 x = 1
12046 y = 3
12047 x, y = y, x + y
12048
12049 return a, b, c, d, f, ten, ten1, ten2, x, y
12050
12051 self.checkScript(test, ())
12052
James Reed80292962018-06-20 14:51:53 -070012053 def test_multi_reduction(self):
Michael Suo5fbaf0e2018-11-01 23:58:35 -070012054 with self.assertRaisesRegex(
Shen Li10224432021-08-12 11:39:31 -070012055 RuntimeError,
12056 'augmented assignment can only have one LHS expression'):
12057 cu = torch.jit.CompilationUnit('''
James Reed80292962018-06-20 14:51:53 -070012058 def multi_reduction(x):
12059 a, b += x
12060 return a, b
Shen Li10224432021-08-12 11:39:31 -070012061 ''')
James Reed80292962018-06-20 14:51:53 -070012062
12063 def test_invalid_call_arguments(self):
Shen Li10224432021-08-12 11:39:31 -070012064 with self.assertRaisesRegex(RuntimeError, 'but instead found type '):
James Reed80292962018-06-20 14:51:53 -070012065 @torch.jit.script
12066 def invalid_call_arguments(x):
12067 return torch.unsqueeze(3, 4, 5, 6, 7, 8)
12068
12069 def test_invalid_lhs_assignment(self):
Shen Li10224432021-08-12 11:39:31 -070012070 with self.assertRaisesRegex(RuntimeError, 'unexpected expression'):
12071 cu = torch.jit.CompilationUnit('''
James Reed80292962018-06-20 14:51:53 -070012072 def invalid_lhs_assignment(x):
12073 x + 1 = x
12074 return x
Shen Li10224432021-08-12 11:39:31 -070012075 ''')
James Reed80292962018-06-20 14:51:53 -070012076
12077 def test_multi_starred_expr_lhs(self):
Shen Li10224432021-08-12 11:39:31 -070012078 with self.assertRaisesRegex(RuntimeError, 'Only one starred expression is allowed on the lhs'):
12079 cu = torch.jit.CompilationUnit('''
James Reed80292962018-06-20 14:51:53 -070012080 def multi_starred_expr_lhs():
12081 a, *b, *c = [1, 2, 3, 4, 5, 6]
12082 return a
Shen Li10224432021-08-12 11:39:31 -070012083 ''')
James Reed80292962018-06-20 14:51:53 -070012084
12085 def test_pack_tuple_into_non_var(self):
Shen Li10224432021-08-12 11:39:31 -070012086 with self.assertRaisesRegex(RuntimeError, 'Cannot pack a tuple into a non-variable'):
12087 cu = torch.jit.CompilationUnit('''
James Reed80292962018-06-20 14:51:53 -070012088 def pack_tuple_into_non_var(x):
12089 a, *1 = (3, 4, 5)
12090 return x
Shen Li10224432021-08-12 11:39:31 -070012091 ''')
James Reed80292962018-06-20 14:51:53 -070012092
12093 def test_print_kwargs(self):
Shen Li10224432021-08-12 11:39:31 -070012094 with self.assertRaisesRegex(RuntimeError, 'print doesn\'t accept any keyword arguments'):
12095 cu = torch.jit.CompilationUnit('''
James Reed80292962018-06-20 14:51:53 -070012096 def print_kwargs(x):
12097 print(x, flush=True)
12098 return x
Shen Li10224432021-08-12 11:39:31 -070012099 ''')
James Reed80292962018-06-20 14:51:53 -070012100
12101 def test_builtin_use_as_value(self):
Shen Li10224432021-08-12 11:39:31 -070012102 with self.assertRaisesRegex(RuntimeError, 'builtin cannot be used as a value'):
James Reed80292962018-06-20 14:51:53 -070012103 @torch.jit.script
12104 def builtin_use_as_value(x):
12105 return x.unsqueeze
12106
12107 def test_wrong_use_as_tuple(self):
Shen Li10224432021-08-12 11:39:31 -070012108 with self.assertRaisesRegex(RuntimeError, 'cannot be used as a tuple'):
James Reed80292962018-06-20 14:51:53 -070012109 def test_fn():
12110 return 3
12111
12112 @torch.jit.script
12113 def wrong_use_as_tuple(self):
12114 a, b = test_fn
12115 return a
12116
12117 def test_wrong_attr_lookup(self):
Shen Li10224432021-08-12 11:39:31 -070012118 with self.assertRaisesRegex(RuntimeError, 'attribute lookup is not defined on builtin'):
James Reed80292962018-06-20 14:51:53 -070012119 @torch.jit.script
12120 def wrong_attr_lookup(self, x):
12121 a = x.unsqueeze.myattr
12122 return a
12123
12124 def test_wrong_use_as_callable(self):
Shen Li10224432021-08-12 11:39:31 -070012125 with self.assertRaisesRegex(RuntimeError, 'cannot call a value'):
James Reed80292962018-06-20 14:51:53 -070012126 @torch.jit.script
12127 def wrong_use_as_callable(x):
12128 return x(3, 4, 5)
12129
James Reed80292962018-06-20 14:51:53 -070012130 def test_python_val_doesnt_have_attr(self):
Shen Li10224432021-08-12 11:39:31 -070012131 with self.assertRaisesRegex(RuntimeError, 'object has no attribute abcd'):
James Reed80292962018-06-20 14:51:53 -070012132
12133 @torch.jit.script
12134 def python_val_doesnt_have_attr():
Zachary DeVito22c9bc32018-08-28 11:19:39 -070012135 # this has to be a module otherwise attr lookup would not be
12136 # allowed in the first place
12137 return shutil.abcd
James Reed80292962018-06-20 14:51:53 -070012138
12139 def test_wrong_module_attr_lookup(self):
Shen Li10224432021-08-12 11:39:31 -070012140 with self.assertRaisesRegex(RuntimeError, 'python value of type \'type\' cannot be used as a value'):
James Reed80292962018-06-20 14:51:53 -070012141 import io
12142
12143 @torch.jit.script
12144 def wrong_module_attr_lookup():
12145 return io.BytesIO
12146
12147 def test_wrong_method_call_inputs(self):
Shen Li10224432021-08-12 11:39:31 -070012148 with self.assertRaisesRegex(RuntimeError, 'Argument y not provided'):
Zsolt Dollensteinb0043072021-08-12 10:56:55 -070012149 class SomeModule(torch.jit.ScriptModule):
Shen Li10224432021-08-12 11:39:31 -070012150
James Reed80292962018-06-20 14:51:53 -070012151 @torch.jit.script_method
12152 def foo(self, x, y):
12153 return x
12154
12155 @torch.jit.script_method
12156 def forward(self, x, y):
12157 return self.foo(x)
12158 SomeModule()
12159
12160 def test_single_starred_expr_for_loop(self):
Shen Li10224432021-08-12 11:39:31 -070012161 with self.assertRaisesRegex(RuntimeError, 'A Starred expression may only appear'):
12162 cu = torch.jit.CompilationUnit('''
James Reed80292962018-06-20 14:51:53 -070012163 def test():
12164 x = 0
12165 for *a in [1, 2, 3]:
12166 x = x + 1
12167 return x
Shen Li10224432021-08-12 11:39:31 -070012168 ''')
James Reed80292962018-06-20 14:51:53 -070012169
James Reed80292962018-06-20 14:51:53 -070012170 def test_call_ge(self):
Shen Li10224432021-08-12 11:39:31 -070012171 with self.assertRaisesRegex(RuntimeError, 'Expected at most 1 arguments but found 3'):
Zachary DeVito93bd2912018-08-30 13:51:45 -070012172 @_trace(torch.zeros(1, 2, 3))
James Reed80292962018-06-20 14:51:53 -070012173 def foo(x):
12174 return x
12175
12176 @torch.jit.script
12177 def test_fn():
Richard Zou8489c4c2018-06-21 15:43:38 -040012178 return foo(torch.full([1], 1), torch.full([1], 2), torch.full([1], 3))
James Reed80292962018-06-20 14:51:53 -070012179
12180 def test_wrong_return_type(self):
Shen Li10224432021-08-12 11:39:31 -070012181 with self.assertRaisesRegex(RuntimeError, 'but instead got value of type tuple'):
davidriazati7a370db2019-07-16 12:50:02 -070012182 @torch.jit.ignore
James Reed80292962018-06-20 14:51:53 -070012183 def somefunc():
12184 # type: () -> Tuple[Tuple[Tensor, Tensor]]
Elias Ellison561037a2019-03-07 09:12:35 -080012185 return torch.zeros(3, 4), torch.zeros(4, 5) # noqa: T484
James Reed80292962018-06-20 14:51:53 -070012186
12187 @torch.jit.script
12188 def wrong_return_type():
12189 return somefunc()
Zachary DeVito289a8c92018-09-11 15:01:48 -070012190 wrong_return_type()
James Reed80292962018-06-20 14:51:53 -070012191
James Reedc8cc2462018-06-21 10:38:03 -070012192 # Tests for calling between different front-end modes
12193 def test_call_python_fn_from_tracing_fn(self):
12194 def python_fn(x):
12195 return torch.neg(x)
12196
Zachary DeVito93bd2912018-08-30 13:51:45 -070012197 @_trace(torch.rand(3, 4))
James Reedc8cc2462018-06-21 10:38:03 -070012198 def traced_fn(x):
12199 return python_fn(x) + 1
12200
12201 # The neg op in the python function should be properly inlined to the
12202 # graph
eellisonbd7fcce2019-03-06 13:41:13 -080012203 FileCheck().check("aten::neg").run(str(traced_fn.graph))
James Reedc8cc2462018-06-21 10:38:03 -070012204
12205 def test_call_python_mod_from_tracing_fn(self):
12206 class PythonMod(torch.nn.Module):
12207 def __init__(self):
Xuehai Pan046e88a2023-02-12 22:20:50 +000012208 super().__init__()
Will Feng1aa90192019-02-07 11:58:50 -080012209 self.param = torch.nn.Parameter(torch.rand(4, 3), requires_grad=False)
James Reedc8cc2462018-06-21 10:38:03 -070012210
12211 def forward(self, x):
12212 return torch.mm(x, self.param)
12213
12214 pm = PythonMod()
12215
Zachary DeVito93bd2912018-08-30 13:51:45 -070012216 @_trace(torch.rand(3, 4))
James Reedc8cc2462018-06-21 10:38:03 -070012217 def traced_fn(x):
James Reedbeeec472018-08-28 20:21:21 -070012218 return pm(x) + 1.0
James Reedc8cc2462018-06-21 10:38:03 -070012219
12220 # Note: the parameter self.param from the Python module is inlined
12221 # into the graph
Elias Ellison7fc3aa82019-03-04 22:38:41 -080012222 self.assertTrue(len(list(traced_fn.graph.inputs())) == 1)
12223 FileCheck().check("aten::mm").check("aten::add").run(str(traced_fn.graph))
James Reedc8cc2462018-06-21 10:38:03 -070012224
Michael Suo60f6cc92019-08-30 01:28:28 -070012225 @_tmp_donotuse_dont_inline_everything
James Reedc8cc2462018-06-21 10:38:03 -070012226 def test_call_traced_fn_from_tracing_fn(self):
Zachary DeVito93bd2912018-08-30 13:51:45 -070012227 @_trace(torch.rand(3, 4))
James Reedc8cc2462018-06-21 10:38:03 -070012228 def traced_fn1(x):
12229 return torch.neg(x)
12230
Zachary DeVito93bd2912018-08-30 13:51:45 -070012231 @_trace(torch.rand(3, 4))
James Reedc8cc2462018-06-21 10:38:03 -070012232 def traced_fn(x):
12233 return traced_fn1(x) + 1
12234
Shen Li10224432021-08-12 11:39:31 -070012235 FileCheck().check("traced_fn").check("prim::CallFunction").check("aten::add") \
12236 .run(str(traced_fn.graph))
James Reedc8cc2462018-06-21 10:38:03 -070012237
Zachary DeVito8c57ce82019-06-12 17:09:29 -070012238 @unittest.skip("error in first class mode")
James Reedc8cc2462018-06-21 10:38:03 -070012239 def test_call_traced_mod_from_tracing_fn(self):
12240 class TracedModule(torch.nn.Module):
12241 def __init__(self):
Xuehai Pan046e88a2023-02-12 22:20:50 +000012242 super().__init__()
Will Feng1aa90192019-02-07 11:58:50 -080012243 self.param = torch.nn.Parameter(torch.rand(4, 3), requires_grad=False)
James Reedc8cc2462018-06-21 10:38:03 -070012244
12245 def forward(self, x):
12246 return torch.mm(x, self.param)
12247
Zachary DeVito93bd2912018-08-30 13:51:45 -070012248 tm = torch.jit.trace(TracedModule(), torch.rand(3, 4))
James Reedc8cc2462018-06-21 10:38:03 -070012249
Zachary DeVito8c57ce82019-06-12 17:09:29 -070012250 with self.assertRaisesRegex(RuntimeError, "must be registered as submodules"):
12251 @_trace(torch.rand(3, 4))
12252 def traced_fn(x):
12253 return tm(x) + 1.0
James Reedc8cc2462018-06-21 10:38:03 -070012254
Michael Suo60f6cc92019-08-30 01:28:28 -070012255 @_tmp_donotuse_dont_inline_everything
James Reedc8cc2462018-06-21 10:38:03 -070012256 def test_call_script_fn_from_tracing_fn(self):
12257 @torch.jit.script
12258 def script_fn(x):
12259 return torch.neg(x)
12260
Zachary DeVito93bd2912018-08-30 13:51:45 -070012261 @_trace(torch.rand(3, 4))
James Reedc8cc2462018-06-21 10:38:03 -070012262 def traced_fn(x):
12263 return script_fn(x) + 1
12264
Shen Li10224432021-08-12 11:39:31 -070012265 FileCheck().check("prim::CallFunction").check("aten::add").run(str(traced_fn.graph))
James Reedc8cc2462018-06-21 10:38:03 -070012266
Zachary DeVito8c57ce82019-06-12 17:09:29 -070012267 @unittest.skip("error in first class mode")
James Reedc8cc2462018-06-21 10:38:03 -070012268 def test_call_script_mod_from_tracing_fn(self):
Zachary DeVito8c57ce82019-06-12 17:09:29 -070012269 with self.assertRaisesRegex(RuntimeError, "must be registered as submodules"):
Elias Ellison7fc3aa82019-03-04 22:38:41 -080012270 class ScriptMod(torch.jit.ScriptModule):
12271 def __init__(self):
Xuehai Pan046e88a2023-02-12 22:20:50 +000012272 super().__init__()
Shen Li10224432021-08-12 11:39:31 -070012273 self.param = torch.nn.Parameter(torch.rand(3, 4), requires_grad=False)
James Reedc8cc2462018-06-21 10:38:03 -070012274
Elias Ellison7fc3aa82019-03-04 22:38:41 -080012275 @torch.jit.script_method
12276 def forward(self, x):
12277 for _i in range(4):
12278 x += self.param
12279 return x
James Reedc8cc2462018-06-21 10:38:03 -070012280
Elias Ellison7fc3aa82019-03-04 22:38:41 -080012281 sm = ScriptMod()
James Reedc8cc2462018-06-21 10:38:03 -070012282
Elias Ellison7fc3aa82019-03-04 22:38:41 -080012283 @_trace(torch.rand(3, 4))
12284 def traced_fn(x):
12285 return sm(x) + 1.0
James Reedc8cc2462018-06-21 10:38:03 -070012286
Shen Li10224432021-08-12 11:39:31 -070012287
James Reedc8cc2462018-06-21 10:38:03 -070012288 def test_call_python_fn_from_traced_module(self):
12289 def python_fn(x):
12290 return torch.neg(x)
12291
12292 class TracedModule(torch.nn.Module):
12293 def __init__(self):
Xuehai Pan046e88a2023-02-12 22:20:50 +000012294 super().__init__()
James Reedc8cc2462018-06-21 10:38:03 -070012295 self.param = torch.nn.Parameter(torch.rand(4, 3))
12296
12297 def forward(self, x):
12298 return torch.mm(python_fn(x), self.param)
12299
Zachary DeVito93bd2912018-08-30 13:51:45 -070012300 tm = torch.jit.trace(TracedModule(), torch.rand(3, 4))
James Reedc8cc2462018-06-21 10:38:03 -070012301
12302 # Note: parameter self.param from the traced module should appear as
12303 # an input to the graph and the neg op from the Python function should
12304 # be properly inlined
Elias Ellison7fc3aa82019-03-04 22:38:41 -080012305 self.assertTrue(len(list(tm.graph.inputs())) == 2)
12306 FileCheck().check("aten::neg").check("aten::mm").run(str(tm.graph))
James Reedc8cc2462018-06-21 10:38:03 -070012307
12308 def test_call_python_mod_from_traced_module(self):
12309 class PythonModule(torch.nn.Module):
12310 def __init__(self):
Xuehai Pan046e88a2023-02-12 22:20:50 +000012311 super().__init__()
James Reedc8cc2462018-06-21 10:38:03 -070012312 self.param = torch.nn.Parameter(torch.rand(5, 7))
12313
12314 def forward(self, x):
12315 return torch.mm(x, self.param)
12316
12317 class TracedModule(torch.nn.Module):
12318 def __init__(self):
Xuehai Pan046e88a2023-02-12 22:20:50 +000012319 super().__init__()
James Reedc8cc2462018-06-21 10:38:03 -070012320 self.param = torch.nn.Parameter(torch.rand(4, 5))
12321 self.mod = PythonModule()
12322
12323 def forward(self, x):
James Reedbeeec472018-08-28 20:21:21 -070012324 return self.mod(torch.mm(x, self.param)) + 1.0
James Reedc8cc2462018-06-21 10:38:03 -070012325
Zachary DeVito93bd2912018-08-30 13:51:45 -070012326 tm = torch.jit.trace(TracedModule(), torch.rand(3, 4))
James Reedc8cc2462018-06-21 10:38:03 -070012327
Shen Li10224432021-08-12 11:39:31 -070012328 FileCheck().check_not("value=<Tensor>").check("aten::mm")\
12329 .check("prim::CallMethod[name=\"forward\"]").check("aten::add") \
12330 .run(str(tm.graph))
James Reed309b28e2019-11-06 15:02:56 -080012331 FileCheck().check("aten::mm").run(str(tm.mod.graph))
12332
eellisone01fc562019-11-11 11:24:43 -080012333 def test_op_dtype(self):
Shen Li10224432021-08-12 11:39:31 -070012334
eellisone01fc562019-11-11 11:24:43 -080012335 def check_equal_and_dtype(a, b):
12336 self.assertEqual(a, b)
12337 self.assertEqual(a.dtype, b.dtype)
12338
12339 def fn():
12340 a = torch.arange(10)
12341 b = torch.arange(10, dtype=torch.float)
12342 c = torch.arange(1, 10, 2)
12343 d = torch.arange(1, 10, 2, dtype=torch.float)
Shen Li10224432021-08-12 11:39:31 -070012344 e = torch.arange(1, 10., 2)
12345 f = torch.arange(1, 10., 2, dtype=torch.float)
eellisone01fc562019-11-11 11:24:43 -080012346 return a, b, c, d, e, f
12347
12348 scripted_fn = torch.jit.script(fn)
12349 eager_out = fn()
12350 script_out = scripted_fn()
12351 for a, b in zip(eager_out, script_out):
12352 check_equal_and_dtype(a, b)
12353
Mike Ruberry089203f2022-05-29 21:28:45 +000012354 def test_floor_div(self):
12355 @torch.jit.script
12356 def foo(a, b):
12357 # type: (int, int) -> int
12358 return a // b
12359 for i in range(-8, 8):
12360 for j in range(-8, 8):
12361 if j != 0:
12362 self.assertEqual(foo(i, j), i // j)
12363
Elias Ellisonf48a8902019-12-10 07:49:11 -080012364 def test_floordiv(self):
Shen Li10224432021-08-12 11:39:31 -070012365 funcs_template = dedent('''
Elias Ellisonf48a8902019-12-10 07:49:11 -080012366 def fn():
12367 ten = {a_construct}
12368 ten_or_scalar = {b_construct}
12369 return ten // ten_or_scalar, torch.floor_divide(ten, ten_or_scalar)
Shen Li10224432021-08-12 11:39:31 -070012370 ''')
Elias Ellisonf48a8902019-12-10 07:49:11 -080012371
Shen Li10224432021-08-12 11:39:31 -070012372 lhs = ["torch.tensor([5.5, 3.2])", "torch.tensor([2, 2])", "torch.tensor([3, 2])"]
Elias Ellisonf48a8902019-12-10 07:49:11 -080012373 rhs = ["1.5", "2", "4", "1.1"] + lhs
12374 for tensor in lhs:
12375 for tensor_or_scalar in rhs:
Shen Li10224432021-08-12 11:39:31 -070012376 funcs_str = funcs_template.format(a_construct=tensor, b_construct=tensor_or_scalar)
Elias Ellisonf48a8902019-12-10 07:49:11 -080012377 scope = {}
12378 execWrapper(funcs_str, globals(), scope)
12379 cu = torch.jit.CompilationUnit(funcs_str)
12380 f_script = cu.fn
Shen Li10224432021-08-12 11:39:31 -070012381 f = scope['fn']
Mike Ruberry089203f2022-05-29 21:28:45 +000012382 self.assertEqual(f_script(), f())
Elias Ellisonf48a8902019-12-10 07:49:11 -080012383
James Reedc8cc2462018-06-21 10:38:03 -070012384 def test_call_python_fn_from_script_fn(self):
davidriazati7a370db2019-07-16 12:50:02 -070012385 @torch.jit.ignore
James Reedc8cc2462018-06-21 10:38:03 -070012386 def python_fn(x):
12387 return torch.neg(x)
12388
12389 @torch.jit.script
12390 def script_fn(x):
12391 return python_fn(x) + 1
12392
12393 # Note: the call to python_fn appears as `^python_fn()` and is called
12394 # as a PythonOp in the interpreter
Elias Ellison7fc3aa82019-03-04 22:38:41 -080012395 a = torch.tensor(1)
12396 self.assertEqual(script_fn(a), torch.tensor(0))
12397 FileCheck().check("python_fn").run(str(script_fn.graph))
James Reedc8cc2462018-06-21 10:38:03 -070012398
12399 def test_call_python_mod_from_script_fn(self):
12400 class PythonModule(torch.nn.Module):
12401 def __init__(self):
Xuehai Pan046e88a2023-02-12 22:20:50 +000012402 super().__init__()
James Reedc8cc2462018-06-21 10:38:03 -070012403 self.param = torch.nn.Parameter(torch.rand(5, 7))
12404
12405 def forward(self, x):
12406 return torch.mm(x, self.param)
12407
12408 pm = PythonModule()
12409
12410 @torch.jit.script
12411 def script_fn(x):
12412 return pm(x) + 1
12413
12414 # Note: call to pm(x) appears as ^<python_value>() in the trace.
12415 # Parameters are NOT inlined.
Elias Ellison7fc3aa82019-03-04 22:38:41 -080012416 FileCheck().check("python_value").check("aten::add").run(str(script_fn.graph))
James Reedc8cc2462018-06-21 10:38:03 -070012417
Michael Suo755f91b2019-08-19 18:41:08 -070012418 @_tmp_donotuse_dont_inline_everything
James Reedc8cc2462018-06-21 10:38:03 -070012419 def test_call_script_fn_from_script_fn(self):
12420 @torch.jit.script
12421 def script_fn1(x):
12422 return torch.neg(x)
12423
12424 @torch.jit.script
12425 def script_fn(x):
12426 return script_fn1(x) + 1
12427
Michael Suo755f91b2019-08-19 18:41:08 -070012428 FileCheck().check("prim::CallFunction").run(str(script_fn.graph))
James Reedc8cc2462018-06-21 10:38:03 -070012429
James Reedc8cc2462018-06-21 10:38:03 -070012430 def test_call_script_mod_from_script_fn(self):
Shen Li10224432021-08-12 11:39:31 -070012431 with self.assertRaisesRegex(RuntimeError, "Cannot call a ScriptModule that is not a submodule of the caller"):
Zachary DeVito31524bd2019-04-25 15:43:52 -070012432 class ScriptMod(torch.jit.ScriptModule):
Zachary DeVito31524bd2019-04-25 15:43:52 -070012433 @torch.jit.script_method
12434 def forward(self, x):
12435 return torch.mm(x, torch.zeros([4, 3]))
James Reedc8cc2462018-06-21 10:38:03 -070012436
Zachary DeVito31524bd2019-04-25 15:43:52 -070012437 sm = ScriptMod()
James Reedc8cc2462018-06-21 10:38:03 -070012438
Zachary DeVito31524bd2019-04-25 15:43:52 -070012439 @torch.jit.script
12440 def script_fn(x):
12441 return sm(x) + 1
James Reedc8cc2462018-06-21 10:38:03 -070012442
James Reedc8cc2462018-06-21 10:38:03 -070012443 def test_call_python_fn_from_script_module(self):
davidriazati7a370db2019-07-16 12:50:02 -070012444 @torch.jit.ignore
James Reedc8cc2462018-06-21 10:38:03 -070012445 def python_fn(x):
12446 return torch.neg(x)
12447
12448 class ScriptMod(torch.jit.ScriptModule):
12449 def __init__(self):
Xuehai Pan046e88a2023-02-12 22:20:50 +000012450 super().__init__()
James Reedc8cc2462018-06-21 10:38:03 -070012451 self.param = torch.nn.Parameter(torch.rand(4, 3))
12452
12453 @torch.jit.script_method
12454 def forward(self, x):
12455 return python_fn(torch.mm(x, self.param))
12456
12457 sm = ScriptMod()
Shen Li10224432021-08-12 11:39:31 -070012458 FileCheck().check("aten::mm").check("python_fn") \
12459 .run(str(sm.forward.graph))
James Reedc8cc2462018-06-21 10:38:03 -070012460
12461 def test_call_python_mod_from_script_module(self):
12462 class PythonMod(torch.nn.Module):
12463 def __init__(self):
Xuehai Pan046e88a2023-02-12 22:20:50 +000012464 super().__init__()
James Reedc8cc2462018-06-21 10:38:03 -070012465 self.param = torch.nn.Parameter(torch.rand(3, 5))
12466
davidriazati7a370db2019-07-16 12:50:02 -070012467 @torch.jit.ignore
James Reedc8cc2462018-06-21 10:38:03 -070012468 def forward(self, x):
12469 return torch.mm(x, self.param)
12470
12471 class ScriptMod(torch.jit.ScriptModule):
12472 def __init__(self):
Xuehai Pan046e88a2023-02-12 22:20:50 +000012473 super().__init__()
James Reedc8cc2462018-06-21 10:38:03 -070012474 self.param = torch.nn.Parameter(torch.rand(4, 3))
12475 self.pm = PythonMod()
12476
12477 @torch.jit.script_method
12478 def forward(self, x):
12479 return self.pm(torch.mm(x, self.param))
12480
12481 sm = ScriptMod()
davidriazati7a370db2019-07-16 12:50:02 -070012482 # Note: the call into PythonMod appears as ^forward(). Parameters
James Reedc8cc2462018-06-21 10:38:03 -070012483 # are NOT inlined
davidriazati7a370db2019-07-16 12:50:02 -070012484 FileCheck().check("aten::mm").check("forward").run(str(sm.graph))
James Reedc8cc2462018-06-21 10:38:03 -070012485
Michael Suo755f91b2019-08-19 18:41:08 -070012486 @_tmp_donotuse_dont_inline_everything
James Reedc8cc2462018-06-21 10:38:03 -070012487 def test_call_script_fn_from_script_module(self):
12488 @torch.jit.script
12489 def script_fn(x):
12490 return torch.neg(x)
12491
12492 class ScriptMod(torch.jit.ScriptModule):
12493 def __init__(self):
Xuehai Pan046e88a2023-02-12 22:20:50 +000012494 super().__init__()
James Reedc8cc2462018-06-21 10:38:03 -070012495 self.param = torch.nn.Parameter(torch.rand(4, 3))
12496
12497 @torch.jit.script_method
12498 def forward(self, x):
12499 return script_fn(torch.mm(x, self.param))
12500
12501 sm = ScriptMod()
Shen Li10224432021-08-12 11:39:31 -070012502 graph = (sm.forward.graph)
Michael Suo755f91b2019-08-19 18:41:08 -070012503 FileCheck().check("aten::mm").check("prim::CallFunction").run(str(graph))
James Reedc8cc2462018-06-21 10:38:03 -070012504
Michael Suo755f91b2019-08-19 18:41:08 -070012505 @_tmp_donotuse_dont_inline_everything
James Reedc8cc2462018-06-21 10:38:03 -070012506 def test_call_script_mod_from_script_module(self):
12507 class ScriptMod1(torch.jit.ScriptModule):
12508 def __init__(self):
Xuehai Pan046e88a2023-02-12 22:20:50 +000012509 super().__init__()
James Reedc8cc2462018-06-21 10:38:03 -070012510 self.param = torch.nn.Parameter(torch.rand(3, 5))
12511
12512 @torch.jit.script_method
12513 def forward(self, x):
12514 return torch.mm(x, self.param)
12515
12516 class ScriptMod(torch.jit.ScriptModule):
12517 def __init__(self):
Xuehai Pan046e88a2023-02-12 22:20:50 +000012518 super().__init__()
James Reedc8cc2462018-06-21 10:38:03 -070012519 self.param = torch.nn.Parameter(torch.rand(4, 3))
12520 self.tm = ScriptMod1()
12521
12522 @torch.jit.script_method
12523 def forward(self, x):
12524 return self.tm(torch.mm(x, self.param))
12525
12526 sm = ScriptMod()
12527 # Note: the parameters from both modules should appear in the flattened
12528 # input list to the graph. The mm op from ScriptMod1 should be properly
12529 # inlined
Elias Ellison7fc3aa82019-03-04 22:38:41 -080012530 # 3 % values in graph input lists, two mms in body
Shen Li10224432021-08-12 11:39:31 -070012531 FileCheck().check_count('%', 3).check(":").check_count("mm", 1).check("prim::CallMethod").run(str(sm.graph))
James Reed2e23bc12018-06-28 16:24:08 -070012532
12533 def test_module_with_params_called_fails(self):
Shen Li10224432021-08-12 11:39:31 -070012534 with self.assertRaisesRegex(RuntimeError, "Cannot call a ScriptModule that is not a submodule of the caller"):
James Reed2e23bc12018-06-28 16:24:08 -070012535 class ScriptMod(torch.jit.ScriptModule):
12536 def __init__(self):
Xuehai Pan046e88a2023-02-12 22:20:50 +000012537 super().__init__()
James Reed2e23bc12018-06-28 16:24:08 -070012538 self.param = torch.nn.Parameter(torch.rand(3, 3))
12539
12540 @torch.jit.script_method
12541 def forward(self, x):
12542 return torch.mm(x, self.param)
12543
12544 sm = ScriptMod()
12545
12546 @torch.jit.script
12547 def some_func(x):
12548 return sm(x)
James Reedc8cc2462018-06-21 10:38:03 -070012549
Elias Ellison26f52752019-05-06 13:25:57 -070012550 def test_tuple_index_to_list(self):
12551 def test_non_constant_input(a):
12552 # type: (bool) -> int
12553 if a:
12554 b = 1
12555 else:
12556 b = 0
12557 c = (0, 1)
12558 return c[b]
12559
12560 self.checkScript(test_non_constant_input, (True,))
12561 self.checkScript(test_non_constant_input, (False,))
12562
Shen Li10224432021-08-12 11:39:31 -070012563 with self.assertRaisesRegex(RuntimeError, "because we cannot resolve the output type"):
Elias Ellison26f52752019-05-06 13:25:57 -070012564 @torch.jit.script
12565 def test_non_constant_input(a):
12566 # type: (bool) -> None
12567 if a:
12568 b = 1
12569 else:
12570 b = 0
12571 c = (0, 1.1)
12572 print(c[b])
12573
Elias Ellisonf9b7ce92018-10-23 17:50:15 -070012574 def test_tuple_indexing(self):
12575 def tuple_index(a):
12576 if bool(a):
12577 b = (1, 2)
12578 else:
12579 b = (0, 2)
12580 return b[-2], b[1]
12581
Elias Ellison411cf432019-02-25 16:11:47 -080012582 self.checkScript(tuple_index, (torch.tensor([0]),))
Elias Ellisonf9b7ce92018-10-23 17:50:15 -070012583 self.checkScript(tuple_index, (torch.tensor([1]),))
12584 self.checkScript(tuple_index, (torch.tensor([1]),), optimize=True)
12585 tuple_comp = torch.jit.script(tuple_index)
Shen Li10224432021-08-12 11:39:31 -070012586 FileCheck().check_count("TupleIndex", 2, exactly=True).run(str(tuple_comp.graph))
Elias Ellisonf9b7ce92018-10-23 17:50:15 -070012587
Elias Ellison26f52752019-05-06 13:25:57 -070012588 with self.assertRaisesRegex(RuntimeError, "index must be an integer"):
Elias Ellisonf9b7ce92018-10-23 17:50:15 -070012589 @torch.jit.script
Elias Ellison26f52752019-05-06 13:25:57 -070012590 def test_indexing_float():
12591 c = (1, 2)
12592 return c[0.1]
Elias Ellisonf9b7ce92018-10-23 17:50:15 -070012593
12594 def test_indexing_out_of_bounds_pos():
12595 c = (1, 2)
12596 return c[2]
12597
Shen Li10224432021-08-12 11:39:31 -070012598 self.checkScriptRaisesRegex(test_indexing_out_of_bounds_pos, (), Exception,
12599 "out of range")
Elias Ellisonf9b7ce92018-10-23 17:50:15 -070012600
12601 def test_indexing_out_of_bounds_neg():
12602 c = (1, 2)
12603 return c[-3]
12604
Shen Li10224432021-08-12 11:39:31 -070012605 self.checkScriptRaisesRegex(test_indexing_out_of_bounds_pos, (), Exception,
12606 "out of range")
Elias Ellisonf9b7ce92018-10-23 17:50:15 -070012607
Elias Ellison881adb52019-06-07 11:11:42 -070012608 def negative_index():
12609 tup = (1, 2, 3, 4)
12610 return tup[-1]
12611
12612 self.checkScript(negative_index, [])
12613
12614 def really_negative_index():
12615 tup = (1, 2, 3, 4)
12616 return tup[-100]
12617
Shen Li10224432021-08-12 11:39:31 -070012618 self.checkScriptRaisesRegex(really_negative_index, [], Exception, "index out of range")
Elias Ellison881adb52019-06-07 11:11:42 -070012619
12620 def negative_slice():
12621 tup = (1, 2, 3, 4)
12622 return tup[-3:4]
12623
12624 self.checkScript(negative_slice, [])
12625
12626 def really_slice_out_of_bounds():
12627 tup = (1, 2, 3, 4)
12628 return tup[-300:4000]
12629
12630 self.checkScript(really_slice_out_of_bounds, [])
12631
Xiang Gaoeae139e2019-02-10 18:10:59 -080012632 def test_namedtuple_attr(self):
12633 def f(x):
12634 return x.max(dim=1).indices + torch.max(x, dim=1).indices
12635
12636 self.checkScript(f, (torch.rand(20, 20, 20),), optimize=True)
12637
Zhengxu Chen8176ab62021-04-16 15:45:49 -070012638 with self.assertRaisesRegex(RuntimeError, "object has no attribute or method"):
Xiang Gaoeae139e2019-02-10 18:10:59 -080012639 @torch.jit.script
12640 def g1(x):
12641 return x.max(dim=1).unknown_symbol
12642
Zhengxu Chen8176ab62021-04-16 15:45:49 -070012643 with self.assertRaisesRegex(RuntimeError, "object has no attribute or method"):
Xiang Gaoeae139e2019-02-10 18:10:59 -080012644 @torch.jit.script
12645 def g2(x):
12646 print((x, x, x).__doc__)
12647 return x
12648
Elias Ellison1ec06762020-03-31 19:22:14 -070012649 def test_tuple_len(self):
12650 @torch.jit.script
12651 def foo():
12652 return len((1, "str", None))
12653
12654 self.assertEqual(foo(), 3)
12655
Elias Ellisonf9b7ce92018-10-23 17:50:15 -070012656 @torch.jit.script
12657 def test_indexing_end_out_of_bounds():
12658 c = (1, 2)
12659 return c[2:10]
12660
Zachary DeVito056cfaf2018-12-18 10:27:26 -080012661 self.assertEqual(test_indexing_end_out_of_bounds(), ())
Elias Ellisonf9b7ce92018-10-23 17:50:15 -070012662
Elias Ellison38d122e2020-01-22 12:09:46 -080012663 def test_lower_nested_tuples(self):
12664 @torch.jit.script
12665 def test():
12666 return ((1, 2), 3)
12667
Shen Li10224432021-08-12 11:39:31 -070012668 self.run_pass('constant_propagation', test.graph)
Elias Ellison38d122e2020-01-22 12:09:46 -080012669 FileCheck().check("prim::Constant").check_not("TupleConstruct").run(test.graph)
12670 # fails if a tuple can't be lowered
Shen Li10224432021-08-12 11:39:31 -070012671 self.run_pass('lower_all_tuples', test.graph)
Elias Ellison38d122e2020-01-22 12:09:46 -080012672
Elias Ellison137150b2018-11-06 11:52:40 -080012673 def test_unwrap_optional_builtin(self):
12674 def test(x):
12675 # type: (Optional[int]) -> int
12676 x = torch.jit._unwrap_optional(x)
Elias Ellison561037a2019-03-07 09:12:35 -080012677 x = x + x # noqa: T484
Elias Ellison137150b2018-11-06 11:52:40 -080012678 return x
12679
12680 self.checkScript(test, (3,))
12681
12682 with self.assertRaisesRegex(AssertionError, "Unwrapping null optional"):
12683 test(None)
12684
12685 test_script = torch.jit.script(test)
12686 with self.assertRaisesRegex(RuntimeError, "Unwrapping null optional"):
12687 test_script(None)
12688
David Riazati0c375572018-11-12 12:13:10 -080012689 @torch.jit.script
12690 def test_test():
12691 return torch.jit._unwrap_optional(1)
12692
Shen Li10224432021-08-12 11:39:31 -070012693 with self.assertRaisesRegex(RuntimeError, r"could not be inferred from actual type None"):
Elias Ellison137150b2018-11-06 11:52:40 -080012694 @torch.jit.script
David Riazati0c375572018-11-12 12:13:10 -080012695 def test_no_type():
12696 # type: () -> int
12697 return torch.jit._unwrap_optional(None)
Elias Ellison137150b2018-11-06 11:52:40 -080012698
Elias Ellisonf9b7ce92018-10-23 17:50:15 -070012699 def test_indexing_error(self):
Wanchao Liangedeb4db2019-07-10 14:38:12 -070012700 with self.assertRaisesRegex(RuntimeError, "'int' object is not subscriptable"):
Elias Ellisonf9b7ce92018-10-23 17:50:15 -070012701 @torch.jit.script
12702 def test_wrong_type():
12703 a = 8
12704 return a[0]
12705
Ailing Zhang50ee1f32019-06-11 11:25:20 -070012706 def test_unsupported_builtin_error(self):
Shen Li10224432021-08-12 11:39:31 -070012707 with self.assertRaisesRegex(RuntimeError,
12708 "Python builtin <built-in function hypot> is currently"):
Ailing Zhang50ee1f32019-06-11 11:25:20 -070012709 @torch.jit.script
12710 def test_unsupported(a):
12711 return math.hypot(a, 2.0)
12712
James Reed0b16b032018-07-25 16:55:09 -070012713 def test_annotated_script_fn(self):
12714 @torch.jit.script
12715 def foo(x, y, z):
12716 # type: (Tensor, Tuple[Tensor, Tensor, Tensor], Tuple[Tensor, Tuple[Tensor, Tensor]]) -> Tensor
12717 return x
12718
Zachary DeVito31524bd2019-04-25 15:43:52 -070012719 self.assertExpected(str(foo.schema))
James Reed0b16b032018-07-25 16:55:09 -070012720
12721 def test_annotated_script_method(self):
12722 class SM(torch.jit.ScriptModule):
12723 @torch.jit.script_method
12724 def forward(self, x, y):
12725 # type: (Tuple[Tensor, Tensor], Tensor) -> Tuple[Tensor, Tensor, Tensor]
12726 return y, y, y
12727
12728 sm = SM()
12729
James Reed44982832019-11-20 16:14:41 -080012730 self.assertExpectedStripMangled(str(sm.forward.schema))
James Reed0b16b032018-07-25 16:55:09 -070012731
12732 def test_annotated_script_fn_return_mismatch(self):
Zachary DeVito056cfaf2018-12-18 10:27:26 -080012733 with self.assertRaisesRegex(RuntimeError, "but is actually of type"):
James Reed0b16b032018-07-25 16:55:09 -070012734 @torch.jit.script
12735 def return_tup(x):
12736 # type: (Tensor) -> Tuple[Tuple[Tensor, Tensor], Tensor]
Elias Ellison561037a2019-03-07 09:12:35 -080012737 return x, x # noqa: T484
James Reed0b16b032018-07-25 16:55:09 -070012738
12739 def test_annotated_script_fn_arg_mismatch(self):
Mikhail Zolotukhinfbecb462019-06-13 17:01:29 -070012740 with self.assertRaisesRegex(RuntimeError, r"Arguments for call are not valid"):
James Reed0b16b032018-07-25 16:55:09 -070012741 @torch.jit.script
12742 def tuple_arg(x):
12743 # type: (Tuple[Tensor, Tensor]) -> Tensor
Elias Ellison561037a2019-03-07 09:12:35 -080012744 return x + 1 # noqa: T484
James Reed0b16b032018-07-25 16:55:09 -070012745
Adam Paszke1f134532018-07-31 14:23:33 -070012746 def test_script_non_tensor_args_outputs(self):
12747 @torch.jit.script
12748 def fn(x, y):
12749 # type: (Tensor, float) -> float
12750 return float((x + y).sum())
12751
12752 x = torch.ones(2, 2)
12753 z = fn(x, 1)
12754 self.assertIsInstance(z, float)
Shen Li10224432021-08-12 11:39:31 -070012755 self.assertEqual(z, 8.)
Adam Paszke1f134532018-07-31 14:23:33 -070012756
Shen Li10224432021-08-12 11:39:31 -070012757 @unittest.skip('https://github.com/pytorch/pytorch/issues/9595')
James Reed0b16b032018-07-25 16:55:09 -070012758 def test_inline_and_run_annotated_script_fn(self):
12759 @torch.jit.script
12760 def to_inline(x, y):
12761 # type: (Tuple[Tensor, Tensor], Tensor) -> Tensor
12762 return y
12763
12764 @torch.jit.script
12765 def some_func(x):
12766 return to_inline((x, x), x)
12767
James Reed9c818bf2018-08-02 10:13:42 -070012768 x = torch.rand(3, 4)
12769 self.assertEqual(some_func(x), x)
James Reed0b16b032018-07-25 16:55:09 -070012770
James Reed851c18d2018-07-27 22:21:05 -070012771 def test_file_format_serialization(self):
James Reed851c18d2018-07-27 22:21:05 -070012772 filename = tempfile.mktemp()
12773 writer = torch._C.PyTorchFileWriter(filename)
Shen Li10224432021-08-12 11:39:31 -070012774 buffers = [os.urandom(size) for size in [random.randint(1, 100) for i in range(20)]]
James Reed851c18d2018-07-27 22:21:05 -070012775 offsets = []
Zachary DeVito170ff772018-11-30 19:15:09 -080012776 for i, buf in enumerate(buffers):
12777 writer.write_record(str(i), buf, len(buf))
12778 offsets.append(i)
James Reed851c18d2018-07-27 22:21:05 -070012779 serialized_offsets = pickle.dumps(offsets)
Zachary DeVito170ff772018-11-30 19:15:09 -080012780 writer.write_record("meta", serialized_offsets, len(serialized_offsets))
James Reed851c18d2018-07-27 22:21:05 -070012781 writer.write_end_of_file()
12782
12783 reader = torch._C.PyTorchFileReader(filename)
Zachary DeVito170ff772018-11-30 19:15:09 -080012784 serialized_offsets_read = reader.get_record("meta")
James Reed851c18d2018-07-27 22:21:05 -070012785 parsed_serialized_offsets = pickle.loads(serialized_offsets)
12786
12787 for i, offset in enumerate(parsed_serialized_offsets):
Zachary DeVito170ff772018-11-30 19:15:09 -080012788 data = reader.get_record(str(offset))
Shen Li10224432021-08-12 11:39:31 -070012789 assert(data == buffers[i])
James Reed851c18d2018-07-27 22:21:05 -070012790
Elias Ellison70db5362018-11-01 10:30:33 -070012791 # for each type, the input type annotation and corresponding return type annotation
12792 def type_input_return_pairs(self):
12793 return [
Shen Li10224432021-08-12 11:39:31 -070012794 ('Tensor', 'Tensor'),
12795 ('torch.Tensor', 'Tensor'),
12796 ('str', 'str'),
12797 ('int', 'int'),
12798 ('bool', 'bool'),
12799 ('BroadcastingList3[float]', 'List[float]'),
12800 ('BroadcastingList2[int]', 'List[int]'),
12801 ('List[int]', 'List[int]'),
12802 ('Optional[int]', 'Optional[int]'),
Elias Ellison70db5362018-11-01 10:30:33 -070012803 ]
12804
12805 # replacing code input & return type pair
12806 def format_code(self, code, pair):
12807 return code.format(input=pair[0], output=pair[1])
12808
James Reed32bb4042018-08-14 17:56:52 -070012809 # ***** Type annotation tests ****
12810 # Test combinations of:
12811 # {String frontend, Python AST Frontend}
12812 # {Python 3-style type annotations, MyPy-style type comments}
12813 # {Script method, Script function}
12814
12815 # String frontend , Python 3-style type annotations , Script function
12816 def test_annot_string_py3_fn(self):
Shen Li10224432021-08-12 11:39:31 -070012817 code = '''
Elias Ellison70db5362018-11-01 10:30:33 -070012818 def foo(x : {input}, y : Tuple[Tensor, Tensor]) -> Tuple[{output}, {output}]:
James Reed32bb4042018-08-14 17:56:52 -070012819 return x, x
Shen Li10224432021-08-12 11:39:31 -070012820 '''
Elias Ellison70db5362018-11-01 10:30:33 -070012821 test_str = []
12822 for pair in self.type_input_return_pairs():
12823 cu = torch.jit.CompilationUnit(self.format_code(code, pair))
Zachary DeVito31524bd2019-04-25 15:43:52 -070012824 test_str.append(str(cu.foo.schema))
Ansley Ussery6831d8e2021-09-03 06:10:37 -070012825 self.assertExpected("\n".join(test_str) + "\n")
James Reed32bb4042018-08-14 17:56:52 -070012826
12827 # String frontend , Python 3-style type annotations , Script method
12828 def test_annot_string_py3_method(self):
12829 class TestModule(torch.jit.ScriptModule):
12830 def __init__(self):
Xuehai Pan046e88a2023-02-12 22:20:50 +000012831 super().__init__()
James Reed32bb4042018-08-14 17:56:52 -070012832
Shen Li10224432021-08-12 11:39:31 -070012833 code = '''
Elias Ellison70db5362018-11-01 10:30:33 -070012834 def foo(self, x : {input}, y : Tuple[Tensor, Tensor]) -> Tuple[{output}, {output}]:
James Reed32bb4042018-08-14 17:56:52 -070012835 return x, x
Shen Li10224432021-08-12 11:39:31 -070012836 '''
Elias Ellison70db5362018-11-01 10:30:33 -070012837 test_str = []
12838 for pair in self.type_input_return_pairs():
Michael Suo34126272019-10-12 09:49:56 -070012839 # clear the class registry as we will be defining foo multiple times
12840 jit_utils.clear_class_registry()
Elias Ellison70db5362018-11-01 10:30:33 -070012841 tm = TestModule()
12842 tm.define(self.format_code(code, pair))
Zachary DeVito31524bd2019-04-25 15:43:52 -070012843 test_str.append(str(tm.foo.schema))
Ansley Ussery6831d8e2021-09-03 06:10:37 -070012844 self.assertExpectedStripMangled("\n".join(test_str) + "\n")
James Reed32bb4042018-08-14 17:56:52 -070012845
12846 # String frontend , MyPy-style type comments , Script function
12847 def test_annot_string_mypy_fn(self):
Shen Li10224432021-08-12 11:39:31 -070012848 code = '''
James Reed32bb4042018-08-14 17:56:52 -070012849 def foo(x, y):
Elias Ellison70db5362018-11-01 10:30:33 -070012850 # type: ({input}, Tuple[Tensor, Tensor]) -> Tuple[{output}, {output}]
James Reed32bb4042018-08-14 17:56:52 -070012851 return x, x
Shen Li10224432021-08-12 11:39:31 -070012852 '''
Elias Ellison70db5362018-11-01 10:30:33 -070012853 test_str = []
12854 for pair in self.type_input_return_pairs():
12855 cu = torch.jit.CompilationUnit(self.format_code(code, pair))
Zachary DeVito31524bd2019-04-25 15:43:52 -070012856 test_str.append(str(cu.foo.schema))
Ansley Ussery6831d8e2021-09-03 06:10:37 -070012857 self.assertExpectedStripMangled("\n".join(test_str) + "\n")
James Reed32bb4042018-08-14 17:56:52 -070012858
12859 # String frontend , MyPy-style type comments , Script method
12860 def test_annot_string_mypy_method(self):
12861 class TestModule(torch.jit.ScriptModule):
12862 def __init__(self):
Xuehai Pan046e88a2023-02-12 22:20:50 +000012863 super().__init__()
James Reed32bb4042018-08-14 17:56:52 -070012864
Shen Li10224432021-08-12 11:39:31 -070012865 code = '''
Elias Ellison70db5362018-11-01 10:30:33 -070012866 def foo(self, x, y):
12867 # type: ({input}, Tuple[Tensor, Tensor]) -> Tuple[{output}, {output}]
12868 return x, x
Shen Li10224432021-08-12 11:39:31 -070012869 '''
James Reed32bb4042018-08-14 17:56:52 -070012870
Elias Ellison70db5362018-11-01 10:30:33 -070012871 test_str = []
12872 for pair in self.type_input_return_pairs():
Michael Suo34126272019-10-12 09:49:56 -070012873 # clear the class registry as we will be defining foo multiple times
12874 jit_utils.clear_class_registry()
Elias Ellison70db5362018-11-01 10:30:33 -070012875 tm = TestModule()
12876 tm.define(self.format_code(code, pair))
Zachary DeVito31524bd2019-04-25 15:43:52 -070012877 test_str.append(str(tm.foo.schema))
Ansley Ussery6831d8e2021-09-03 06:10:37 -070012878 self.assertExpectedStripMangled("\n".join(test_str) + "\n")
James Reed32bb4042018-08-14 17:56:52 -070012879
James Reed32bb4042018-08-14 17:56:52 -070012880 # Python AST Frontend , Python 3-style type annotations , Script function
James Reed32bb4042018-08-14 17:56:52 -070012881 def test_annot_ast_py3_fn(self):
Shen Li10224432021-08-12 11:39:31 -070012882 code = dedent('''
Elias Ellison70db5362018-11-01 10:30:33 -070012883 from typing import Tuple, List, Optional
James Reed32bb4042018-08-14 17:56:52 -070012884 from torch import Tensor
David Riazati556ff8e2018-11-08 11:24:36 -080012885 from torch.jit.annotations import BroadcastingList2, BroadcastingList3
James Reed32bb4042018-08-14 17:56:52 -070012886 import torch
12887 @torch.jit.script
Elias Ellison70db5362018-11-01 10:30:33 -070012888 def foo(x : {input}, y : Tuple[Tensor, Tensor]) -> Tuple[{output}, {output}]:
James Reed32bb4042018-08-14 17:56:52 -070012889 return x, x
Shen Li10224432021-08-12 11:39:31 -070012890 ''')
Elias Ellison70db5362018-11-01 10:30:33 -070012891 test_str = []
12892 for pair in self.type_input_return_pairs():
Shen Li10224432021-08-12 11:39:31 -070012893 fn = jit_utils._get_py3_code(self.format_code(code, pair), 'foo')
Zachary DeVito31524bd2019-04-25 15:43:52 -070012894 test_str.append(str(fn.schema))
Ansley Ussery6831d8e2021-09-03 06:10:37 -070012895 self.assertExpectedStripMangled("\n".join(test_str) + "\n")
James Reed32bb4042018-08-14 17:56:52 -070012896
davidriazati7d3d5b72019-05-10 17:04:48 -070012897 def test_multiline_annot_ast_py3_fn(self):
Shen Li10224432021-08-12 11:39:31 -070012898 code = dedent('''
davidriazati7d3d5b72019-05-10 17:04:48 -070012899 from typing import Tuple, List, Optional
12900 from torch import Tensor
12901 from torch.jit.annotations import BroadcastingList2, BroadcastingList3
12902 import torch
12903 @torch.jit.script
12904 def foo(x, # type: {input}
12905 y # type: Tuple[Tensor, Tensor]
12906 ):
12907 # type: (...) -> Tuple[{output}, {output}]
12908 return x, x
Shen Li10224432021-08-12 11:39:31 -070012909 ''')
davidriazati7d3d5b72019-05-10 17:04:48 -070012910 test_str = []
12911
12912 for pair in self.type_input_return_pairs():
Shen Li10224432021-08-12 11:39:31 -070012913 fn = jit_utils._get_py3_code(self.format_code(code, pair), 'foo')
davidriazati7d3d5b72019-05-10 17:04:48 -070012914 args = fn.schema.arguments
12915 returns = fn.schema.returns
12916 self.assertEqual(str(args[0].type), pair[1])
12917 self.assertEqual(str(args[1].type), "Tuple[Tensor, Tensor]")
Shen Li10224432021-08-12 11:39:31 -070012918 self.assertEqual(str(returns[0].type), "Tuple[{}, {}]".format(pair[1], pair[1]))
davidriazati7d3d5b72019-05-10 17:04:48 -070012919
12920 def test_bad_multiline_annotations(self):
12921 with self.assertRaisesRegex(RuntimeError, "Return type line"):
12922 @torch.jit.script
Shen Li10224432021-08-12 11:39:31 -070012923 def bad_type_line(a, # type: Tensor
12924 b, # type: Tensor
12925 c # type: Tensor
12926 ):
davidriazati7d3d5b72019-05-10 17:04:48 -070012927 # type: (int, int, int) -> Tensor
davidriazati63c05bf2019-05-13 11:29:04 -070012928 # type: bad type line # noqa: F723
davidriazati7d3d5b72019-05-10 17:04:48 -070012929
12930 return a + b + c
12931
12932 with self.assertRaisesRegex(RuntimeError, "Return type line"):
12933 @torch.jit.script
Shen Li10224432021-08-12 11:39:31 -070012934 def bad_return_line(a, # type: Tensor
12935 b,
12936 c # type: Tensor
12937 ):
davidriazati7d3d5b72019-05-10 17:04:48 -070012938 # type: (int, int, int) -> Tensor
12939 return a + b + c
12940
12941 # TODO: this should be supported but is difficult to parse
12942 with self.assertRaisesRegex(RuntimeError, "Number of type annotations"):
12943 @torch.jit.script
Shen Li10224432021-08-12 11:39:31 -070012944 def missing_type(a, # type: Tensor
12945 b,
12946 c # type: Tensor
12947 ):
davidriazati7d3d5b72019-05-10 17:04:48 -070012948 # type: (...) -> Tensor
12949 return a + b + c
12950
James Reed32bb4042018-08-14 17:56:52 -070012951 # Python AST Frontend , Python 3-style type annotations , Script method
James Reed32bb4042018-08-14 17:56:52 -070012952 def test_annot_ast_py3_method(self):
Shen Li10224432021-08-12 11:39:31 -070012953 code = dedent('''
Elias Ellison70db5362018-11-01 10:30:33 -070012954 from typing import Tuple, List, Optional
James Reed32bb4042018-08-14 17:56:52 -070012955 from torch import Tensor
David Riazati556ff8e2018-11-08 11:24:36 -080012956 from torch.jit.annotations import BroadcastingList2, \\
12957 BroadcastingList3
James Reed32bb4042018-08-14 17:56:52 -070012958 import torch
12959 class FooModule(torch.jit.ScriptModule):
12960 @torch.jit.script_method
Elias Ellison70db5362018-11-01 10:30:33 -070012961 def foo(self, x : {input}, y : Tuple[Tensor, Tensor]) -> Tuple[{output}, {output}]:
James Reed32bb4042018-08-14 17:56:52 -070012962 return x, x
12963 instance = FooModule()
Shen Li10224432021-08-12 11:39:31 -070012964 ''')
Elias Ellison70db5362018-11-01 10:30:33 -070012965
12966 test_str = []
12967 for pair in self.type_input_return_pairs():
Shen Li10224432021-08-12 11:39:31 -070012968 fn = jit_utils._get_py3_code(self.format_code(code, pair), 'instance')
Zachary DeVito31524bd2019-04-25 15:43:52 -070012969 test_str.append(str(fn.foo.schema))
Ansley Ussery6831d8e2021-09-03 06:10:37 -070012970 self.assertExpectedStripMangled("\n".join(test_str) + "\n")
James Reed32bb4042018-08-14 17:56:52 -070012971
12972 # Python AST Frontend , MyPy-style type comments , Script function
James Reed32bb4042018-08-14 17:56:52 -070012973 def test_annot_ast_mypy_fn(self):
Shen Li10224432021-08-12 11:39:31 -070012974 code = dedent('''
James Reed32bb4042018-08-14 17:56:52 -070012975 import torch
12976 @torch.jit.script
12977 def foo(x, y):
Elias Ellison70db5362018-11-01 10:30:33 -070012978 # type: ({input}, Tuple[Tensor, Tensor]) -> Tuple[{output}, {output}]
James Reed32bb4042018-08-14 17:56:52 -070012979 return x, x
Shen Li10224432021-08-12 11:39:31 -070012980 ''')
Elias Ellison70db5362018-11-01 10:30:33 -070012981
12982 test_str = []
12983 for pair in self.type_input_return_pairs():
Shen Li10224432021-08-12 11:39:31 -070012984 fn = jit_utils._get_py3_code(self.format_code(code, pair), 'foo')
Zachary DeVito31524bd2019-04-25 15:43:52 -070012985 test_str.append(str(fn.schema))
Ansley Ussery6831d8e2021-09-03 06:10:37 -070012986 self.assertExpected("\n".join(test_str) + "\n")
James Reed32bb4042018-08-14 17:56:52 -070012987
12988 # Python AST Frontend , MyPy-style type comments , Script method
James Reed32bb4042018-08-14 17:56:52 -070012989 def test_annot_ast_mypy_method(self):
Shen Li10224432021-08-12 11:39:31 -070012990 code = dedent('''
James Reed32bb4042018-08-14 17:56:52 -070012991 import torch
12992 class FooModule(torch.jit.ScriptModule):
12993 @torch.jit.script_method
12994 def foo(self, x, y):
Elias Ellison70db5362018-11-01 10:30:33 -070012995 # type: ({input}, Tuple[Tensor, Tensor]) -> Tuple[{output}, {output}]
James Reed32bb4042018-08-14 17:56:52 -070012996 return x, x
12997 instance = FooModule()
Shen Li10224432021-08-12 11:39:31 -070012998 ''')
James Reed32bb4042018-08-14 17:56:52 -070012999
Elias Ellison70db5362018-11-01 10:30:33 -070013000 test_str = []
13001 for pair in self.type_input_return_pairs():
Shen Li10224432021-08-12 11:39:31 -070013002 fn = jit_utils._get_py3_code(self.format_code(code, pair), 'instance')
Zachary DeVito31524bd2019-04-25 15:43:52 -070013003 test_str.append(str(fn.foo.schema))
Ansley Ussery6831d8e2021-09-03 06:10:37 -070013004 self.assertExpectedStripMangled("\n".join(test_str) + "\n")
James Reed32bb4042018-08-14 17:56:52 -070013005
Yanan Caob9acfcd2021-02-08 11:16:52 -080013006 # Tests that "# type: ignore[*]" is supported in type lines and is
13007 # properly ignored.
13008 def test_mypy_type_ignore(self):
13009 @torch.jit.script
13010 def foo(x): # type: ignore
13011 return x
13012
13013 @torch.jit.script
13014 def bar(x): # type: ignore[no-redef]
13015 return x
13016
James Reed585e6b52018-08-20 14:07:44 -070013017 def test_method_casts_script(self):
Shen Li10224432021-08-12 11:39:31 -070013018 cast_types = [
13019 'byte', 'char', 'double', 'float', 'int', 'long', 'short'
13020 ]
James Reed585e6b52018-08-20 14:07:44 -070013021
13022 for cast_type in cast_types:
Shen Li10224432021-08-12 11:39:31 -070013023 cu = torch.jit.CompilationUnit('''
James Reed585e6b52018-08-20 14:07:44 -070013024 def cast_to(x):
13025 return x.{cast_type}()
Shen Li10224432021-08-12 11:39:31 -070013026 '''.format(cast_type=cast_type))
James Reed585e6b52018-08-20 14:07:44 -070013027
13028 x = torch.rand(3, 4, 5) * 128
13029 cu_result = cu.cast_to(x)
13030 reference = getattr(x, cast_type)()
13031 self.assertEqual(cu_result, reference)
13032
James Reed278e3042018-09-14 10:04:05 -070013033 def test_string_frontend_elif(self):
Shen Li10224432021-08-12 11:39:31 -070013034 code = '''
David Riazatidefd23b2019-06-25 16:17:49 -070013035 def func(niter):
13036 # type: (int)
James Reed278e3042018-09-14 10:04:05 -070013037 rv = 0
13038 for i in range(niter):
13039 if i % 3 == 0 and i % 5 == 0:
13040 rv += 35
13041 elif i % 3 == 0:
13042 rv += 3
13043 elif i % 5 == 0:
13044 rv += 5
13045 else:
13046 rv += i
13047 return rv
Shen Li10224432021-08-12 11:39:31 -070013048 '''
James Reed278e3042018-09-14 10:04:05 -070013049
David Riazatidefd23b2019-06-25 16:17:49 -070013050 self.checkScript(dedent(code), (101,))
James Reed278e3042018-09-14 10:04:05 -070013051
Elias Ellisonbee63442019-12-12 13:30:46 -080013052 def test_module_parameters_and_buffers(self):
David Riazatiaf78d4c2018-10-23 09:02:50 -070013053 weights = torch.randn(10, 10)
13054 bias = torch.randn(10)
13055 weights2 = torch.randn(10, 10)
13056 bias2 = torch.randn(10)
13057
David Riazatiaf78d4c2018-10-23 09:02:50 -070013058 class TestLinear(torch.nn.Module):
13059 def __init__(self, in_features, out_features):
Xuehai Pan046e88a2023-02-12 22:20:50 +000013060 super().__init__()
David Riazatiaf78d4c2018-10-23 09:02:50 -070013061 self.in_features = in_features
13062 self.out_features = out_features
Yukio Siraichi93bf0ae2021-04-11 15:43:54 -070013063 self.weight = torch.nn.Parameter(torch.empty(out_features, in_features))
13064 self.bias = torch.nn.Parameter(torch.empty(out_features))
Shen Li10224432021-08-12 11:39:31 -070013065 self.register_buffer('counter', torch.ones(out_features))
David Riazatiaf78d4c2018-10-23 09:02:50 -070013066 self.reset_parameters()
13067
13068 def reset_parameters(self):
13069 torch.nn.init.kaiming_uniform_(self.weight, a=math.sqrt(5))
13070 if self.bias is not None:
13071 fan_in, _ = torch.nn.init._calculate_fan_in_and_fan_out(self.weight)
13072 bound = 1 / math.sqrt(fan_in)
13073 torch.nn.init.uniform_(self.bias, -bound, bound)
13074
David Riazatiaf78d4c2018-10-23 09:02:50 -070013075 def forward(self, input):
13076 return F.linear(input, self.weight, self.bias) + self.counter
13077
13078 # Initialize a ScriptModule that uses the weak module above multiple times
13079 class Strong(torch.jit.ScriptModule):
13080 def __init__(self):
Xuehai Pan046e88a2023-02-12 22:20:50 +000013081 super().__init__()
David Riazatiaf78d4c2018-10-23 09:02:50 -070013082 self.fc1 = TestLinear(10, 10)
13083 self.fc1.weight = torch.nn.Parameter(weights)
13084 self.fc1.bias = torch.nn.Parameter(bias)
13085 self.fc2 = TestLinear(10, 10)
13086 self.fc2.weight = torch.nn.Parameter(weights2)
13087 self.fc2.bias = torch.nn.Parameter(bias2)
13088
13089 @torch.jit.script_method
13090 def forward(self, x):
13091 return x + self.fc1(x) + self.fc1(x) + self.fc2(x)
13092
13093 strong_mod = Strong()
David Riazatiaf78d4c2018-10-23 09:02:50 -070013094
13095 # Run same calculation as module
13096 inp = torch.ones(10)
13097 lin = torch.nn.Linear(10, 10)
13098 lin.weight = torch.nn.Parameter(weights)
13099 lin.bias = torch.nn.Parameter(bias)
13100 lin2 = torch.nn.Linear(10, 10)
13101 lin2.weight = torch.nn.Parameter(weights2)
13102 lin2.bias = torch.nn.Parameter(bias2)
Shen Li10224432021-08-12 11:39:31 -070013103 expected_result = inp + (lin(inp) + torch.ones(10)) * 2 + lin2(inp) + torch.ones(10)
David Riazatiaf78d4c2018-10-23 09:02:50 -070013104
13105 self.assertEqual(strong_mod(inp), expected_result)
David Riazati4655b7b2018-12-06 21:50:35 -080013106 self.assertExportImportModule(strong_mod, (inp,))
David Riazatiaf78d4c2018-10-23 09:02:50 -070013107
Elias Ellisonbee63442019-12-12 13:30:46 -080013108 def test_module_copying(self):
David Riazatiaf78d4c2018-10-23 09:02:50 -070013109 class Submodule(torch.nn.Module):
David Riazatiaf78d4c2018-10-23 09:02:50 -070013110 def forward(self, x):
13111 return x + 100
13112
David Riazatiaf78d4c2018-10-23 09:02:50 -070013113 class Weak(torch.nn.Module):
13114 def __init__(self, in_features, out_features):
Xuehai Pan046e88a2023-02-12 22:20:50 +000013115 super().__init__()
David Riazatiaf78d4c2018-10-23 09:02:50 -070013116 self.weight = torch.nn.Parameter(torch.ones(out_features, in_features))
David Riazati5163a282018-11-13 13:49:01 -080013117 self.bias = torch.nn.Parameter(torch.ones(out_features))
David Riazatiaf78d4c2018-10-23 09:02:50 -070013118 self.register_buffer("buffer", torch.ones(out_features))
13119 self.submodule = Submodule()
13120
David Riazatiaf78d4c2018-10-23 09:02:50 -070013121 def forward(self, x):
Shen Li10224432021-08-12 11:39:31 -070013122 return F.linear(x, self.weight, self.bias) \
13123 + self.buffer + self.submodule(x)
David Riazatiaf78d4c2018-10-23 09:02:50 -070013124
13125 class Strong(torch.jit.ScriptModule):
13126 def __init__(self, weak):
Xuehai Pan046e88a2023-02-12 22:20:50 +000013127 super().__init__()
David Riazatiaf78d4c2018-10-23 09:02:50 -070013128 self.weak = weak
13129
13130 @torch.jit.script_method
13131 def forward(self, x):
13132 return self.weak(x)
13133
13134 inp = torch.ones(5, 5) * 5
13135 weak_mod = Weak(5, 5)
13136 strong_mod = Strong(weak_mod)
13137
13138 self.assertTrue(isinstance(strong_mod.weak, torch.jit.ScriptModule))
13139 self.assertFalse(isinstance(weak_mod, torch.jit.ScriptModule))
13140
13141 self.assertIs(strong_mod.weak.weight, weak_mod.weight)
13142 self.assertIs(strong_mod.weak.buffer, weak_mod.buffer)
Elias Ellisonbee63442019-12-12 13:30:46 -080013143 # strong_mod.weak.submodule has been recursively scripted
13144 self.assertIsNot(strong_mod.weak.submodule, weak_mod.submodule)
David Riazatiaf78d4c2018-10-23 09:02:50 -070013145
13146 weak_mod.weight.data += torch.ones(5, 5) * 100
13147 self.assertTrue(strong_mod(inp).allclose(weak_mod(inp)))
13148
13149 # Re-assignment is not tracked
13150 weak_mod.weight = torch.nn.Parameter(torch.ones(5, 5) * 100)
13151 self.assertFalse(strong_mod(inp).allclose(weak_mod(inp)))
13152
David Riazatibc74ec82018-10-31 09:29:18 -070013153 def test_backend_cudnn_enabled(self):
13154 # Only test that this compiles
13155 @torch.jit.script
13156 def fn(x):
13157 if torch.backends.cudnn.enabled:
13158 x = x + 2
13159 else:
13160 x = x + 3
13161 return x
13162
Zachary DeVitoce0d3e92018-10-26 10:35:10 -070013163 def test_inplace_add(self):
Shen Li10224432021-08-12 11:39:31 -070013164
Zachary DeVitoce0d3e92018-10-26 10:35:10 -070013165 def foo(a, b):
13166 c = a + b
13167 c.add_(b)
13168 return c
13169 self.checkScript(foo, (torch.rand(3), torch.rand(3)))
13170
13171 def test_add_out(self):
13172 def foo(a, b):
13173 c = a + b
13174 e = 2 * a
13175 torch.add(c, b, out=e)
13176 return e
13177 self.checkScript(foo, (torch.rand(3), torch.rand(3)))
13178
Tugsbayasgalan Manlaibaatarb80c6f82021-04-11 02:28:49 -070013179 def test_tuple_error_msg(self):
13180 def fn(t: Any):
13181 if isinstance(t, tuple):
13182 a, b = t
13183 return a + b
Shen Li10224432021-08-12 11:39:31 -070013184 with self.assertRaisesRegexWithHighlight(RuntimeError, "Provided tuple is not fully defined/refined", "t"):
Tugsbayasgalan Manlaibaatarb80c6f82021-04-11 02:28:49 -070013185 s = torch.jit.script(fn)
13186
Michael Suo5fbaf0e2018-11-01 23:58:35 -070013187 def test_augmented_assign(self):
13188 def foo(a, b):
13189 a += b
13190 a -= b
13191 a /= b
13192 a *= b
13193 return a, b
Elias Ellison7fc3aa82019-03-04 22:38:41 -080013194 self.checkScript(foo, (torch.rand(3), torch.rand(3)))
Michael Suo5fbaf0e2018-11-01 23:58:35 -070013195
Tugsbayasgalan Manlaibaatare658d7c2021-02-23 10:38:21 -080013196 def test_ignored_props(self):
13197 class A(nn.Module):
13198 __jit_ignored_attributes__ = ["ignored", "ignored_return_val"]
13199
Tugsbayasgalan Manlaibaatare658d7c2021-02-23 10:38:21 -080013200 @property
13201 def ignored(self):
13202 raise ValueError("shouldn't be called")
13203
13204 @property
13205 def ignored_return_val(self):
13206 return 1
13207
13208 @torch.jit.ignore
13209 def call(self):
13210 return self.ignored_return_val
13211
13212 f = torch.jit.script(A())
13213 # jank way to test if there is no error
13214 self.assertTrue(isinstance(f, torch.jit.ScriptModule))
13215 self.assertTrue(isinstance(f.call(), property))
13216
Shen Li10224432021-08-12 11:39:31 -070013217
David Riazati23e3a122018-11-05 16:53:45 -080013218 def test_pass(self):
13219 def foo(x):
13220 # type: (bool) -> int
13221 for _i in range(3):
13222 pass
13223 if x:
13224 pass
13225 else:
13226 pass
13227 return 3
13228
13229 self.checkScript(foo, (True,))
13230
Michael Suo21991c02018-11-07 23:06:03 -080013231 def test_lhs_indexing(self):
13232 def foo(a, b):
13233 a = a.clone()
13234 a[0] = b
13235 return a
13236 self.checkScript(foo, (torch.rand(2, 3), torch.rand(3)))
13237
Michael Suo2fa3c832018-11-26 12:02:09 -080013238 def test_lhs_advanced_indexing_assignment(self):
13239 def foo(x, y):
13240 a = torch.exp(x)
13241 b = x == 1
13242 a[b] = y[b]
13243 return a
13244 self.checkScript(foo, (torch.ones(4, 3), torch.ones(4, 3)))
13245
13246 def test_lhs_advanced_indexing_augmented_assignment(self):
13247 def foo(x, y):
13248 a = torch.exp(x)
13249 b = x == 1
13250 a[b] += y[b]
13251 return a
13252 self.checkScript(foo, (torch.ones(4, 3), torch.ones(4, 3)))
13253
Michael Suo21991c02018-11-07 23:06:03 -080013254 def test_lhs_indexing_list(self):
13255 def foo(a, b):
13256 ls = [a]
13257 ls[0] = b
13258 return ls
13259 self.checkScript(foo, (torch.rand(2, 3), torch.rand(3)))
13260
James Reed19759172018-11-29 20:30:02 -080013261 def test_inplace_copy_script(self):
13262 def foo(x):
13263 a = torch.rand(3, 4)
13264 a.copy_(x)
13265 return a
13266 self.checkScript(foo, (torch.rand(3, 4),))
13267
Michael Suo21991c02018-11-07 23:06:03 -080013268 def test_lhs_indexing_increment(self):
13269 def foo(a, b):
13270 a[0] += b
13271 return a
13272 self.checkScript(foo, (torch.rand(2, 3), torch.rand(3)))
13273
13274 def test_lhs_indexing_increment_list(self):
13275 def foo(a, b):
13276 a = a.clone()
13277 ls = [a, b]
13278 ls[0] += b
13279 return ls
13280 self.checkScript(foo, (torch.rand(2, 3), torch.rand(3)))
13281
13282 def test_lhs_indexing_increment_list_prim(self):
13283 def foo():
13284 ls = [1, 2, 3]
13285 ls[0] += 5
13286 return ls
13287 self.checkScript(foo, ())
13288
13289 def test_lhs_indexing_multi(self):
13290 def foo(a, b):
13291 a = a.clone()
13292 foo, a[0], bar = (1, b, 3)
13293 return foo, a, bar
13294 self.checkScript(foo, (torch.rand(2, 3), torch.rand(3)))
13295
David Riazatid75f7512018-11-27 19:33:47 -080013296 def test_bool_dispatch(self):
Michael Suoca1b8eb2020-07-13 16:57:41 -070013297 with torch._jit_internal._disable_emit_hooks(): # TODO: Python print broadcasting list
David Riazatia23863f2018-12-03 23:49:39 -080013298 def kwarg_false(x):
13299 # type: (Tensor) -> Tensor
13300 return F.max_pool1d(x, 1, 1, return_indices=False)
13301 self.checkScript(kwarg_false, (torch.randn(3, 3, 3),))
David Riazatid75f7512018-11-27 19:33:47 -080013302
David Riazatia23863f2018-12-03 23:49:39 -080013303 def kwarg_true(x):
13304 # type: (Tensor) -> Tuple[Tensor, Tensor]
13305 return F.max_pool1d(x, 1, 1, return_indices=True)
13306 self.checkScript(kwarg_true, (torch.randn(3, 3, 3),))
David Riazatid75f7512018-11-27 19:33:47 -080013307
David Riazatia23863f2018-12-03 23:49:39 -080013308 def full_kwarg_false(x):
13309 # type: (Tensor) -> Tensor
13310 return F.max_pool1d(x, 1, 1, ceil_mode=False, return_indices=False)
13311 self.checkScript(full_kwarg_false, (torch.randn(3, 3, 3),))
David Riazatid75f7512018-11-27 19:33:47 -080013312
David Riazatia23863f2018-12-03 23:49:39 -080013313 def full_kwarg_true(x):
13314 # type: (Tensor) -> Tuple[Tensor, Tensor]
13315 return F.max_pool1d(x, 1, 1, ceil_mode=False, return_indices=True)
13316 self.checkScript(full_kwarg_true, (torch.randn(3, 3, 3),))
David Riazatid75f7512018-11-27 19:33:47 -080013317
David Riazatia23863f2018-12-03 23:49:39 -080013318 def use_default(x):
13319 # type: (Tensor) -> Tensor
13320 return F.max_pool1d(x, 1, 1)
13321 self.checkScript(use_default, (torch.randn(3, 3, 3),))
David Riazatid75f7512018-11-27 19:33:47 -080013322
David Riazatia23863f2018-12-03 23:49:39 -080013323 def arg_false(x):
13324 # type: (Tensor) -> Tensor
13325 return F.max_pool1d(x, 1, 1, 0, 1, False, False)
13326 self.checkScript(arg_false, (torch.randn(3, 3, 3),))
David Riazatid75f7512018-11-27 19:33:47 -080013327
David Riazatia23863f2018-12-03 23:49:39 -080013328 def arg_true(x):
13329 # type: (Tensor) -> Tuple[Tensor, Tensor]
13330 return F.max_pool1d(x, 1, 1, 0, 1, False, True)
13331 self.checkScript(arg_true, (torch.randn(3, 3, 3),))
David Riazatid75f7512018-11-27 19:33:47 -080013332
David Riazati89c3dbc2018-11-29 22:18:43 -080013333 def test_infer_size(self):
13334 from torch._C import _infer_size
13335
13336 def fn(x, y):
13337 # type: (Tensor, Tensor) -> List[int]
13338 return _infer_size(x.size(), y.size())
13339
13340 self.checkScript(fn, (torch.ones(2, 4, 2), torch.ones(2, 4, 2)))
13341
David Riazati7b0ef312019-03-29 18:23:28 -070013342 def test_hash(self):
13343 def tester(fn, inputs):
13344 for x in inputs:
13345 for y in inputs:
13346 if x == y:
13347 self.assertEqual(fn(x), fn(y))
13348 else:
13349 self.assertNotEqual(fn(x), fn(y))
13350
13351 @torch.jit.script
13352 def int_hash(x):
13353 # type: (int) -> int
13354 return hash(x)
13355
13356 @torch.jit.script
13357 def float_hash(x):
13358 # type: (float) -> int
13359 return hash(x)
13360
13361 @torch.jit.script
13362 def str_hash(x):
13363 # type: (str) -> int
13364 return hash(x)
13365
13366 tester(int_hash, (20, 21, 22))
13367 tester(float_hash, (20.0, 21.00001, 22.443))
13368 tester(str_hash, ("", "hello", "a"))
13369
Elias Ellison7ab25b22020-03-23 17:01:56 -070013370 def test_id(self):
Ailing Zhang77bbbf02020-03-26 20:18:12 -070013371 with self.assertRaisesRegex(RuntimeError, "Expected a value"):
Elias Ellison7ab25b22020-03-23 17:01:56 -070013372 @torch.jit.script
13373 def test_id_scalars():
Ailing Zhang77bbbf02020-03-26 20:18:12 -070013374 return id(2) == id(None)
Elias Ellison7ab25b22020-03-23 17:01:56 -070013375
13376 @torch.jit.script
Aaron Gokaslan8fce9a02023-02-07 21:10:52 +000013377 class FooTest:
Elias Ellison7ab25b22020-03-23 17:01:56 -070013378 def __init__(self, x):
13379 self.foo = x
13380
13381 def getFooTest(self):
13382 return self.foo
13383
13384 @torch.jit.script
13385 def test_id_class_types():
13386 obj1 = FooTest(torch.tensor(3))
13387 obj2 = FooTest(torch.tensor(2))
13388 assert obj1 is not obj2
13389 assert id(obj1) != id(obj2)
13390 assert id(obj1) != id(None)
13391 return True
13392
13393 self.assertTrue(test_id_class_types())
13394
Michael Suob768db02018-12-03 13:27:59 -080013395 def test_mutable_dce(self):
13396 @torch.jit.script
13397 def foo():
13398 a = torch.rand(2, 3)
13399 a += torch.rand(2, 3)
13400 b = torch.rand(2, 3)
13401 b += torch.rand(2, 3)
13402 # b should be cleaned up but not a
13403 return a
13404
Shen Li10224432021-08-12 11:39:31 -070013405 FileCheck().check_count("aten::rand", 2, exactly=True) \
13406 .check_count("aten::add", 1, exactly=True).run(str(foo.graph))
Michael Suob768db02018-12-03 13:27:59 -080013407
13408 def test_mutable_dce_block(self):
13409 @torch.jit.script
13410 def foo():
13411 a = torch.rand(2, 3)
13412 a += torch.rand(2, 3)
13413 b = torch.rand(2, 3)
13414 if bool(a > torch.zeros(2, 3)):
13415 b += torch.rand(2, 3)
13416 a += torch.rand(2, 3)
13417 # a should be cleaned up but not b
13418 return b
13419
Shen Li10224432021-08-12 11:39:31 -070013420 FileCheck().check("prim::If").check_count("aten::rand", 1, exactly=True) \
13421 .run(str(foo.graph))
Michael Suob768db02018-12-03 13:27:59 -080013422
13423 def test_mutable_dce_graph_input(self):
13424 @torch.jit.script
13425 def foo(a):
13426 a += torch.rand(2, 3)
13427 # shouldn't clean up `a` even though it's not used in the output
13428
Elias Ellison411cf432019-02-25 16:11:47 -080013429 FileCheck().check("aten::rand").check("aten::add").run(str(foo.graph))
Michael Suob768db02018-12-03 13:27:59 -080013430
13431 def test_mutable_dce_list(self):
13432 @torch.jit.script
13433 def foo(a):
13434 l = []
13435 l.append(a)
13436 c = l[0]
13437 b = torch.rand(2, 3)
13438 c += torch.rand(2, 3)
13439 return b
13440
Elias Ellison411cf432019-02-25 16:11:47 -080013441 # c does not get cleaned up because there is a wildcard + mutation
13442 FileCheck().check_count("aten::rand", 2, exactly=True).run(str(foo.graph))
Michael Suob768db02018-12-03 13:27:59 -080013443
13444 def test_mutable_dce_loop(self):
13445 @torch.jit.script
13446 def foo(a):
13447 l = []
13448 l.append(a)
13449 i = 0
13450 b = torch.rand(2, 3)
13451 while i < 1:
13452 dead = torch.rand(2, 3)
13453 c = l[0]
13454 c += torch.rand(2, 3)
13455 i += 1
13456 return b
13457
Shen Li10224432021-08-12 11:39:31 -070013458 FileCheck().check("prim::Loop").check_not("aten::rand").check("aten::__getitem__") \
13459 .check_count("aten::rand", 1, exactly=True).run(str(foo.graph))
Michael Suob768db02018-12-03 13:27:59 -080013460
Michael Suo194acd02019-08-29 23:30:42 -070013461 def test_mutable_dce_indirect_wildcards(self):
13462 def fn():
13463 x = torch.ones(2, 3)
13464 x_1 = x.view(-1)
13465 l = []
13466 l.append(x_1)
13467 x_view = l[0]
13468 x.add_(torch.ones(2, 3))
13469 return x_view
13470 self.checkScript(fn, ())
13471
13472 def test_mutable_dce_indirect_wildcard_write(self):
13473 def fn():
13474 indexes = torch.jit.annotate(List[Tensor], [])
13475 word_ids = torch.zeros(10, dtype=torch.int32)
13476 word_ids[1] = 1
13477 indexes.append(word_ids)
13478
13479 return word_ids
13480 self.checkScript(fn, ())
13481
Michael Suodc84ff12019-01-30 11:06:32 -080013482 def test_mutable_dce_wildcards(self):
13483 def fn():
13484 x = torch.ones(2, 3)
13485 l = []
13486 l.append(x)
13487 x_view = l[0]
13488 x.add_(torch.ones(2, 3))
13489 return x_view
13490
Nikolay Korovaiko5b702ab2019-11-11 13:39:03 -080013491 self.checkScript(fn, (), profiling=ProfilingMode.SIMPLE)
Michael Suodc84ff12019-01-30 11:06:32 -080013492
David Riazati1dbc7cf2018-12-17 13:08:03 -080013493 def test_cpp_function_tensor_str(self):
13494 x = torch.randn(2, 2)
13495 scale = torch.randn(2, 2, requires_grad=True)
13496 shift = torch.randn(2, 2, requires_grad=True)
13497
13498 @torch.jit.script
13499 def fn(x, scale, shift):
13500 return scale * x + shift
13501
13502 with self.capture_stdout() as captured:
13503 print(fn(x, scale, shift))
13504
David Riazatibe01c902019-04-01 11:58:28 -070013505 def test_string_index(self):
13506 def fn(x):
Elias Ellison19d3a7a2019-11-07 17:26:31 -080013507 # type: (str)
13508 return x[2], x[-1]
David Riazatibe01c902019-04-01 11:58:28 -070013509
13510 self.checkScript(fn, ("abcde",))
13511
13512 def test_ord(self):
13513 def fn(x):
13514 # type: (str) -> int
13515 return ord(x)
13516
13517 self.checkScript(fn, ("h"))
13518 self.checkScript(fn, ("y"))
13519
David Riazatie7b95262019-04-19 10:20:43 -070013520 def index_str_to_tensor(s):
Ansley Ussery6831d8e2021-09-03 06:10:37 -070013521 # type: (str) -> Tensor
13522 return torch.tensor(ord(s)) # noqa: T484
David Riazatie7b95262019-04-19 10:20:43 -070013523
Shen Li10224432021-08-12 11:39:31 -070013524 s = u'\u00a3'.encode('utf8')[:1]
David Riazatie7b95262019-04-19 10:20:43 -070013525 self.checkScript(index_str_to_tensor, (s,))
13526
Ailing Zhangff1172d2019-06-16 09:45:14 -070013527 def test_chr(self):
13528 def fn(x):
13529 # type: (int) -> str
13530 return chr(x)
13531
13532 self.checkScript(fn, (1,))
13533 self.checkScript(fn, (97,))
13534
13535 def test_round(self):
13536 def round_float(x):
13537 # type: (float) -> float
13538 return round(x)
13539
13540 def round_int(x):
13541 # type: (int) -> float
13542 return round(x)
13543
13544 self.checkScript(round_float, (1.5,))
13545 self.checkScript(round_int, (2,))
13546
Ailing Zhangff1172d2019-06-16 09:45:14 -070013547 def test_convert_base(self):
13548 def test_hex(x):
13549 # type: (int) -> str
13550 return hex(x)
13551
13552 def test_oct(x):
13553 # type: (int) -> str
13554 return oct(x)
13555
13556 def test_bin(x):
13557 # type: (int) -> str
13558 return bin(x)
13559
13560 numbers = [-1000, -10, 0, 1, 10, 2343]
13561 for n in numbers:
13562 self.checkScript(test_bin, (n,))
13563 self.checkScript(test_oct, (n,))
13564 self.checkScript(test_hex, (n,))
13565
Shen Li10224432021-08-12 11:39:31 -070013566 @unittest.skipIf(IS_WINDOWS or IS_SANDCASTLE, "NYI: TemporaryFileName support for Windows or Sandcastle")
davidriazaticd28ff52019-05-17 14:29:37 -070013567 def test_get_set_state(self):
Michael Suo77c08aa2019-08-11 15:43:28 -070013568 class Root(torch.jit.ScriptModule):
Shen Li10224432021-08-12 11:39:31 -070013569 __constants__ = ['number']
davidriazaticd28ff52019-05-17 14:29:37 -070013570
Michael Suo77c08aa2019-08-11 15:43:28 -070013571 def __init__(self, number):
Xuehai Pan046e88a2023-02-12 22:20:50 +000013572 super().__init__()
Shen Li10224432021-08-12 11:39:31 -070013573 self.register_buffer('buffer1', torch.ones(2, 2))
13574 self.register_buffer('buffer2', torch.ones(2, 2))
davidriazaticd28ff52019-05-17 14:29:37 -070013575 self.number = number
davidriazaticd28ff52019-05-17 14:29:37 -070013576
13577 @torch.jit.script_method
13578 def __getstate__(self):
davidriazati00460922019-10-07 13:50:23 -070013579 return (self.buffer1, self.buffer2, 74, self.training)
davidriazaticd28ff52019-05-17 14:29:37 -070013580
13581 @torch.jit.script_method
13582 def __setstate__(self, state):
davidriazaticd28ff52019-05-17 14:29:37 -070013583 self.buffer1 = state[0] + 10
13584 self.buffer2 = state[1] + 10
davidriazati00460922019-10-07 13:50:23 -070013585 self.training = state[3]
davidriazaticd28ff52019-05-17 14:29:37 -070013586
Michael Suo77c08aa2019-08-11 15:43:28 -070013587 class M(torch.jit.ScriptModule):
Shen Li10224432021-08-12 11:39:31 -070013588 __constants__ = ['number']
Michael Suo77c08aa2019-08-11 15:43:28 -070013589
13590 def __init__(self, number, submodule):
Xuehai Pan046e88a2023-02-12 22:20:50 +000013591 super().__init__()
Shen Li10224432021-08-12 11:39:31 -070013592 self.register_buffer('buffer1', torch.ones(2, 2))
13593 self.register_buffer('buffer2', torch.ones(2, 2))
Michael Suo77c08aa2019-08-11 15:43:28 -070013594 self.number = number
13595 self.submodule = submodule
13596
13597 @torch.jit.script_method
13598 def __getstate__(self):
davidriazati00460922019-10-07 13:50:23 -070013599 return (self.buffer1, self.buffer2, 74, self.submodule, self.training)
Michael Suo77c08aa2019-08-11 15:43:28 -070013600
13601 @torch.jit.script_method
13602 def __setstate__(self, state):
13603 self.buffer1 = state[0] + 10
13604 self.buffer2 = state[1] + 10
13605 self.submodule = state[3]
davidriazati00460922019-10-07 13:50:23 -070013606 self.training = state[4]
Michael Suo77c08aa2019-08-11 15:43:28 -070013607
davidriazaticd28ff52019-05-17 14:29:37 -070013608 with TemporaryFileName() as fname:
Michael Suo77c08aa2019-08-11 15:43:28 -070013609 m = M(23, submodule=Root(99))
davidriazaticd28ff52019-05-17 14:29:37 -070013610 m.save(fname)
13611 loaded = torch.jit.load(fname)
13612
13613 # Check original module
13614 self.assertEqual(m.buffer1, torch.ones(2, 2))
13615 self.assertEqual(m.buffer2, torch.ones(2, 2))
13616
13617 # Check top level module
13618 self.assertEqual(loaded.buffer1, torch.ones(2, 2) + 10)
13619 self.assertEqual(loaded.buffer2, torch.ones(2, 2) + 10)
13620
13621 # Check submodule
13622 self.assertEqual(loaded.submodule.buffer1, torch.ones(2, 2) + 10)
13623 self.assertEqual(loaded.submodule.buffer2, torch.ones(2, 2) + 10)
13624
davidriazatifcdfc352019-07-22 12:25:09 -070013625 # Check simpler module
13626 class NoArgState(torch.nn.Module):
13627 def __init__(self):
Xuehai Pan046e88a2023-02-12 22:20:50 +000013628 super().__init__()
Shen Li10224432021-08-12 11:39:31 -070013629 self.register_buffer('buffer1', torch.ones(2, 2))
13630 self.register_buffer('buffer2', torch.ones(2, 2))
davidriazatifcdfc352019-07-22 12:25:09 -070013631
13632 def forward(self):
13633 pass
13634
13635 @torch.jit.export
13636 def __getstate__(self):
Michael Suo34126272019-10-12 09:49:56 -070013637 return 5, self.training
davidriazatifcdfc352019-07-22 12:25:09 -070013638
13639 @torch.jit.export
Michael Suo77c08aa2019-08-11 15:43:28 -070013640 def __setstate__(self, state):
davidriazati00460922019-10-07 13:50:23 -070013641 self.buffer1 = torch.ones(2, 2) + state[0]
davidriazatifcdfc352019-07-22 12:25:09 -070013642 self.buffer2 = torch.ones(2, 2) + 10
davidriazati00460922019-10-07 13:50:23 -070013643 self.training = state[1]
davidriazatifcdfc352019-07-22 12:25:09 -070013644
13645 with TemporaryFileName() as fname:
13646 m = torch.jit.script(NoArgState())
13647 m.save(fname)
13648 loaded = torch.jit.load(fname)
Michael Suo77c08aa2019-08-11 15:43:28 -070013649 self.assertEqual(loaded.buffer1, torch.ones(2, 2) + 5)
davidriazatifcdfc352019-07-22 12:25:09 -070013650 self.assertEqual(loaded.buffer2, torch.ones(2, 2) + 10)
13651
Shen Li10224432021-08-12 11:39:31 -070013652
13653
David Riazatibe01c902019-04-01 11:58:28 -070013654 def test_string_slicing(self):
13655 def fn1(x):
13656 # type: (str) -> str
13657 return x[1:3]
13658
13659 def fn2(x):
13660 # type: (str) -> str
13661 return x[-1:3]
13662
13663 def fn3(x):
13664 # type: (str) -> str
13665 return x[3:1]
13666
13667 def fn4(x):
13668 # type: (str) -> str
13669 return x[3:100]
13670
13671 self.checkScript(fn1, ("abcdefghi",))
13672 self.checkScript(fn2, ("abcdefghi",))
13673 self.checkScript(fn3, ("abcdefghi",))
13674 self.checkScript(fn4, ("abcdefghi",))
13675
Elias Ellisonca76c822019-07-26 16:36:55 -070013676 def test_early_return_closure(self):
Shen Li10224432021-08-12 11:39:31 -070013677 code = dedent('''
Elias Ellisonca76c822019-07-26 16:36:55 -070013678 def tanh(self):
13679 output = torch.tanh(self)
13680 def backward(grad_output):
13681 pass
13682 return output, backward
Shen Li10224432021-08-12 11:39:31 -070013683 ''')
Elias Ellisonca76c822019-07-26 16:36:55 -070013684 cu = torch.jit.CompilationUnit(code)
13685 g = cu.tanh.graph
Shen Li10224432021-08-12 11:39:31 -070013686 FileCheck().check_count("prim::Closure_0", 2).check("NoneType = prim::Constant") \
13687 .check_next("return").run(g)
Zachary DeVito6bf05bf2018-12-21 13:46:12 -080013688
Shen Li10224432021-08-12 11:39:31 -070013689 code = dedent('''
Elias Ellisonca76c822019-07-26 16:36:55 -070013690 def tanh(self):
13691 output = torch.tanh(self)
13692 def backward(grad_output):
13693 a = 1
Elias Ellisondaa85cf2020-05-15 12:20:13 -070013694 if output:
Elias Ellisonca76c822019-07-26 16:36:55 -070013695 return 1
13696 else:
13697 a = 2
13698 return a
13699 return output, backward
Shen Li10224432021-08-12 11:39:31 -070013700 ''')
Elias Ellisonca76c822019-07-26 16:36:55 -070013701 cu = torch.jit.CompilationUnit(code)
13702 g = cu.tanh.graph
Shen Li10224432021-08-12 11:39:31 -070013703 FileCheck().check_count("prim::Closure_0", 2).check("int = prim::If") \
13704 .run(g)
Elias Ellisonca76c822019-07-26 16:36:55 -070013705
Shen Li10224432021-08-12 11:39:31 -070013706 code = dedent('''
Elias Ellisonca76c822019-07-26 16:36:55 -070013707 def loop_in_closure(self):
13708 output = torch.tanh(self)
13709 def backward(grad_output):
13710 for i in range(3):
13711 return 1
13712 return 4
13713 return output, backward
Shen Li10224432021-08-12 11:39:31 -070013714 ''')
Elias Ellisonca76c822019-07-26 16:36:55 -070013715 cu = torch.jit.CompilationUnit(code)
13716 fc = FileCheck()
James Reed68e07962021-04-12 17:33:20 -070013717 fc.check("prim::Closure").check("(Tensor, NoneType) = prim::TupleConstruct")
Elias Ellisonca76c822019-07-26 16:36:55 -070013718 # Loop then two if's added in exit transform
Michael Suodc817632020-10-28 16:23:58 -070013719 fc.check("prim::Closure").check("prim::Loop").check_count("prim::If", 2)
Elias Ellisonca76c822019-07-26 16:36:55 -070013720 fc.run(cu.loop_in_closure.graph)
13721
Shen Li10224432021-08-12 11:39:31 -070013722 code = dedent('''
Elias Ellisonca76c822019-07-26 16:36:55 -070013723 def tanh(self):
13724 output = torch.tanh(self)
13725 def backward(grad_output):
Elias Ellisond1b8da72020-11-20 11:14:59 -080013726 if 1 == 1:
Elias Ellisonca76c822019-07-26 16:36:55 -070013727 return 1
13728 else:
13729 return 1.
13730 return output, backward
Shen Li10224432021-08-12 11:39:31 -070013731 ''')
Elias Ellisonca76c822019-07-26 16:36:55 -070013732 with self.assertRaisesRegex(RuntimeError, "returned a value of type int but"):
13733 cu = torch.jit.CompilationUnit(code)
13734
Michael Suo755f91b2019-08-19 18:41:08 -070013735 @_inline_everything
Elias Ellisonca76c822019-07-26 16:36:55 -070013736 def test_early_return_fork_join(self):
13737 @torch.jit.script
13738 def foo(x):
13739 if x.dim() == 2:
13740 return torch.neg(x), x
13741 else:
13742 return torch.neg(x), x + 1
13743
13744 x = torch.rand(3, 4)
13745
13746 @torch.jit.script
13747 def wait_script(x):
13748 fut = torch.jit._fork(foo, x)
13749 y_hat = foo(x)
13750 y = torch.jit._wait(fut)
13751 return y, y_hat
13752
Shen Li10224432021-08-12 11:39:31 -070013753 FileCheck().check("with prim::fork").check("prim::If").check("return")\
13754 .run(wait_script.graph)
Elias Ellisonca76c822019-07-26 16:36:55 -070013755
13756 def test_early_return_type_refinement(self):
13757 @torch.jit.script
13758 def test(x):
13759 # type: (Optional[int]) -> int
13760 if x is None:
13761 return 1
13762 else:
13763 return x
13764 self.assertEqual(test(None), 1)
13765 self.assertEqual(test(2), 2)
13766
Elias Ellisoned4ee092019-08-07 08:54:14 -070013767 def test_exceptions_with_control_flow(self):
13768 def test_num_ifs(func, num_ifs):
13769 g = torch.jit.script(func).graph
13770 FileCheck().check_count("prim::If", num_ifs, exactly=True).run(g)
13771
13772 def no_guard_ifs_added(x):
13773 # type: (int) -> int
13774 if x == 1:
13775 return 1
13776 else:
13777 if x == 2:
13778 raise RuntimeError("hi")
13779 else:
13780 raise RuntimeError("hi")
13781
13782 self.checkScript(no_guard_ifs_added, (1,))
13783 self.checkScriptRaisesRegex(no_guard_ifs_added, (2,), Exception, "")
13784 test_num_ifs(no_guard_ifs_added, 2)
13785
13786 # FUNCTION LOOKS LIKE:
13787 # graph(%x.1 : int):
13788 # %7 : str = prim::Constant[value="Exception"]()
13789 # %2 : int = prim::Constant[value=1]()
13790 # %5 : int = prim::Constant[value=2]()
13791 # %19 : int = prim::Uninitialized()
13792 # %3 : bool = aten::eq(%x.1, %2)
13793 # %20 : int = prim::If(%3)
13794 # block0():
13795 # -> (%2)
13796 # block1():
13797 # %6 : bool = aten::eq(%x.1, %5)
13798 # = prim::If(%6)
13799 # block0():
13800 # = prim::RaiseException(%7)
13801 # -> ()
13802 # block1():
13803 # = prim::RaiseException(%7)
13804 # -> ()
13805 # -> (%19)
13806 # return (%20)
13807
13808 def no_ifs_added(x):
13809 # type: (int) -> int
13810 if x < 0:
Yanan Caobdcf3202020-08-01 13:02:20 -070013811 raise RuntimeError("hi")
Elias Ellisoned4ee092019-08-07 08:54:14 -070013812 return x
13813
13814 self.checkScript(no_ifs_added, (1,))
13815 self.checkScriptRaisesRegex(no_ifs_added, (-2,), Exception, "")
13816 test_num_ifs(no_ifs_added, 1)
13817
13818 def test_if_might(x):
13819 # type: (int)
13820 if x > 0:
13821 if x == 1:
13822 return 1
13823 else:
13824 a = 2
13825 else:
Yanan Caobdcf3202020-08-01 13:02:20 -070013826 raise RuntimeError("hi")
Elias Ellisoned4ee092019-08-07 08:54:14 -070013827 return a + 2
13828
13829 self.checkScript(test_if_might, (1,))
13830 self.checkScript(test_if_might, (3,))
13831 self.checkScriptRaisesRegex(no_ifs_added, (-2,), Exception, "")
13832 test_num_ifs(test_if_might, 3) # one if added to guard a + 2
13833
13834 def test_loop_no_escape(x):
13835 # type: (int)
13836 if x >= 0:
13837 for i in range(x):
Yanan Caobdcf3202020-08-01 13:02:20 -070013838 raise RuntimeError("hi")
Elias Ellisoned4ee092019-08-07 08:54:14 -070013839 else:
13840 return 5
13841 return x + 3
13842
13843 self.checkScript(test_loop_no_escape, (0,))
13844 self.checkScript(test_loop_no_escape, (-1,))
13845 self.checkScriptRaisesRegex(test_loop_no_escape, (1,), Exception, "")
13846
Elias Ellisondaa85cf2020-05-15 12:20:13 -070013847 # if guard gets optimized away
13848 test_num_ifs(test_loop_no_escape, 1)
Elias Ellisoned4ee092019-08-07 08:54:14 -070013849
13850 def test_loop_exception_with_continue(x):
13851 # type: (int)
13852 i = 0
13853 for i in range(5):
13854 if i == x:
Yanan Caobdcf3202020-08-01 13:02:20 -070013855 raise RuntimeError("hi")
Elias Ellisoned4ee092019-08-07 08:54:14 -070013856 else:
13857 continue
13858 print(i)
13859 return i + 5
13860
13861 self.checkScript(test_loop_exception_with_continue, (-1,))
Shen Li10224432021-08-12 11:39:31 -070013862 self.checkScriptRaisesRegex(test_loop_exception_with_continue, (1,), Exception, "")
13863 test_num_ifs(test_loop_exception_with_continue, 1) # no ifs added to guard print
13864
Elias Ellisoned4ee092019-08-07 08:54:14 -070013865
13866 def test_exception_exits_closure(self):
Shen Li10224432021-08-12 11:39:31 -070013867 code = dedent('''
Elias Ellisoned4ee092019-08-07 08:54:14 -070013868 def no_return_func(self):
13869 # type: (Tensor) -> Tensor
13870 output = torch.tanh(self)
13871 def backward(grad_output):
Yanan Caobdcf3202020-08-01 13:02:20 -070013872 raise RuntimeError("Hi")
Shen Li10224432021-08-12 11:39:31 -070013873 ''')
Elias Ellisoned4ee092019-08-07 08:54:14 -070013874 with self.assertRaisesRegex(RuntimeError, "does not return along all"):
13875 cu = torch.jit.CompilationUnit(code)
13876
Shen Li10224432021-08-12 11:39:31 -070013877 code = dedent('''
Elias Ellisoned4ee092019-08-07 08:54:14 -070013878 def test_exit_pair_reset(x):
13879 # type: (int) -> int
13880 if x > 0:
13881 a = 0
13882 def backward(grad_output):
Yanan Caobdcf3202020-08-01 13:02:20 -070013883 raise RuntimeError("Hi")
Elias Ellison011db3b2019-08-27 22:15:27 -070013884 a = a + 1
Elias Ellisoned4ee092019-08-07 08:54:14 -070013885 else:
13886 return x
13887 return a + 1
Shen Li10224432021-08-12 11:39:31 -070013888 ''')
Elias Ellisoned4ee092019-08-07 08:54:14 -070013889 func = torch.jit.CompilationUnit(code).test_exit_pair_reset
Shen Li10224432021-08-12 11:39:31 -070013890 self.assertEqual(func(1,), 2)
13891 self.assertEqual(func(-1,), -1)
Elias Ellisondaa85cf2020-05-15 12:20:13 -070013892 # final a + 1 gets inlined into the first branch and optimized away
13893 FileCheck().check_count("prim::If", 1, exactly=True).run(func.graph)
Elias Ellisoned4ee092019-08-07 08:54:14 -070013894
Elias Ellisonca76c822019-07-26 16:36:55 -070013895 def test_non_final_return(self):
Zachary DeVito6bf05bf2018-12-21 13:46:12 -080013896 def simple(x):
13897 if bool(x > 3):
13898 return x + 1
13899 else:
13900 return x + 2
13901 raise RuntimeError("nope")
13902
13903 def nest(x):
13904 x = x + 1
13905 if bool(x > 3):
13906 if bool(x > 4):
13907 x += 1
13908 return x + 1
13909 else:
13910 return x + 2
13911
13912 def early_ret(x):
13913 x = x + 1
13914 if bool(x > 3):
13915 return x + 1
13916 x = x + 1
13917 return x + 2
13918
13919 def nest_early_ret(x):
13920 x = x + 1
13921 if bool(x > 3):
13922 if bool(x > 4):
13923 return x + 2
13924 return x + 1
13925 x = x + 1
13926 return x + 2
13927
Elias Ellisonca76c822019-07-26 16:36:55 -070013928 def not_early_ret(x):
13929 s = ""
13930 if bool(x > 3):
13931 if bool(x > 4):
13932 return 1, s
13933 s += "foo"
13934 else:
13935 s += "5"
13936 s += "hi"
13937 return 7, s
Zachary DeVito6bf05bf2018-12-21 13:46:12 -080013938
Elias Ellisonca76c822019-07-26 16:36:55 -070013939 def not_total_ret(x):
13940 s = ""
13941 if bool(x > 3):
13942 if bool(x > 4):
13943 return 1, s
Zachary DeVito6bf05bf2018-12-21 13:46:12 -080013944 else:
Elias Ellisonca76c822019-07-26 16:36:55 -070013945 return 2, s
13946 else:
13947 s += "5"
13948 return 7, s
Zachary DeVito6bf05bf2018-12-21 13:46:12 -080013949
Elias Ellisonca76c822019-07-26 16:36:55 -070013950 for i in range(3):
Shen Li10224432021-08-12 11:39:31 -070013951 for func in [simple, nest, early_ret, nest_early_ret, not_early_ret,
13952 not_total_ret]:
Elias Ellisonca76c822019-07-26 16:36:55 -070013953 self.checkScript(func, (torch.tensor(2.5 + i),))
13954
13955 def vars_used_after_ret(x):
13956 # type: (int) -> int
13957 if x == 0:
13958 return x
13959 else:
13960 y = 2
13961 z = 3
13962 return x + y * z
13963
13964 self.checkScript(vars_used_after_ret, (1,))
13965 self.checkScript(vars_used_after_ret, (0,))
13966
13967 def complicated(x):
13968 # type: (int) -> int
13969 if x:
13970 if x == 2:
13971 return 1
13972 assert 1 == 2
13973 else:
13974 if x == 3:
Zachary DeVito6bf05bf2018-12-21 13:46:12 -080013975 return 2
Elias Ellisonca76c822019-07-26 16:36:55 -070013976 assert 1 == 2
13977 else:
13978 a = 2
13979 b = 3
13980 else:
13981 a = 4
13982 b = 1
13983 return a + b
13984 assert 1 == 2
13985
13986 for i in range(4):
13987 self.checkScript(complicated, (i,))
13988
Elias Ellisonca76c822019-07-26 16:36:55 -070013989 def test_partial_returns(self):
13990 with self.assertRaisesRegex(RuntimeError, "does not return along all"):
13991 @torch.jit.script
13992 def no_ret():
13993 # type: () -> int
13994 pass
13995
13996 with self.assertRaisesRegex(RuntimeError, "does not return along all"):
13997 @torch.jit.script
Sam Estepe3900d22021-04-19 13:14:27 -070013998 def partial(x):
Elias Ellisonca76c822019-07-26 16:36:55 -070013999 # type: (Tensor) -> int
14000 if x:
14001 return 1
14002
14003 with self.assertRaisesRegex(RuntimeError, "does not return along all"):
14004 @torch.jit.script
Sam Estepe3900d22021-04-19 13:14:27 -070014005 def typed_none():
Elias Ellisonca76c822019-07-26 16:36:55 -070014006 # type: () -> Optional[int]
14007 pass
14008
14009 @torch.jit.script
14010 def none_ret():
14011 pass
14012
14013 self.assertIs(none_ret(), None)
14014 FileCheck().check(": None").run(none_ret.graph)
14015
14016 def test_early_returns_loops(self):
14017 def nest_while_ret(x):
14018 # type: (int) -> int
14019 y = 4
14020 while x < 4:
14021 if x < 3:
14022 return y
Zachary DeVito6bf05bf2018-12-21 13:46:12 -080014023 else:
Elias Ellisonca76c822019-07-26 16:36:55 -070014024 y = y + 1
14025 break
14026 y = y + 2
14027 y = y + 1
14028 return y
Zachary DeVito6bf05bf2018-12-21 13:46:12 -080014029
Elias Ellisonca76c822019-07-26 16:36:55 -070014030 self.checkScript(nest_while_ret, (2,))
14031 self.checkScript(nest_while_ret, (3,))
14032 self.checkScript(nest_while_ret, (4,))
Zachary DeVito6bf05bf2018-12-21 13:46:12 -080014033
Elias Ellisonca76c822019-07-26 16:36:55 -070014034 def loop_ret(x, y):
14035 # type: (int, int) -> (int)
14036 i = 0
14037 for i in range(x):
14038 if x == y:
14039 return x + y
14040 i = i + y
14041 i = i - 1
14042 return i
14043
14044 self.checkScript(loop_ret, (3, 3))
14045 self.checkScript(loop_ret, (2, 3))
14046 self.checkScript(loop_ret, (3, 1))
14047
14048 def test_will_ret(y):
14049 # type: (int) -> int
14050 for i in range(y):
14051 return 2
14052 return 1
14053
14054 self.checkScript(test_will_ret, (0,))
14055 self.checkScript(test_will_ret, (1,))
14056
14057 def test_loop_nest_ret(y):
14058 # type: (int) -> int
14059 for i in range(y):
14060 for i in range(y - 2):
14061 return 10
Zachary DeVito6bf05bf2018-12-21 13:46:12 -080014062 return 5
Elias Ellisonca76c822019-07-26 16:36:55 -070014063 return 0
14064
14065 self.checkScript(test_loop_nest_ret, (0,))
14066 self.checkScript(test_loop_nest_ret, (1,))
14067 self.checkScript(test_loop_nest_ret, (2,))
Zachary DeVito6bf05bf2018-12-21 13:46:12 -080014068
davidriazatic08f3d02019-04-24 16:44:29 -070014069 def test_nn_init(self):
14070 tests = (
Shen Li10224432021-08-12 11:39:31 -070014071 ('constant_', (lambda: (torch.ones(2, 2), 2.5)), "Tensor, float"),
14072 ('ones_', (lambda: (torch.ones(2, 2),)), "Tensor"),
14073 ('zeros_', (lambda: (torch.ones(2, 2),)), "Tensor"),
14074 ('uniform_', (lambda: (torch.ones(2, 2),)), "Tensor"),
14075 ('normal_', (lambda: (torch.ones(2, 2),)), "Tensor"),
14076 ('xavier_normal_', (lambda: (torch.ones(2, 2),)), "Tensor"),
14077 ('xavier_uniform_', (lambda: (torch.ones(2, 2),)), "Tensor"),
davidriazatic08f3d02019-04-24 16:44:29 -070014078 )
14079
14080 for name, args_fn, type_str in tests:
14081 # Build test code
Shen Li10224432021-08-12 11:39:31 -070014082 arg_str = ', '.join([chr(i + ord('a')) for i in range(len(args_fn()))])
davidriazatic08f3d02019-04-24 16:44:29 -070014083
Shen Li10224432021-08-12 11:39:31 -070014084 code = dedent('''
davidriazatic08f3d02019-04-24 16:44:29 -070014085 def test({arg_str}):
14086 # type: ({type_str})
14087 return torch.nn.init.{name}({arg_str})
Shen Li10224432021-08-12 11:39:31 -070014088 ''').format(arg_str=arg_str, type_str=type_str, name=name)
davidriazatic08f3d02019-04-24 16:44:29 -070014089 cu = torch.jit.CompilationUnit(code)
14090
14091 # Compare functions
14092 init_fn = getattr(torch.nn.init, name)
14093 script_out = self.runAndSaveRNG(cu.test, args_fn())
14094 eager_out = self.runAndSaveRNG(init_fn, args_fn())
14095 self.assertEqual(script_out, eager_out)
14096
14097 FileCheck().check_not("prim::PythonOp").run(cu.test.graph)
14098
Elias Ellisondaa85cf2020-05-15 12:20:13 -070014099 def test_early_return_rewrite(self):
14100 def test_foo(x: bool):
14101 if x:
14102 return 1
14103 return 2
14104
14105 self.checkScript(test_foo, (True,))
14106 self.checkScript(test_foo, (False,))
Shen Li10224432021-08-12 11:39:31 -070014107 FileCheck().check_count("prim::If", 1, exactly=True).run(torch.jit.script(test_foo).graph)
Elias Ellisondaa85cf2020-05-15 12:20:13 -070014108
14109 def test_multiple(x: int):
14110 if x == 5:
14111 return x * x
14112 else:
14113 y = 2 * x
14114
14115 z = y * 2
14116 if z == 8:
14117 return 1
14118
14119 if z != 16:
14120 z = z - 2
14121 abc = 4
14122 else:
14123 return 3
14124
14125 z = z * abc
14126 return z * z * z
14127
14128 self.checkScript(test_multiple, (5,))
14129 self.checkScript(test_multiple, (2,))
14130 self.checkScript(test_multiple, (4,))
14131 self.checkScript(test_multiple, (3,))
14132 self.checkScript(test_multiple, (10,))
14133
14134 graph = torch.jit.script(test_multiple).graph
14135 FileCheck().check_count("prim::If", 3, exactly=True).run(graph)
Elias Ellisondaa85cf2020-05-15 12:20:13 -070014136
Elias Ellison040bc1d2020-01-31 18:19:53 -080014137 def test_is_scripting_metacompile(self):
14138 @torch.jit.script
14139 def foo():
14140 if torch.jit.is_scripting():
14141 return 1
14142 else:
Yuxin Wu9a007ba2020-09-17 18:07:07 -070014143 print("hello") + 2 # will not be compiled
Elias Ellison040bc1d2020-01-31 18:19:53 -080014144
14145 self.assertEqual(foo(), 1)
14146
Elias Ellisond1b8da72020-11-20 11:14:59 -080014147 def test_boolean_literal_constant_metacompile(self):
14148 class Mod(torch.nn.Module):
Shen Li10224432021-08-12 11:39:31 -070014149 __constants__ = ['val']
Elias Ellisond1b8da72020-11-20 11:14:59 -080014150
14151 def __init__(self, val):
Xuehai Pan046e88a2023-02-12 22:20:50 +000014152 super().__init__()
Elias Ellisond1b8da72020-11-20 11:14:59 -080014153 self.val = val
14154
14155 def forward(self):
14156 if self.val:
14157 return 1
14158 else:
14159 return "2"
14160
14161 self.checkModule(Mod(True), ())
14162 self.checkModule(Mod(False), ())
14163
14164 @torch.jit.script
14165 def foo():
14166 if True:
14167 return 1
14168 else:
14169 return "2"
14170
14171 self.assertEqual(foo(), 1)
14172
Yuxin Wu9a007ba2020-09-17 18:07:07 -070014173 def test_assert_is_scripting_metacompile(self):
14174 def foo():
14175 assert not torch.jit.is_scripting(), "TestErrorMsg"
14176 print("hello") + 2 # will not be compiled
14177
14178 f = torch.jit.script(foo)
14179 with self.assertRaisesRegex(torch.jit.Error, "TestErrorMsg"):
14180 f()
14181
Elias Ellison9ecc33d2019-08-07 19:11:55 -070014182 def test_isinstance_metacompile(self):
14183 @torch.jit.script
14184 def test_primitive_type(x):
14185 # type: (int) -> int
14186 if isinstance(x, int):
14187 return x + 1
14188 else:
14189 return x - 1
14190
14191 self.assertEqual(test_primitive_type(1), 2)
14192 with self.assertRaisesRegex(Exception, "Expected a value of type"):
14193 test_primitive_type(1.5)
14194
Shen Li10224432021-08-12 11:39:31 -070014195 _MyNamedTuple = namedtuple('_MyNamedTuple', ['value'])
Elias Ellison9ecc33d2019-08-07 19:11:55 -070014196
14197 @torch.jit.script
14198 def test_non_primitive_types(x):
14199 # type: (_MyNamedTuple) -> Tensor
14200 if isinstance(1, _MyNamedTuple):
14201 return 10
14202
14203 if isinstance(x, _MyNamedTuple):
14204 return x.value + 1
14205 else:
14206 return 1
14207
14208 out = test_non_primitive_types(_MyNamedTuple(value=torch.tensor(5.0)))
14209 self.assertEqual(out, torch.tensor(6.0))
14210
tmanlaibaatarfee585b2020-10-29 11:58:10 -070014211 def test_namedtuple_type_inference(self):
Shen Li10224432021-08-12 11:39:31 -070014212 _AnnotatedNamedTuple = NamedTuple('_NamedTupleAnnotated', [('value', int)])
14213 _UnannotatedNamedTuple = namedtuple('_NamedTupleUnAnnotated', ['value'])
tmanlaibaatarfee585b2020-10-29 11:58:10 -070014214
14215 def test_check_named_tuple_value():
14216 named_tuple = _AnnotatedNamedTuple(1)
14217 return named_tuple.value
14218
14219 self.checkScript(test_check_named_tuple_value, ())
14220
14221 def test_error():
14222 return _UnannotatedNamedTuple(1)
14223
Shen Li10224432021-08-12 11:39:31 -070014224 with self.assertRaisesRegex(RuntimeError, r"Expected a value of type \'Tensor \(inferred\)\' "
14225 r"for argument \'value\' but instead found type \'int\'."):
tmanlaibaatarfee585b2020-10-29 11:58:10 -070014226 torch.jit.script(test_error)
14227
Ansley Ussery0fbc4712021-06-26 15:17:10 -070014228 def test_namedtuple_default_values_simple_type(self):
Shen Li10224432021-08-12 11:39:31 -070014229
Ansley Ussery0fbc4712021-06-26 15:17:10 -070014230 class Point(NamedTuple):
14231 x: Optional[int] = None
14232 y: int = 2
14233
14234 make_global(Point)
14235
14236 class M(torch.nn.Module):
Ansley Ussery0fbc4712021-06-26 15:17:10 -070014237 def forward(self, point: Point):
14238 return point
14239
14240 p = Point(x=3, y=2)
14241
14242 self.checkModule(M(), (p,))
14243 self.checkModule(M(), (Point(),))
14244
14245 m = torch.jit.script(M())
14246
Shen Li10224432021-08-12 11:39:31 -070014247 FileCheck().check(r"NamedTuple(x : int? = None, y : int = 2))") \
14248 .run(m.graph)
Ansley Ussery0fbc4712021-06-26 15:17:10 -070014249
14250 def test_namedtuple_default_values_missing(self):
Shen Li10224432021-08-12 11:39:31 -070014251
Ansley Ussery0fbc4712021-06-26 15:17:10 -070014252 class Point(NamedTuple):
14253 x: Optional[int]
14254 y: int
14255 z: int = 3
14256
14257 make_global(Point)
14258
14259 class M(torch.nn.Module):
Ansley Ussery0fbc4712021-06-26 15:17:10 -070014260 def forward(self, point: Point):
14261 return point
14262
14263 p1 = Point(x=3, y=2)
14264 p2 = Point(x=3, y=2, z=1)
14265
14266 self.checkModule(M(), (p1,))
14267 self.checkModule(M(), (p2,))
14268
14269 m = torch.jit.script(M())
14270
Shen Li10224432021-08-12 11:39:31 -070014271 FileCheck().check(r"NamedTuple(x : int?, y : int, z : int = 3))") \
14272 .run(m.graph)
Ansley Ussery0fbc4712021-06-26 15:17:10 -070014273
14274 def test_namedtuple_default_values_container_type(self):
Shen Li10224432021-08-12 11:39:31 -070014275
Ansley Ussery0fbc4712021-06-26 15:17:10 -070014276 class Point(NamedTuple):
14277 x: Optional[List[int]] = None
14278 y: List[int] = [1, 2, 3]
14279 z: Optional[Dict[str, int]] = {"a": 1}
14280
14281 make_global(Point)
14282
14283 class M(torch.nn.Module):
Ansley Ussery0fbc4712021-06-26 15:17:10 -070014284 def forward(self, point: Point):
14285 return point
14286
14287 p = Point(x=[4, 5, 6], y=[3, 2, 1], z={"b": 2})
14288
14289 self.checkModule(M(), (p,))
14290 self.checkModule(M(), (Point(),))
14291
14292 m = torch.jit.script(M())
14293
Shen Li10224432021-08-12 11:39:31 -070014294 first_line = r"NamedTuple(x : int[]? = None, y : int[] = " \
14295 r"[1, 2, 3], z : Dict(str, int)? = {a: 1}))"
Ansley Ussery0fbc4712021-06-26 15:17:10 -070014296
Shen Li10224432021-08-12 11:39:31 -070014297 FileCheck().check(first_line) \
14298 .run(m.graph)
Ansley Ussery0fbc4712021-06-26 15:17:10 -070014299
14300 def test_namedtuple_default_values_Tensor_type(self):
Shen Li10224432021-08-12 11:39:31 -070014301
Ansley Ussery0fbc4712021-06-26 15:17:10 -070014302 class Point(NamedTuple):
14303 x: torch.Tensor = torch.rand(2, 3)
14304
14305 make_global(Point)
14306
14307 class M(torch.nn.Module):
Ansley Ussery0fbc4712021-06-26 15:17:10 -070014308 def forward(self, point: Point):
14309 return point
14310
14311 p = Point(x=torch.rand(2, 3))
14312
Shen Li10224432021-08-12 11:39:31 -070014313 with self.assertRaisesRegex(RuntimeError, "Tensors are not "
14314 "supported as default NamedTuple "
14315 "fields"):
Ansley Ussery0fbc4712021-06-26 15:17:10 -070014316 m = torch.jit.script(M())
14317 m(p)
14318
Ansley Ussery0fbc4712021-06-26 15:17:10 -070014319 def test_namedtuple_default_values_using_factory_constructor(self):
14320 Pair = namedtuple("Pair", ["x", "y"], defaults=(1, 2))
14321
14322 make_global(Pair)
14323
14324 @torch.jit.script
14325 def fn(x: Pair) -> Pair:
14326 return x
14327
14328 # TODO: We can't use `checkScript` with the NamedTuple factory
14329 # constructor. Using the factory constructor with TorchScript
14330 # TorchScript creates an anonymous `NamedTuple` class instead of
14331 # preserving the actual name. For example, the actual generated
14332 # signature in this case is:
14333 # graph(%x.1 : NamedTuple(x : Tensor, y : Tensor))
14334 # It looks like similar test cases have had this issue as well
14335 # (see: `test_namedtuple_python`).
Shen Li10224432021-08-12 11:39:31 -070014336 FileCheck().check(r"NamedTuple(x : Tensor = 1, y : Tensor = 2))") \
14337 .check_next(r"return (%x.1)") \
14338 .run(fn.graph)
Ansley Ussery0fbc4712021-06-26 15:17:10 -070014339
Zachary DeVitobecf0802019-10-01 16:37:34 -070014340 def test_isinstance_dynamic(self):
14341 @torch.jit.script
14342 def foo(a):
14343 # type: (Optional[List[int]]) -> int
14344 b = 0
14345 if isinstance(a, (int, (float,), list, str)):
14346 b += 1
14347 if isinstance(a, (int, str)):
14348 b += 1
14349 if isinstance(a, List[int]):
14350 b += 1
14351 return b
14352 self.assertEqual(foo([3, 4]), 2)
14353 self.assertEqual(foo(None), 0)
14354
Elias Ellison451fc512019-08-07 19:11:55 -070014355 def test_function_overloads(self):
14356 # TODO: pyflakes currently does not compose @overload annotation with other
14357 # decorators. This is fixed on master but not on version 2.1.1.
14358 # Next version update remove noqa and add @typing.overload annotation
14359
14360 @torch.jit._overload # noqa: F811
14361 def test_simple(x1): # noqa: F811
14362 # type: (int) -> int
14363 pass
14364
14365 @torch.jit._overload # noqa: F811
14366 def test_simple(x1): # noqa: F811
14367 # type: (float) -> float
14368 pass
14369
14370 def test_simple(x1): # noqa: F811
Elias Ellison56de8852019-12-12 07:52:00 -080014371 return x1
Elias Ellison451fc512019-08-07 19:11:55 -070014372
14373 def invoke_function():
Shen Li10224432021-08-12 11:39:31 -070014374 return test_simple(1.0), test_simple(.5)
Elias Ellison451fc512019-08-07 19:11:55 -070014375
14376 self.checkScript(invoke_function, ())
14377
14378 # testing that the functions are cached
Michael Suoc93e96f2020-07-08 11:35:52 -070014379 compiled_fns_1 = torch.jit._script._get_overloads(test_simple)
14380 compiled_fns_2 = torch.jit._script._get_overloads(test_simple)
Elias Ellison451fc512019-08-07 19:11:55 -070014381 for a, b in zip(compiled_fns_1, compiled_fns_2):
Elias Ellison56de8852019-12-12 07:52:00 -080014382 self.assertIs(a.graph, b.graph)
14383
14384 old_func = test_simple
14385
14386 # testing that new functions added work with caching
14387 @torch.jit._overload # noqa: F811
14388 def test_simple(x1): # noqa: F811
14389 # type: (str) -> str
14390 pass
14391
14392 @torch.jit.script
14393 def my_func():
14394 return old_func("hi")
14395
14396 # testing new function same qualified name
14397 @torch.jit._overload # noqa: F811
14398 def test_simple(a, b): # noqa: F811
14399 # type: (int, int) -> int
14400 pass
14401
14402 def test_simple(a, b):
14403 return a + b
14404
14405 @torch.jit.script
14406 def fn():
14407 return test_simple(3, 4)
14408
14409 self.assertEqual(fn(), 7)
Elias Ellison451fc512019-08-07 19:11:55 -070014410
14411 # currently we take the default values have to be specified in the
14412 # overload as well - TODO take them from implementation and apply
14413 # where the type is valid.
14414 @torch.jit._overload # noqa: F811
14415 def identity(x1): # noqa: F811
14416 # type: (str) -> str
14417 pass
Edward Yangda2004e2020-06-04 12:53:53 -070014418
Elias Ellison451fc512019-08-07 19:11:55 -070014419 @torch.jit._overload # noqa: F811
Elias Ellison56de8852019-12-12 07:52:00 -080014420 def identity(x1): # noqa: F811
Elias Ellison451fc512019-08-07 19:11:55 -070014421 # type: (float) -> float
14422 pass
14423
14424 def identity(x1=1.0): # noqa: F811
14425 return x1
14426
14427 def invoke():
Shen Li10224432021-08-12 11:39:31 -070014428 return identity(), identity(.5), identity("hi")
Elias Ellison451fc512019-08-07 19:11:55 -070014429
14430 self.checkScript(invoke, ())
14431
14432 def schema_match_failure():
14433 return identity((1, 2))
14434
14435 thrown = False
14436 try:
14437 torch.jit.script(schema_match_failure)
14438 except Exception as e:
14439 thrown = True
14440 self.assertTrue(r"of type 'str'" in str(e) and r"of type 'float" in str(e))
14441 self.assertTrue(thrown)
14442
14443 with self.assertRaisesRegex(Exception, "cannot be directly compiled"):
14444 torch.jit.script(identity)
14445
14446 @torch.jit._overload # noqa: F811
14447 def impl_compile_failure(x, y): # noqa: F811
14448 # type: (str, str) -> (str)
14449 pass
14450
14451 @torch.jit._overload # noqa: F811
14452 def impl_compile_failure(x, y): # noqa: F811
14453 # type: (int, int) -> (int)
14454 pass
14455
14456 def impl_compile_failure(x, y): # noqa: F811
14457 return x - y
14458
14459 def test():
14460 impl_compile_failure("one", "two")
14461
Shen Li10224432021-08-12 11:39:31 -070014462
Elias Ellison451fc512019-08-07 19:11:55 -070014463 with self.assertRaisesRegex(Exception, "Arguments for call are not valid"):
14464 torch.jit.script(test)
14465
Elias Ellison56de8852019-12-12 07:52:00 -080014466 @torch.jit._overload # noqa: F811
14467 def good_overload(x=1): # noqa: F811
14468 # type: (int) -> (int)
14469 pass
14470
14471 def good_overload(x=1): # noqa: F811
14472 return x
14473
14474 @torch.jit.script
14475 def foo():
14476 return good_overload()
14477
14478 self.assertEqual(foo(), 1)
14479
Zsolt Dollensteinb0043072021-08-12 10:56:55 -070014480
Shen Li10224432021-08-12 11:39:31 -070014481 with self.assertRaisesRegex(Exception, "must equal to the default parameter"):
Elias Ellison56de8852019-12-12 07:52:00 -080014482 @torch.jit._overload # noqa: F811
14483 def bad_default_on_overload(x, y=2): # noqa: F811
14484 # type: (int, int) -> (int)
14485 pass
14486
14487 def bad_default_on_overload(x, y=1): # noqa: F811
14488 # type: (int, int) -> (int)
14489 pass
14490
14491 @torch.jit.script
14492 def test():
14493 return bad_default_on_overload(1, 2)
14494
14495 @torch.jit._overload # noqa: F811
14496 def diff_default(x): # noqa: F811
14497 # type: (int) -> int
14498 pass
14499
14500 @torch.jit._overload # noqa: F811
14501 def diff_default(x): # noqa: F811
14502 # type: (str) -> str
14503 pass
14504
14505 def diff_default(x="hi"): # noqa: F811
14506 return x
14507
14508 def test():
14509 return diff_default(), diff_default(2), diff_default("abc")
14510
14511 self.assertEqual(test(), torch.jit.script(test)())
14512
14513 @torch.jit._overload # noqa: F811
14514 def diff_num_params(x): # noqa: F811
14515 # type: (float) -> float
14516 pass
14517
14518 @torch.jit._overload # noqa: F811
14519 def diff_num_params(x, y): # noqa: F811
14520 # type: (int, int) -> int
14521 pass
14522
14523 def diff_num_params(x, y=2, z=3): # noqa: F811
14524 # type: (Union[float, int], int, int)
14525 return x + y + z
14526
14527 def test():
Shen Li10224432021-08-12 11:39:31 -070014528 return diff_num_params(1.0), diff_num_params(1, 2), diff_num_params(1), diff_num_params(1, 2, 3)
Elias Ellison56de8852019-12-12 07:52:00 -080014529
14530 self.assertEqual(test(), torch.jit.script(test)())
14531
14532 @torch.jit._overload # noqa: F811
14533 def diff_num_params_no_annot():
14534 # type: () -> int
14535 pass
14536
Shen Li10224432021-08-12 11:39:31 -070014537 def diff_num_params_no_annot(x=1): # noqa: F811
Elias Ellison56de8852019-12-12 07:52:00 -080014538 return x
14539
14540 def test():
14541 return diff_num_params_no_annot(1.0)
14542
14543 with self.assertRaisesRegex(Exception, "Parameters not specified"):
14544 torch.jit.script(test)
14545
Zhengxu Chene62189a2021-08-05 14:19:56 -070014546 def test_function_overload_misuse(self):
Shen Li10224432021-08-12 11:39:31 -070014547 with self.assertRaisesRegex(RuntimeError, "Only `pass` statement or `...` can be the body"):
Zhengxu Chene62189a2021-08-05 14:19:56 -070014548 @torch.jit._overload
14549 def wrong_decl_body(x: str) -> str:
14550 return x + "0"
14551
Shen Li10224432021-08-12 11:39:31 -070014552 with self.assertRaisesRegex(RuntimeError, "Only `pass` statement or `...` can be the body"):
Zhengxu Chene62189a2021-08-05 14:19:56 -070014553 class MyClass:
14554 @torch.jit._overload_method
14555 def method(self):
14556 return 0
14557
14558 @torch.jit._overload
Shen Li10224432021-08-12 11:39:31 -070014559 def null_overload(x: int) -> int: ... # noqa: E704
Zhengxu Chene62189a2021-08-05 14:19:56 -070014560
Michael Suo5c3529a2021-11-16 10:19:57 -080014561 @torch.jit._overload # noqa: F811
Zhengxu Chene62189a2021-08-05 14:19:56 -070014562 def null_overload(x: str) -> str: # noqa: F811
14563 pass
14564
14565 def null_overload_driver():
14566 return null_overload(0)
14567
Shen Li10224432021-08-12 11:39:31 -070014568 with self.assertRaisesRegex(RuntimeError, 'Implementation for the function ".+" is missing.'):
Zhengxu Chene62189a2021-08-05 14:19:56 -070014569 torch.jit.script(null_overload_driver)
14570
14571 class OverloadMisuse(torch.nn.Module):
Zhengxu Chene62189a2021-08-05 14:19:56 -070014572 @torch.jit._overload_method
14573 def forward(self, x: int):
14574 pass
14575
Michael Suo5c3529a2021-11-16 10:19:57 -080014576 @torch.jit._overload_method # noqa: F811
Zhengxu Chene62189a2021-08-05 14:19:56 -070014577 def forward(self, x: Tensor): # noqa: F811
14578 pass
14579
Shen Li10224432021-08-12 11:39:31 -070014580 with self.assertRaisesRegex(RuntimeError, 'Implementation for the method ".+" is missing.'):
Zhengxu Chene62189a2021-08-05 14:19:56 -070014581 m = torch.jit.script(OverloadMisuse())
14582
Shen Li10224432021-08-12 11:39:31 -070014583
James Reed71a53142021-04-15 08:42:12 -070014584 def test_script_method_torch_function_overload(self):
14585 class MyCustomTensor(torch.Tensor):
14586 pass
14587
14588 class MyCustomModule(torch.nn.Module):
14589 def forward(self, x):
14590 return torch.relu(x)
14591
14592 scripted_mod = torch.jit.script(MyCustomModule())
14593 t = torch.tensor([3.0])
14594 ref_out = scripted_mod(t)
14595
14596 t_custom = MyCustomTensor([3.0])
14597 out1 = scripted_mod(t_custom)
14598 self.assertEqual(out1, ref_out)
14599
14600 out2 = scripted_mod.forward(t_custom)
14601 self.assertEqual(out2, ref_out)
Elias Ellison56de8852019-12-12 07:52:00 -080014602
Elias Ellison451fc512019-08-07 19:11:55 -070014603 def test_function_overloading_isinstance(self):
14604 @torch.jit._overload # noqa: F811
14605 def my_conv(x, y): # noqa: F811
14606 # type: (float, str) -> (float)
14607 pass
14608
14609 @torch.jit._overload # noqa: F811
Elias Ellison56de8852019-12-12 07:52:00 -080014610 def my_conv(x, y): # noqa: F811
Elias Ellison451fc512019-08-07 19:11:55 -070014611 # type: (float, float) -> (float)
14612 pass
14613
14614 def my_conv(x, y=2.0): # noqa: F811
14615 if isinstance(y, str):
14616 if y == "hi":
14617 return 4.0 - x
14618 else:
14619 return 5.0 - x
14620 else:
14621 return 2.0 + x
14622
14623 def test_uses():
14624 return my_conv(1.5), my_conv(1.5, "hi"), my_conv(1.5, 5.0)
14625
14626 self.checkScript(test_uses, ())
14627
Elias Ellison8e3c0212019-08-20 16:45:55 -070014628 def test_method_overloading(self):
14629 class Over(torch.nn.Module):
Elias Ellison8e3c0212019-08-20 16:45:55 -070014630 @torch.jit._overload_method # noqa: F811
14631 def forward(self, x): # noqa: F811
David Riazati31de19f2019-01-23 18:11:04 -080014632 # type: (Tuple[Tensor, Tensor]) -> Tensor
Elias Ellison8e3c0212019-08-20 16:45:55 -070014633 pass
David Riazati31de19f2019-01-23 18:11:04 -080014634
Elias Ellison8e3c0212019-08-20 16:45:55 -070014635 @torch.jit._overload_method # noqa: F811
14636 def forward(self, x): # noqa: F811
David Riazati31de19f2019-01-23 18:11:04 -080014637 # type: (Tensor) -> Tensor
Elias Ellison8e3c0212019-08-20 16:45:55 -070014638 pass
14639
14640 def forward(self, x): # noqa: F811
14641 if isinstance(x, Tensor):
14642 return x + 20
14643 else:
14644 return x[0] + 5
David Riazati31de19f2019-01-23 18:11:04 -080014645
14646 class S(torch.jit.ScriptModule):
14647 def __init__(self):
Xuehai Pan046e88a2023-02-12 22:20:50 +000014648 super().__init__()
Elias Ellison8e3c0212019-08-20 16:45:55 -070014649 self.weak = Over()
David Riazati31de19f2019-01-23 18:11:04 -080014650
14651 @torch.jit.script_method
14652 def forward(self, x):
14653 return self.weak(x) + self.weak((x, x))
14654
Elias Ellison8e3c0212019-08-20 16:45:55 -070014655 s_mod = S()
David Riazati31de19f2019-01-23 18:11:04 -080014656 x = torch.ones(1)
Elias Ellison8e3c0212019-08-20 16:45:55 -070014657 self.assertEqual(s_mod(x), x + 20 + 5 + x)
David Riazati31de19f2019-01-23 18:11:04 -080014658
Elias Ellison8e3c0212019-08-20 16:45:55 -070014659 over = Over()
14660 self.assertEqual(over((x, x)), x + 5)
14661 self.assertEqual(over((x)), x + 20)
14662
14663 class Unannotated(torch.nn.Module):
Elias Ellison8e3c0212019-08-20 16:45:55 -070014664 @torch.jit._overload_method # noqa: F811
14665 def hello(self, x): # noqa: F811
14666 pass
14667
14668 @torch.jit._overload_method # noqa: F811
14669 def hello(self, x): # noqa: F811
14670 # type: (int) -> (int)
14671 pass
14672
14673 def hello(self, x): # noqa: F811
14674 return x + 3
14675
14676 def forward(self):
Shen Li10224432021-08-12 11:39:31 -070014677 return self.hello(1), self.hello(.5)
Elias Ellison8e3c0212019-08-20 16:45:55 -070014678
14679 w = Unannotated()
Shen Li10224432021-08-12 11:39:31 -070014680 with self.assertRaisesRegex(Exception, "explicitly add type annotations to overloaded functions"):
Elias Ellison8e3c0212019-08-20 16:45:55 -070014681 torch.jit.script(w)
14682
14683 class CompileOverloadError(torch.nn.Module):
Elias Ellison8e3c0212019-08-20 16:45:55 -070014684 @torch.jit._overload_method # noqa: F811
14685 def hello(self, x): # noqa: F811
14686 # type: (str) -> (int)
14687 pass
14688
14689 @torch.jit._overload_method # noqa: F811
14690 def hello(self, x): # noqa: F811
14691 # type: (int) -> (int)
14692 pass
14693
14694 def hello(self, x): # noqa: F811
14695 return x + 1
14696
14697 def forward(self):
Shen Li10224432021-08-12 11:39:31 -070014698 return self.hello("hi"), self.hello(.5)
Elias Ellison8e3c0212019-08-20 16:45:55 -070014699
14700 w = CompileOverloadError()
Shen Li10224432021-08-12 11:39:31 -070014701 with self.assertRaisesRegex(Exception, "but instead found type \'str\'"):
Elias Ellison8e3c0212019-08-20 16:45:55 -070014702 torch.jit.script(w)
14703
14704 # testing overload declared first, then non-overload
Shen Li10224432021-08-12 11:39:31 -070014705 with self.assertRaisesRegex(Exception, "Overloads are not useable when a module"):
Elias Ellison8e3c0212019-08-20 16:45:55 -070014706 class W3(torch.nn.Module):
Elias Ellison8e3c0212019-08-20 16:45:55 -070014707 @torch.jit._overload_method # noqa: F811
14708 def forward(self, x): # noqa: F811
14709 # type: (int) -> int
14710 pass
14711
14712 @torch.jit._overload_method # noqa: F811
14713 def forward(self, x): # noqa: F811
14714 # type: (Tensor) -> Tensor
14715 pass
14716
14717 def forward(self, x): # noqa: F811
14718 return x + 5
14719
14720 a = W3()
14721 b = torch.jit.script(a)
14722
14723 class W3(torch.nn.Module):
Elias Ellison8e3c0212019-08-20 16:45:55 -070014724 def forward(self, x): # noqa: F811
14725 return x + 5 + 10
14726
14727 a = W3()
14728 b = torch.jit.script(a)
14729
14730 # testing non-overload declared first, then overload
14731 class W2(torch.nn.Module):
Elias Ellison8e3c0212019-08-20 16:45:55 -070014732 def hello(self, x1, x2):
14733 return x1 + x2
14734
14735 def forward(self, x):
14736 return self.hello(x, x)
14737
14738 a = torch.jit.script(W2())
14739 self.assertEqual(a(torch.tensor(1)), torch.tensor(2))
14740
14741 class W2(torch.nn.Module):
Elias Ellison8e3c0212019-08-20 16:45:55 -070014742 @torch.jit._overload_method # noqa: F811
14743 def hello(self, x): # noqa: F811
14744 pass
14745
14746 @torch.jit._overload_method # noqa: F811
14747 def hello(self, x): # noqa: F811
14748 # type: (int) -> (int)
14749 pass
14750
14751 def hello(self, x): # noqa: F811
14752 return x + 5 + 10
14753
14754 def forward(self, x):
14755 return self.hello(1), self.hello(x)
14756
Shen Li10224432021-08-12 11:39:31 -070014757 with self.assertRaisesRegex(Exception, "Overloads are not useable when a module"):
Elias Ellison8e3c0212019-08-20 16:45:55 -070014758 a = torch.jit.script(W2())
David Riazati31de19f2019-01-23 18:11:04 -080014759
Nikolay Korovaiko5177f952022-03-31 14:51:36 -070014760 def test_narrow_copy(self):
14761 def foo(a):
14762 return a.narrow_copy(0, 0, 5)
14763
14764 self.checkScript(foo, [torch.rand(10)])
14765
Michael Suod86cc3e2019-01-02 14:32:00 -080014766 def test_select_after_chunk(self):
14767 def foo(x):
14768 chunked = torch.chunk(x, 1)
14769 foo = chunked[0]
14770 foo.add_(5)
14771 return x
14772
14773 self.checkScript(foo, [torch.rand(2, 3)])
14774
David Riazatif5435632019-04-18 11:07:45 -070014775 def test_nn_LSTM_with_layers(self):
14776 class M(torch.jit.ScriptModule):
14777 def __init__(self):
Xuehai Pan046e88a2023-02-12 22:20:50 +000014778 super().__init__()
James Reed02fd1872019-06-03 16:55:52 -070014779 self.rnn = nn.LSTM(2, 3, 2, dropout=0)
David Riazatif5435632019-04-18 11:07:45 -070014780
14781 @torch.jit.script_method
14782 def forward(self, x, lengths, h0, c0):
14783 return self.rnn(x, (h0, c0))[0]
14784
14785 class Eager(torch.nn.Module):
14786 def __init__(self):
Xuehai Pan046e88a2023-02-12 22:20:50 +000014787 super().__init__()
James Reed02fd1872019-06-03 16:55:52 -070014788 self.rnn = nn.LSTM(2, 3, 2, dropout=0)
David Riazatif5435632019-04-18 11:07:45 -070014789
14790 def forward(self, x, lengths, h0, c0):
14791 return self.rnn(x, (h0, c0))[0]
14792
Shen Li10224432021-08-12 11:39:31 -070014793 inputs = (torch.randn(1, 1, 2), torch.LongTensor([7]), torch.randn(2, 1, 3), torch.randn(2, 1, 3))
David Riazatif5435632019-04-18 11:07:45 -070014794 eager_out = self.runAndSaveRNG(lambda: Eager()(*inputs), ())[0]
14795 script_out = self.runAndSaveRNG(lambda: M()(*inputs), ())[0]
14796
14797 self.assertEqual(eager_out, script_out)
14798
David Riazati2370c982019-02-21 16:11:37 -080014799 def test_nn_LSTM(self):
14800 input = torch.nn.utils.rnn.pack_sequence([torch.randn(5, 5)])
14801
14802 class S(torch.jit.ScriptModule):
14803 def __init__(self):
Xuehai Pan046e88a2023-02-12 22:20:50 +000014804 super().__init__()
David Riazati2370c982019-02-21 16:11:37 -080014805 self.x = torch.nn.LSTM(5, 5)
14806
14807 @torch.jit.script_method
Shen Li10224432021-08-12 11:39:31 -070014808 def forward(self, input: PackedSequence) -> Tuple[PackedSequence, Tuple[torch.Tensor, torch.Tensor]]:
David Riazati2370c982019-02-21 16:11:37 -080014809 return self.x(input)
14810
14811 eager_out = self.runAndSaveRNG(lambda x: torch.nn.LSTM(5, 5)(x), (input,))[0]
14812 script_out = self.runAndSaveRNG(lambda x: S()(x), (input,))[0]
14813
14814 self.assertEqual(eager_out, script_out)
14815
Wanchao Liang9d2cc2c2019-08-01 17:12:18 -070014816 def test_nn_GRU(self):
Wanchao Liang9d2cc2c2019-08-01 17:12:18 -070014817 seq_input = torch.nn.utils.rnn.pack_sequence([torch.randn(5, 5)])
14818 tensor_input = torch.randn(5, 5, 5)
14819
14820 class SeqLengthGRU(torch.jit.ScriptModule):
14821 def __init__(self):
Xuehai Pan046e88a2023-02-12 22:20:50 +000014822 super().__init__()
Wanchao Liang9d2cc2c2019-08-01 17:12:18 -070014823 self.x = torch.nn.GRU(5, 5)
14824
14825 @torch.jit.script_method
Shen Li10224432021-08-12 11:39:31 -070014826 def forward(self, input: PackedSequence) -> Tuple[PackedSequence, torch.Tensor]:
Wanchao Liang9d2cc2c2019-08-01 17:12:18 -070014827 return self.x(input)
14828
14829 class TensorGRU(torch.jit.ScriptModule):
14830 def __init__(self):
Xuehai Pan046e88a2023-02-12 22:20:50 +000014831 super().__init__()
Wanchao Liang9d2cc2c2019-08-01 17:12:18 -070014832 self.x = torch.nn.GRU(5, 5)
14833
14834 @torch.jit.script_method
Richard Barnesec6d29d2021-01-07 12:07:49 -080014835 def forward(self, input: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
Wanchao Liang9d2cc2c2019-08-01 17:12:18 -070014836 return self.x(input)
14837
Shen Li10224432021-08-12 11:39:31 -070014838 seq_eager_out = self.runAndSaveRNG(lambda x: torch.nn.GRU(5, 5)(x), (seq_input,))[0]
14839 seq_script_out = self.runAndSaveRNG(lambda x: SeqLengthGRU()(x), (seq_input,))[0]
14840 tensor_eager_out = self.runAndSaveRNG(lambda x: torch.nn.GRU(5, 5)(x), (tensor_input,))[0]
14841 tensor_script_out = self.runAndSaveRNG(lambda x: TensorGRU()(x), (tensor_input,))[0]
Wanchao Liang9d2cc2c2019-08-01 17:12:18 -070014842
14843 self.assertEqual(seq_eager_out, seq_script_out)
14844 self.assertEqual(tensor_eager_out, tensor_script_out)
14845
Vitaly Fedyunin5f510372019-11-18 05:32:23 -080014846 def test_torchscript_memoryformat(self):
14847 @torch.jit.script
14848 def fn(x):
14849 return x.contiguous(memory_format=torch.channels_last)
14850 x = torch.randn(4, 3, 6, 6)
14851 y = fn(x)
14852 self.assertTrue(y.is_contiguous(memory_format=torch.channels_last))
Wanchao Liang9d2cc2c2019-08-01 17:12:18 -070014853
Guanheng Zhang8e3311c2019-05-27 15:09:50 -070014854 def test_torchscript_multi_head_attn(self):
14855 @torch.jit.script
Shen Li10224432021-08-12 11:39:31 -070014856 def jit_multihead_attn_forward(query, # type: Tensor
14857 key, # type: Tensor
14858 value, # type: Tensor
14859 embed_dim_to_check, # type: int
14860 num_heads, # type: int
14861 in_proj_weight, # type: Tensor
14862 in_proj_bias, # type: Tensor
14863 bias_k, # type: Optional[Tensor]
14864 bias_v, # type: Optional[Tensor]
14865 add_zero_attn, # type: bool
14866 dropout, # type: float
14867 out_proj_weight, # type: Tensor
14868 out_proj_bias, # type: Tensor
14869 training=True, # type: bool
14870 key_padding_mask=None, # type: Optional[Tensor]
14871 need_weights=True, # type: bool
14872 attn_mask=None # type: Optional[Tensor]
14873 ):
Guanheng Zhang8e3311c2019-05-27 15:09:50 -070014874 # type: (...) -> Tuple[Tensor, Optional[Tensor]]
Shen Li10224432021-08-12 11:39:31 -070014875 return torch.nn.functional.multi_head_attention_forward(query, key, value,
14876 embed_dim_to_check, num_heads,
14877 in_proj_weight, in_proj_bias,
14878 bias_k, bias_v,
14879 add_zero_attn, dropout,
14880 out_proj_weight, out_proj_bias,
14881 training, key_padding_mask,
14882 need_weights, attn_mask)
Guanheng Zhang8e3311c2019-05-27 15:09:50 -070014883
14884 src_l = 3
14885 bsz = 5
14886 embed_size = 8
14887 nhead = 2
14888 multi_head_attn = torch.nn.MultiheadAttention(embed_size, nhead)
14889 query = torch.rand((src_l, bsz, embed_size))
14890 key = torch.rand((src_l, bsz, embed_size))
14891 value = torch.rand((src_l, bsz, embed_size))
14892
14893 mask = (torch.triu(torch.ones(src_l, src_l)) == 1).transpose(0, 1)
Shen Li10224432021-08-12 11:39:31 -070014894 mask = mask.float().masked_fill(mask == 0, float('-inf')).masked_fill(mask == 1, float(0.0)).double()
Guanheng Zhang8e3311c2019-05-27 15:09:50 -070014895
Shen Li10224432021-08-12 11:39:31 -070014896 jit_out = jit_multihead_attn_forward(query, key, value,
14897 embed_size, nhead,
14898 multi_head_attn.in_proj_weight,
14899 multi_head_attn.in_proj_bias,
14900 multi_head_attn.bias_k, multi_head_attn.bias_v,
14901 multi_head_attn.add_zero_attn, multi_head_attn.dropout,
14902 multi_head_attn.out_proj.weight,
14903 multi_head_attn.out_proj.bias, attn_mask=mask)[0]
Guanheng Zhang8e3311c2019-05-27 15:09:50 -070014904
Shen Li10224432021-08-12 11:39:31 -070014905 py_out = torch.nn.functional.multi_head_attention_forward(query, key, value,
14906 embed_size, nhead,
14907 multi_head_attn.in_proj_weight,
14908 multi_head_attn.in_proj_bias,
14909 multi_head_attn.bias_k,
14910 multi_head_attn.bias_v,
14911 multi_head_attn.add_zero_attn,
14912 multi_head_attn.dropout,
14913 multi_head_attn.out_proj.weight,
14914 multi_head_attn.out_proj.bias,
14915 attn_mask=mask)[0]
Guanheng Zhang8e3311c2019-05-27 15:09:50 -070014916 # print("rel. error: ")
14917 # print(jit_out / py_out - 1)
Philip Meier57d4c6c2021-08-25 16:42:14 -070014918 self.assertEqual(jit_out, py_out, atol=5e-4, rtol=1e-4)
Guanheng Zhang8e3311c2019-05-27 15:09:50 -070014919
Scott Wolchokb182c222022-04-26 16:12:02 -070014920 def test_torchscript_multi_head_attn_fast_path(self):
14921 src_l = 3
14922 bsz = 5
14923 embed_size = 8
14924 nhead = 2
14925 multi_head_attn = torch.nn.MultiheadAttention(embed_size, nhead, batch_first=True)
14926 multi_head_attn = multi_head_attn.eval()
14927
14928 query = key = value = torch.rand((bsz, src_l, embed_size))
14929
14930 with torch.no_grad():
14931 py_out = multi_head_attn(query, key, value)
14932 mha = torch.jit.script(multi_head_attn)
14933 jit_out = mha(query, key, value)
14934 torch.testing.assert_close(jit_out, py_out)
14935
Guanheng Zhang8e3311c2019-05-27 15:09:50 -070014936 @unittest.skipIf(not RUN_CUDA, "no CUDA")
14937 def test_scriptmodule_multi_head_attn_cuda(self):
Shen Li10224432021-08-12 11:39:31 -070014938
Guanheng Zhang8e3311c2019-05-27 15:09:50 -070014939 class MyModule(torch.jit.ScriptModule):
14940 def __init__(self, embed_dim, num_heads):
Xuehai Pan046e88a2023-02-12 22:20:50 +000014941 super().__init__()
Guanheng Zhang8e3311c2019-05-27 15:09:50 -070014942 sample_q = torch.randn(3, 2, embed_dim)
14943 sample_kv = torch.randn(3, 2, embed_dim)
14944 attention = nn.MultiheadAttention(embed_dim, num_heads)
14945 attention.eval()
14946
Shen Li10224432021-08-12 11:39:31 -070014947 self.mod = torch.jit.trace(attention,
14948 (sample_q, sample_kv, sample_kv))
Guanheng Zhang8e3311c2019-05-27 15:09:50 -070014949
14950 @torch.jit.script_method
14951 def forward(self, q, k, v):
14952 return self.mod(q, k, v)
14953
14954 embed_dim = 8
14955 num_heads = 2
14956 sl = 3
14957 bs = 2
14958 model = MyModule(embed_dim, num_heads).cuda()
14959 q = torch.randn(sl, bs, embed_dim, device="cuda")
14960 kv = torch.randn(sl, bs, embed_dim, device="cuda")
14961
14962 jit_out = model(q, kv, kv)[0]
Shen Li10224432021-08-12 11:39:31 -070014963 py_out = torch.nn.functional.multi_head_attention_forward(q, kv, kv,
14964 embed_dim, num_heads,
14965 model.mod.in_proj_weight,
14966 model.mod.in_proj_bias,
14967 None, None, None, 0.0,
14968 model.mod.out_proj.weight,
14969 model.mod.out_proj.bias)[0]
Philip Meier57d4c6c2021-08-25 16:42:14 -070014970 self.assertEqual(jit_out, py_out, atol=5e-4, rtol=1e-4)
Guanheng Zhang8e3311c2019-05-27 15:09:50 -070014971
Guanheng Zhang83cec5f2019-06-12 12:01:27 -070014972 @unittest.skipIf(not RUN_CUDA, "no CUDA")
14973 def test_scriptmodule_transformer_cuda(self):
Shen Li10224432021-08-12 11:39:31 -070014974
Guanheng Zhang83cec5f2019-06-12 12:01:27 -070014975 class MyModule(torch.jit.ScriptModule):
14976 def __init__(self, transformer, sample_q, sample_kv):
Xuehai Pan046e88a2023-02-12 22:20:50 +000014977 super().__init__()
Guanheng Zhang83cec5f2019-06-12 12:01:27 -070014978 transformer.eval()
14979
Shen Li10224432021-08-12 11:39:31 -070014980 self.mod = torch.jit.trace(transformer,
14981 (sample_q, sample_kv))
Guanheng Zhang83cec5f2019-06-12 12:01:27 -070014982
14983 @torch.jit.script_method
14984 def forward(self, q, k):
14985 return self.mod(q, k)
14986
14987 d_model = 8
14988 nhead = 2
14989 num_encoder_layers = 2
14990 num_decoder_layers = 2
14991 dim_feedforward = 16
14992 bsz = 2
14993 seq_length = 5
14994 tgt_length = 3
14995
14996 src = torch.randn(seq_length, bsz, d_model)
14997 tgt = torch.randn(tgt_length, bsz, d_model)
Shen Li10224432021-08-12 11:39:31 -070014998 transformer = nn.Transformer(d_model, nhead, num_encoder_layers,
14999 num_decoder_layers, dim_feedforward, dropout=0.0)
Guanheng Zhang83cec5f2019-06-12 12:01:27 -070015000 model = MyModule(transformer, tgt, src)
15001
15002 src = torch.randn(seq_length, bsz, d_model)
15003 tgt = torch.randn(tgt_length, bsz, d_model)
15004 jit_out = model(tgt, src)
15005 py_out = transformer(tgt, src)
15006
15007 # print(jit_out/py_out-1)
15008 # print(torch.allclose(jit_out, py_out, atol=5e-4, rtol=1e-4))
Philip Meier57d4c6c2021-08-25 16:42:14 -070015009 self.assertEqual(jit_out, py_out, atol=5e-4, rtol=1e-4)
Guanheng Zhang83cec5f2019-06-12 12:01:27 -070015010
David Riazati76feb8c2019-01-07 13:49:20 -080015011 def test_list_python_op(self):
15012 def python_list_op(lst):
15013 # type: (List[Tensor]) -> Tensor
15014 return lst[0]
15015
15016 def fn(lst):
15017 # type: (List[Tensor]) -> Tensor
15018 return python_list_op(lst)
15019
15020 self.checkScript(fn, ([torch.ones(2) + 2, torch.ones(2)],))
15021
davidriazatie13b4832019-05-31 10:28:39 -070015022 @unittest.skipIf(not RUN_CUDA, "no CUDA")
15023 def test_weak_cuda(self):
15024 class M(torch.jit.ScriptModule):
15025 def __init__(self):
Xuehai Pan046e88a2023-02-12 22:20:50 +000015026 super().__init__()
davidriazatie13b4832019-05-31 10:28:39 -070015027 self.lstm = torch.nn.LSTM(5, 5)
15028 self.lstm.cuda()
15029
15030 @torch.jit.script_method
15031 def forward(self, x):
15032 return self.lstm(x)
15033
15034 m = M()
15035 m.cuda()
15036 out = m(torch.ones(5, 5, 5).cuda())
15037 self.assertTrue(out[0].is_cuda)
15038
David Riazatic865d462019-02-01 16:24:36 -080015039 def test_ignore_decorator(self):
Elias Ellison7ab4ad72019-09-09 20:22:54 -070015040 with warnings.catch_warnings(record=True) as warns:
15041 class M(torch.jit.ScriptModule):
15042 def __init__(self):
Xuehai Pan046e88a2023-02-12 22:20:50 +000015043 super().__init__()
Elias Ellison7ab4ad72019-09-09 20:22:54 -070015044 tensor = torch.zeros(1, requires_grad=False)
Shen Li10224432021-08-12 11:39:31 -070015045 self.register_buffer('some_state', torch.nn.Parameter(tensor))
David Riazatic865d462019-02-01 16:24:36 -080015046
Elias Ellison7ab4ad72019-09-09 20:22:54 -070015047 @torch.jit.script_method
15048 def forward(self, x):
15049 self.ignored_code(x)
15050 return x
David Riazatic865d462019-02-01 16:24:36 -080015051
Elias Ellison7ab4ad72019-09-09 20:22:54 -070015052 @torch.jit.ignore(drop_on_export=True)
15053 def ignored_code(self, x):
15054 self.some_state = torch.tensor((100,))
15055
David Reisse75fb432020-04-22 09:20:13 -070015056 FileCheck().check("TorchScript will now drop the function").run(str(warns[0]))
David Riazatic865d462019-02-01 16:24:36 -080015057
15058 # Assert ignored code is run
15059 m = M()
David Riazatic865d462019-02-01 16:24:36 -080015060
Zachary DeVito8c57ce82019-06-12 17:09:29 -070015061 m2 = self.getExportImportCopy(m)
15062 pp = str(m2.forward.code)
Shen Li10224432021-08-12 11:39:31 -070015063 self.assertNotIn('ignored_code', pp)
David Riazatic865d462019-02-01 16:24:36 -080015064
Shen Li10224432021-08-12 11:39:31 -070015065 with self.assertRaisesRegex(torch.jit.Error, "annotated to be ignored and cannot be run"):
Zachary DeVito8c57ce82019-06-12 17:09:29 -070015066 m2.forward(torch.ones(1))
David Riazatic865d462019-02-01 16:24:36 -080015067
Elias Ellison7ab4ad72019-09-09 20:22:54 -070015068 def test_ignored_as_value(self):
15069 class Model(nn.Module):
Elias Ellison7ab4ad72019-09-09 20:22:54 -070015070 @torch.jit.unused
15071 def tuple_ignored(self, x):
15072 # type: (Tensor) -> Tuple[Tensor, Tensor]
15073 return x, x
15074
15075 @torch.jit.unused
15076 def single_val_ignored(self, x, y):
15077 # type: (Tensor, Tensor) -> Tensor
15078 return x
15079
15080 def forward(self, x, use_ignore_path):
15081 # type: (Tensor, bool) -> Tuple[Tensor, Tensor]
Elias Ellisond1b8da72020-11-20 11:14:59 -080015082 if 1 == 2:
Elias Ellison7ab4ad72019-09-09 20:22:54 -070015083 return self.tuple_ignored(x)
15084 if use_ignore_path:
15085 return self.single_val_ignored(x, x), self.single_val_ignored(x, x)
15086 return x, x
15087
15088 original = Model()
15089 scripted = torch.jit.script(original)
Shen Li10224432021-08-12 11:39:31 -070015090 self.assertEqual(scripted(torch.tensor(.5), False), (torch.tensor(.5), torch.tensor(.5)))
Elias Ellison7ab4ad72019-09-09 20:22:54 -070015091
15092 buffer = io.BytesIO()
15093 torch.jit.save(scripted, buffer)
15094 buffer.seek(0)
15095 loaded = torch.jit.load(buffer)
15096
Shen Li10224432021-08-12 11:39:31 -070015097 with self.assertRaisesRegex(torch.jit.Error, "annotated to be ignored and cannot be run"):
15098 loaded(torch.tensor(.5), True)
Elias Ellison7ab4ad72019-09-09 20:22:54 -070015099
davidriazati48ca64d2019-07-24 14:14:38 -070015100 def test_module_error(self):
15101 class MyModule(torch.nn.Module):
davidriazati48ca64d2019-07-24 14:14:38 -070015102 def forward(self, foo):
15103 return foo
15104
Shen Li10224432021-08-12 11:39:31 -070015105 with self.assertRaisesRegex(RuntimeError, "cannot be compiled since it inherits from nn.Module"):
davidriazati48ca64d2019-07-24 14:14:38 -070015106 torch.jit.script(MyModule)
15107
Michael Suo431a34f2019-01-17 14:38:42 -080015108 def test_view_write(self):
15109 def fn(x, y):
15110 l = []
15111 l.append(x)
15112 x_view = l[0]
15113 a = x + x
15114 x_view.add_(y)
15115 b = x + x
15116 return a == b
15117 self.checkScript(fn, (torch.rand(2, 3), torch.rand(2, 3)))
15118
David Riazatia2381fa2019-03-07 10:41:13 -080015119 def test_module_attrs(self):
15120 class M(torch.jit.ScriptModule):
15121 def __init__(self, table):
Xuehai Pan046e88a2023-02-12 22:20:50 +000015122 super().__init__()
David Riazatia2381fa2019-03-07 10:41:13 -080015123 self.table = torch.jit.Attribute(table, Dict[str, torch.Tensor])
15124 self.x = torch.nn.Parameter(torch.tensor([100.0]))
15125
15126 @torch.jit.script_method
15127 def forward(self, key):
15128 # type: (str) -> Tensor
15129 return self.table[key] + self.x
15130
Michael Suoca1b8eb2020-07-13 16:57:41 -070015131 with torch._jit_internal._disable_emit_hooks():
David Riazatia2381fa2019-03-07 10:41:13 -080015132 # TODO: re-enable module hook when Python printing of attributes is
15133 # supported
Shen Li10224432021-08-12 11:39:31 -070015134 m = M({char : torch.ones(1) + ord(char) - ord("a") for char in "abcdefg"})
Sergii Dymchenko58d1cf72022-08-03 22:45:39 +000015135 self.assertEqual(m("c"), torch.tensor([103.]))
David Riazatia2381fa2019-03-07 10:41:13 -080015136
Elias Ellisonbf166882020-02-13 11:59:35 -080015137 def test_module_none_attrs(self):
15138 class MyMod(torch.jit.ScriptModule):
15139 def __init__(self):
Xuehai Pan046e88a2023-02-12 22:20:50 +000015140 super().__init__()
Elias Ellisonbf166882020-02-13 11:59:35 -080015141 self.optional_value = None
15142
15143 @torch.jit.script_method
15144 def forward(self):
15145 return self.optional_value
15146
15147 graph = MyMod().forward.graph
15148 FileCheck().check("prim::GetAttr").run(graph)
Shen Li10224432021-08-12 11:39:31 -070015149 self.run_pass('peephole', graph)
Elias Ellisonbf166882020-02-13 11:59:35 -080015150 FileCheck().check_not("prim::GetAttr").run(graph)
15151
Michael Suo2c302b62019-02-21 00:15:59 -080015152 def test_tensor_import_export(self):
15153 @torch.jit.script
15154 def foo(x):
15155 a = torch.tensor(1)
15156 b = torch.tensor([1, 2])
15157 c = [a, b]
15158 return c
15159
Shen Li10224432021-08-12 11:39:31 -070015160 self.run_pass('constant_propagation', foo.graph)
Zachary DeVito6cb1b992019-04-25 15:43:53 -070015161 m = self.createFunctionFromGraph(foo.graph)
Michael Suo2c302b62019-02-21 00:15:59 -080015162 self.getExportImportCopy(m)
15163
davidriazati00d0ddb2019-05-10 17:01:17 -070015164 def get_pickle_values(self):
Shen Li10224432021-08-12 11:39:31 -070015165 return (('dict', {"I": "am", "a test": "test"}, Dict[str, str]),
15166 ('float', 2.3, float),
15167 ('int', 99, int),
15168 ('bool', False, bool),
15169 ('tuple', (1, 2, 3, 4), Tuple[int, int, int, int]),
15170 ('list', [(1, 2), (3, 4)], List[Tuple[int, int]]),
15171 ('tensor', torch.randn(2, 2), torch.Tensor),
15172 ('int_list', [1, 2, 3, 4], List[int]),
15173 ('tensor_list', [torch.ones(2, 2) + i for i in range(4)], List[torch.Tensor]),
15174 ('bool_list', [True, True, False, True], List[bool]),
15175 ('float_list', [1., 2., 3., 4.], List[float]),
15176 ('str_list', ['hello', 'bye'], List[str]),
15177 ('none', None, Optional[int]),
15178 ('a_device', torch.device('cpu'), torch.device),
15179 ('another_device', torch.device('cuda:1'), torch.device))
davidriazati00d0ddb2019-05-10 17:01:17 -070015180
David Riazati3d443052019-03-18 18:15:17 -070015181 def test_attribute_serialization(self):
davidriazati00d0ddb2019-05-10 17:01:17 -070015182 tester = self
15183
David Riazati3d443052019-03-18 18:15:17 -070015184 class M(torch.jit.ScriptModule):
15185 def __init__(self):
Xuehai Pan046e88a2023-02-12 22:20:50 +000015186 super().__init__()
davidriazati00d0ddb2019-05-10 17:01:17 -070015187 for name, value, the_type in tester.get_pickle_values():
15188 setattr(self, name, torch.jit.Attribute(value, the_type))
David Riazati3d443052019-03-18 18:15:17 -070015189
15190 @torch.jit.script_method
15191 def forward(self):
Shen Li10224432021-08-12 11:39:31 -070015192 return (self.dict, self.float, self.int, self.bool, self.tuple,
15193 self.list, self.int_list, self.tensor_list, self.bool_list,
15194 self.float_list, self.str_list, self.none)
David Riazati3d443052019-03-18 18:15:17 -070015195
15196 m = M()
15197 imported_m = self.getExportImportCopy(m)
15198 self.assertEqual(m(), imported_m())
15199
David Riazati78f589e2019-04-16 15:03:47 -070015200 def test_string_len(self):
15201 def fn(x):
15202 # type: (str) -> int
15203 return len(x)
15204
15205 self.checkScript(fn, ("",))
15206 self.checkScript(fn, ("h",))
15207 self.checkScript(fn, ("hello",))
15208
nikithamalgi14d529a2021-04-15 14:05:47 -070015209 def test_multiline_optional_future_refinement(self):
15210 @torch.jit.script
15211 def fun() -> int:
Shen Li10224432021-08-12 11:39:31 -070015212 future: Optional[
15213 torch.jit.Future[Tuple[torch.Tensor]]
15214 ] = None
nikithamalgi14d529a2021-04-15 14:05:47 -070015215
15216 return 1
15217 self.assertEqual(fun(), 1)
15218
Shen Li10224432021-08-12 11:39:31 -070015219 @unittest.skipIf(IS_WINDOWS or IS_SANDCASTLE, "NYI: TemporaryFileName support for Windows or Sandcastle")
David Riazati3d443052019-03-18 18:15:17 -070015220 def test_attribute_unpickling(self):
davidriazati4294dba2019-04-29 13:37:57 -070015221 tensor = torch.randn(2, 2)
davidriazati00d0ddb2019-05-10 17:01:17 -070015222 tester = self
davidriazati4294dba2019-04-29 13:37:57 -070015223
David Riazati3d443052019-03-18 18:15:17 -070015224 class M(torch.jit.ScriptModule):
15225 def __init__(self):
Xuehai Pan046e88a2023-02-12 22:20:50 +000015226 super().__init__()
davidriazati00d0ddb2019-05-10 17:01:17 -070015227 for name, value, the_type in tester.get_pickle_values():
Michael Suo77c08aa2019-08-11 15:43:28 -070015228 setattr(self, "_" + name, torch.jit.Attribute(value, the_type))
David Riazati3d443052019-03-18 18:15:17 -070015229
15230 @torch.jit.script_method
15231 def forward(self):
Shen Li10224432021-08-12 11:39:31 -070015232 return (self._dict, self._float, self._int, self._bool, self._tuple,
15233 self._list, self._int_list, self._tensor_list, self._bool_list,
15234 self._float_list, self._str_list, self._none)
David Riazati3d443052019-03-18 18:15:17 -070015235
David Riazati3d443052019-03-18 18:15:17 -070015236 with TemporaryFileName() as fname:
15237 M().save(fname)
Michael Suo77c08aa2019-08-11 15:43:28 -070015238 loaded = torch.jit.load(fname)
davidriazati4294dba2019-04-29 13:37:57 -070015239
davidriazati00d0ddb2019-05-10 17:01:17 -070015240 def is_tensor_value(item):
15241 if isinstance(item, torch.Tensor):
15242 return True
15243 if isinstance(item, list):
15244 return is_tensor_value(item[0])
15245 return False
Michael Suo77c08aa2019-08-11 15:43:28 -070015246 for name, value, the_type in self.get_pickle_values():
15247 if is_tensor_value(value):
davidriazati00d0ddb2019-05-10 17:01:17 -070015248 continue
Michael Suo77c08aa2019-08-11 15:43:28 -070015249 self.assertEqual(value, getattr(loaded, "_" + name))
David Riazati3d443052019-03-18 18:15:17 -070015250
Shen Li10224432021-08-12 11:39:31 -070015251 @unittest.skipIf(IS_WINDOWS or IS_SANDCASTLE, "NYI: TemporaryFileName support for Windows or Sandcastle")
Nikita Shulga77becca2021-10-21 20:31:01 -070015252 @unittest.skipIf(not BUILD_WITH_CAFFE2, "PyTorch is build without Caffe2 support")
Vitaly Fedyunin86619b82019-04-15 09:13:49 -070015253 def test_old_models_bc(self):
15254 model = {
Shen Li10224432021-08-12 11:39:31 -070015255 'archive/version': b'1',
15256 'archive/code/archive.py':
15257 b'''
Vitaly Fedyunin86619b82019-04-15 09:13:49 -070015258 op_version_set = 0
15259 def forward(self,
15260 _0: Tensor) -> Tensor:
15261 _1 = torch.zeros([10], dtype=6, layout=0, device=torch.device("cpu"))
Tongzhou Wang93201d02019-04-16 20:25:11 -070015262 result = torch.to(torch.fill_(_1, 5), dtype=6, layout=0, device=torch.device("cpu"),
15263 non_blocking=False, copy=False)
Vitaly Fedyunin86619b82019-04-15 09:13:49 -070015264 result2 = torch.rand([10], dtype=6, layout=0, device=torch.device("cpu"))
15265 result3 = torch.rand_like(result2, dtype=6, layout=0, device=torch.device("cpu"))
15266 _2 = torch.add(torch.add(result, result2, alpha=1), result3, alpha=1)
15267 return _2
Shen Li10224432021-08-12 11:39:31 -070015268 ''',
15269 'archive/attributes.pkl': b'\x80\x02](e.',
15270 'archive/libs.py': b'op_version_set = 0\n',
15271 'archive/model.json':
15272 b'''
Vitaly Fedyunin86619b82019-04-15 09:13:49 -070015273 {
15274 "protoVersion":"2",
15275 "mainModule":{
15276 "torchscriptArena":{
15277 "key":"code/archive.py"
15278 },
15279 "name":"archive",
15280 "optimize":true
15281 },
15282 "producerName":"pytorch",
15283 "producerVersion":"1.0",
15284 "libs":{
15285 "torchscriptArena":{
15286 "key":"libs.py"
15287 }
15288 }
Shen Li10224432021-08-12 11:39:31 -070015289 }'''}
Vitaly Fedyunin86619b82019-04-15 09:13:49 -070015290 with TemporaryFileName() as fname:
15291 archive_name = os.path.basename(os.path.normpath(fname))
Shen Li10224432021-08-12 11:39:31 -070015292 with zipfile.ZipFile(fname, 'w') as archive:
Vitaly Fedyunin86619b82019-04-15 09:13:49 -070015293 for k, v in model.items():
15294 archive.writestr(k, v)
15295
15296 with open(fname, "rb") as f:
15297 fn = torch.jit.load(f)
15298
15299 x = torch.zeros(10)
15300 fn(x)
15301
David Riazati3d443052019-03-18 18:15:17 -070015302 def test_submodule_attribute_serialization(self):
15303 class S(torch.jit.ScriptModule):
15304 def __init__(self, list_data):
Xuehai Pan046e88a2023-02-12 22:20:50 +000015305 super().__init__()
Shen Li10224432021-08-12 11:39:31 -070015306 self.table = torch.jit.Attribute({"I": "am", "a test": "test"}, Dict[str, str])
David Riazati3d443052019-03-18 18:15:17 -070015307 self.list = torch.jit.Attribute(list_data, List[Tuple[int, int]])
15308
15309 @torch.jit.script_method
15310 def forward(self):
15311 return (self.table, self.list)
15312
15313 class M(torch.jit.ScriptModule):
15314 def __init__(self):
Xuehai Pan046e88a2023-02-12 22:20:50 +000015315 super().__init__()
Shen Li10224432021-08-12 11:39:31 -070015316 self.table = torch.jit.Attribute({"this": "is", "a different": "dict"}, Dict[str, str])
David Riazati3d443052019-03-18 18:15:17 -070015317 self.tensor = torch.jit.Attribute(torch.randn(2, 2), torch.Tensor)
15318 self.s1 = S([(1, 2)])
15319 self.s2 = S([(4, 5)])
15320
15321 @torch.jit.script_method
15322 def forward(self):
Shen Li10224432021-08-12 11:39:31 -070015323 return (self.table, self.tensor, self.s1.table, self.s2.list, self.s1.list)
David Riazati3d443052019-03-18 18:15:17 -070015324
15325 m = M()
15326 imported_m = self.getExportImportCopy(m)
15327 self.assertEqual(m(), imported_m())
15328
David Riazati24db1662019-03-29 19:06:06 -070015329 def test_serialization_big_ints(self):
15330 class M(torch.jit.ScriptModule):
15331 def __init__(self):
Xuehai Pan046e88a2023-02-12 22:20:50 +000015332 super().__init__()
Shen Li10224432021-08-12 11:39:31 -070015333 self.int32_max = torch.jit.Attribute(2**31 - 1, int)
15334 self.int32_min = torch.jit.Attribute(-2**31, int)
15335 self.uint32_max = torch.jit.Attribute(2**32, int)
David Riazati24db1662019-03-29 19:06:06 -070015336
Shen Li10224432021-08-12 11:39:31 -070015337 self.int64_max = torch.jit.Attribute(2**63 - 1, int)
15338 self.int64_min = torch.jit.Attribute(-2**63, int)
David Riazati24db1662019-03-29 19:06:06 -070015339
15340 self.tensor = torch.nn.Parameter(torch.ones(2, 2))
15341
15342 @torch.jit.script_method
15343 def forward(self, x):
15344 # type: (int) -> (int)
Shen Li10224432021-08-12 11:39:31 -070015345 return x + (self.int32_max + self.int32_min) + (self.int64_max + self.int64_min)
David Riazati24db1662019-03-29 19:06:06 -070015346
15347 m = M()
15348 imported = self.getExportImportCopy(m)
15349 self.assertEqual(m(10), imported(10))
15350
15351 self.assertEqual(m.int32_max, imported.int32_max)
15352 self.assertEqual(m.int32_min, imported.int32_min)
15353 self.assertEqual(m.uint32_max, imported.uint32_max)
15354 self.assertEqual(m.int64_max, imported.int64_max)
15355 self.assertEqual(m.int64_min, imported.int64_min)
15356
davidriazati28917842019-07-22 11:36:23 -070015357 def test_script_scope(self):
Peter Bellcb37e7a2022-04-22 17:34:59 +010015358 scripted = torch.jit.script(torch.nn.functional.triplet_margin_loss)
davidriazati28917842019-07-22 11:36:23 -070015359
Michael Suo0c6ee942019-09-04 12:20:51 -070015360 @unittest.skipIf(IS_WINDOWS, "NYI: TemporaryFileName on Windows")
David Riazati24db1662019-03-29 19:06:06 -070015361 def test_serialization_sharing(self):
15362 class M(torch.jit.ScriptModule):
15363 def __init__(self):
Xuehai Pan046e88a2023-02-12 22:20:50 +000015364 super().__init__()
David Riazati24db1662019-03-29 19:06:06 -070015365 self.list = torch.jit.Attribute([], List[str])
15366
15367 @torch.jit.script_method
15368 def forward(self, key):
15369 # type: (str) -> List[str]
15370 self.list.append(key)
15371 self.list.append(key)
15372 self.list.append(key)
15373 return self.list
15374
15375 # the text of the string should only appear once in the pickling
15376 m = M()
15377 s1 = "a long string"
15378 s2 = "a different, even longer string"
15379 self.assertEqual(m(s1), [s1] * 3)
15380 self.assertEqual(m(s2), [s1] * 3 + [s2] * 3)
15381 with TemporaryFileName() as fname:
15382 m.save(fname)
15383 archive_name = os.path.basename(os.path.normpath(fname))
Shen Li10224432021-08-12 11:39:31 -070015384 archive = zipfile.ZipFile(fname, 'r')
15385 pickled_data = archive.read(os.path.join(archive_name, 'data.pkl'))
David Riazati24db1662019-03-29 19:06:06 -070015386
Chester Liu58eb2332021-02-08 13:56:12 -080015387 out = io.StringIO()
David Riazati24db1662019-03-29 19:06:06 -070015388 pickletools.dis(pickled_data, out=out)
15389 disassembled = out.getvalue()
15390
Shen Li10224432021-08-12 11:39:31 -070015391 FileCheck().check_count(s1, 1, exactly=True) \
15392 .check_count("BINGET", 2, exactly=True) \
15393 .check_count(s2, 1, exactly=True) \
15394 .check_count("BINGET", 2, exactly=True).run(out.getvalue())
David Riazati24db1662019-03-29 19:06:06 -070015395
James Reedc2a18a62019-06-11 22:55:16 -070015396 def test_sys_stdout_override(self):
15397 @torch.jit.script
15398 def foo():
Shen Li10224432021-08-12 11:39:31 -070015399 print('foo')
James Reedc2a18a62019-06-11 22:55:16 -070015400
Aaron Gokaslan8fce9a02023-02-07 21:10:52 +000015401 class Redirect:
James Reedc2a18a62019-06-11 22:55:16 -070015402 def __init__(self):
Shen Li10224432021-08-12 11:39:31 -070015403 self.s = ''
James Reedc2a18a62019-06-11 22:55:16 -070015404
15405 def write(self, s):
15406 self.s += s
15407
15408 old_stdout = sys.stdout
15409 redirect = Redirect()
15410 try:
15411 sys.stdout = redirect
15412 foo()
15413 finally:
15414 sys.stdout = old_stdout
15415
Shen Li10224432021-08-12 11:39:31 -070015416 FileCheck().check('foo').run(redirect.s)
James Reedc2a18a62019-06-11 22:55:16 -070015417
James Reed489cc462019-08-05 19:25:58 -070015418 def test_dtype_attr(self):
15419 class Foo(torch.nn.Module):
15420 def __init__(self):
Xuehai Pan046e88a2023-02-12 22:20:50 +000015421 super().__init__()
James Reed489cc462019-08-05 19:25:58 -070015422 self.dtype = torch.zeros([]).dtype
15423
15424 def forward(self):
15425 return torch.zeros(3, 4, dtype=self.dtype)
15426
15427 f = Foo()
15428 torch.jit.script(f)
15429
Shen Li10224432021-08-12 11:39:31 -070015430
Michael Voznesenskyf6f13842020-05-18 23:21:27 -070015431 def test_named_buffers_are_iterable(self):
15432 class MyMod(torch.nn.Module):
15433 def __init__(self):
Xuehai Pan046e88a2023-02-12 22:20:50 +000015434 super().__init__()
Shen Li10224432021-08-12 11:39:31 -070015435 self.mod = (torch.nn.ReLU())
15436 self.mod2 = (torch.nn.ReLU())
Michael Voznesenskyf6f13842020-05-18 23:21:27 -070015437 self.mod3 = torch.nn.Sequential(torch.nn.Sequential(torch.nn.ReLU()))
Shen Li10224432021-08-12 11:39:31 -070015438 self.register_buffer('x', torch.zeros(3))
15439 self.register_buffer('y', torch.zeros(3))
Michael Voznesenskyf6f13842020-05-18 23:21:27 -070015440 self.z = torch.zeros(3)
15441
15442 def bleh(self):
15443 return self.z + 4
15444
15445 @torch.jit.export
15446 def method(self):
15447 names = [""]
15448 vals = []
15449 for name, buffer in self.named_buffers():
15450 names.append(name)
15451 vals.append(buffer + 2)
15452
15453 return names, vals
15454
15455 def forward(self, x):
15456 return x
15457
15458 model = MyMod()
15459 x = torch.jit.script(model)
15460 z = self.getExportImportCopy(x)
15461
15462 self.assertEqual(z.method(), x.method())
15463 self.assertEqual(z.method(), model.method())
15464 self.assertEqual(x.method(), model.method())
15465 names = x.method()
15466 for name in names:
Shen Li10224432021-08-12 11:39:31 -070015467 self.assertNotEqual('z', name)
15468
Michael Voznesenskyf6f13842020-05-18 23:21:27 -070015469
Zachary DeVito9097b552020-04-02 11:41:26 -070015470 def test_static_if_prop(self):
15471 class MaybeHasAttr(torch.nn.Module):
15472 def __init__(self, add_attr):
Xuehai Pan046e88a2023-02-12 22:20:50 +000015473 super().__init__()
Zachary DeVito9097b552020-04-02 11:41:26 -070015474 if add_attr:
15475 self.maybe_attr = 1
15476
15477 def forward(self):
15478 if hasattr(self, "maybe_attr") and True:
15479 return self.maybe_attr
15480 else:
15481 return 0
15482
15483 class MaybeHasAttr2(torch.nn.Module):
15484 def __init__(self, add_attr):
Xuehai Pan046e88a2023-02-12 22:20:50 +000015485 super().__init__()
Zachary DeVito9097b552020-04-02 11:41:26 -070015486 if add_attr:
15487 self.maybe_attr = 1
15488
15489 def forward(self):
15490 if not hasattr(self, "maybe_attr") or False:
15491 return 0
15492 else:
15493 return self.maybe_attr
15494
15495 torch.jit.script(MaybeHasAttr(True))
15496 torch.jit.script(MaybeHasAttr(False))
15497 torch.jit.script(MaybeHasAttr2(True))
15498 torch.jit.script(MaybeHasAttr2(False))
15499
Elias Ellison23d04412020-05-05 09:04:50 -070015500 class MyMod(torch.nn.Module):
15501 def forward(self):
15502 if hasattr(self, "foo"):
15503 return 1
15504 else:
15505 return 0
Zachary DeVito9097b552020-04-02 11:41:26 -070015506
Elias Ellison23d04412020-05-05 09:04:50 -070015507 @torch.jit.export
15508 def fee(self):
15509 return 1
15510
15511 self.checkModule(MyMod(), ())
15512
15513 class HasAttrMod(torch.nn.Module):
15514 __constants__ = ["fee"]
15515
15516 def __init__(self):
15517 super().__init__()
15518 self.fee = 3
15519
15520 def forward(self):
15521 a = hasattr(self, "fee")
15522 b = hasattr(self, "foo")
15523 c = hasattr(self, "hi")
15524 d = hasattr(self, "nonexistant")
15525 return (a, b, c, d)
15526
15527 def foo(self):
15528 return 1
15529
15530 @torch.jit._overload_method
Shen Li10224432021-08-12 11:39:31 -070015531 def hi(self, x: Tensor): ... # noqa: E704
Elias Ellison23d04412020-05-05 09:04:50 -070015532
15533 def hi(self, x): # noqa: F811
15534 return 2
15535
15536 self.checkModule(HasAttrMod(), ())
15537
15538 @torch.jit.script
Aaron Gokaslan8fce9a02023-02-07 21:10:52 +000015539 class FooTest:
Elias Ellison23d04412020-05-05 09:04:50 -070015540 def __init__(self):
15541 self.x = 1
15542
15543 def foo(self, y):
15544 return self.x + y
15545
15546 def foo():
15547 a = FooTest()
15548 val1 = hasattr(a, "foo"), hasattr(a, "x"), hasattr(a, "bla")
15549 val2 = hasattr(FooTest, "foo"), hasattr(FooTest, "a")
15550 return val1, val2
15551
15552 self.assertEqual(foo(), torch.jit.script(foo)())
Zachary DeVito9097b552020-04-02 11:41:26 -070015553
davidriazati8ebb86d2019-05-08 16:40:37 -070015554 def _test_pickle_checkpoint(self, device):
15555 with TemporaryFileName() as fname:
15556 class M(torch.jit.ScriptModule):
Shen Li10224432021-08-12 11:39:31 -070015557 __constants__ = ['fname']
davidriazati8ebb86d2019-05-08 16:40:37 -070015558
15559 def __init__(self, tensor):
Xuehai Pan046e88a2023-02-12 22:20:50 +000015560 super().__init__()
davidriazati8ebb86d2019-05-08 16:40:37 -070015561 self.fname = fname
15562 self.tensor = torch.nn.Parameter(tensor)
15563
15564 @torch.jit.script_method
15565 def forward(self, x):
15566 y = self.tensor + x
15567 torch.save(y, self.fname)
15568 return y
15569
15570 param = torch.randn(2, 2).to(device)
15571 input = torch.randn(2, 2).to(device)
15572 m = M(param)
15573 m(input)
15574 with open(fname, "rb") as handle:
15575 loaded_tensor = torch.load(fname)
15576 self.assertEqual(loaded_tensor, input + param)
15577
15578 def _test_pickle_checkpoint_views(self, device):
15579 with TemporaryFileName() as fname:
15580 class M(torch.jit.ScriptModule):
Shen Li10224432021-08-12 11:39:31 -070015581 __constants__ = ['fname']
davidriazati8ebb86d2019-05-08 16:40:37 -070015582
15583 def __init__(self, tensor):
Xuehai Pan046e88a2023-02-12 22:20:50 +000015584 super().__init__()
davidriazati8ebb86d2019-05-08 16:40:37 -070015585 self.fname = fname
15586 self.tensor = torch.nn.Parameter(tensor)
15587
15588 @torch.jit.script_method
15589 def forward(self, x):
15590 y = self.tensor + x
15591 y_view = y.view(4)
15592 torch.save((y, y_view, y), self.fname)
15593 return y
15594
15595 param = torch.randn(2, 2).to(device)
15596 input = torch.randn(2, 2).to(device)
15597 m = M(param)
15598 m(input)
15599 with open(fname, "rb") as handle:
15600 loaded_y, loaded_y_view, loaded_y_2 = torch.load(fname)
15601 self.assertEqual(loaded_y, input + param)
15602 with torch.no_grad():
15603 loaded_y_view[1] += 20
15604 # assert that loaded_y changed as well
15605 self.assertEqual(loaded_y.view(4), loaded_y_view)
15606 self.assertEqual(loaded_y_2.view(4), loaded_y_view)
15607
15608 @unittest.skipIf(not RUN_CUDA, "no CUDA")
15609 def test_pickle_checkpoint_cuda(self):
Shen Li10224432021-08-12 11:39:31 -070015610 self._test_pickle_checkpoint('cuda')
15611 self._test_pickle_checkpoint_views('cuda')
davidriazati8ebb86d2019-05-08 16:40:37 -070015612
15613 def test_pickle_checkpoint(self):
Shen Li10224432021-08-12 11:39:31 -070015614 self._test_pickle_checkpoint('cpu')
15615 self._test_pickle_checkpoint_views('cpu')
davidriazati8ebb86d2019-05-08 16:40:37 -070015616
Zachary DeVito93da1032019-07-24 17:10:02 -070015617 def test_pickle_checkpoint_tup(self):
15618 @torch.jit.script
15619 def foo(fname):
15620 # type: (str) -> None
15621 torch.save((3, 4), fname)
15622 with TemporaryFileName() as name:
15623 foo(name)
15624 self.assertEqual(torch.load(name), (3, 4))
15625
davidriazatic819d762019-05-17 14:45:01 -070015626 def test_string_list(self):
15627 def fn(string):
15628 # type: (str) -> List[str]
15629 return list(string)
15630
15631 self.checkScript(fn, ("abcdefgh",))
15632
davidriazatie0e58132019-08-19 16:31:36 -070015633 def test_unicode_comments(self):
15634 @torch.jit.script
15635 def test(self, a):
15636 # 🤷🤷🤷🤷
15637 return torch.nn.functional.relu(a)
15638
davidriazati9d1acd62019-08-07 12:29:26 -070015639 def test_get_set_state_with_tensors(self):
15640 class M(torch.nn.Module):
15641 def __init__(self):
Xuehai Pan046e88a2023-02-12 22:20:50 +000015642 super().__init__()
davidriazati9d1acd62019-08-07 12:29:26 -070015643 self.tensor = torch.randn(2, 2)
15644
15645 @torch.jit.export
15646 def __getstate__(self):
davidriazati00460922019-10-07 13:50:23 -070015647 return (self.tensor, self.training)
davidriazati9d1acd62019-08-07 12:29:26 -070015648
15649 @torch.jit.export
15650 def __setstate__(self, state):
davidriazati9d1acd62019-08-07 12:29:26 -070015651 self.tensor = state[0]
davidriazati00460922019-10-07 13:50:23 -070015652 self.training = state[1]
davidriazati9d1acd62019-08-07 12:29:26 -070015653
15654 def forward(self, x):
15655 return x + self.tensor
15656
15657 with TemporaryFileName() as fname:
15658 m = torch.jit.script(M())
15659 m.save(fname)
15660 loaded = torch.jit.load(fname)
15661 self.assertEqual(loaded.tensor, m.tensor)
15662
davidriazati5eb25c32019-06-18 09:43:45 -070015663 def test_in_for_and_comp_expr(self):
15664 def fn(d):
15665 # type: (Dict[str, int]) -> List[int]
15666 out = [1]
15667 for i in range(d["hi"] if "hi" in d else 6):
15668 out.append(i)
15669 return out
15670
Shen Li10224432021-08-12 11:39:31 -070015671 self.checkScript(fn, ({'hi': 2, 'bye': 3},))
15672 self.checkScript(fn, ({'bye': 3},))
davidriazati5eb25c32019-06-18 09:43:45 -070015673
Nikitha Malgib955da32021-01-28 08:14:35 -080015674 def test_for_else(self):
15675 def fn():
15676 c = 0
15677 for i in range(4):
15678 c += 10
15679 else:
15680 print("In else block of for...else")
15681
Shen Li10224432021-08-12 11:39:31 -070015682 with self.assertRaisesRegex(torch.jit.frontend.NotSupportedError, "else branches of for loops aren't supported"):
Nikitha Malgib955da32021-01-28 08:14:35 -080015683 torch.jit.script(fn)
15684
Michael Suo066d1582019-03-11 15:09:00 -070015685 def test_split(self):
15686 def split_two(tensor):
15687 a, b, c = torch.split(tensor, 2, dim=1)
15688 return a, b, c
15689 x = torch.randn(3, 6)
15690 y = torch.randn(3, 6)
15691 self.checkScript(split_two, [(x + y)])
15692
davidriazati3858e162019-05-24 14:55:13 -070015693 def test_conv_error(self):
15694 @torch.jit.script
15695 def fn(x, y):
15696 return F.conv2d(x, y)
15697
15698 try:
15699 fn(torch.ones(2, 2), torch.ones(4, 4))
15700 except RuntimeError as e:
Shen Li10224432021-08-12 11:39:31 -070015701 self.assertFalse('frame' in str(e))
davidriazati3858e162019-05-24 14:55:13 -070015702
davidriazatic267d0c2019-05-17 15:23:58 -070015703 def test_python_op_name(self):
15704 import random
15705
15706 with self.assertRaisesRegex(RuntimeError, "randint"):
15707 @torch.jit.script
15708 def fn():
15709 return random.randint()
15710
David Riazati2c18bf22019-07-01 19:19:44 -070015711 def test_dir(self):
15712 class M(torch.jit.ScriptModule):
15713 def forward(self, t):
15714 return t
15715
Shen Li10224432021-08-12 11:39:31 -070015716 self.assertTrue('forward' in dir(M()))
David Riazati2c18bf22019-07-01 19:19:44 -070015717
davidriazati756bdcb2019-07-30 17:20:16 -070015718 def test_kwarg_expansion_error(self):
15719 @torch.jit.ignore
15720 def something_else(h, i):
15721 pass
15722
15723 def fn(x):
15724 something_else(**x)
15725
Shen Li10224432021-08-12 11:39:31 -070015726 with self.assertRaisesRegex(torch.jit.frontend.NotSupportedError, "keyword-arg expansion is not supported"):
davidriazati756bdcb2019-07-30 17:20:16 -070015727 torch.jit.script(fn)
15728
davidriazati995920a2019-08-02 11:16:51 -070015729 def test_kwargs_error_msg(self):
15730 def other(**kwargs):
15731 print(kwargs)
15732
15733 def fn():
15734 return other()
15735
Shen Li10224432021-08-12 11:39:31 -070015736 with self.assertRaisesRegex(torch.jit.frontend.NotSupportedError, 'variable number'):
davidriazati995920a2019-08-02 11:16:51 -070015737 torch.jit.script(fn)
15738
15739 def another_other(*args):
15740 print(args)
15741
15742 def another_fn():
15743 return another_other()
15744
Shen Li10224432021-08-12 11:39:31 -070015745 with self.assertRaisesRegex(torch.jit.frontend.NotSupportedError, 'variable number'):
davidriazati995920a2019-08-02 11:16:51 -070015746 torch.jit.script(another_fn)
15747
Michael Suob6d1a722019-05-30 10:46:52 -070015748 def test_inferred_error_msg(self):
15749 """
15750 Test that when we get a type mismatch on a function where we inferred
15751 the type to be tensor, a good error message is given.
15752 """
15753 @torch.jit.script
15754 def foo(a):
15755 return a
15756
Shen Li10224432021-08-12 11:39:31 -070015757 with self.assertRaisesRegex(RuntimeError, (r"Expected a value of type \'Tensor \(inferred\)\'"
15758 r"[\S\s]*Inferred \'a\' to be of type \'Tensor\'")):
Edward Z. Yangf2eed942022-05-06 09:24:42 -070015759 foo("1")
Michael Suo52456b22019-11-08 13:54:38 -080015760
Alexander Stantef30b14d2019-12-12 18:17:42 -080015761 def test_type_comments_in_body(self):
15762 @torch.jit.script
Shen Li10224432021-08-12 11:39:31 -070015763 def foo(a, # type: int
15764 b, # type: int
15765 ):
Alexander Stantef30b14d2019-12-12 18:17:42 -080015766 # type: (...) -> int
15767 # type: int
15768 return a + b
15769
15770 class M(torch.nn.Module):
Shen Li10224432021-08-12 11:39:31 -070015771 def __init__(self,
15772 a, # type: int
15773 b # type: int
15774 ):
Alexander Stantef30b14d2019-12-12 18:17:42 -080015775 # type: (...) -> None
Xuehai Pan046e88a2023-02-12 22:20:50 +000015776 super().__init__()
Alexander Stantef30b14d2019-12-12 18:17:42 -080015777 self.a = a # type: int
15778 self.b = b # type: int
15779
15780 torch.jit.script(M(2, 3))
15781
Tugsbayasgalan (Tugsuu) Manlaibaatare59403f2022-02-24 11:52:31 -080015782 def test_input_keyword_in_schema(self):
15783 def f(x):
15784 return torch.ceil(input=x)
15785
15786 inp = torch.randn(10)
15787 self.checkScript(f, (inp, ))
15788
Michael Suo167a9782020-05-12 23:17:57 -070015789 def test_module_method_reassignment(self):
15790 class Foo(torch.nn.Module):
Michael Suo167a9782020-05-12 23:17:57 -070015791 def _forward(self, x):
15792 return x
15793
15794 forward = _forward
15795
15796 sm = torch.jit.script(Foo())
15797 input = torch.ones(2, 2)
15798 self.assertEqual(input, sm(input))
15799
Yanan Cao00311082020-06-05 11:52:20 -070015800 # Tests the case where a torch.Tensor subclass (like Parameter) is used as
15801 # input.
15802 def test_script_module_tensor_subclass_argument(self):
15803 @torch.jit.script
15804 def parameter_script(x: torch.nn.Parameter):
15805 return x
15806
15807 input = torch.ones(2, 2)
15808 self.assertEqual(input, parameter_script(input))
15809
Michael Suo74f18472020-09-03 14:58:02 -070015810 def test_save_load_attr_error(self):
15811 class Inner(nn.Module):
Michael Suo74f18472020-09-03 14:58:02 -070015812 def forward(self, x):
15813 return x
15814
15815 class Wrapper(nn.Module):
15816 def __init__(self, inner):
15817 super().__init__()
15818 self.inner = inner
15819
15820 def forward(self, x):
15821 # this attribute doesn't exist on `Inner`
15822 return self.inner.b(x)
15823
15824 inner_module = torch.jit.script(Inner())
15825 inner_module = self.getExportImportCopy(inner_module)
15826 wrapped = Wrapper(inner_module)
15827 # This should properly complain that `self.inner` doesn't have the attribute `b`
Shen Li10224432021-08-12 11:39:31 -070015828 with self.assertRaisesRegex(RuntimeError, 'has no attribute'):
Michael Suo74f18472020-09-03 14:58:02 -070015829 torch.jit.script(wrapped)
15830
Michael Suo9dd86702020-09-03 14:58:02 -070015831 def test_rescripting_loaded_modules(self):
15832 class InnerSubmod(nn.Module):
Shen Li10224432021-08-12 11:39:31 -070015833 __constants__ = ['my_constant']
Michael Suo9dd86702020-09-03 14:58:02 -070015834
15835 def __init__(self):
15836 super().__init__()
15837 self.register_buffer("foo", torch.ones(1))
15838 self.register_parameter("bar", torch.nn.Parameter(torch.ones(1)))
15839 self.baz = torch.ones(1)
15840 self.my_constant = 1
15841
15842 def forward(self, x):
15843 return x + x
15844
15845 class Inner(nn.Module):
15846 def __init__(self):
15847 super().__init__()
15848 self.submod = InnerSubmod()
15849
15850 def forward(self, x):
15851 return self.submod(x)
15852
15853 class Wrapper(nn.Module):
15854 def __init__(self, inner):
15855 super().__init__()
15856 self.inner = inner
15857
15858 def forward(self, x):
15859 # access inner elements
Shen Li10224432021-08-12 11:39:31 -070015860 ret = self.inner.submod(x) + self.inner.submod.foo + self.inner.submod.bar + self.inner.submod.baz
Michael Suo9dd86702020-09-03 14:58:02 -070015861 ret = ret + self.inner.submod.my_constant
15862 return ret
15863
15864 inner_module = torch.jit.script(Inner())
15865 wrapped = Wrapper(inner_module)
15866 self.checkModule(wrapped, torch.ones(1))
15867
15868 inner_module_loaded = self.getExportImportCopy(inner_module)
15869 wrapped_loaded = Wrapper(inner_module_loaded)
15870 self.assertEqual(wrapped(torch.ones(1)), wrapped_loaded(torch.ones(1)))
15871
Mikhail Zolotukhinc6febc62020-09-11 02:53:51 -070015872 def test_interpret_graph(self):
15873 def fn(x):
15874 return x.unfold(0, 1, 1)
15875
15876 graph_str = """
15877 graph(%a : Tensor, %b : Tensor):
15878 %c : Tensor = aten::mul(%a, %b)
15879 return (%c)
15880 """
15881 graph = parse_ir(graph_str)
15882 a = torch.rand(10)
15883 b = torch.rand(10)
15884 test = torch._C._jit_interpret_graph(graph, (a, b))
15885 ref = a * b
15886 self.assertEqual(test, ref)
Yanan Cao00311082020-06-05 11:52:20 -070015887
Tugsbayasgalan Manlaibaatar29184f82020-11-07 00:36:44 -080015888 def test_signed_float_zero(self):
Shen Li10224432021-08-12 11:39:31 -070015889
Tugsbayasgalan Manlaibaatar29184f82020-11-07 00:36:44 -080015890 class MyModule(torch.nn.Module):
Tugsbayasgalan Manlaibaatar29184f82020-11-07 00:36:44 -080015891 def forward(self, x):
Shen Li10224432021-08-12 11:39:31 -070015892 return torch.div(x, -0.)
Tugsbayasgalan Manlaibaatar29184f82020-11-07 00:36:44 -080015893
15894 inp = torch.ones(1)
15895 self.checkModule(MyModule(), inp)
15896
Michael Suoc10908c2022-06-13 10:04:03 -070015897 def test_index_with_tuple(self):
15898 class MyModule(torch.nn.Module):
Michael Suoc10908c2022-06-13 10:04:03 -070015899 def forward(self, x):
15900 return x[(1,)]
15901
15902 self.checkModule(MyModule(), (torch.ones(2, 3),))
15903
Wei-Sheng Chinf1144682022-07-22 21:34:47 +000015904 def test_context_manager(self):
15905 class MyModule(torch.nn.Module):
Wei-Sheng Chinf1144682022-07-22 21:34:47 +000015906 def forward(self, x, y):
15907 p = x + y
15908 q = p + 2.0
15909 return q
15910
15911 x = torch.randn(3, 2, dtype=torch.float)
15912 y = torch.randn(3, 2, dtype=torch.float)
15913 for fuser_name in ['fuser0', 'fuser1', 'none']:
15914 with torch.jit.fuser(fuser_name):
15915 self.checkModule(MyModule(), (x, y))
15916
Zachary DeVito2f25d1f2018-06-14 11:48:15 -070015917# known to be failing in tracer
15918EXCLUDE_TRACED = {
Richard Zouc8a0b112018-09-26 07:56:45 -070015919 # The following fail due to #12024.
Elias Ellison221eddd2019-02-27 18:59:19 -080015920 # A prim::ListConstruct is involved and the indices get traced as TensorType,
Richard Zouc8a0b112018-09-26 07:56:45 -070015921 # which always require_grad. This causes a crash in autodiff.
Shen Li10224432021-08-12 11:39:31 -070015922 'test___getitem___adv_index',
15923 'test___getitem___adv_index_beg',
15924 'test___getitem___adv_index_comb',
15925 'test___getitem___adv_index_dup',
15926 'test___getitem___adv_index_sub',
15927 'test___getitem___adv_index_sub_2',
15928 'test___getitem___adv_index_sub_3',
15929 'test___getitem___adv_index_var',
15930
Brian Vaughan8a9ea552019-06-06 13:55:47 -070015931 # jit doesn't support sparse tensors.
Shen Li10224432021-08-12 11:39:31 -070015932 'test_to_sparse',
15933 'test_to_sparse_dim',
Zachary DeVito2f25d1f2018-06-14 11:48:15 -070015934}
15935
Adam Paszke0ddbe662018-09-11 05:56:17 -070015936EXCLUDE_TYPE_CHECK = {
15937 # slogdet tests use itemgetter to select its only differentiable output,
15938 # but this happens outside of the graph we handle, so there are fewer
15939 # reference outputs than graph outputs.
Shen Li10224432021-08-12 11:39:31 -070015940 'test_slogdet_1x1_neg_det',
15941 'test_slogdet_1x1_pos_det',
15942 'test_slogdet_distinct_singular_values',
15943 'test_slogdet_neg_det',
15944 'test_slogdet_pos_det',
15945 'test_slogdet_symmetric',
15946 'test_slogdet_symmetric_pd',
15947 'test_slogdet_batched_1x1_neg_det',
15948 'test_slogdet_batched_pos_det',
15949 'test_slogdet_batched_symmetric',
15950 'test_slogdet_batched_symmetric_pd',
15951 'test_slogdet_batched_distinct_singular_values'
Adam Paszke0ddbe662018-09-11 05:56:17 -070015952}
15953
Ailing Zhang9c875432019-03-31 08:41:46 -070015954# chunk returns a list in scripting and we don't unpack the list,
15955# Thus it won't be replaced by ConstantChunk and run AD.
15956# It's explicitly checked in test_chunk_constant_script_ad
Zachary DeVito1abbee02019-04-10 18:12:38 -070015957# Similary for split, it's replaced by split_with_sizes in tracing,
15958# but we don't have AD formula for aten::split(Tensor, int[], int),
15959# an op registered in JIT so AD is not triggered in scripting.
Ailing Zhang9c875432019-03-31 08:41:46 -070015960EXCLUDE_SCRIPT_AD_CHECK = {
Shen Li10224432021-08-12 11:39:31 -070015961 'test_chunk',
15962 'test_chunk_dim',
15963 'test_chunk_dim_neg0',
15964 'test_split_size_list',
15965 'test_split_size_list_dim',
15966 'test_split_size_list_dim_neg0',
15967 'test_tensor_indices_sections',
15968 'test_tensor_indices_sections_dim',
15969 'test_tensor_indices_sections_dim_neg0',
15970 'test_tensor_split_sections',
15971 'test_tensor_split_sections_dim',
15972 'test_tensor_split_sections_dim_neg0'
Ailing Zhang9c875432019-03-31 08:41:46 -070015973}
15974
David Riazati666d3832018-11-29 15:13:45 -080015975EXCLUDE_PYTHON_PRINT = {
David Riazatia23863f2018-12-03 23:49:39 -080015976 # no support for BroadcastingList in python printer
Shen Li10224432021-08-12 11:39:31 -070015977 'test_nn_max_unpool1d',
15978 'test_nn_max_unpool2d',
15979 'test_nn_max_unpool3d',
15980 'test_nn_max_pool1d',
15981 'test_nn_max_pool2d',
15982 'test_nn_max_pool3d',
15983 'test_nn_max_pool1d_with_indices',
Zachary DeVito2f25d1f2018-06-14 11:48:15 -070015984}
15985
Mike Ruberry686e2812020-09-14 15:43:21 -070015986EXCLUDE_ALIAS = {
15987 # aliases, which may appear in method_tests but are tested elsewhere
Shen Li10224432021-08-12 11:39:31 -070015988 'true_divide',
15989
Nikita Vedeneevc31ced42020-10-23 10:11:05 -070015990 # Disable tests for lu from common_methods_invocations.py
15991 # TODO(@nikitaved) Enable jit tests once autograd.Function does support scripting
Shen Li10224432021-08-12 11:39:31 -070015992 'lu'
Mike Ruberry686e2812020-09-14 15:43:21 -070015993}
15994
Zachary DeVito2f25d1f2018-06-14 11:48:15 -070015995
Jason Anselae57bd62023-02-14 19:06:50 +000015996@skipIfTorchDynamo()
David Riazati3270e4d2019-01-03 14:31:09 -080015997class TestJitGeneratedModule(JitTestCase):
15998 pass
15999
16000
Jason Anselae57bd62023-02-14 19:06:50 +000016001@skipIfTorchDynamo()
David Riazati3270e4d2019-01-03 14:31:09 -080016002class TestJitGeneratedFunctional(JitTestCase):
16003 pass
16004
Will Fengff501c32018-07-05 15:43:29 -070016005# UBSAN per-function exclusions don't seem to work with OpenMP pragmas,
16006# and we have to disable the failing tests here instead.
Meghan Lelef85a27e2020-07-17 11:27:07 -070016007UBSAN_DISABLED_TESTS = [
Mike Ruberry4f761f32020-06-19 18:29:36 -070016008 "test___rdiv___constant",
16009 "test___rdiv___scalar_constant",
Will Fengff501c32018-07-05 15:43:29 -070016010 "test_addcdiv",
16011 "test_addcdiv_broadcast_all",
16012 "test_addcdiv_broadcast_rhs",
16013 "test_addcdiv_scalar",
16014 "test_addcdiv_scalar_broadcast_lhs",
16015 "test_addcdiv_scalar_broadcast_rhs",
16016 "test_addcdiv_scalar_scale",
16017 "test_addcdiv_scalar_scale_broadcast_lhs",
16018 "test_addcdiv_scalar_scale_broadcast_rhs",
16019 "test_addcdiv_scale",
16020 "test_addcdiv_scale_broadcast_all",
16021 "test_addcdiv_scale_broadcast_rhs",
16022 "test_add_broadcast_all",
16023 "test_add_broadcast_lhs",
16024 "test_add_broadcast_rhs",
16025 "test_add_constant",
16026 "test_add_scalar",
16027 "test_add_scalar_broadcast_lhs",
16028 "test_add_scalar_broadcast_rhs",
16029 "test_div",
16030 "test_div_broadcast_all",
16031 "test_div_broadcast_lhs",
16032 "test_div_broadcast_rhs",
16033 "test_div_scalar",
16034 "test_div_scalar_broadcast_lhs",
16035 "test_div_scalar_broadcast_rhs",
16036 "test_rsqrt",
16037 "test_rsqrt_scalar",
16038 "test_add",
16039 "test_reciprocal",
16040 "test_reciprocal_scalar",
16041]
16042
Wanchao Liang52058202018-08-17 11:05:01 -070016043L = 20
16044M = 10
16045S = 5
16046
David Riazati9e93a022018-11-28 23:28:59 -080016047def add_nn_module_test(*args, **kwargs):
Shen Li10224432021-08-12 11:39:31 -070016048 no_grad = False if 'no_grad' not in kwargs else kwargs['no_grad']
Elias Ellisonba70cf22018-12-04 12:27:22 -080016049
Shen Li10224432021-08-12 11:39:31 -070016050 if 'desc' in kwargs and 'eval' in kwargs['desc']:
David Riazati9e93a022018-11-28 23:28:59 -080016051 # eval() is not supported, so skip these tests
16052 return
16053
Gary Miguel543b7fb2021-10-14 10:57:26 -070016054 test_name = get_nn_mod_test_name(**kwargs)
David Riazati9e93a022018-11-28 23:28:59 -080016055
David Riazati7408ce22018-12-11 13:49:59 -080016056 @suppress_warnings
David Riazati14ea4bf2018-10-25 13:57:40 -070016057 def do_test(self):
David Riazati9e93a022018-11-28 23:28:59 -080016058 if test_name in EXCLUDE_SCRIPT_MODULES:
16059 return
Shen Li10224432021-08-12 11:39:31 -070016060 if not kwargs.get('check_jit', True):
16061 raise unittest.SkipTest('module test skipped on JIT')
Gregory Chananfa158c42020-09-10 08:19:16 -070016062
Gary Miguel543b7fb2021-10-14 10:57:26 -070016063 module_name = get_nn_module_name_from_kwargs(**kwargs)
16064
Shen Li10224432021-08-12 11:39:31 -070016065 if 'constructor' in kwargs:
16066 nn_module = kwargs['constructor']
David Riazati9e93a022018-11-28 23:28:59 -080016067 else:
Gary Miguel543b7fb2021-10-14 10:57:26 -070016068 nn_module = getattr(torch.nn, module_name)
Elias Ellisonba70cf22018-12-04 12:27:22 -080016069
16070 if "FunctionalModule" in str(nn_module):
16071 return
16072
Shen Li10224432021-08-12 11:39:31 -070016073 if 'constructor_args_fn' in kwargs:
16074 constructor_args = kwargs['constructor_args_fn']()
David Riazatia66669a2018-12-04 18:32:05 -080016075 else:
Shen Li10224432021-08-12 11:39:31 -070016076 constructor_args = kwargs.get('constructor_args', ())
David Riazati14ea4bf2018-10-25 13:57:40 -070016077
Wanchao Liangd6bfc532018-11-20 14:09:27 -080016078 def create_script_module(*args, **kwargs):
Gary Miguel543b7fb2021-10-14 10:57:26 -070016079 """Construct a script module that passes arguments through to self.submodule"""
David Riazati14ea4bf2018-10-25 13:57:40 -070016080 formals, tensors, actuals = get_script_args(args)
16081
Shen Li10224432021-08-12 11:39:31 -070016082 method_args = ', '.join(['self'] + actuals)
16083 call_args_str = ', '.join(actuals)
David Riazati14ea4bf2018-10-25 13:57:40 -070016084 call = "self.submodule({})".format(call_args_str)
16085 script = script_method_template.format(method_args, call)
David Riazati14ea4bf2018-10-25 13:57:40 -070016086
David Riazati53bc5fb2018-11-13 13:46:46 -080016087 submodule_constants = []
Shen Li10224432021-08-12 11:39:31 -070016088 if kwargs.get('is_constant'):
16089 submodule_constants = ['submodule']
David Riazati53bc5fb2018-11-13 13:46:46 -080016090
David Riazati14ea4bf2018-10-25 13:57:40 -070016091 # Create module to use the script method
16092 class TheModule(torch.jit.ScriptModule):
David Riazati53bc5fb2018-11-13 13:46:46 -080016093 __constants__ = submodule_constants
16094
David Riazati14ea4bf2018-10-25 13:57:40 -070016095 def __init__(self):
Xuehai Pan046e88a2023-02-12 22:20:50 +000016096 super().__init__()
David Riazati14ea4bf2018-10-25 13:57:40 -070016097 self.submodule = nn_module(*constructor_args)
davidriazati736bf7b2019-05-29 13:48:34 -070016098
16099 def make_module(script):
16100 module = TheModule()
16101 # check __repr__
16102 str(module)
16103 module.define(script)
16104 return module
16105
Elias Ellison4fae5a62020-03-23 11:53:18 -070016106 module = make_module(script)
16107 self.assertExportImportModule(module, tensors)
16108 create_script_module.last_graph = module.graph
16109 mod = module(*args)
Elias Ellison6d63e9d2018-11-28 19:14:16 -080016110 return mod
Wanchao Liangd6bfc532018-11-20 14:09:27 -080016111
16112 # Construct a normal nn module to stay consistent with create_script_module
16113 # and make use of a single global rng_state in module initialization
16114 def create_nn_module(*args, **kwargs):
16115 module = nn_module(*constructor_args)
David Riazati14ea4bf2018-10-25 13:57:40 -070016116 return module(*args)
16117
David Riazatia66669a2018-12-04 18:32:05 -080016118 # Set up inputs from tuple of sizes or constructor fn
Peter Bell47f0bda2021-01-22 09:34:29 -080016119 dtype = torch.double
Shen Li10224432021-08-12 11:39:31 -070016120 if 'input_fn' in kwargs:
16121 input = kwargs['input_fn']()
Peter Bell47f0bda2021-01-22 09:34:29 -080016122 if isinstance(input, Tensor):
16123 input = (input,)
16124
16125 if all(tensor.is_complex() for tensor in input):
16126 dtype = torch.cdouble
David Riazati9e93a022018-11-28 23:28:59 -080016127 else:
Shen Li10224432021-08-12 11:39:31 -070016128 input = (kwargs['input_size'],)
David Riazatib8da44d2018-12-03 13:58:49 -080016129
Shen Li10224432021-08-12 11:39:31 -070016130 if 'target_size' in kwargs:
16131 input = input + (kwargs['target_size'],)
16132 elif 'target_fn' in kwargs:
David Riazatia66669a2018-12-04 18:32:05 -080016133 if torch.is_tensor(input):
16134 input = (input,)
Shen Li10224432021-08-12 11:39:31 -070016135 input = input + (kwargs['target_fn'](),)
16136 elif 'target' in kwargs:
16137 input = input + (kwargs['target'],)
Gregory Chananfa158c42020-09-10 08:19:16 -070016138
16139 # Extra parameters to forward()
Shen Li10224432021-08-12 11:39:31 -070016140 if 'extra_args' in kwargs:
16141 input = input + kwargs['extra_args']
David Riazatia66669a2018-12-04 18:32:05 -080016142
Peter Bell47f0bda2021-01-22 09:34:29 -080016143 args_variable, kwargs_variable = create_input(input, dtype=dtype)
David Riazati14ea4bf2018-10-25 13:57:40 -070016144 f_args_variable = deepcopy(unpack_variables(args_variable))
David Riazati14ea4bf2018-10-25 13:57:40 -070016145
Michael Dagitses91451362021-06-22 10:09:25 -070016146 # TODO(issue#52052) Neither this nor no_grad should be required
16147 # if check_against_reference() is updated to check gradients
16148 # w.r.t. weights and then only check w.r.t. inputs if any
16149 # inputs require it.
16150 any_requires_grad = any(input.requires_grad for input in f_args_variable)
16151
David Riazatia66669a2018-12-04 18:32:05 -080016152 # Check against Python module as reference
Shen Li10224432021-08-12 11:39:31 -070016153 check_against_reference(self, create_script_module, create_nn_module,
16154 lambda x: x, f_args_variable,
16155 no_grad=no_grad or not any_requires_grad)
David Riazati14ea4bf2018-10-25 13:57:40 -070016156
Shen Li10224432021-08-12 11:39:31 -070016157 if 'slowTest' in kwargs:
davidriazati60642232019-12-30 11:43:04 -080016158 do_test = slowTest(do_test)
16159
David Riazati3270e4d2019-01-03 14:31:09 -080016160 post_add_test(test_name, (), do_test, TestJitGeneratedModule)
David Riazati14ea4bf2018-10-25 13:57:40 -070016161
16162
David Riazati3270e4d2019-01-03 14:31:09 -080016163def post_add_test(test_name, skipTestIf, do_test, test_class):
Shen Li10224432021-08-12 11:39:31 -070016164 assert not hasattr(test_class, test_name), 'Two tests have the same name: ' + test_name
Wanchao Liang52058202018-08-17 11:05:01 -070016165
16166 for skip in skipTestIf:
16167 do_test = skip(do_test)
16168
Meghan Lelef85a27e2020-07-17 11:27:07 -070016169 if not (TEST_WITH_UBSAN and test_name in UBSAN_DISABLED_TESTS):
David Riazati3270e4d2019-01-03 14:31:09 -080016170 setattr(test_class, test_name, do_test)
Wanchao Liang52058202018-08-17 11:05:01 -070016171
Zachary DeVito2f25d1f2018-06-14 11:48:15 -070016172
Ailing Zhang9c875432019-03-31 08:41:46 -070016173def normalize_check_ad(check_ad, name):
16174 # normalized check_ad is 3-element tuple: (bool, List[str], List[str])
16175 if len(check_ad) == 0:
Shen Li10224432021-08-12 11:39:31 -070016176 check_ad = [False, ['aten::' + name], []]
Ailing Zhang9c875432019-03-31 08:41:46 -070016177 elif len(check_ad) == 1:
Shen Li10224432021-08-12 11:39:31 -070016178 check_ad = [check_ad[0], ['aten::' + name], []]
Ailing Zhang9c875432019-03-31 08:41:46 -070016179 elif len(check_ad) == 2:
16180 check_ad = [check_ad[0], check_ad[1], []]
16181 elif len(check_ad) == 3:
16182 check_ad = list(check_ad)
16183 else:
Shen Li10224432021-08-12 11:39:31 -070016184 raise Exception('Invalid check_ad, requires (bool, str|List[str], str|List[str])')
Ailing Zhang9c875432019-03-31 08:41:46 -070016185
16186 check_ad = [[t] if isinstance(t, str) else t for t in check_ad]
16187
16188 return check_ad
16189
16190
Jane Xu32e30032021-10-19 16:52:27 -070016191class TestProducerVersion(TestCase):
Shen Li10224432021-08-12 11:39:31 -070016192
mattipec8006c2020-04-27 10:58:01 -070016193 def test_version(self):
16194 # issue gh-32561
Nikita Shulgaf1ce7f42021-06-03 07:54:18 -070016195 self.assertTrue(torch.__version__.startswith(torch.onnx.producer_version))
mattipec8006c2020-04-27 10:58:01 -070016196
David Riazatia66669a2018-12-04 18:32:05 -080016197for test in module_tests + new_module_tests + additional_module_tests:
16198 add_nn_module_test(**test)
16199
Gregory Chananc8914af2020-09-10 08:19:16 -070016200for test in criterion_tests:
Shen Li10224432021-08-12 11:39:31 -070016201 test['no_grad'] = True
David Riazati9e93a022018-11-28 23:28:59 -080016202 add_nn_module_test(**test)
Wanchao Liang52058202018-08-17 11:05:01 -070016203
Shen Li10224432021-08-12 11:39:31 -070016204if __name__ == '__main__':
Edward Z. Yang2ced9182017-07-17 07:49:10 -070016205 run_tests()
Shen Li10224432021-08-12 11:39:31 -070016206 import jit.test_module_interface
Nikita Shulga47c4dca2020-04-24 17:39:53 -070016207 suite = unittest.findTestCases(jit.test_module_interface)
16208 unittest.TextTestRunner().run(suite)