blob: 2fdc4e81f85abd0318d2f55f9e6d84d3d810b73a [file] [log] [blame]
# Owner(s): ["module: dynamo"]
# flake8: noqa
import collections
import functools
import inspect
import itertools
import operator
import unittest
from typing import Any
from unittest.mock import patch
import torch
import torch._dynamo.test_case
import torch._dynamo.testing
from torch import sub
from torch._dynamo.testing import requires_static_shapes
from torch._dynamo.utils import same
from torch.nn import functional as F
tensor_for_import_testing = torch.ones(10, 10)
d = torch.ones(10, 10)
e = torch.nn.Linear(10, 10)
flag = True
def constant3(a, b):
return a - b + (1.0 + 2)
def func_with_default(a, b, some_default_arg=True):
if some_default_arg:
return a - b
def make_test(fn):
nargs = len(inspect.signature(fn).parameters)
def test_fn(self):
return torch._dynamo.testing.standard_test(self, fn=fn, nargs=nargs)
return test_fn
@torch.jit.script_if_tracing
def inline_script_if_tracing(x):
return x + 1.2
@torch.jit.ignore
def inline_ignore(x):
return x + 3.4
@torch.jit.unused
def inline_unused(x):
return x + 5.6
class FunctionTests(torch._dynamo.test_case.TestCase):
@make_test
def test_inline_jit_annotations(x):
x = inline_script_if_tracing(x)
x = inline_ignore(x)
x = inline_unused(x)
return
@make_test
def test_add(a, b):
return a + b
@make_test
def test_add_(a, b):
a_copy = torch.tensor(a)
return a_copy.add_(b, alpha=5.0)
@make_test
def test_addcdiv(a, b, c):
# dynamo decomposes this to avoid a graph break when
# the value kwarg is populated
return torch.addcdiv(a, b, c, value=5.0)
@make_test
def test_addcdiv_(a, b, c):
a_copy = torch.tensor(a)
return a_copy.addcdiv_(b, c, value=5.0)
@make_test
def test_is_not_null(a, b):
if a is not None and b is not None:
return a + b
@make_test
def test_constant1(a, b, c):
return a - b * c + 1.0
@make_test
def test_constant2(a, b, c):
return a - b * c + 1
@make_test
def test_constant3(a):
b = 1
c = 2
d = 3
return b + c - d + a
@make_test
def test_constant4(a, b):
c = 2
d = 3
if c > d:
return a - b
return b - a
@make_test
def test_finfo(a, b):
if torch.iinfo(torch.int32).bits == 32:
return torch.finfo(a.dtype).min * b
@make_test
def test_globalfn(a, b):
return sub(a, b)
@make_test
def test_viatorch(a, b):
return torch.sub(a, b)
@make_test
def test_viamethod(a, b):
return a.sub(b)
@make_test
def test_indirect1(a, b):
t = a.sub
return t(b)
@make_test
def test_indirect2(a, b):
t = a.sub
args = (b,)
return t(*args)
@make_test
def test_indirect3(a, b):
t = a.sub
args = (b,)
kwargs = {}
return t(*args, **kwargs)
@make_test
def test_methodcall1(a, b, c):
return constant3(a, b) * c
@make_test
def test_methodcall2(a, b):
return constant3(a=b, b=a) + 1
@make_test
def test_methodcall3(a, b):
return constant3(a, b=1.0) + b
@make_test
def test_device_constant(a):
return a + torch.ones(1, device=torch.device("cpu"))
@make_test
def test_tuple1(a, b):
args = (a, b)
return sub(*args)
@make_test
def test_tuple2(a, b):
args = [a, b]
return sub(*args)
@make_test
def test_is_in_onnx_export(x, y):
if torch.onnx.is_in_onnx_export():
return x - 1
else:
return y + 1
@make_test
def test_is_fx_tracing(x, y):
if torch.fx._symbolic_trace.is_fx_tracing():
return x - 1
else:
return y + 1
@make_test
def test_listarg1(a, b):
return torch.cat([a, b])
@make_test
def test_listarg2(a, b):
return torch.cat((a, b), dim=0)
@make_test
def test_listarg3(a, b):
kwargs = {"tensors": (a, b), "dim": 0}
return torch.cat(**kwargs)
@make_test
def test_listarg4(a, b):
return torch.cat(tensors=[a, b], dim=0)
@make_test
def test_listarg5(a, b):
args = [(a, b)]
kwargs = {"dim": 0}
return torch.cat(*args, **kwargs)
@make_test
def test_slice1(a):
return a[5]
@make_test
def test_slice2(a):
return a[:5]
@make_test
def test_slice3(a):
return a[5:]
@make_test
def test_slice4(a):
return a[2:5]
@make_test
def test_slice5(a):
return a[::2]
@make_test
def test_slice6(a):
return torch.unsqueeze(a, 0)[:, 2:]
@make_test
def test_range1(a):
return torch.tensor(range(a.size(0)))
@make_test
def test_range2(x, y):
r = x + y
for i in range(x.size(0) + 2):
r = r / y
return r
@make_test
def test_unpack1(a):
a, b = a[:5], a[5:]
return a - b
@make_test
def test_unpack2(a):
packed = [a[:5], a[5:]]
a, b = packed
return a - b
@make_test
def test_unpack3(a):
packed = (a[:5], a[5:])
a, b = packed
return a - b
@make_test
def test_fn_with_self_set(a, b):
# avg_pool2d is an odd one with __self__ set
return F.avg_pool2d(
torch.unsqueeze(a, 0) * torch.unsqueeze(b, 1), kernel_size=2, padding=1
)
@make_test
def test_return_tuple1(a, b):
return (a - b, b - a, a, b)
@make_test
def test_globalvar(a, b):
return a - b + d
@make_test
def test_globalmodule(x):
return e(x)
@make_test
def test_inline_with_default(a, b, c):
return func_with_default(a, b) * c
@make_test
def test_inner_function(x):
def fn(x):
return torch.add(x, x)
return fn(x)
@make_test
def test_transpose_for_scores(x):
new_x_shape = x.size()[:-1] + (2, 5)
x = x.view(*new_x_shape)
return x.permute(0, 2, 1)
@make_test
def test_return_tuple2(x):
return (torch.add(x, x), x)
@make_test
def test_load_global_bool(x):
if flag:
return torch.add(x, x)
else:
return x
@make_test
def test_len_tensor(x):
z = len(x)
return torch.add(x, z)
@make_test
def test_len_constant_list(x):
z = len([1, 2, 3])
return torch.add(x, z)
@make_test
def test_len_constant_dict(x):
z = len({"foo": "bar"})
return torch.add(x, z)
@make_test
def test_dict_copy(x):
z = dict({"foo": x + 1})
return z
@make_test
def test_len_constant_misc_iterables(x):
a = len((1, 2, 3))
b = len("test str")
c = a + b
return torch.add(x, c)
@make_test
def test_float(x):
y = float(1.2)
y += float("1.2")
return torch.add(x, y)
@make_test
def test_dtype(x):
if x.dtype == torch.float32:
return x + 1
@make_test
def test_get_default_dtype(x):
if x.dtype == torch.get_default_dtype():
return x + 1
else:
return x - 1
@make_test
def test_device(x):
if not x.is_cuda:
return x + 1
@make_test
def test_tensor_type(a, b):
m = a.to(torch.float16)
return b.type(m.type())
@unittest.skipIf(not torch.cuda.is_available(), "requires cuda")
@make_test
def test_tensor_type2(a, b):
m = a.to("cuda")
return m + b.type(m.type())
@make_test
def test_ndim(x):
if x.ndim == 2 and x.ndimension() == 2 and x.dim() == 2:
return x + 1
@make_test
def test_T(x):
return torch.ones_like(x.T)
@make_test
def test_is_sparse(x):
if not x.is_sparse:
return x + 1
@requires_static_shapes
@make_test
def test_shape1(x):
if x.shape[0] == 10:
return x + 1
@requires_static_shapes
@make_test
def test_shape2(x):
if x.size(1) == 10:
return x + 1
@make_test
def test_del(a, b):
c = a + 1
d = c + 2
del c, a
return b + d
@requires_static_shapes
@make_test
def test_chunks1(x):
chunk_size = 5
assert x.shape[0] % chunk_size == 0
assert x.shape[0] // chunk_size == 2
return x[:chunk_size] - x[chunk_size:]
@make_test
def test_import1(x, y):
import torch
from torch import sub
return sub(torch.add(x, y), y)
@make_test
def test_return_dict(x, y):
z = [x + y, y, False]
return {"x": x, "z": z, "a": x, "b": z, "c": x}
@make_test
def test_return_dict2(x, y):
tmp = {"x": x}
tmp["z"] = [x + y, y]
tmp["y"] = y
tmp["z"].append(False)
return tmp
@make_test
def test_funcdef_closure(x, y):
x = x + y + 1.0
def inner(z):
nonlocal x, y
y = x + z + 20.0
x = y + z + 10.0
inner(2.0)
inner(3.0)
return x, y
@make_test
def test_module_constant(x, y):
r = x + y
for i in range(torch._dynamo.testing.three):
r = r / y
return r
@make_test
def test_inline_softmax(x, y):
# This is common in sme huggingface models
return torch.nn.Softmax(dim=-1)(x + y * 2)
@make_test
def test_dtype_compare(a, b):
if a.dtype == torch.float16:
return a + 10
if a.dtype == torch.float32:
return a - b * 32
@make_test
def test_build_list_unpack(a, b):
it1 = (x + 1 for x in (a, b))
it2 = (x - 1 for x in (a, b))
return torch.cat([*it1, *it2], dim=-1)
@make_test
def test_tensor_len(a, b):
return a + b + len(a) + b.__len__()
@make_test
def test_pop(a, b):
ll = [a, b]
ll.append(a + 1)
ll.extend(
[
b + 2,
a + b,
]
)
ll.pop(-1)
ll.pop(0)
ll.pop()
v1, v2 = ll
return v1 - v2
@make_test
def test_list_convert(a, b):
ll = [a + 2, b]
ll = tuple(ll)
tmp = b + 3
ll = list(ll)
v1, v2 = ll
return v1 - v2 + tmp
@make_test
def test_list_add(a, b):
l1 = (a, b)
l2 = () # being a LOAD_CONST in the bytecode
l3 = l1 + l2
return l3[0] + l3[1]
@make_test
def test_startswith(a, b):
x = a + b
if "foobar".startswith("foo") and "test" in constant3.__module__:
x = x + 1
return x
@make_test
def test_dict_ops(a, b):
tmp = {"a": a + 1, "b": b + 2}
v = tmp.pop("b") + tmp.get("a") + tmp.get("missing", 3) + tmp.pop("missing", 4)
tmp.update({"d": 3})
tmp["c"] = v + tmp["d"]
if "c" in tmp and "missing" not in tmp:
return tmp["c"] - tmp["a"] + len(tmp)
def test_dict_param_keys(self):
a_param = torch.nn.Parameter(torch.ones([4, 4]))
def fn(a):
tmp = {"a": a, a_param: 3}
return tmp["a"] + tmp[a_param]
test = make_test(fn)
test(self)
def test_default_dict(self):
dd = collections.defaultdict(dict)
param = torch.nn.Parameter(torch.ones([2, 2]))
def fn(x):
dd["a"] = x + 1
dd[param] = 123
dd["c"] = x * 2
return dd["b"], dd
test = make_test(fn)
test(self)
@make_test
def test_min_max(a, b):
c = a + b
a = a.sum()
b = b.sum()
a = min(max(a, 0), 1)
b = max(0, min(1, b))
return max(a, b) - min(a, b) + c
@make_test
def test_map_sum(a, b, c, d):
return sum(map(lambda x: x + 1, [a, b, c, d]))
@make_test
def test_reduce(a, b, c, d):
return functools.reduce(operator.add, [a, b, c, d])
@make_test
def test_tuple_contains(a, b):
v1 = "a"
v2 = "b"
v3 = "c"
vals1 = (v1, v2, v3)
vals2 = ("d", "e", "f")
if "a" in vals1 and "b" not in vals2:
return a + b
return a - b
@make_test
def test_tuple_iadd(a, b):
output = (a, b)
output += (a + b, a - b)
return output
@make_test
def test_unpack_ex1(x):
output = (x, x + 1, x + 2, x + 3)
a, b, *cd = output
return a - b / cd[0]
@make_test
def test_unpack_ex2(x):
output = (x, x + 1, x + 2, x + 3)
*ab, c, d = output
return c - d / ab[0]
@make_test
def test_unpack_ex3(x):
output = (x, x + 1, x + 2, x + 3)
a, *bc, d = output
return a - d / bc[0]
@make_test
def test_const_tuple_add1(x):
output = (x, x + 1, x + 2, x + 3)
output = () + output + ()
return output[2] + output[3]
@make_test
def test_const_tuple_add2(x):
output = (x, x + 1, x + 2, x + 3)
output = (None,) + output + (None,)
return output[2] + output[3]
@make_test
def test_list_truth(a, b):
tmp = [1, 2, 3]
if tmp:
return a + b
else:
return a - b
@make_test
def test_list_reversed(a, b):
tmp = [a + 1, a + 2, a + 3]
return a + b + next(iter(reversed(tmp)))
@make_test
def test_list_clear(a, b):
tmp = [a + 1, a + 2]
tmp.clear()
tmp.append(a + b)
return tmp
@make_test
def test_islice_chain(a, b):
tmp1 = [a + 1, a + 2]
tmp2 = [a + 3, a + 4]
a, b = list(itertools.islice(itertools.chain(tmp1, tmp2), 1, 3))
c = next(itertools.islice(tmp1, 1, None))
return a - b / c
@make_test
def test_is_quantized(a, b):
if not a.is_quantized:
return a + b
@make_test
def test_fstrings1(a, b):
x = 1.229
tmp = f"{x:.2f} bar"
if tmp.startswith("1.23"):
return a + b
@requires_static_shapes
@make_test
def test_fstrings2(x):
tmp = f"{x.shape[0]} bar"
if tmp.startswith("10"):
return x + 1
@make_test
def test_fstrings3(x):
tmp = f"{x.__class__.__name__} foo"
if tmp.startswith("Tensor"):
return x + 1
@requires_static_shapes
@make_test
def test_tensor_new_with_size(x):
y = torch.rand(5, 8)
z = x.new(y.size())
assert z.size() == y.size()
@requires_static_shapes
@make_test
def test_tensor_new_with_shape(x):
y = torch.rand(5, 8)
z = x.new(y.shape)
assert z.size() == y.size()
@make_test
def test_jit_annotate(x):
y = torch.jit.annotate(Any, x + 1)
return y + 2
@requires_static_shapes
@make_test
def test_is_contiguous_memory_format(tensor):
if torch.jit.is_scripting():
return None
elif tensor.is_contiguous(memory_format=torch.contiguous_format):
return tensor + 1
@make_test
def test_list_slice_assignment(x):
m = [1, 2, 3, 4]
m[1:] = [6] * (len(m) - 1)
return x + 1
@make_test
def test_distributed_is_available(x):
if torch.distributed.is_available():
return x + 1
else:
return x - 1
@unittest.skipIf(
not torch.distributed.is_available(), "requires distributed package"
)
@make_test
def test_distributed_is_initialized(x):
if torch.distributed.is_initialized():
return x + 1
else:
return x - 1
# # This is to test the new syntax for pattern matching
# # ("match ... case ...") added on python 3.10.
# # Uncomment these test cases if you run on 3.10+
# @make_test
# def test_match_sequence(a):
# point = (5, 8)
# match point:
# case (0, 0):
# return a
# case (0, y):
# return a - y
# case (x, 0):
# return a + x
# case (x, y):
# return a + x - y
# @make_test
# def test_match_mapping_and_match_keys(x):
# param = {"a": 0.5}
# match param:
# case {"a": param}:
# return x * param
# case {"b": param}:
# return x / param
def global_func_with_default_tensor_args(
x=torch.zeros((2, 2)), *, kw_x=torch.zeros((1, 2))
):
x.add_(1)
kw_x.add_(1)
return x, kw_x
class ModuleWithDefaultTensorArgsMethod(torch.nn.Module):
def __init__(self):
super().__init__()
def forward(self, x=torch.zeros((2, 2)), *, kw_x=torch.zeros((1, 2))):
x.add_(1)
kw_x.add_(1)
return x, kw_x
class WrapperModule(torch.nn.Module):
def __init__(self):
super().__init__()
self.m = ModuleWithDefaultTensorArgsMethod()
def forward(self):
return self.m()
class DefaultsTests(torch._dynamo.test_case.TestCase):
def test_func_default_tensor_args(self):
"""
Tests that we indeed reference (and mutate) "the one" default tensor arg
stored on the globally allocated function object, both from the orig and
compiled function
"""
def func():
return global_func_with_default_tensor_args()
cnts = torch._dynamo.testing.CompileCounter()
compiled_func = torch.compile(func, backend=cnts)
for i in range(4):
if i % 2 == 0:
x, kw_x = func()
else:
x, kw_x = compiled_func()
# the inner func mutates += 1 each call
self.assertTrue(same(x, torch.ones_like(x) + i))
self.assertTrue(same(kw_x, torch.ones_like(kw_x) + i))
# Calling compiled_func twice does not recompile
self.assertEqual(cnts.frame_count, 1)
self.assertEqual(cnts.op_count, 2)
# But with a change to the guarded default tensor, we do recompile
with patch.object(
global_func_with_default_tensor_args,
"__defaults__",
(torch.ones((3, 4, 5)),),
):
x, kw_x = compiled_func()
self.assertEqual(cnts.frame_count, 2)
self.assertEqual(cnts.op_count, 4)
with patch.object(
global_func_with_default_tensor_args,
"__kwdefaults__",
{"kw_x": torch.ones((3, 4, 5))},
):
x, kw_x = compiled_func()
self.assertEqual(cnts.frame_count, 3)
self.assertEqual(cnts.op_count, 6)
def test_meth_default_tensor_args(self):
"""
Tests that we indeed reference (and mutate) "the one" default tensor arg
stored on the globally allocated function object, both from the orig and
compiled function
"""
mod = WrapperModule()
cnts = torch._dynamo.testing.CompileCounter()
compiled_mod = torch.compile(mod, backend=cnts)
for i in range(4):
if i % 2 == 0:
x, kw_x = mod()
else:
x, kw_x = compiled_mod()
# the inner func mutates += 1 each call
self.assertTrue(same(x, torch.ones_like(x) + i))
self.assertTrue(same(kw_x, torch.ones_like(kw_x) + i))
# Calling compiled_func twice does not recompile
self.assertEqual(cnts.frame_count, 1)
self.assertEqual(cnts.op_count, 2)
# But with a change to the guarded default tensor, we do recompile
with patch.object(
ModuleWithDefaultTensorArgsMethod.forward,
"__defaults__",
(torch.ones((3, 4, 5)),),
):
x, kw_x = compiled_mod()
self.assertEqual(cnts.frame_count, 2)
self.assertEqual(cnts.op_count, 4)
with patch.object(
ModuleWithDefaultTensorArgsMethod.forward,
"__kwdefaults__",
{"kw_x": torch.ones((3, 4, 5))},
):
x, kw_x = compiled_mod()
self.assertEqual(cnts.frame_count, 3)
self.assertEqual(cnts.op_count, 6)
if __name__ == "__main__":
from torch._dynamo.test_case import run_tests
run_tests()