| # Owner(s): ["module: dynamo"] |
| # flake8: noqa |
| import collections |
| import functools |
| import inspect |
| import itertools |
| import operator |
| import unittest |
| from typing import Any |
| from unittest.mock import patch |
| |
| import torch |
| |
| import torch._dynamo.test_case |
| import torch._dynamo.testing |
| from torch import sub |
| from torch._dynamo.testing import requires_static_shapes |
| from torch._dynamo.utils import same |
| from torch.nn import functional as F |
| |
| tensor_for_import_testing = torch.ones(10, 10) |
| d = torch.ones(10, 10) |
| e = torch.nn.Linear(10, 10) |
| flag = True |
| |
| |
| def constant3(a, b): |
| return a - b + (1.0 + 2) |
| |
| |
| def func_with_default(a, b, some_default_arg=True): |
| if some_default_arg: |
| return a - b |
| |
| |
| def make_test(fn): |
| nargs = len(inspect.signature(fn).parameters) |
| |
| def test_fn(self): |
| return torch._dynamo.testing.standard_test(self, fn=fn, nargs=nargs) |
| |
| return test_fn |
| |
| |
| @torch.jit.script_if_tracing |
| def inline_script_if_tracing(x): |
| return x + 1.2 |
| |
| |
| @torch.jit.ignore |
| def inline_ignore(x): |
| return x + 3.4 |
| |
| |
| @torch.jit.unused |
| def inline_unused(x): |
| return x + 5.6 |
| |
| |
| class FunctionTests(torch._dynamo.test_case.TestCase): |
| @make_test |
| def test_inline_jit_annotations(x): |
| x = inline_script_if_tracing(x) |
| x = inline_ignore(x) |
| x = inline_unused(x) |
| return |
| |
| @make_test |
| def test_add(a, b): |
| return a + b |
| |
| @make_test |
| def test_add_(a, b): |
| a_copy = torch.tensor(a) |
| return a_copy.add_(b, alpha=5.0) |
| |
| @make_test |
| def test_addcdiv(a, b, c): |
| # dynamo decomposes this to avoid a graph break when |
| # the value kwarg is populated |
| return torch.addcdiv(a, b, c, value=5.0) |
| |
| @make_test |
| def test_addcdiv_(a, b, c): |
| a_copy = torch.tensor(a) |
| return a_copy.addcdiv_(b, c, value=5.0) |
| |
| @make_test |
| def test_is_not_null(a, b): |
| if a is not None and b is not None: |
| return a + b |
| |
| @make_test |
| def test_constant1(a, b, c): |
| return a - b * c + 1.0 |
| |
| @make_test |
| def test_constant2(a, b, c): |
| return a - b * c + 1 |
| |
| @make_test |
| def test_constant3(a): |
| b = 1 |
| c = 2 |
| d = 3 |
| return b + c - d + a |
| |
| @make_test |
| def test_constant4(a, b): |
| c = 2 |
| d = 3 |
| if c > d: |
| return a - b |
| return b - a |
| |
| @make_test |
| def test_finfo(a, b): |
| if torch.iinfo(torch.int32).bits == 32: |
| return torch.finfo(a.dtype).min * b |
| |
| @make_test |
| def test_globalfn(a, b): |
| return sub(a, b) |
| |
| @make_test |
| def test_viatorch(a, b): |
| return torch.sub(a, b) |
| |
| @make_test |
| def test_viamethod(a, b): |
| return a.sub(b) |
| |
| @make_test |
| def test_indirect1(a, b): |
| t = a.sub |
| return t(b) |
| |
| @make_test |
| def test_indirect2(a, b): |
| t = a.sub |
| args = (b,) |
| return t(*args) |
| |
| @make_test |
| def test_indirect3(a, b): |
| t = a.sub |
| args = (b,) |
| kwargs = {} |
| return t(*args, **kwargs) |
| |
| @make_test |
| def test_methodcall1(a, b, c): |
| return constant3(a, b) * c |
| |
| @make_test |
| def test_methodcall2(a, b): |
| return constant3(a=b, b=a) + 1 |
| |
| @make_test |
| def test_methodcall3(a, b): |
| return constant3(a, b=1.0) + b |
| |
| @make_test |
| def test_device_constant(a): |
| return a + torch.ones(1, device=torch.device("cpu")) |
| |
| @make_test |
| def test_tuple1(a, b): |
| args = (a, b) |
| return sub(*args) |
| |
| @make_test |
| def test_tuple2(a, b): |
| args = [a, b] |
| return sub(*args) |
| |
| @make_test |
| def test_is_in_onnx_export(x, y): |
| if torch.onnx.is_in_onnx_export(): |
| return x - 1 |
| else: |
| return y + 1 |
| |
| @make_test |
| def test_is_fx_tracing(x, y): |
| if torch.fx._symbolic_trace.is_fx_tracing(): |
| return x - 1 |
| else: |
| return y + 1 |
| |
| @make_test |
| def test_listarg1(a, b): |
| return torch.cat([a, b]) |
| |
| @make_test |
| def test_listarg2(a, b): |
| return torch.cat((a, b), dim=0) |
| |
| @make_test |
| def test_listarg3(a, b): |
| kwargs = {"tensors": (a, b), "dim": 0} |
| return torch.cat(**kwargs) |
| |
| @make_test |
| def test_listarg4(a, b): |
| return torch.cat(tensors=[a, b], dim=0) |
| |
| @make_test |
| def test_listarg5(a, b): |
| args = [(a, b)] |
| kwargs = {"dim": 0} |
| return torch.cat(*args, **kwargs) |
| |
| @make_test |
| def test_slice1(a): |
| return a[5] |
| |
| @make_test |
| def test_slice2(a): |
| return a[:5] |
| |
| @make_test |
| def test_slice3(a): |
| return a[5:] |
| |
| @make_test |
| def test_slice4(a): |
| return a[2:5] |
| |
| @make_test |
| def test_slice5(a): |
| return a[::2] |
| |
| @make_test |
| def test_slice6(a): |
| return torch.unsqueeze(a, 0)[:, 2:] |
| |
| @make_test |
| def test_range1(a): |
| return torch.tensor(range(a.size(0))) |
| |
| @make_test |
| def test_range2(x, y): |
| r = x + y |
| for i in range(x.size(0) + 2): |
| r = r / y |
| return r |
| |
| @make_test |
| def test_unpack1(a): |
| a, b = a[:5], a[5:] |
| return a - b |
| |
| @make_test |
| def test_unpack2(a): |
| packed = [a[:5], a[5:]] |
| a, b = packed |
| return a - b |
| |
| @make_test |
| def test_unpack3(a): |
| packed = (a[:5], a[5:]) |
| a, b = packed |
| return a - b |
| |
| @make_test |
| def test_fn_with_self_set(a, b): |
| # avg_pool2d is an odd one with __self__ set |
| return F.avg_pool2d( |
| torch.unsqueeze(a, 0) * torch.unsqueeze(b, 1), kernel_size=2, padding=1 |
| ) |
| |
| @make_test |
| def test_return_tuple1(a, b): |
| return (a - b, b - a, a, b) |
| |
| @make_test |
| def test_globalvar(a, b): |
| return a - b + d |
| |
| @make_test |
| def test_globalmodule(x): |
| return e(x) |
| |
| @make_test |
| def test_inline_with_default(a, b, c): |
| return func_with_default(a, b) * c |
| |
| @make_test |
| def test_inner_function(x): |
| def fn(x): |
| return torch.add(x, x) |
| |
| return fn(x) |
| |
| @make_test |
| def test_transpose_for_scores(x): |
| new_x_shape = x.size()[:-1] + (2, 5) |
| x = x.view(*new_x_shape) |
| return x.permute(0, 2, 1) |
| |
| @make_test |
| def test_return_tuple2(x): |
| return (torch.add(x, x), x) |
| |
| @make_test |
| def test_load_global_bool(x): |
| if flag: |
| return torch.add(x, x) |
| else: |
| return x |
| |
| @make_test |
| def test_len_tensor(x): |
| z = len(x) |
| return torch.add(x, z) |
| |
| @make_test |
| def test_len_constant_list(x): |
| z = len([1, 2, 3]) |
| return torch.add(x, z) |
| |
| @make_test |
| def test_len_constant_dict(x): |
| z = len({"foo": "bar"}) |
| return torch.add(x, z) |
| |
| @make_test |
| def test_dict_copy(x): |
| z = dict({"foo": x + 1}) |
| return z |
| |
| @make_test |
| def test_len_constant_misc_iterables(x): |
| a = len((1, 2, 3)) |
| b = len("test str") |
| c = a + b |
| return torch.add(x, c) |
| |
| @make_test |
| def test_float(x): |
| y = float(1.2) |
| y += float("1.2") |
| return torch.add(x, y) |
| |
| @make_test |
| def test_dtype(x): |
| if x.dtype == torch.float32: |
| return x + 1 |
| |
| @make_test |
| def test_get_default_dtype(x): |
| if x.dtype == torch.get_default_dtype(): |
| return x + 1 |
| else: |
| return x - 1 |
| |
| @make_test |
| def test_device(x): |
| if not x.is_cuda: |
| return x + 1 |
| |
| @make_test |
| def test_tensor_type(a, b): |
| m = a.to(torch.float16) |
| return b.type(m.type()) |
| |
| @unittest.skipIf(not torch.cuda.is_available(), "requires cuda") |
| @make_test |
| def test_tensor_type2(a, b): |
| m = a.to("cuda") |
| return m + b.type(m.type()) |
| |
| @make_test |
| def test_ndim(x): |
| if x.ndim == 2 and x.ndimension() == 2 and x.dim() == 2: |
| return x + 1 |
| |
| @make_test |
| def test_T(x): |
| return torch.ones_like(x.T) |
| |
| @make_test |
| def test_is_sparse(x): |
| if not x.is_sparse: |
| return x + 1 |
| |
| @requires_static_shapes |
| @make_test |
| def test_shape1(x): |
| if x.shape[0] == 10: |
| return x + 1 |
| |
| @requires_static_shapes |
| @make_test |
| def test_shape2(x): |
| if x.size(1) == 10: |
| return x + 1 |
| |
| @make_test |
| def test_del(a, b): |
| c = a + 1 |
| d = c + 2 |
| del c, a |
| return b + d |
| |
| @requires_static_shapes |
| @make_test |
| def test_chunks1(x): |
| chunk_size = 5 |
| assert x.shape[0] % chunk_size == 0 |
| assert x.shape[0] // chunk_size == 2 |
| return x[:chunk_size] - x[chunk_size:] |
| |
| @make_test |
| def test_import1(x, y): |
| import torch |
| from torch import sub |
| |
| return sub(torch.add(x, y), y) |
| |
| @make_test |
| def test_return_dict(x, y): |
| z = [x + y, y, False] |
| return {"x": x, "z": z, "a": x, "b": z, "c": x} |
| |
| @make_test |
| def test_return_dict2(x, y): |
| tmp = {"x": x} |
| tmp["z"] = [x + y, y] |
| tmp["y"] = y |
| tmp["z"].append(False) |
| return tmp |
| |
| @make_test |
| def test_funcdef_closure(x, y): |
| x = x + y + 1.0 |
| |
| def inner(z): |
| nonlocal x, y |
| y = x + z + 20.0 |
| x = y + z + 10.0 |
| |
| inner(2.0) |
| inner(3.0) |
| |
| return x, y |
| |
| @make_test |
| def test_module_constant(x, y): |
| r = x + y |
| for i in range(torch._dynamo.testing.three): |
| r = r / y |
| return r |
| |
| @make_test |
| def test_inline_softmax(x, y): |
| # This is common in sme huggingface models |
| return torch.nn.Softmax(dim=-1)(x + y * 2) |
| |
| @make_test |
| def test_dtype_compare(a, b): |
| if a.dtype == torch.float16: |
| return a + 10 |
| if a.dtype == torch.float32: |
| return a - b * 32 |
| |
| @make_test |
| def test_build_list_unpack(a, b): |
| it1 = (x + 1 for x in (a, b)) |
| it2 = (x - 1 for x in (a, b)) |
| return torch.cat([*it1, *it2], dim=-1) |
| |
| @make_test |
| def test_tensor_len(a, b): |
| return a + b + len(a) + b.__len__() |
| |
| @make_test |
| def test_pop(a, b): |
| ll = [a, b] |
| ll.append(a + 1) |
| ll.extend( |
| [ |
| b + 2, |
| a + b, |
| ] |
| ) |
| ll.pop(-1) |
| ll.pop(0) |
| ll.pop() |
| v1, v2 = ll |
| return v1 - v2 |
| |
| @make_test |
| def test_list_convert(a, b): |
| ll = [a + 2, b] |
| ll = tuple(ll) |
| tmp = b + 3 |
| ll = list(ll) |
| v1, v2 = ll |
| return v1 - v2 + tmp |
| |
| @make_test |
| def test_list_add(a, b): |
| l1 = (a, b) |
| l2 = () # being a LOAD_CONST in the bytecode |
| l3 = l1 + l2 |
| return l3[0] + l3[1] |
| |
| @make_test |
| def test_startswith(a, b): |
| x = a + b |
| if "foobar".startswith("foo") and "test" in constant3.__module__: |
| x = x + 1 |
| return x |
| |
| @make_test |
| def test_dict_ops(a, b): |
| tmp = {"a": a + 1, "b": b + 2} |
| v = tmp.pop("b") + tmp.get("a") + tmp.get("missing", 3) + tmp.pop("missing", 4) |
| tmp.update({"d": 3}) |
| tmp["c"] = v + tmp["d"] |
| if "c" in tmp and "missing" not in tmp: |
| return tmp["c"] - tmp["a"] + len(tmp) |
| |
| def test_dict_param_keys(self): |
| a_param = torch.nn.Parameter(torch.ones([4, 4])) |
| |
| def fn(a): |
| tmp = {"a": a, a_param: 3} |
| return tmp["a"] + tmp[a_param] |
| |
| test = make_test(fn) |
| test(self) |
| |
| def test_default_dict(self): |
| dd = collections.defaultdict(dict) |
| param = torch.nn.Parameter(torch.ones([2, 2])) |
| |
| def fn(x): |
| dd["a"] = x + 1 |
| dd[param] = 123 |
| dd["c"] = x * 2 |
| return dd["b"], dd |
| |
| test = make_test(fn) |
| test(self) |
| |
| @make_test |
| def test_min_max(a, b): |
| c = a + b |
| a = a.sum() |
| b = b.sum() |
| a = min(max(a, 0), 1) |
| b = max(0, min(1, b)) |
| return max(a, b) - min(a, b) + c |
| |
| @make_test |
| def test_map_sum(a, b, c, d): |
| return sum(map(lambda x: x + 1, [a, b, c, d])) |
| |
| @make_test |
| def test_reduce(a, b, c, d): |
| return functools.reduce(operator.add, [a, b, c, d]) |
| |
| @make_test |
| def test_tuple_contains(a, b): |
| v1 = "a" |
| v2 = "b" |
| v3 = "c" |
| vals1 = (v1, v2, v3) |
| vals2 = ("d", "e", "f") |
| if "a" in vals1 and "b" not in vals2: |
| return a + b |
| return a - b |
| |
| @make_test |
| def test_tuple_iadd(a, b): |
| output = (a, b) |
| output += (a + b, a - b) |
| return output |
| |
| @make_test |
| def test_unpack_ex1(x): |
| output = (x, x + 1, x + 2, x + 3) |
| a, b, *cd = output |
| return a - b / cd[0] |
| |
| @make_test |
| def test_unpack_ex2(x): |
| output = (x, x + 1, x + 2, x + 3) |
| *ab, c, d = output |
| return c - d / ab[0] |
| |
| @make_test |
| def test_unpack_ex3(x): |
| output = (x, x + 1, x + 2, x + 3) |
| a, *bc, d = output |
| return a - d / bc[0] |
| |
| @make_test |
| def test_const_tuple_add1(x): |
| output = (x, x + 1, x + 2, x + 3) |
| output = () + output + () |
| return output[2] + output[3] |
| |
| @make_test |
| def test_const_tuple_add2(x): |
| output = (x, x + 1, x + 2, x + 3) |
| output = (None,) + output + (None,) |
| return output[2] + output[3] |
| |
| @make_test |
| def test_list_truth(a, b): |
| tmp = [1, 2, 3] |
| if tmp: |
| return a + b |
| else: |
| return a - b |
| |
| @make_test |
| def test_list_reversed(a, b): |
| tmp = [a + 1, a + 2, a + 3] |
| return a + b + next(iter(reversed(tmp))) |
| |
| @make_test |
| def test_list_clear(a, b): |
| tmp = [a + 1, a + 2] |
| tmp.clear() |
| tmp.append(a + b) |
| return tmp |
| |
| @make_test |
| def test_islice_chain(a, b): |
| tmp1 = [a + 1, a + 2] |
| tmp2 = [a + 3, a + 4] |
| a, b = list(itertools.islice(itertools.chain(tmp1, tmp2), 1, 3)) |
| c = next(itertools.islice(tmp1, 1, None)) |
| return a - b / c |
| |
| @make_test |
| def test_is_quantized(a, b): |
| if not a.is_quantized: |
| return a + b |
| |
| @make_test |
| def test_fstrings1(a, b): |
| x = 1.229 |
| tmp = f"{x:.2f} bar" |
| if tmp.startswith("1.23"): |
| return a + b |
| |
| @requires_static_shapes |
| @make_test |
| def test_fstrings2(x): |
| tmp = f"{x.shape[0]} bar" |
| if tmp.startswith("10"): |
| return x + 1 |
| |
| @make_test |
| def test_fstrings3(x): |
| tmp = f"{x.__class__.__name__} foo" |
| if tmp.startswith("Tensor"): |
| return x + 1 |
| |
| @requires_static_shapes |
| @make_test |
| def test_tensor_new_with_size(x): |
| y = torch.rand(5, 8) |
| z = x.new(y.size()) |
| assert z.size() == y.size() |
| |
| @requires_static_shapes |
| @make_test |
| def test_tensor_new_with_shape(x): |
| y = torch.rand(5, 8) |
| z = x.new(y.shape) |
| assert z.size() == y.size() |
| |
| @make_test |
| def test_jit_annotate(x): |
| y = torch.jit.annotate(Any, x + 1) |
| return y + 2 |
| |
| @requires_static_shapes |
| @make_test |
| def test_is_contiguous_memory_format(tensor): |
| if torch.jit.is_scripting(): |
| return None |
| elif tensor.is_contiguous(memory_format=torch.contiguous_format): |
| return tensor + 1 |
| |
| @make_test |
| def test_list_slice_assignment(x): |
| m = [1, 2, 3, 4] |
| m[1:] = [6] * (len(m) - 1) |
| return x + 1 |
| |
| @make_test |
| def test_distributed_is_available(x): |
| if torch.distributed.is_available(): |
| return x + 1 |
| else: |
| return x - 1 |
| |
| @unittest.skipIf( |
| not torch.distributed.is_available(), "requires distributed package" |
| ) |
| @make_test |
| def test_distributed_is_initialized(x): |
| if torch.distributed.is_initialized(): |
| return x + 1 |
| else: |
| return x - 1 |
| |
| # # This is to test the new syntax for pattern matching |
| # # ("match ... case ...") added on python 3.10. |
| # # Uncomment these test cases if you run on 3.10+ |
| # @make_test |
| # def test_match_sequence(a): |
| # point = (5, 8) |
| # match point: |
| # case (0, 0): |
| # return a |
| # case (0, y): |
| # return a - y |
| # case (x, 0): |
| # return a + x |
| # case (x, y): |
| # return a + x - y |
| |
| # @make_test |
| # def test_match_mapping_and_match_keys(x): |
| # param = {"a": 0.5} |
| # match param: |
| # case {"a": param}: |
| # return x * param |
| # case {"b": param}: |
| # return x / param |
| |
| |
| def global_func_with_default_tensor_args( |
| x=torch.zeros((2, 2)), *, kw_x=torch.zeros((1, 2)) |
| ): |
| x.add_(1) |
| kw_x.add_(1) |
| return x, kw_x |
| |
| |
| class ModuleWithDefaultTensorArgsMethod(torch.nn.Module): |
| def __init__(self): |
| super().__init__() |
| |
| def forward(self, x=torch.zeros((2, 2)), *, kw_x=torch.zeros((1, 2))): |
| x.add_(1) |
| kw_x.add_(1) |
| return x, kw_x |
| |
| |
| class WrapperModule(torch.nn.Module): |
| def __init__(self): |
| super().__init__() |
| self.m = ModuleWithDefaultTensorArgsMethod() |
| |
| def forward(self): |
| return self.m() |
| |
| |
| class DefaultsTests(torch._dynamo.test_case.TestCase): |
| def test_func_default_tensor_args(self): |
| """ |
| Tests that we indeed reference (and mutate) "the one" default tensor arg |
| stored on the globally allocated function object, both from the orig and |
| compiled function |
| """ |
| |
| def func(): |
| return global_func_with_default_tensor_args() |
| |
| cnts = torch._dynamo.testing.CompileCounter() |
| compiled_func = torch.compile(func, backend=cnts) |
| for i in range(4): |
| if i % 2 == 0: |
| x, kw_x = func() |
| else: |
| x, kw_x = compiled_func() |
| # the inner func mutates += 1 each call |
| self.assertTrue(same(x, torch.ones_like(x) + i)) |
| self.assertTrue(same(kw_x, torch.ones_like(kw_x) + i)) |
| # Calling compiled_func twice does not recompile |
| self.assertEqual(cnts.frame_count, 1) |
| self.assertEqual(cnts.op_count, 2) |
| |
| # But with a change to the guarded default tensor, we do recompile |
| with patch.object( |
| global_func_with_default_tensor_args, |
| "__defaults__", |
| (torch.ones((3, 4, 5)),), |
| ): |
| x, kw_x = compiled_func() |
| self.assertEqual(cnts.frame_count, 2) |
| self.assertEqual(cnts.op_count, 4) |
| |
| with patch.object( |
| global_func_with_default_tensor_args, |
| "__kwdefaults__", |
| {"kw_x": torch.ones((3, 4, 5))}, |
| ): |
| x, kw_x = compiled_func() |
| self.assertEqual(cnts.frame_count, 3) |
| self.assertEqual(cnts.op_count, 6) |
| |
| def test_meth_default_tensor_args(self): |
| """ |
| Tests that we indeed reference (and mutate) "the one" default tensor arg |
| stored on the globally allocated function object, both from the orig and |
| compiled function |
| """ |
| mod = WrapperModule() |
| cnts = torch._dynamo.testing.CompileCounter() |
| compiled_mod = torch.compile(mod, backend=cnts) |
| for i in range(4): |
| if i % 2 == 0: |
| x, kw_x = mod() |
| else: |
| x, kw_x = compiled_mod() |
| # the inner func mutates += 1 each call |
| self.assertTrue(same(x, torch.ones_like(x) + i)) |
| self.assertTrue(same(kw_x, torch.ones_like(kw_x) + i)) |
| # Calling compiled_func twice does not recompile |
| self.assertEqual(cnts.frame_count, 1) |
| self.assertEqual(cnts.op_count, 2) |
| |
| # But with a change to the guarded default tensor, we do recompile |
| with patch.object( |
| ModuleWithDefaultTensorArgsMethod.forward, |
| "__defaults__", |
| (torch.ones((3, 4, 5)),), |
| ): |
| x, kw_x = compiled_mod() |
| self.assertEqual(cnts.frame_count, 2) |
| self.assertEqual(cnts.op_count, 4) |
| |
| with patch.object( |
| ModuleWithDefaultTensorArgsMethod.forward, |
| "__kwdefaults__", |
| {"kw_x": torch.ones((3, 4, 5))}, |
| ): |
| x, kw_x = compiled_mod() |
| self.assertEqual(cnts.frame_count, 3) |
| self.assertEqual(cnts.op_count, 6) |
| |
| |
| if __name__ == "__main__": |
| from torch._dynamo.test_case import run_tests |
| |
| run_tests() |