blob: 0f0e0ae8ddbac7312792dc215f3ab3fa9ba495a1 [file] [log] [blame]
# Owner(s): ["module: functorch"]
import contextlib
import functools
import unittest
import torch
import torch.utils._pytree as pytree
from functorch.experimental import control_flow
from functorch.experimental.control_flow import cond, UnsupportedAliasMutationException
from torch._higher_order_ops.associative_scan import associative_scan
from torch._higher_order_ops.while_loop import while_loop
from torch._subclasses.functional_tensor import (
CppFunctionalizeAPI,
FunctionalTensor,
FunctionalTensorMode,
PythonFunctionalizeAPI,
)
from torch.fx.experimental.proxy_tensor import make_fx
from torch.testing._internal.common_cuda import SM70OrLater
from torch.testing._internal.common_quantization import skipIfNoDynamoSupport
from torch.testing._internal.common_utils import (
decorateIf,
instantiate_parametrized_tests,
IS_WINDOWS,
parametrize,
run_tests,
skipIfTorchDynamo,
TEST_WITH_TORCHDYNAMO,
TestCase,
xfailIfTorchDynamo,
)
# TODO: pull these helpers from AOTAutograd later
def to_fun(t):
if isinstance(t, torch.Tensor):
return FunctionalTensor.to_functional(t)
return t
def from_fun(t):
if not isinstance(t, FunctionalTensor):
# quick sanity assert
if isinstance(t, torch.Tensor):
assert not torch._is_functional_tensor(t)
return t
torch._sync(t)
return torch._from_functional_tensor(t.elem)
def to_fun_old(t):
if isinstance(t, torch.Tensor) and not torch._is_functional_tensor(t):
out = torch._to_functional_tensor(t)
torch._mirror_autograd_meta_to(t, out)
return out
return t
def from_fun_old(t):
# quick sanity assert
if isinstance(t, torch.Tensor):
assert torch._is_functional_tensor(t)
torch._sync(t)
return torch._from_functional_tensor(t)
return t
def _fake_map(f, x, *args):
from functorch.experimental.control_flow import _stack_pytree, _unstack_pytree
x_pytrees = _unstack_pytree(x)
zs = []
for xp in x_pytrees:
zs.append(f(xp, *args))
return _stack_pytree(zs)
def _fake_while_loop(cond_fn, body_fn, operands):
while cond_fn(*operands):
operands = body_fn(*operands)
return operands
def _fake_associative_scan(combine_fn, input, dim, reverse=False):
inp_leaves, spec = pytree.tree_flatten(input)
result_flat = []
num_leaves = len(inp_leaves)
op = reversed if reverse else lambda x: x
for ind in op(range(inp_leaves[0].size(dim))):
r = [
inp_leaves[leave_ind][(slice(None),) * dim + (ind,)]
for leave_ind in range(num_leaves)
]
if (ind > 0 and not reverse) or (
ind < (inp_leaves[0].size(dim) - 1) and reverse
):
r = combine_fn(
pytree.tree_unflatten(result_flat[-1], spec),
pytree.tree_unflatten(r, spec),
)
r_flat, _ = pytree.tree_flatten(r)
result_flat.append(r_flat)
results = [
torch.stack([e[leave_ind] for e in op(result_flat)], dim)
for leave_ind in range(num_leaves)
]
return pytree.tree_unflatten(results, spec)
def _while_loop_tests():
def simple(x):
def cond_fn(x):
return x.sum() < 10
def body_fn(x):
return (x + 1,)
return while_loop(cond_fn, body_fn, (x,))
def simple_with_mutation(x):
def cond_fn(x):
y = x.clone().add_(1).add_(-1)
return y.sum() < 10
def body_fn(x):
y = x.clone().add_(1).add_(-1)
return (y + 1,)
return while_loop(cond_fn, body_fn, (x,))
def nested(out_iter, it, y):
def cond_fn(out_iter, it, y):
return it.sum() < 10
def body_fn(out_iter, it, y):
return (out_iter.clone(), it + y, y + 1)
def outer_cond_fn(out_iter, it, y):
return out_iter.sum() < 2
def outer_body_fn(out_iter, it, y):
out_iter, it, y = while_loop(cond_fn, body_fn, (out_iter, it, y))
return (out_iter + 1, it, y)
return while_loop(outer_cond_fn, outer_body_fn, (out_iter, it, y))
class Nested(torch.nn.Module):
def forward(self, ci, cj, a, b):
def cond_fn(i1, j1, x1, y1):
return i1 > 0
def body_fn(i1, j1, x1, y1):
def cond_fn_nested(i2, j2, x2, y2):
return j2 > 0
def body_fn_nested(i2, j2, x2, y2):
return i2.clone(), j2 - 1, x2 + 3.14, y2 - 2.71
i1, j1, x1, y1 = while_loop(
cond_fn_nested, body_fn_nested, [i1, j1, x1, y1]
)
return i1 - 1, j1.clone(), x1 * 2, y1 / 2
return while_loop(cond_fn, body_fn, (ci, cj, a, b))
class SimpleWithLinear(torch.nn.Module):
def __init__(self) -> None:
super().__init__()
self.linear = torch.nn.Linear(2, 2)
self.dec = torch.nn.Buffer(torch.tensor(1))
def forward(self, iter, x):
def cond_fn(it, x):
return it - self.dec > 0
def body_fn(it, x):
return it - 1, self.linear(x)
return while_loop(cond_fn, body_fn, (iter, x))
class NestedWithLinear(torch.nn.Module):
def __init__(self) -> None:
super().__init__()
self.mod = SimpleWithLinear()
self.outer_linear = torch.nn.Linear(2, 2)
self.dec = torch.nn.Buffer(torch.tensor(1))
def forward(self, iter, x):
def cond_fn(it, x):
return it - self.dec > 0
def body_fn(it, x):
return it - 1, self.outer_linear(self.mod(it, x)[1])
return while_loop(cond_fn, body_fn, (iter, x))
nested2 = Nested()
simple_with_linear = SimpleWithLinear()
nested_with_linear = NestedWithLinear()
x = torch.zeros(1)
y = torch.zeros(1)
z = torch.zeros(1)
return {
"simple": (simple, (x,)),
"nested": (nested, (x, y, z)),
"nested2": (
nested2,
(torch.tensor(2), torch.tensor(2), torch.ones(2, 2), torch.ones(2, 2)),
),
"simple_with_mutation": (simple_with_mutation, (x,)),
"simple_with_linear": (
simple_with_linear,
(torch.tensor(3), torch.randn(2, 2)),
),
"nested_with_linear": (
nested_with_linear,
(torch.tensor(3), torch.randn(2, 2)),
),
}
WHILE_LOOP_TESTS = _while_loop_tests()
def collect_meta_for_filtered_nodes(
gm: torch.fx.GraphModule, node_names, meta_field_name
):
ret = []
for mod in gm.modules():
for node in mod.graph.nodes:
if node.name in node_names:
for field_name in meta_field_name:
ret.append(node.meta.get(field_name))
return ret
def reduce_func(*operands):
acc = 0
for operand in operands:
acc += operand
return acc
class ReduceObj:
def __call__(self, *operands):
return reduce_func(*operands)
class ReduceMod(torch.nn.Module):
def _reduce(self, *operands):
return reduce_func(*operands)
def forward(self, *operands):
return self._reduce(*operands)
@unittest.skipIf(IS_WINDOWS, "Windows not supported for this test")
@skipIfNoDynamoSupport
class TestControlFlow(TestCase):
def setUp(self):
torch._dynamo.reset()
super().setUp()
def test_cond_no_trace(self):
def true_fn(x):
return x.sin()
def false_fn(x):
return x.cos()
x = torch.randn(4)
result = cond(False, true_fn, false_fn, [x])
self.assertEqual(result, torch.cos(x))
@unittest.skipIf(not torch.cuda.is_available(), "Test requires CUDA.")
def test_cond_gpu(self):
def true_fn(x):
return x.sin()
def false_fn(x):
return x.cos()
x = torch.randn(4, device="cuda")
pred = torch.tensor(False, device="cuda")
result = cond(pred, true_fn, false_fn, [x])
self.assertEqual(result, torch.cos(x))
def test_cond_autograd_simple(self):
def true_fn(x):
return x.sin()
def false_fn(x):
return x.cos()
for pred, fn in zip(
[torch.tensor(False), torch.tensor(True)], [false_fn, true_fn]
):
x = torch.randn(4, requires_grad=True)
result = cond(pred, true_fn, false_fn, (x,))
self.assertEqual(result, fn(x))
grad_out = torch.ones_like(result)
grads = torch.autograd.grad(result, (x,), grad_out)
expected_grads = torch.autograd.grad(fn(x), (x,), grad_out)
self.assertEqual(expected_grads, grads)
def f(pred, x):
result = cond(pred, true_fn, false_fn, (x,))
grad_out = torch.ones_like(result)
return torch.autograd.grad(result, (x,), grad_out)
gm = make_fx(f, tracing_mode="symbolic")(pred, x)
self.assertExpectedInline(
gm.code.strip(),
"""\
def forward(self, pred_1, x_1):
true_graph_0 = self.true_graph_0
false_graph_0 = self.false_graph_0
cond = torch.ops.higher_order.cond(pred_1, true_graph_0, false_graph_0, (x_1,)); true_graph_0 = false_graph_0 = None
getitem = cond[0]; cond = None
ones_like = torch.ops.aten.ones_like.default(getitem, pin_memory = False); getitem = None
true_graph_1 = self.true_graph_1
false_graph_1 = self.false_graph_1
cond_1 = torch.ops.higher_order.cond(pred_1, true_graph_1, false_graph_1, (ones_like, x_1)); pred_1 = true_graph_1 = false_graph_1 = ones_like = x_1 = None
getitem_1 = cond_1[0]; cond_1 = None
return (getitem_1,)""", # noqa: B950
)
def test_cond_autograd_complex(self):
def true_fn(x):
return torch.abs((x**2).sin())
def false_fn(x):
return (x + 42).cos()
for pred, fn in zip(
[torch.tensor(False), torch.tensor(True)], [false_fn, true_fn]
):
x = torch.randn(4, requires_grad=True)
result = cond(pred, true_fn, false_fn, (x,))
self.assertEqual(result, fn(x))
grad_out = torch.ones_like(result)
grads = torch.autograd.grad(result, (x,), grad_out)
expected_grads = torch.autograd.grad(fn(x), (x,), grad_out)
self.assertEqual(expected_grads, grads)
def f(pred, x):
result = cond(pred, true_fn, false_fn, (x,))
grad_out = torch.ones_like(result)
return torch.autograd.grad(result, (x,), grad_out)
gm = make_fx(f, tracing_mode="symbolic")(pred, x)
self.assertExpectedInline(
gm.code.strip(),
"""\
def forward(self, pred_1, x_1):
true_graph_0 = self.true_graph_0
false_graph_0 = self.false_graph_0
cond = torch.ops.higher_order.cond(pred_1, true_graph_0, false_graph_0, (x_1,)); true_graph_0 = false_graph_0 = None
getitem = cond[0]; cond = None
ones_like = torch.ops.aten.ones_like.default(getitem, pin_memory = False); getitem = None
true_graph_1 = self.true_graph_1
false_graph_1 = self.false_graph_1
cond_1 = torch.ops.higher_order.cond(pred_1, true_graph_1, false_graph_1, (ones_like, x_1)); pred_1 = true_graph_1 = false_graph_1 = ones_like = x_1 = None
getitem_1 = cond_1[0]; cond_1 = None
return (getitem_1,)""", # noqa: B950
)
@skipIfTorchDynamo("Skip due to graph break when run with dynamo")
def test_cond_autograd_nested(self):
class Nested(torch.nn.Module):
def forward(self, p0, p1, p2, a, b, c):
def true_fn(x0, y0, z0):
def true_true_fn(x1, y1, z1):
return (x1 - y1 * z1) * 3.14
def true_false_fn(x1, y1, z1):
def true_false_true_fn(x2, y2, z2):
return (x2 * y2 * z2) / 2.71
def true_false_false_fn(x2, y2, z2):
return (x2 + y2 + z2) * 1.23
return torch.cond(
p2, true_false_true_fn, true_false_false_fn, [x1, y1, z1]
)
return torch.cond(p1, true_true_fn, true_false_fn, [x0, y0, z0])
def false_fn(x0, y0, z0):
def false_true_fn(x1, y1, z1):
def false_true_true_fn(x2, y2, z2):
return (x2 - y2 - z2) + 1.23
def false_true_false_fn(x2, y2, z2):
return (x2 / y2 / z2) - 3.14
return torch.cond(
p2, false_true_true_fn, false_true_false_fn, [x1, y1, z1]
)
def false_false_fn(x1, y1, z1):
return (x1 - y1 * z1) / 2.71
return torch.cond(p1, false_true_fn, false_false_fn, [x0, y0, z0])
return torch.cond(p0, true_fn, false_fn, [a, b, c])
nn_module = Nested()
def true_fn(x):
return nn_module(
torch.tensor(False), torch.tensor(True), torch.tensor(False), x, x, x
)
def false_fn(x):
return nn_module(
torch.tensor(True), torch.tensor(False), torch.tensor(True), x, x, x
)
x = torch.randn(4, requires_grad=True)
for pred, fn in zip(
[torch.tensor(False), torch.tensor(True)], [false_fn, true_fn]
):
result = cond(pred, true_fn, false_fn, (x,))
self.assertEqual(result, fn(x))
grad_out = torch.ones_like(result)
grads = torch.autograd.grad(result, (x,), grad_out)
expected_grads = torch.autograd.grad(fn(x), (x,), grad_out)
self.assertEqual(expected_grads, grads)
@skipIfTorchDynamo("Skip due to graph break when run with dynamo")
def test_cond_autograd_mixed_require_grad(self):
def true_fn(x, y, z):
return x * y * z
def false_fn(x, y, z):
return x + y + z
x = torch.randn(4, requires_grad=True)
y = torch.randn(4, requires_grad=False)
for pred, fn in zip(
[torch.tensor(False), torch.tensor(True)], [false_fn, true_fn]
):
result = cond(pred, true_fn, false_fn, (x, y, x))
self.assertEqual(result, fn(x, y, x))
grad_out = torch.ones_like(result)
grads = torch.autograd.grad(result, (x,), grad_out)
expected_grads = torch.autograd.grad(fn(x, y, x), (x,), grad_out)
self.assertEqual(expected_grads, grads)
def f(pred, x, y, z):
result = cond(pred, true_fn, false_fn, (x, y, z))
grad_out = torch.ones_like(result)
return torch.autograd.grad(result, (x,), grad_out)
gm = make_fx(f, tracing_mode="symbolic")(pred, x, y, x)
self.assertExpectedInline(
gm.code.strip(),
"""\
def forward(self, pred_1, x_1, y_1, z_1):
true_graph_0 = self.true_graph_0
false_graph_0 = self.false_graph_0
cond = torch.ops.higher_order.cond(pred_1, true_graph_0, false_graph_0, (z_1, y_1)); true_graph_0 = false_graph_0 = None
getitem = cond[0]; cond = None
ones_like = torch.ops.aten.ones_like.default(getitem, pin_memory = False); getitem = None
true_graph_1 = self.true_graph_1
false_graph_1 = self.false_graph_1
cond_1 = torch.ops.higher_order.cond(pred_1, true_graph_1, false_graph_1, (ones_like, z_1, y_1)); pred_1 = true_graph_1 = false_graph_1 = ones_like = z_1 = y_1 = None
getitem_1 = cond_1[0]
getitem_2 = cond_1[1]; cond_1 = getitem_2 = None
return (getitem_1,)""", # noqa: B950
)
@skipIfTorchDynamo("Skip due to graph break when run with dynamo")
def test_cond_autograd_grad_through_cond(self):
nn_module = torch.nn.Linear(4, 4)
def true_fn(x):
return nn_module(x)
def false_fn(X):
return x * nn_module(x)
x = torch.randn(4, requires_grad=True)
for pred, fn in zip(
[torch.tensor(False), torch.tensor(True)], [false_fn, true_fn]
):
result = cond(pred, true_fn, false_fn, (x,))
self.assertEqual(result, fn(x))
grad_out = torch.ones_like(result)
grads = torch.autograd.grad(result, (nn_module.weight,), grad_out)
expected_grads = torch.autograd.grad(
fn(
x,
),
(nn_module.weight,),
grad_out,
)
self.assertEqual(expected_grads, grads)
def f(pred, x):
result = cond(pred, true_fn, false_fn, (x,))
grad_out = torch.ones_like(result)
return torch.autograd.grad(result, (nn_module.weight,), grad_out)
# need to set _allow_non_fake_inputs = True because model parameters don't
# get fakified.
gm = make_fx(f, tracing_mode="symbolic", _allow_non_fake_inputs=True)(pred, x)
self.assertExpectedInline(
gm.code.strip(),
"""\
def forward(self, pred_1, x_1):
true_graph_0 = self.true_graph_0
false_graph_0 = self.false_graph_0
_param_constant0 = self._param_constant0
_param_constant1 = self._param_constant1
_tensor_constant0 = self._tensor_constant0
cond = torch.ops.higher_order.cond(pred_1, true_graph_0, false_graph_0, (_param_constant0, _param_constant1, x_1, _tensor_constant0)); true_graph_0 = false_graph_0 = _param_constant0 = _param_constant1 = _tensor_constant0 = None
getitem = cond[0]; cond = None
ones_like = torch.ops.aten.ones_like.default(getitem, pin_memory = False); getitem = None
true_graph_1 = self.true_graph_1
false_graph_1 = self.false_graph_1
_param_constant0_1 = self._param_constant0
_param_constant1_1 = self._param_constant1
_tensor_constant0_1 = self._tensor_constant0
cond_1 = torch.ops.higher_order.cond(pred_1, true_graph_1, false_graph_1, (ones_like, _param_constant0_1, _param_constant1_1, x_1, _tensor_constant0_1)); pred_1 = true_graph_1 = false_graph_1 = ones_like = _param_constant0_1 = _param_constant1_1 = x_1 = _tensor_constant0_1 = None
getitem_1 = cond_1[0]; getitem_1 = None
getitem_2 = cond_1[1]
getitem_3 = cond_1[2]; getitem_3 = None
getitem_4 = cond_1[3]; cond_1 = getitem_4 = None
return (getitem_2,)""", # noqa: B950
)
def test_cond_in_forloop(self):
def for_loop_fake(x):
for i in range(3):
x = x * x + 1
return x
def for_loop_test(x):
for i in range(3):
pred = i < 3
def true_fn(x):
return x * x + 1
def false_fn(x):
return x
x = cond(pred, true_fn, false_fn, (x,))
return x
x = torch.ones(4, requires_grad=True)
x_new = for_loop_test(x)
x_exp = for_loop_fake(x)
self.assertEqual(x_new, x_exp)
grad_out = torch.ones_like(x_new)
grads = torch.autograd.grad(x_new, (x,), grad_out)
expected_grads = torch.autograd.grad(x_exp, (x,), grad_out)
self.assertEqual(expected_grads, grads)
def f(x):
x_new = for_loop_test(x)
grad_out = torch.ones_like(x_new)
return torch.autograd.grad(x_new, (x,), grad_out)
gm = make_fx(f, tracing_mode="symbolic")(x)
self.assertExpectedInline(
gm.code.strip(),
"""\
def forward(self, x_1):
mul = torch.ops.aten.mul.Tensor(x_1, x_1)
add = torch.ops.aten.add.Tensor(mul, 1); mul = None
mul_1 = torch.ops.aten.mul.Tensor(add, add)
add_1 = torch.ops.aten.add.Tensor(mul_1, 1); mul_1 = None
mul_2 = torch.ops.aten.mul.Tensor(add_1, add_1)
add_2 = torch.ops.aten.add.Tensor(mul_2, 1); mul_2 = None
ones_like = torch.ops.aten.ones_like.default(add_2, pin_memory = False); add_2 = None
mul_3 = torch.ops.aten.mul.Tensor(ones_like, add_1)
mul_4 = torch.ops.aten.mul.Tensor(ones_like, add_1); ones_like = add_1 = None
add_3 = torch.ops.aten.add.Tensor(mul_4, mul_3); mul_4 = mul_3 = None
mul_5 = torch.ops.aten.mul.Tensor(add_3, add)
mul_6 = torch.ops.aten.mul.Tensor(add_3, add); add_3 = add = None
add_4 = torch.ops.aten.add.Tensor(mul_6, mul_5); mul_6 = mul_5 = None
mul_7 = torch.ops.aten.mul.Tensor(add_4, x_1)
mul_8 = torch.ops.aten.mul.Tensor(add_4, x_1); add_4 = x_1 = None
add_5 = torch.ops.aten.add.Tensor(mul_8, mul_7); mul_8 = mul_7 = None
return (add_5,)""", # noqa: B950
)
@skipIfTorchDynamo("Skip due to graph break when run with dynamo")
def test_cond_autograd_pytree_not_all_inputs_used(self):
def true_fn(x):
return x["t"][0] + x["t"][1]["b"]
def false_fn(x):
return x["t"][0] * (x["t"][2][0] / x["t"][1]["b"])
a = torch.randn(4, requires_grad=True)
b = torch.randn(4, requires_grad=True)
c = torch.randn(4, requires_grad=True)
for pred, fn in zip(
[torch.tensor(False), torch.tensor(True)], [false_fn, true_fn]
):
result = cond(pred, true_fn, false_fn, ({"t": [a, {"b": b}, (c,)]},))
self.assertEqual(result, fn({"t": [a, {"b": b}, (c,)]}))
grad_out = torch.ones_like(result)
if pred:
with self.assertRaisesRegex(Exception, r"."):
grads = torch.autograd.grad(result, (a, b, c), grad_out)
expected_grads = torch.autograd.grad(
fn({"t": [a, {"b": b}, (c,)]}), (a, b, c), grad_out
)
self.assertEqual(expected_grads, grads)
def f(pred, a, b, c):
result = cond(pred, true_fn, false_fn, ({"t": [a, {"b": b}, (c,)]},))
grad_out = torch.ones_like(result)
return torch.autograd.grad(result, (a, b), grad_out)
gm = make_fx(f, tracing_mode="symbolic", _allow_non_fake_inputs=True)(
pred, a, b, c
)
self.assertExpectedInline(
gm.code.strip(),
"""\
def forward(self, pred_1, a_1, b_1, c_1):
true_graph_0 = self.true_graph_0
false_graph_0 = self.false_graph_0
cond = torch.ops.higher_order.cond(pred_1, true_graph_0, false_graph_0, (a_1, b_1, c_1)); true_graph_0 = false_graph_0 = None
getitem = cond[0]; cond = None
ones_like = torch.ops.aten.ones_like.default(getitem, pin_memory = False); getitem = None
true_graph_1 = self.true_graph_1
false_graph_1 = self.false_graph_1
cond_1 = torch.ops.higher_order.cond(pred_1, true_graph_1, false_graph_1, (ones_like, a_1, b_1, c_1)); pred_1 = true_graph_1 = false_graph_1 = ones_like = a_1 = b_1 = c_1 = None
getitem_1 = cond_1[0]
getitem_2 = cond_1[1]
getitem_3 = cond_1[2]; cond_1 = getitem_3 = None
return (getitem_1, getitem_2)""", # noqa: B950
)
# Forward
self.assertExpectedInline(
gm.true_graph_0.code.strip(),
"""\
def forward(self, arg0_1, arg1_1, arg2_1):
add = torch.ops.aten.add.Tensor(arg0_1, arg1_1); arg0_1 = arg1_1 = None
return (add,)""",
)
# Backward
self.assertExpectedInline(
gm.true_graph_1.code.strip(),
"""\
def forward(self, arg0_1, arg1_1, arg2_1, arg3_1):
add = torch.ops.aten.add.Tensor(arg1_1, arg2_1); arg1_1 = arg2_1 = add = None
clone = torch.ops.aten.clone.default(arg0_1)
clone_1 = torch.ops.aten.clone.default(arg0_1); arg0_1 = None
return [clone, clone_1, None]""",
)
def test_cond_autograd_pytree_input(self):
def true_fn(x):
return x["t"][0] + x["t"][1]["b"] * x["t"][2][0]
def false_fn(x):
return x["t"][0] * (x["t"][2][0] / x["t"][1]["b"])
a = torch.randn(4, requires_grad=True)
b = torch.randn(4, requires_grad=True)
c = torch.randn(4, requires_grad=True)
for pred, fn in zip(
[torch.tensor(False), torch.tensor(True)], [false_fn, true_fn]
):
result = cond(pred, true_fn, false_fn, ({"t": [a, {"b": b}, (c,)]},))
self.assertEqual(result, fn({"t": [a, {"b": b}, (c,)]}))
grad_out = torch.ones_like(result)
grads = torch.autograd.grad(result, (a, b), grad_out)
expected_grads = torch.autograd.grad(
fn({"t": [a, {"b": b}, (c,)]}), (a, b), grad_out
)
self.assertEqual(expected_grads, grads)
def f(pred):
result = cond(pred, true_fn, false_fn, ({"t": [a, {"b": b}, (c,)]},))
grad_out = torch.ones_like(result)
return torch.autograd.grad(result, (a, b), grad_out)
# need to set _allow_non_fake_inputs = True because model parameters don't
# get fakified.
gm = make_fx(f, tracing_mode="symbolic", _allow_non_fake_inputs=True)(pred)
self.assertExpectedInline(
gm.code.strip(),
"""\
def forward(self, pred_1):
true_graph_0 = self.true_graph_0
false_graph_0 = self.false_graph_0
_tensor_constant0 = self._tensor_constant0
_tensor_constant1 = self._tensor_constant1
_tensor_constant2 = self._tensor_constant2
cond = torch.ops.higher_order.cond(pred_1, true_graph_0, false_graph_0, (_tensor_constant0, _tensor_constant1, _tensor_constant2)); true_graph_0 = false_graph_0 = _tensor_constant0 = _tensor_constant1 = _tensor_constant2 = None
getitem = cond[0]; cond = None
ones_like = torch.ops.aten.ones_like.default(getitem, pin_memory = False); getitem = None
true_graph_1 = self.true_graph_1
false_graph_1 = self.false_graph_1
_tensor_constant0_1 = self._tensor_constant0
_tensor_constant1_1 = self._tensor_constant1
_tensor_constant2_1 = self._tensor_constant2
cond_1 = torch.ops.higher_order.cond(pred_1, true_graph_1, false_graph_1, (ones_like, _tensor_constant0_1, _tensor_constant1_1, _tensor_constant2_1)); pred_1 = true_graph_1 = false_graph_1 = ones_like = _tensor_constant0_1 = _tensor_constant1_1 = _tensor_constant2_1 = None
getitem_1 = cond_1[0]
getitem_2 = cond_1[1]
getitem_3 = cond_1[2]; cond_1 = getitem_3 = None
return (getitem_1, getitem_2)""", # noqa: B950
)
def test_cond_autograd_different_pytree_output(self):
def true_fn(x):
return x["t"][0], {"r": x["t"][2][0] / x["t"][1]["b"]}, [x["t"][2][0]]
def false_fn(x):
return {"res": [x["t"][0] * x["t"][1]["b"], x["t"][2][0]]}
a = torch.randn(4, requires_grad=True)
b = torch.randn(4, requires_grad=True)
c = torch.randn(4, requires_grad=True)
for pred, fn in zip(
[torch.tensor(False), torch.tensor(True)], [false_fn, true_fn]
):
with self.assertRaisesRegex(
torch._dynamo.exc.UncapturedHigherOrderOpError,
"Cond doesn't work unless it is captured completely with torch.compile",
):
cond(pred, true_fn, false_fn, ({"t": [a, {"b": b}, (c,)]},))
@skipIfTorchDynamo("Skip due to graph break when run with dynamo")
def test_cond_autograd_same_pytree_output(self):
def true_fn(x):
return {"res": [x["t"][0], (x["t"][2][0],)]}
def false_fn(x):
return {"res": [x["t"][1]["b"], (x["t"][2][0],)]}
a = torch.randn(4, requires_grad=True)
b = torch.randn(4, requires_grad=True)
c = torch.randn(4, requires_grad=True)
for pred, fn in zip(
[torch.tensor(False), torch.tensor(True)], [false_fn, true_fn]
):
result = cond(pred, true_fn, false_fn, ({"t": [a, {"b": b}, (c,)]},))
result_exp = fn({"t": [a, {"b": b}, (c,)]})
self.assertEqual(result, result_exp)
result_flat, _ = pytree.tree_flatten(result)
result_exp_flat, _ = pytree.tree_flatten(result_exp)
grad_out = [torch.ones_like(g) for g in result_flat]
expected_grads = torch.autograd.grad(result_exp_flat, (c,), grad_out)
grads = torch.autograd.grad(result_flat, (c,), grad_out)
self.assertEqual(expected_grads, grads)
def f(pred):
result = cond(pred, true_fn, false_fn, ({"t": [a, {"b": b}, (c,)]},))
return result
gm = make_fx(f, tracing_mode="symbolic", _allow_non_fake_inputs=True)(pred)
self.assertExpectedInline(
gm.code.strip(),
"""\
def forward(self, pred_1):
true_graph_0 = self.true_graph_0
false_graph_0 = self.false_graph_0
_tensor_constant0 = self._tensor_constant0
_tensor_constant1 = self._tensor_constant1
_tensor_constant2 = self._tensor_constant2
cond = torch.ops.higher_order.cond(pred_1, true_graph_0, false_graph_0, (_tensor_constant0, _tensor_constant1, _tensor_constant2)); pred_1 = true_graph_0 = false_graph_0 = _tensor_constant0 = _tensor_constant1 = _tensor_constant2 = None
getitem = cond[0]
getitem_1 = cond[1]; cond = None
view = torch.ops.aten.view.default(getitem, [4]); getitem = None
view_1 = torch.ops.aten.view.default(getitem_1, [4]); getitem_1 = None
return {'res': [view, (view_1,)]}""", # noqa: B950
)
@skipIfTorchDynamo("Skip due to graph break when run with dynamo")
def test_cond_autograd_torch_nn_module(self):
nn_module_true = torch.nn.Linear(4, 4)
def true_fn(x):
return nn_module_true(torch.abs((x**2).sin()))
nn_module_false = torch.nn.GRUCell(4, 4)
def false_fn(x):
return nn_module_false((x + 42).cos())
for pred, fn in zip(
[torch.tensor(False), torch.tensor(True)], [false_fn, true_fn]
):
x = torch.randn(4, requires_grad=True)
result = cond(pred, true_fn, false_fn, (x,))
self.assertEqual(result, fn(x))
grad_out = torch.ones_like(result)
grads = torch.autograd.grad(result, (x,), grad_out)
expected_grads = torch.autograd.grad(fn(x), (x,), grad_out)
self.assertEqual(expected_grads, grads)
def f(pred, x):
result = cond(pred, true_fn, false_fn, (x,))
grad_out = torch.ones_like(result)
return torch.autograd.grad(result, (x,), grad_out)
gm = make_fx(f)(pred, x)
self.assertExpectedInline(
gm.code.strip(),
"""\
def forward(self, pred_1, x_1):
true_graph_0 = self.true_graph_0
false_graph_0 = self.false_graph_0
_param_constant0 = self._param_constant0
_param_constant1 = self._param_constant1
_param_constant2 = self._param_constant2
_param_constant3 = self._param_constant3
_param_constant4 = self._param_constant4
_param_constant5 = self._param_constant5
cond = torch.ops.higher_order.cond(pred_1, true_graph_0, false_graph_0, (x_1, _param_constant0, _param_constant1, _param_constant2, _param_constant3, _param_constant4, _param_constant5)); true_graph_0 = false_graph_0 = _param_constant0 = _param_constant1 = _param_constant2 = _param_constant3 = _param_constant4 = _param_constant5 = None
getitem = cond[0]; cond = None
ones_like = torch.ops.aten.ones_like.default(getitem, pin_memory = False); getitem = None
true_graph_1 = self.true_graph_1
false_graph_1 = self.false_graph_1
_param_constant0_1 = self._param_constant0
_param_constant1_1 = self._param_constant1
_param_constant2_1 = self._param_constant2
_param_constant3_1 = self._param_constant3
_param_constant4_1 = self._param_constant4
_param_constant5_1 = self._param_constant5
cond_1 = torch.ops.higher_order.cond(pred_1, true_graph_1, false_graph_1, (ones_like, x_1, _param_constant0_1, _param_constant1_1, _param_constant2_1, _param_constant3_1, _param_constant4_1, _param_constant5_1)); pred_1 = true_graph_1 = false_graph_1 = ones_like = x_1 = _param_constant0_1 = _param_constant1_1 = _param_constant2_1 = _param_constant3_1 = _param_constant4_1 = _param_constant5_1 = None
getitem_1 = cond_1[0]
getitem_2 = cond_1[1]; getitem_2 = None
getitem_3 = cond_1[2]; getitem_3 = None
getitem_4 = cond_1[3]; getitem_4 = None
getitem_5 = cond_1[4]; getitem_5 = None
getitem_6 = cond_1[5]; getitem_6 = None
getitem_7 = cond_1[6]; cond_1 = getitem_7 = None
return (getitem_1,)""", # noqa: B950
)
def test_cond_autograd_user_nn_module(self):
class User_nn_module(torch.nn.Module):
def __init__(self) -> None:
super().__init__()
def forward(self, input):
return input * input
nn_module_true = User_nn_module()
def true_fn(x):
return nn_module_true(torch.abs((x**2).sin()))
nn_module_false = torch.nn.ReLU(inplace=False)
def false_fn(x):
return nn_module_false((x + 42).cos())
for pred, fn in zip(
[torch.tensor(False), torch.tensor(True)], [false_fn, true_fn]
):
x = torch.randn(4, requires_grad=True)
result = cond(pred, true_fn, false_fn, (x,))
self.assertEqual(result, fn(x))
grad_out = torch.ones_like(result)
grads = torch.autograd.grad(result, (x,), grad_out)
expected_grads = torch.autograd.grad(fn(x), (x,), grad_out)
self.assertEqual(expected_grads, grads)
def f(pred, x):
result = cond(pred, true_fn, false_fn, (x,))
grad_out = torch.ones_like(result)
return torch.autograd.grad(result, (x,), grad_out)
gm = make_fx(f)(pred, x)
self.assertExpectedInline(
gm.code.strip(),
"""\
def forward(self, pred_1, x_1):
true_graph_0 = self.true_graph_0
false_graph_0 = self.false_graph_0
cond = torch.ops.higher_order.cond(pred_1, true_graph_0, false_graph_0, (x_1,)); true_graph_0 = false_graph_0 = None
getitem = cond[0]; cond = None
ones_like = torch.ops.aten.ones_like.default(getitem, pin_memory = False); getitem = None
true_graph_1 = self.true_graph_1
false_graph_1 = self.false_graph_1
cond_1 = torch.ops.higher_order.cond(pred_1, true_graph_1, false_graph_1, (ones_like, x_1)); pred_1 = true_graph_1 = false_graph_1 = ones_like = x_1 = None
getitem_1 = cond_1[0]; cond_1 = None
return (getitem_1,)""", # noqa: B950
)
def test_cond_autograd_inner_fn(self):
def true_fn(x):
return torch.abs((x**2).sin())
def false_fn(x):
def inner_fn(x):
return x**2
return torch.abs(inner_fn(x).sin())
x = torch.randn(4, requires_grad=True)
pred = torch.tensor(False)
fn = false_fn
result_false = cond(pred, true_fn, false_fn, (x,))
self.assertEqual(result_false, fn(x))
grad_out = torch.ones_like(result_false)
grads_false = torch.autograd.grad(result_false, (x,), grad_out)
expected_grads = torch.autograd.grad(fn(x), (x,), grad_out)
self.assertEqual(expected_grads, grads_false)
pred = torch.tensor(True)
fn = true_fn
result_true = cond(pred, true_fn, false_fn, (x,))
self.assertEqual(result_true, fn(x))
self.assertEqual(result_false, result_true)
grad_out = torch.ones_like(result_true)
grads_true = torch.autograd.grad(result_true, (x,), grad_out)
expected_grads = torch.autograd.grad(fn(x), (x,), grad_out)
self.assertEqual(expected_grads, grads_true)
self.assertEqual(grads_false, grads_true)
def f(pred, x):
result = cond(pred, true_fn, false_fn, (x,))
grad_out = torch.ones_like(result)
return torch.autograd.grad(result, (x,), grad_out)
gm = make_fx(f)(pred, x)
self.assertExpectedInline(
gm.code.strip(),
"""\
def forward(self, pred_1, x_1):
true_graph_0 = self.true_graph_0
false_graph_0 = self.false_graph_0
cond = torch.ops.higher_order.cond(pred_1, true_graph_0, false_graph_0, (x_1,)); true_graph_0 = false_graph_0 = None
getitem = cond[0]; cond = None
ones_like = torch.ops.aten.ones_like.default(getitem, pin_memory = False); getitem = None
true_graph_1 = self.true_graph_1
false_graph_1 = self.false_graph_1
cond_1 = torch.ops.higher_order.cond(pred_1, true_graph_1, false_graph_1, (ones_like, x_1)); pred_1 = true_graph_1 = false_graph_1 = ones_like = x_1 = None
getitem_1 = cond_1[0]; cond_1 = None
return (getitem_1,)""", # noqa: B950
)
def test_cond_autograd_inner_tensor(self):
def true_fn(x):
return torch.abs((x**2).sin())
def false_fn(x):
y = torch.ones(4, requires_grad=False) * 42
return (x * y).cos()
for pred, fn in zip(
[torch.tensor(False), torch.tensor(True)], [false_fn, true_fn]
):
x = torch.randn(4, requires_grad=True)
result = cond(pred, true_fn, false_fn, (x,))
self.assertEqual(result, fn(x))
grad_out = torch.ones_like(result)
grads = torch.autograd.grad(result, (x,), grad_out)
expected_grads = torch.autograd.grad(fn(x), (x,), grad_out)
self.assertEqual(expected_grads, grads)
def f(pred, x):
result = cond(pred, true_fn, false_fn, (x,))
grad_out = torch.ones_like(result)
return torch.autograd.grad(result, (x,), grad_out)
gm = make_fx(f, tracing_mode="symbolic")(pred, x)
self.assertExpectedInline(
gm.code.strip(),
"""\
def forward(self, pred_1, x_1):
true_graph_0 = self.true_graph_0
false_graph_0 = self.false_graph_0
cond = torch.ops.higher_order.cond(pred_1, true_graph_0, false_graph_0, (x_1,)); true_graph_0 = false_graph_0 = None
getitem = cond[0]; cond = None
ones_like = torch.ops.aten.ones_like.default(getitem, pin_memory = False); getitem = None
true_graph_1 = self.true_graph_1
false_graph_1 = self.false_graph_1
cond_1 = torch.ops.higher_order.cond(pred_1, true_graph_1, false_graph_1, (ones_like, x_1)); pred_1 = true_graph_1 = false_graph_1 = ones_like = x_1 = None
getitem_1 = cond_1[0]; cond_1 = None
return (getitem_1,)""", # noqa: B950
)
@unittest.skipIf(not torch.cuda.is_available(), "Test requires CUDA.")
def test_cond_autograd_gpu(self):
def true_fn(x):
return x.sin()
def false_fn(x):
return x.cos()
for pred, fn in zip(
[torch.tensor(False, device="cuda"), torch.tensor(True, device="cuda")],
[false_fn, true_fn],
):
x = torch.randn(4, requires_grad=True, device="cuda")
result = cond(pred, true_fn, false_fn, (x,))
self.assertEqual(result, fn(x))
grad_out = torch.ones_like(result)
grads = torch.autograd.grad(result, (x,), grad_out)
expected_grads = torch.autograd.grad(fn(x), (x,), grad_out)
self.assertEqual(expected_grads, grads)
@unittest.skipIf(not torch.cuda.is_available(), "Test requires CUDA.")
def test_map_gpu(self):
def f(x, y):
return x + y
xs = torch.ones(3, 2, 2, device="cuda")
y = torch.ones(2, device="cuda")
res = control_flow.map(f, xs, y)
expected = _fake_map(f, xs, y)
self.assertEqual(expected, res)
@unittest.skipIf(not torch.cuda.is_available(), "Test requires CUDA.")
def test_while_loop_gpu(self):
def cond_fn(x):
return x.sum() < 10
def body_fn(x):
return (x + 1,)
x = torch.zeros(1, device="cuda")
res = while_loop(cond_fn, body_fn, (x,))
expected = _fake_while_loop(cond_fn, body_fn, (x,))
self.assertEqual(expected, res)
def test_map_illegal_inputs(self):
def f(x, y):
return x[0] + x[1] + y
with self.assertRaisesRegex(
RuntimeError,
r"Mapped xs can only consist of tensors\. Got xs \[3, tensor\(\[1\., 1\.\]\)\]\.",
):
_ = control_flow.map(f, (3, torch.ones(2)), torch.ones(2))
with self.assertRaisesRegex(
RuntimeError, r"Leading dimensions of mapped xs cannot be 0\."
):
_ = control_flow.map(
f, (torch.ones(0, 1, 2), torch.ones(0, 1, 2)), torch.ones(2)
)
with self.assertRaisesRegex(
RuntimeError,
r"Leading dimensions of mapped xs must be consistent\. "
r"Got shapes \[torch\.Size\(\[3, 4, 5\]\), torch\.Size\(\[4, 4, 5\]\)\]\.",
):
_ = control_flow.map(
f, (torch.ones(3, 4, 5), torch.ones(4, 4, 5)), torch.ones(5)
)
def test_map_illegal_outputs(self):
def f(x, y):
return x.item()
def f1(x, y):
return y.size()
def f2(x, y):
return None
x = torch.ones([3])
y = torch.ones([1, 2, 3])
with self.assertRaisesRegex(
RuntimeError, r"Expect outputs of map only contains tensors or None\."
):
_ = control_flow.map(f, x, y)
with self.assertRaisesRegex(
RuntimeError, r"Expect outputs of map only contains tensors or None\."
):
out = control_flow.map(f1, x, y)
# return None is OK
_ = control_flow.map(f2, x, y)
def test_map_list_in_out(self):
def f(x, y):
return [[x[0][0] + y]]
xs = [[torch.ones(3, 2, 2)]]
y = torch.ones(2)
res = control_flow.map(f, xs, y)
expected = _fake_map(f, xs, y)
self.assertEqual(len(res), 1)
self.assertEqual(len(res[0]), 1)
self.assertEqual(expected, res)
def test_map_dict_in_out(self):
def f(x, y):
return {"c": x["a"]["b"] + y}
xs = {"a": {"b": torch.ones(3, 2, 2)}}
y = torch.ones(2)
res = control_flow.map(f, xs, y)
expected = _fake_map(f, xs, y)
self.assertEqual(len(res), 1)
self.assertTrue("c" in res)
self.assertEqual(expected, res)
def test_map_autograd_simple(self):
def f(x, y):
return x.sin().cos() * y.cos().sin()
xs = torch.ones(3, 2, 2, requires_grad=True)
y = torch.ones(2, requires_grad=True)
res = control_flow.map(f, xs, y)
expected_res = _fake_map(f, xs, y)
grad_out = torch.ones_like(res)
grads = torch.autograd.grad(res, (xs, y), grad_out)
expected_grads = torch.autograd.grad(expected_res, (xs, y), grad_out)
self.assertEqual(expected_res, res)
self.assertEqual(expected_grads, grads)
def test_map_autograd_simple_partial_grad(self):
def f(x, y):
return x.sin().cos() * y.cos().sin()
xs = torch.ones(3, 2, 2, requires_grad=True)
# Disable the gradient computation for y
y = torch.ones(2, requires_grad=False)
res = control_flow.map(f, xs, y)
expected_res = _fake_map(f, xs, y)
grad_out = torch.ones_like(res)
grads = torch.autograd.grad(res, (xs,), grad_out)
expected_grads = torch.autograd.grad(expected_res, (xs,), grad_out)
self.assertEqual(expected_res, res)
self.assertEqual(expected_grads, grads)
def test_map_autograd_no_grad_output(self):
def f(x, y):
return x[0].sin().cos() + y, y.cos().sin()
xs = [torch.ones(3, 2, 2, requires_grad=True), torch.ones(3, 3)]
# Disable the gradient computation for y
y = torch.ones(2, requires_grad=False)
res = control_flow.map(f, xs, y)
expected_res = _fake_map(f, xs, y)
grad_out = torch.ones_like(res[0])
grads = torch.autograd.grad(res[0], (xs[0],), grad_out)
expected_grads = torch.autograd.grad(expected_res[0], (xs[0],), grad_out)
self.assertEqual(expected_res, res)
self.assertEqual(expected_grads, grads)
def test_map_autograd_nested_list(self):
import torch.utils._pytree as pytree
def f(x, y):
a, b = x
c, d = a
return [[b.sin() * c.cos()], d.sin() * y.cos()]
def fwbw(map_op, f, x, y):
z = map_op(f, x, y)
flat_x = pytree.tree_leaves(x)
flat_z = pytree.tree_leaves(z)
grads = torch.autograd.grad(
flat_z, flat_x, [torch.ones_like(z) for z in flat_z]
)
return z, grads
x = [
[
torch.randn(3, 2, 2, requires_grad=True),
torch.randn(3, 2, 1, requires_grad=True),
],
torch.ones(3, 1, 2, requires_grad=True),
]
y = torch.ones(1, requires_grad=True)
true_outs = fwbw(control_flow.map, f, x, y)
fake_outs = fwbw(_fake_map, f, x, y)
self.assertEqual(true_outs, fake_outs)
@unittest.skipIf(not SM70OrLater, "triton")
@unittest.skipIf(not torch.cuda.is_available(), "Test requires CUDA.")
@parametrize("reverse", [False, True])
@parametrize("combine_mode", ["pointwise", "generic"])
@parametrize("device", [torch.device("cpu"), torch.device("cuda")])
# Skipping the combination of combine_mode=pointwise and device=cpu
# as the current implementation of pointwise does only support CUDA device
@decorateIf(
unittest.skip,
lambda params: (
params["combine_mode"] == "pointwise"
and (params["device"] == torch.device("cpu") or torch.version.hip)
),
)
def test_pointwise_associative_scan_simple(self, reverse, combine_mode, device):
def add(x: torch.Tensor, y: torch.Tensor):
return x + y
def mul(x: torch.Tensor, y: torch.Tensor):
return x * y
x = torch.randn(3, 10, 2, device=device)
for op, op_pt in [(add, torch.cumsum), (mul, torch.cumprod)]:
result = associative_scan(
op, x, 0, reverse=reverse, combine_mode=combine_mode
)
result_exp = _fake_associative_scan(op, x, 0, reverse=reverse)
self.assertEqual(result, result_exp)
# Jax Examples
x = torch.arange(0, 4, device=device)
cumsum1 = associative_scan(
add, x, 0, reverse=reverse, combine_mode=combine_mode
)
cumsum_exp = _fake_associative_scan(add, x, 0, reverse=reverse)
if not reverse:
self.assertEqual(
cumsum1, torch.tensor([0.0, 1.0, 3.0, 6.0], dtype=torch.int64)
)
else:
self.assertEqual(
cumsum1, torch.tensor([6.0, 6.0, 5.0, 3.0], dtype=torch.int64)
)
self.assertEqual(cumsum1, cumsum_exp)
@unittest.skipIf(not SM70OrLater, "triton")
@unittest.skipIf(not torch.cuda.is_available(), "Test requires CUDA.")
@parametrize("reverse", [False, True])
@parametrize("combine_mode", ["pointwise", "generic"])
@parametrize("device", [torch.device("cpu"), torch.device("cuda")])
# Skipping the combination of combine_mode=pointwise and device=cpu
# as the current implementation of pointwise does only support CUDA device
@decorateIf(
unittest.skip,
lambda params: (
params["combine_mode"] == "pointwise"
and (params["device"] == torch.device("cpu") or torch.version.hip)
),
)
def test_pointwise_associative_scan_dim(self, reverse, combine_mode, device):
import random
def add(x: torch.Tensor, y: torch.Tensor):
return x + y
def mul(x: torch.Tensor, y: torch.Tensor):
return x * y
num_dims = [random.randint(2, 5) for _ in range(10)]
for num_dim in num_dims:
shapes = [random.randint(1, 10) for _ in range(num_dim)]
rnd_scan_dim = random.randint(0, num_dim - 1)
x = torch.randn(*shapes, device=device)
for op, op_pt in [(add, torch.cumsum), (mul, torch.cumprod)]:
result = associative_scan(
op, x, rnd_scan_dim, reverse=reverse, combine_mode=combine_mode
)
result_exp = _fake_associative_scan(
op, x, rnd_scan_dim, reverse=reverse
)
self.assertEqual(result, result_exp)
if not reverse:
result_exp_PT = op_pt(x, rnd_scan_dim)
self.assertEqual(result, result_exp_PT)
@unittest.skipIf(not SM70OrLater, "triton")
@unittest.skipIf(not torch.cuda.is_available(), "Test requires CUDA.")
@parametrize("reverse", [False, True])
@parametrize("combine_mode", ["pointwise", "generic"])
@parametrize("compile_mode", ["compile", "compile_dynamic_shape"])
@parametrize("device", [torch.device("cpu"), torch.device("cuda")])
# Skipping the combination of combine_mode=pointwise and device=cpu
# as the current implementation of pointwise does only support CUDA device
@decorateIf(
unittest.skip,
lambda params: (
params["combine_mode"] == "pointwise"
and (params["device"] == torch.device("cpu") or torch.version.hip)
),
)
def test_pointwise_associative_scan_compile(
self, reverse, combine_mode, compile_mode, device
):
def add(x: torch.Tensor, y: torch.Tensor):
return x + y
def mul(x: torch.Tensor, y: torch.Tensor):
return x * y
x = torch.randn(3, 10, 2, device=device)
torch.compiler.reset()
if compile_mode == "compile":
associative_scan_fct = torch.compile(
associative_scan, fullgraph=True, dynamic=False
)
else:
associative_scan_fct = torch.compile(
associative_scan, fullgraph=True, dynamic=True
)
for op, op_pt in [(add, torch.cumsum), (mul, torch.cumprod)]:
result = associative_scan_fct(
op, x, 0, reverse=reverse, combine_mode=combine_mode
)
result_exp = _fake_associative_scan(op, x, 0, reverse=reverse)
self.assertEqual(result, result_exp)
if not reverse:
result_exp_PT = op_pt(x, 0)
self.assertEqual(result, result_exp_PT)
# Jax Examples
x = torch.arange(0, 4, device=device)
cumsum1 = associative_scan(
add, x, 0, reverse=reverse, combine_mode=combine_mode
)
cumsum_exp = _fake_associative_scan(add, x, 0, reverse=reverse)
if not reverse:
self.assertEqual(
cumsum1, torch.tensor([0.0, 1.0, 3.0, 6.0], dtype=torch.int64)
)
else:
self.assertEqual(
cumsum1, torch.tensor([6.0, 6.0, 5.0, 3.0], dtype=torch.int64)
)
self.assertEqual(cumsum1, cumsum_exp)
@unittest.skipIf(not SM70OrLater, "triton")
@unittest.skipIf(not torch.cuda.is_available(), "Test requires CUDA.")
@parametrize("reverse", [False, True])
@parametrize("combine_mode", ["pointwise", "generic"])
@parametrize("device", [torch.device("cpu"), torch.device("cuda")])
# Skipping the combination of combine_mode=pointwise and device=cpu
# as the current implementation of pointwise does only support CUDA device
@decorateIf(
unittest.skip,
lambda params: (
params["combine_mode"] == "pointwise"
and (params["device"] == torch.device("cpu") or torch.version.hip)
),
)
def test_pointwise_associative_scan_binary_operator(
self, reverse, combine_mode, device
):
def fct(x, y):
A_i, Bu_i = x
A_j, Bu_j = y
return A_j * A_i, A_j * Bu_i + Bu_j
torch.compiler.reset()
associative_scan1 = torch.compile(associative_scan, fullgraph=True)
associative_scan2 = associative_scan
state_dim = 20
timesteps = 10
projected_inputs = torch.randn(
timesteps, state_dim, requires_grad=True, device=device
)
A = torch.randn(state_dim, requires_grad=True, device=device)
elements = (A.repeat((timesteps, 1)), projected_inputs)
result1 = associative_scan1(
fct, elements, 0, combine_mode=combine_mode, reverse=reverse
)
result2 = associative_scan2(
fct, elements, 0, combine_mode=combine_mode, reverse=reverse
)
expected_result = _fake_associative_scan(fct, elements, 0, reverse=reverse)
self.assertEqual(
result1,
expected_result,
)
self.assertEqual([r.device.type for r in result1], [device.type] * len(result1))
self.assertEqual(
result2,
expected_result,
)
self.assertEqual([r.device.type for r in result2], [device.type] * len(result2))
@unittest.skipIf(not SM70OrLater, "triton")
@unittest.skipIf(not torch.cuda.is_available(), "Test requires CUDA.")
@parametrize("reverse", [False, True])
@parametrize("combine_mode", ["pointwise", "generic"])
@parametrize("device", [torch.device("cpu"), torch.device("cuda")])
# Skipping the combination of combine_mode=pointwise and device=cpu
# as the current implementation of pointwise does only support CUDA device
@decorateIf(
unittest.skip,
lambda params: (
params["combine_mode"] == "pointwise"
and (params["device"] == torch.device("cpu") or torch.version.hip)
),
)
def test_pointwise_associative_scan_tuple(self, reverse, combine_mode, device):
def fct(x, y):
return (x[0] + y[0], x[1] * y[1])
x = torch.randn(3, 2, 2, device=device, requires_grad=True)
y = torch.randn(3, 2, 2, device=device, requires_grad=True)
inp = (x, y)
result1 = associative_scan(
fct, inp, 0, reverse=reverse, combine_mode=combine_mode
)
expected_result = _fake_associative_scan(fct, inp, 0, reverse=reverse)
self.assertEqual(result1, expected_result)
@unittest.skipIf(not SM70OrLater, "triton")
@unittest.skipIf(not torch.cuda.is_available(), "Test requires CUDA.")
@parametrize("reverse", [False, True])
@parametrize("combine_mode", ["pointwise", "generic"])
@parametrize("device", [torch.device("cpu"), torch.device("cuda")])
# Skipping the combination of combine_mode=pointwise and device=cpu
# as the current implementation of pointwise does only support CUDA device
@decorateIf(
unittest.skip,
lambda params: (
params["combine_mode"] == "pointwise"
and (params["device"] == torch.device("cpu") or torch.version.hip)
),
)
def test_pointwise_associative_scan_complex_pytree(
self, reverse, combine_mode, device
):
def fct_wrong_pytree(x, y):
return {
"i": x["i"] * y["j"][0][0],
"k": 0.0,
"j": ([x["j"][1][0]["o"]], [{"o": torch.sin(x["i"])}]),
}
def fct_pointwise(x, y):
return {
"i": x["i"] * y["i"],
"j": (
[x["j"][0][0] * y["j"][0][0]],
[{"o": x["j"][1][0]["o"] + y["j"][1][0]["o"]}],
),
}
x = torch.randn(3, 2, 2, device=device, requires_grad=True)
y = torch.randn(3, 2, 2, device=device, requires_grad=True)
z = torch.randn(3, 2, 2, device=device, requires_grad=True)
inp = {"i": x, "j": ([y], [{"o": z}])}
with self.assertRaisesRegex(Exception, r"."):
result = associative_scan(fct_wrong_pytree, inp, 0, combine_mode="generic")
torch.compiler.reset()
associative_scan1 = torch.compile(associative_scan, fullgraph=True)
associative_scan2 = associative_scan
result1 = associative_scan1(
fct_pointwise, inp, 0, combine_mode=combine_mode, reverse=reverse
)
result2 = associative_scan2(
fct_pointwise, inp, 0, combine_mode=combine_mode, reverse=reverse
)
expected_result = _fake_associative_scan(fct_pointwise, inp, 0, reverse=reverse)
self.assertEqual(result1, expected_result)
self.assertEqual(result2, expected_result)
@unittest.skipIf(not SM70OrLater, "triton")
@unittest.skipIf(not torch.cuda.is_available(), "Test requires CUDA.")
@parametrize("reverse", [False, True])
@parametrize("device", [torch.device("cpu"), torch.device("cuda")])
def test_generic_associative_scan_generic_simple(self, reverse, device):
def non_pointwise(x: torch.Tensor, y: torch.Tensor):
W = torch.diag(torch.ones(2, device=device))
return x @ W + y @ W
x = torch.randn(3, 10, 2, device=device)
with self.assertRaisesRegex(Exception, ".*"):
out = associative_scan(
non_pointwise, x, 0, reverse=reverse, combine_mode="pointwise"
)
result1 = associative_scan(
non_pointwise, x, 0, reverse=reverse, combine_mode="generic"
)
result_expected = _fake_associative_scan(non_pointwise, x, 0, reverse=reverse)
self.assertEqual(result1, result_expected)
@unittest.skipIf(IS_WINDOWS, "Windows not supported for this test")
@skipIfNoDynamoSupport
class TestControlFlowTraced(TestCase):
def setUp(self):
torch._dynamo.reset()
super().setUp()
def _check_tracing(self, fn, args, allow_non_fake_inputs=False):
graphs = {}
eager_res = fn(*args)
for tracing_mode in ["symbolic", "real", "fake"]:
graph = make_fx(
fn,
tracing_mode=tracing_mode,
_allow_non_fake_inputs=allow_non_fake_inputs,
)(*args)
graphs[tracing_mode] = graph
self.assertEqual(graph(*args), eager_res)
return graphs
def _check_compile(self, fn, args, *, backend="eager"):
eager_res = fn(*args)
compiled_fn = torch.compile(fn, backend=backend)
self.assertEqual(compiled_fn(*args), eager_res)
def test_cond_traced_not_nested(self):
def true_fn(x):
return x.sin()
def false_fn(x):
return x.cos()
def f(x, y):
return cond(y, true_fn, false_fn, [x])
x = torch.randn(4)
graph = make_fx(f)(x, torch.tensor(False))
result_true = graph.forward(x, torch.tensor(True))
result_false = graph.forward(x, torch.tensor(False))
self.assertFalse(torch.allclose(result_true, result_false))
self.assertEqual(result_true, torch.sin(x))
self.assertEqual(result_false, torch.cos(x))
graph = make_fx(f, tracing_mode="symbolic")(x, torch.tensor(False))
self.assertEqual(graph(x, torch.tensor(True)), f(x, torch.tensor(True)))
@skipIfTorchDynamo("Graph is not captured by backend if test with dynamo")
def test_cond_simple_with_linear_compile_check_graph(self):
from torch._dynamo.testing import EagerAndRecordGraphs
def true_fn(x):
return x.sin()
def false_fn(x):
return x.cos()
x = torch.randn(4, requires_grad=True)
def f(pred, x):
result = cond(pred, true_fn, false_fn, (x,))
grad_out = torch.ones_like(result)
return torch.autograd.grad(result, (x,), grad_out)
backend = EagerAndRecordGraphs()
torch.compile(f, backend=backend)(torch.tensor(False), x)
self.assertEqual(len(backend.graphs), 2)
gm = backend.graphs[0]
self.assertExpectedInline(
gm.code.strip(),
"""\
def forward(self, L_pred_ : torch.Tensor, L_x_ : torch.Tensor):
l_pred_ = L_pred_
l_x_ = L_x_
cond_true_0 = self.cond_true_0
cond_false_0 = self.cond_false_0
cond = torch.ops.higher_order.cond(l_pred_, cond_true_0, cond_false_0, [l_x_]); l_pred_ = cond_true_0 = cond_false_0 = l_x_ = None
result = cond[0]; cond = None
grad_out = torch.ones_like(result)
return (result, grad_out)""", # noqa: B950
)
self.assertExpectedInline(
gm.cond_true_0.code.strip(),
"""\
def forward(self, l_x_):
l_x__1 = l_x_
sin = l_x__1.sin(); l_x__1 = None
return (sin,)""", # noqa: B950
)
self.assertExpectedInline(
gm.cond_false_0.code.strip(),
"""\
def forward(self, l_x_):
l_x__1 = l_x_
cos = l_x__1.cos(); l_x__1 = None
return (cos,)""", # noqa: B950
)
backward_gm = backend.graphs[1]
self.assertExpectedInline(
backward_gm.code.strip(),
"""\
def forward(self, L_ctx_saved_tensors_0_ : torch.Tensor, L_ctx_pred : torch.Tensor, L_flat_grads_0_ : torch.Tensor):
l_ctx_saved_tensors_0_ = L_ctx_saved_tensors_0_
l_ctx_pred = L_ctx_pred
l_flat_grads_0_ = L_flat_grads_0_
cond_true_0 = self.cond_true_0
cond_false_0 = self.cond_false_0
cond = torch.ops.higher_order.cond(l_ctx_pred, cond_true_0, cond_false_0, [l_ctx_saved_tensors_0_, l_flat_grads_0_]); l_ctx_pred = cond_true_0 = cond_false_0 = l_ctx_saved_tensors_0_ = l_flat_grads_0_ = None
getitem = cond[0]; cond = None
return (getitem,)""", # noqa: B950
)
def test_while_loop_nested_traced(self):
fn, inp = WHILE_LOOP_TESTS["nested"]
graphs = self._check_tracing(fn, inp)
self.assertExpectedInline(
graphs["symbolic"].code.strip("\n"),
"""\
def forward(self, out_iter_1, it_1, y_1):
while_loop_cond_graph_0 = self.while_loop_cond_graph_0
while_loop_body_graph_0 = self.while_loop_body_graph_0
while_loop = torch.ops.higher_order.while_loop(while_loop_cond_graph_0, while_loop_body_graph_0, (out_iter_1, it_1, y_1), ()); while_loop_cond_graph_0 = while_loop_body_graph_0 = out_iter_1 = it_1 = y_1 = None
getitem = while_loop[0]
getitem_1 = while_loop[1]
getitem_2 = while_loop[2]; while_loop = None
return (getitem, getitem_1, getitem_2)
""", # noqa: B950
)
self.assertExpectedInline(
graphs["symbolic"].while_loop_cond_graph_0.code.strip("\n"),
"""\
def forward(self, arg0_1, arg1_1, arg2_1):
sum_1 = torch.ops.aten.sum.default(arg0_1); arg0_1 = None
lt = torch.ops.aten.lt.Scalar(sum_1, 2); sum_1 = None
return lt
""",
)
self.assertExpectedInline(
graphs["symbolic"].while_loop_body_graph_0.code.strip("\n"),
"""\
def forward(self, arg0_1, arg1_1, arg2_1):
while_loop_cond_graph_0 = self.while_loop_cond_graph_0
while_loop_body_graph_0 = self.while_loop_body_graph_0
while_loop = torch.ops.higher_order.while_loop(while_loop_cond_graph_0, while_loop_body_graph_0, (arg0_1, arg1_1, arg2_1), ()); while_loop_cond_graph_0 = while_loop_body_graph_0 = arg0_1 = arg1_1 = arg2_1 = None
getitem = while_loop[0]
getitem_1 = while_loop[1]
getitem_2 = while_loop[2]; while_loop = None
add = torch.ops.aten.add.Tensor(getitem, 1); getitem = None
return (add, getitem_1, getitem_2)
""", # noqa: B950
)
def _wrap_with_functionalize(self, fn, func_type):
mode = None
if func_type == "cpp":
fn = CppFunctionalizeAPI().functionalize(fn)
elif func_type == "python":
fn = PythonFunctionalizeAPI().functionalize(fn)
mode = FunctionalTensorMode()
elif func_type == "functorch":
fn = torch.func.functionalize(fn)
else:
assert func_type == "no"
return fn, mode
@parametrize("func_type", ["no", "cpp", "python", "functorch"])
def test_while_loop_simple_functionalize_check_graph(self, func_type):
fn, inp = WHILE_LOOP_TESTS["simple_with_mutation"]
fn, mode = self._wrap_with_functionalize(fn, func_type)
mode = mode if mode is not None else contextlib.nullcontext()
with mode:
graphs = self._check_tracing(fn, inp)
if func_type == "no":
self.assertExpectedInline(
graphs["symbolic"].code.strip("\n"),
"""\
def forward(self, x_1):
while_loop_cond_graph_0 = self.while_loop_cond_graph_0
while_loop_body_graph_0 = self.while_loop_body_graph_0
while_loop = torch.ops.higher_order.while_loop(while_loop_cond_graph_0, while_loop_body_graph_0, (x_1,), ()); while_loop_cond_graph_0 = while_loop_body_graph_0 = x_1 = None
getitem = while_loop[0]; while_loop = None
return (getitem,)
""", # noqa: B950
)
self.assertExpectedInline(
graphs["symbolic"].while_loop_cond_graph_0.code.strip("\n"),
"""\
def forward(self, arg0_1):
clone = torch.ops.aten.clone.default(arg0_1); arg0_1 = None
add_ = torch.ops.aten.add_.Tensor(clone, 1); clone = None
add__1 = torch.ops.aten.add_.Tensor(add_, -1); add_ = None
sum_1 = torch.ops.aten.sum.default(add__1); add__1 = None
lt = torch.ops.aten.lt.Scalar(sum_1, 10); sum_1 = None
return lt
""",
)
self.assertExpectedInline(
graphs["symbolic"].while_loop_body_graph_0.code.strip("\n"),
"""\
def forward(self, arg0_1):
clone = torch.ops.aten.clone.default(arg0_1); arg0_1 = None
add_ = torch.ops.aten.add_.Tensor(clone, 1); clone = None
add__1 = torch.ops.aten.add_.Tensor(add_, -1); add_ = None
add = torch.ops.aten.add.Tensor(add__1, 1); add__1 = None
return (add,)
""",
)
elif func_type == "python":
self.assertExpectedInline(
graphs["symbolic"].code.strip("\n"),
"""\
def forward(self, arg0_1):
while_loop_cond_graph_0 = self.while_loop_cond_graph_0
while_loop_body_graph_0 = self.while_loop_body_graph_0
while_loop = torch.ops.higher_order.while_loop(while_loop_cond_graph_0, while_loop_body_graph_0, (arg0_1,), ()); while_loop_cond_graph_0 = while_loop_body_graph_0 = arg0_1 = None
getitem = while_loop[0]; while_loop = None
return (getitem,)
""", # noqa: B950
)
self.assertExpectedInline(
graphs["symbolic"].while_loop_cond_graph_0.code.strip("\n"),
"""\
def forward(self, arg0_1):
clone = torch.ops.aten.clone.default(arg0_1); arg0_1 = None
add = torch.ops.aten.add.Tensor(clone, 1); clone = None
add_1 = torch.ops.aten.add.Tensor(add, -1); add = None
sum_1 = torch.ops.aten.sum.default(add_1); add_1 = None
lt = torch.ops.aten.lt.Scalar(sum_1, 10); sum_1 = None
return lt
""",
)
self.assertExpectedInline(
graphs["symbolic"].while_loop_body_graph_0.code.strip("\n"),
"""\
def forward(self, arg0_1):
clone = torch.ops.aten.clone.default(arg0_1); arg0_1 = None
add = torch.ops.aten.add.Tensor(clone, 1); clone = None
add_1 = torch.ops.aten.add.Tensor(add, -1); add = None
add_2 = torch.ops.aten.add.Tensor(add_1, 1); add_1 = None
return (add_2,)
""",
)
else:
self.assertExpectedInline(
graphs["symbolic"].code.strip("\n"),
"""\
def forward(self, x_1):
while_loop_cond_graph_0 = self.while_loop_cond_graph_0
while_loop_body_graph_0 = self.while_loop_body_graph_0
while_loop = torch.ops.higher_order.while_loop(while_loop_cond_graph_0, while_loop_body_graph_0, (x_1,), ()); while_loop_cond_graph_0 = while_loop_body_graph_0 = x_1 = None
getitem = while_loop[0]; while_loop = None
return (getitem,)
""", # noqa: B950
)
self.assertExpectedInline(
graphs["symbolic"].while_loop_cond_graph_0.code.strip("\n"),
"""\
def forward(self, arg0_1):
clone = torch.ops.aten.clone.default(arg0_1); arg0_1 = None
add = torch.ops.aten.add.Tensor(clone, 1); clone = None
add_1 = torch.ops.aten.add.Tensor(add, -1); add = None
sum_1 = torch.ops.aten.sum.default(add_1); add_1 = None
lt = torch.ops.aten.lt.Scalar(sum_1, 10); sum_1 = None
return lt
""",
)
self.assertExpectedInline(
graphs["symbolic"].while_loop_body_graph_0.code.strip("\n"),
"""\
def forward(self, arg0_1):
clone = torch.ops.aten.clone.default(arg0_1); arg0_1 = None
add = torch.ops.aten.add.Tensor(clone, 1); clone = None
add_1 = torch.ops.aten.add.Tensor(add, -1); add = None
add_2 = torch.ops.aten.add.Tensor(add_1, 1); add_1 = None
return (add_2,)
""",
)
@parametrize("func_type", ["no", "cpp", "python", "functorch"])
@parametrize("while_loop_test", list(WHILE_LOOP_TESTS.keys()))
def test_while_loop_functionalize(self, func_type, while_loop_test):
# simple_with_linear doesn't work becaue parameters and buffers
# are not inputs so they're not wrapped by functionalization and tracing.
if while_loop_test not in ("simple_with_linear", "nested_with_linear"):
fn, inp = WHILE_LOOP_TESTS[while_loop_test]
fn, mode = self._wrap_with_functionalize(fn, func_type)
mode = mode if mode is not None else contextlib.nullcontext()
with mode:
self._check_tracing(fn, inp)
@parametrize("while_loop_test", list(WHILE_LOOP_TESTS.keys()))
def test_while_loop_tracing(self, while_loop_test):
fn, inp = WHILE_LOOP_TESTS[while_loop_test]
allow_non_fake_inputs = (
False
if while_loop_test not in ("simple_with_linear", "nested_with_linear")
else True
)
self._check_tracing(fn, inp, allow_non_fake_inputs)
@parametrize("backend", ["eager", "aot_eager"])
@parametrize("while_loop_test", list(WHILE_LOOP_TESTS.keys()))
def test_while_loop_compile(self, backend, while_loop_test):
fn, inp = WHILE_LOOP_TESTS[while_loop_test]
self._check_compile(fn, inp, backend=backend)
@skipIfTorchDynamo("Graph is not captured by backend if test with dynamo")
def test_while_loop_simple_with_linear_compile_check_graph(self):
fn, inp = WHILE_LOOP_TESTS["simple_with_linear"]
from torch._dynamo.testing import EagerAndRecordGraphs
backend = EagerAndRecordGraphs()
torch.compile(fn, backend=backend)(*inp)
self.assertEqual(len(backend.graphs), 1)
gm = backend.graphs[0]
if torch._dynamo.config.inline_inbuilt_nn_modules:
self.assertExpectedInline(
gm.code.strip(),
"""\
def forward(self, L_iter_ : torch.Tensor, L_x_ : torch.Tensor, L_self_buffers_dec_ : torch.Tensor, L_self_modules_linear_parameters_weight_ : torch.nn.parameter.Parameter, L_self_modules_linear_parameters_bias_ : torch.nn.parameter.Parameter):
l_iter_ = L_iter_
l_x_ = L_x_
l_self_buffers_dec_ = L_self_buffers_dec_
l_self_modules_linear_parameters_weight_ = L_self_modules_linear_parameters_weight_
l_self_modules_linear_parameters_bias_ = L_self_modules_linear_parameters_bias_
cond_fn_0 = self.cond_fn_0
body_fn_0 = self.body_fn_0
while_loop = torch.ops.higher_order.while_loop(cond_fn_0, body_fn_0, (l_iter_, l_x_), (l_self_buffers_dec_, l_self_modules_linear_parameters_bias_, l_self_modules_linear_parameters_weight_)); cond_fn_0 = body_fn_0 = l_iter_ = l_x_ = l_self_buffers_dec_ = l_self_modules_linear_parameters_bias_ = l_self_modules_linear_parameters_weight_ = None
getitem = while_loop[0]
getitem_1 = while_loop[1]; while_loop = None
return (getitem, getitem_1)""", # noqa: B950
)
self.assertExpectedInline(
gm.cond_fn_0.code.strip(),
"""\
def forward(self, l_iter_, l_x_, l_self_buffers_dec__cond_fn, l_self_modules_linear_parameters_bias__body_fn, l_self_modules_linear_parameters_weight__body_fn):
sub = l_iter_ - l_self_buffers_dec__cond_fn; l_iter_ = l_self_buffers_dec__cond_fn = None
gt = sub > 0; sub = None
return gt""", # noqa: B950
)
self.assertExpectedInline(
gm.body_fn_0.code.strip(),
"""\
def forward(self, l_iter_, l_x_, l_self_buffers_dec__cond_fn, l_self_modules_linear_parameters_bias__body_fn, l_self_modules_linear_parameters_weight__body_fn):
child = l_iter_ - 1; l_iter_ = None
child_1 = torch._C._nn.linear(l_x_, l_self_modules_linear_parameters_weight__body_fn, l_self_modules_linear_parameters_bias__body_fn); l_x_ = l_self_modules_linear_parameters_weight__body_fn = l_self_modules_linear_parameters_bias__body_fn = None
return (child, child_1)""", # noqa: B950
)
else:
self.assertExpectedInline(
gm.code.strip(),
"""\
def forward(self, L_iter_ : torch.Tensor, L_x_ : torch.Tensor):
l_iter_ = L_iter_
l_x_ = L_x_
l__self___dec = self.L__self___dec
l__self___linear_weight = self.L__self___linear_weight
l__self___linear_bias = self.L__self___linear_bias
cond_fn_0 = self.cond_fn_0
body_fn_0 = self.body_fn_0
while_loop = torch.ops.higher_order.while_loop(cond_fn_0, body_fn_0, (l_iter_, l_x_), (l__self___dec, l__self___linear_bias, l__self___linear_weight)); cond_fn_0 = body_fn_0 = l_iter_ = l_x_ = l__self___dec = l__self___linear_bias = l__self___linear_weight = None
getitem = while_loop[0]
getitem_1 = while_loop[1]; while_loop = None
return (getitem, getitem_1)""", # noqa: B950
)
self.assertExpectedInline(
gm.cond_fn_0.code.strip(),
"""\
def forward(self, l_iter_, l_x_, l__self___dec_cond_fn, l__self___linear_bias_body_fn, l__self___linear_weight_body_fn):
sub = l_iter_ - l__self___dec_cond_fn; l_iter_ = l__self___dec_cond_fn = None
gt = sub > 0; sub = None
return gt""", # noqa: B950
)
self.assertExpectedInline(
gm.body_fn_0.code.strip(),
"""\
def forward(self, l_iter_, l_x_, l__self___dec_cond_fn, l__self___linear_bias_body_fn, l__self___linear_weight_body_fn):
child = l_iter_ - 1; l_iter_ = None
child_1 = torch._C._nn.linear(l_x_, l__self___linear_weight_body_fn, l__self___linear_bias_body_fn); l_x_ = l__self___linear_weight_body_fn = l__self___linear_bias_body_fn = None
return (child, child_1)""", # noqa: B950
)
def test_while_loop_nested2_traced(self):
fn, inp = WHILE_LOOP_TESTS["nested2"]
graphs = self._check_tracing(fn, inp)
gm = graphs["symbolic"]
outer_body = gm.while_loop_body_graph_0
outer_cond = gm.while_loop_cond_graph_0
inner_body = outer_body.while_loop_body_graph_0
inner_cond = outer_body.while_loop_cond_graph_0
self.assertExpectedInline(
gm.code.strip("\n"),
"""\
def forward(self, arg0_1, arg1_1, arg2_1, arg3_1):
while_loop_cond_graph_0 = self.while_loop_cond_graph_0
while_loop_body_graph_0 = self.while_loop_body_graph_0
while_loop = torch.ops.higher_order.while_loop(while_loop_cond_graph_0, while_loop_body_graph_0, (arg0_1, arg1_1, arg2_1, arg3_1), ()); while_loop_cond_graph_0 = while_loop_body_graph_0 = arg0_1 = arg1_1 = arg2_1 = arg3_1 = None
getitem = while_loop[0]
getitem_1 = while_loop[1]
getitem_2 = while_loop[2]
getitem_3 = while_loop[3]; while_loop = None
return (getitem, getitem_1, getitem_2, getitem_3)
""", # noqa: B950
)
self.assertExpectedInline(
outer_body.code.strip("\n"),
"""\
def forward(self, arg0_1, arg1_1, arg2_1, arg3_1):
while_loop_cond_graph_0 = self.while_loop_cond_graph_0
while_loop_body_graph_0 = self.while_loop_body_graph_0
while_loop = torch.ops.higher_order.while_loop(while_loop_cond_graph_0, while_loop_body_graph_0, (arg0_1, arg1_1, arg2_1, arg3_1), ()); while_loop_cond_graph_0 = while_loop_body_graph_0 = arg0_1 = arg1_1 = arg2_1 = arg3_1 = None
getitem = while_loop[0]
getitem_1 = while_loop[1]
getitem_2 = while_loop[2]
getitem_3 = while_loop[3]; while_loop = None
sub = torch.ops.aten.sub.Tensor(getitem, 1); getitem = None
clone = torch.ops.aten.clone.default(getitem_1); getitem_1 = None
mul = torch.ops.aten.mul.Tensor(getitem_2, 2); getitem_2 = None
div = torch.ops.aten.div.Tensor(getitem_3, 2); getitem_3 = None
return (sub, clone, mul, div)
""", # noqa: B950
)
self.assertExpectedInline(
outer_body.code.strip("\n"),
"""\
def forward(self, arg0_1, arg1_1, arg2_1, arg3_1):
while_loop_cond_graph_0 = self.while_loop_cond_graph_0
while_loop_body_graph_0 = self.while_loop_body_graph_0
while_loop = torch.ops.higher_order.while_loop(while_loop_cond_graph_0, while_loop_body_graph_0, (arg0_1, arg1_1, arg2_1, arg3_1), ()); while_loop_cond_graph_0 = while_loop_body_graph_0 = arg0_1 = arg1_1 = arg2_1 = arg3_1 = None
getitem = while_loop[0]
getitem_1 = while_loop[1]
getitem_2 = while_loop[2]
getitem_3 = while_loop[3]; while_loop = None
sub = torch.ops.aten.sub.Tensor(getitem, 1); getitem = None
clone = torch.ops.aten.clone.default(getitem_1); getitem_1 = None
mul = torch.ops.aten.mul.Tensor(getitem_2, 2); getitem_2 = None
div = torch.ops.aten.div.Tensor(getitem_3, 2); getitem_3 = None
return (sub, clone, mul, div)
""", # noqa: B950
)
self.assertExpectedInline(
inner_body.code.strip("\n"),
"""\
def forward(self, arg0_1, arg1_1, arg2_1, arg3_1):
clone = torch.ops.aten.clone.default(arg0_1); arg0_1 = None
sub = torch.ops.aten.sub.Tensor(arg1_1, 1); arg1_1 = None
add = torch.ops.aten.add.Tensor(arg2_1, 3.14); arg2_1 = None
sub_1 = torch.ops.aten.sub.Tensor(arg3_1, 2.71); arg3_1 = None
return (clone, sub, add, sub_1)
""",
)
self.assertExpectedInline(
inner_cond.code.strip("\n"),
"""\
def forward(self, arg0_1, arg1_1, arg2_1, arg3_1):
gt = torch.ops.aten.gt.Scalar(arg1_1, 0); arg1_1 = None
return gt
""",
)
def test_cond_nested_traced(self):
def true_nested(y):
return y * y
def false_nested(y):
return y + y
def true_fn(x, pred2):
z = cond(pred2, true_nested, false_nested, [x])
return x + z
def false_fn(x, _):
return x.cos()
def f(x, pred, pred2):
return cond(pred, true_fn, false_fn, [x, pred2])
x = torch.randn(4)
graph = make_fx(f)(x, torch.tensor(False), torch.tensor(False))
result_true_true = graph.forward(
x, torch.tensor(True), torch.tensor(True)
) # True + True -> x * x
result_true_false = graph.forward(
x, torch.tensor(True), torch.tensor(False)
) # True + True -> x + x
result_false_true = graph.forward(
x, torch.tensor(False), torch.tensor(True)
) # False + either -> cos
result_false_false = graph.forward(
x, torch.tensor(False), torch.tensor(False)
) # False + either -> cos
self.assertNotEqual(result_true_true, result_true_false)
self.assertFalse(torch.allclose(result_false_true, result_true_true))
self.assertEqual(result_false_true, result_false_false)
self.assertEqual(result_true_true, (x * x) + x)
self.assertEqual(result_true_false, x + x + x)
self.assertEqual(result_false_true, torch.cos(x))
graph = make_fx(f, tracing_mode="symbolic")(
x, torch.tensor(False), torch.tensor(False)
)
self.assertEqual(
graph(x, torch.tensor(True), torch.tensor(True)),
f(x, torch.tensor(True), torch.tensor(True)),
)
def test_cond_functionalized(self):
def true_fn(x):
y = x.sin()
y.add_(4)
return x.sin().max() + y.sum()
def false_fn(x):
return x.cos().min()
def f(x):
pred = x.shape[0] == 1
return cond(pred, true_fn, false_fn, [x])
example_inputs = (torch.ones(4, 5),)
functional_f = torch.func.functionalize(f)
self.assertEqual(functional_f(*example_inputs), f(*example_inputs))
graph_module = make_fx(torch.func.functionalize(f), tracing_mode="symbolic")(
*example_inputs
)
self.assertEqual(graph_module(*example_inputs), f(*example_inputs))
all_ops_in_true_branch = []
for node in graph_module.true_graph_0.graph.nodes:
if node.op == "call_function":
all_ops_in_true_branch.append(node.target)
self.assertFalse(any(op._schema.is_mutable for op in all_ops_in_true_branch))
self.assertEqual(graph_module(*example_inputs), f(*example_inputs))
def test_cond_accepts_torch_function_as_inputs(self):
a = torch.randn(3, 4)
b = torch.randn(3, 4)
def f(a, b):
return cond(a.sum() > 0, torch.add, torch.mul, (a, b))
gm = self._check_tracing(f, (a, b))["symbolic"]
self.assertExpectedInline(
gm.code.strip(),
"""\
def forward(self, a_1, b_1):
sum_1 = torch.ops.aten.sum.default(a_1)
gt = torch.ops.aten.gt.Scalar(sum_1, 0); sum_1 = None
true_graph_0 = self.true_graph_0
false_graph_0 = self.false_graph_0
cond = torch.ops.higher_order.cond(gt, true_graph_0, false_graph_0, [a_1, b_1]); gt = true_graph_0 = false_graph_0 = a_1 = b_1 = None
getitem = cond[0]; cond = None
return getitem""", # noqa: B950
)
self.assertExpectedInline(
gm.true_graph_0.code.strip(),
"""\
def forward(self, arg0_1, arg1_1):
add = torch.ops.aten.add.Tensor(arg0_1, arg1_1); arg0_1 = arg1_1 = None
return (add,)""",
)
self.assertExpectedInline(
gm.false_graph_0.code.strip(),
"""\
def forward(self, arg0_1, arg1_1):
mul = torch.ops.aten.mul.Tensor(arg0_1, arg1_1); arg0_1 = arg1_1 = None
return (mul,)""",
)
def test_cond_retrace_functionalized(self):
def true_fn(x):
return x.sin()
def false_fn(x):
return x.cos()
def f(x):
return cond(x.all(), true_fn, false_fn, (x,))
inp = torch.ones(1, 2)
gm_non_functional = make_fx(f, tracing_mode="real")(inp)
gm_functional = make_fx(
torch.func.functionalize(gm_non_functional), tracing_mode="real"
)(inp)
self.assertEqual(gm_functional(torch.zeros(1, 2)), f(torch.zeros(1, 2)))
def test_cond_subgraph_same_shape_env_as_parent(self):
def true_fn(x):
return x.sin() + 10
def false_fn(x):
return x.cos() - 20
def f(x, pred):
y = cond(pred, true_fn, false_fn, [x])
z = torch.add(y, y)
return z
symbolic_traced_graph = self._check_tracing(
f, (torch.ones(4), torch.Tensor([True]))
)["symbolic"]
graph_shape_env = symbolic_traced_graph.shape_env
def _node_shape_env_iter(gm):
for node in symbolic_traced_graph.graph.nodes:
if node.op == "call_function":
val = node.meta.get("val")
if isinstance(val, tuple):
for v in val:
yield v.fake_mode.shape_env
else:
yield val.fake_mode.shape_env
for shape_env in _node_shape_env_iter(symbolic_traced_graph):
self.assertTrue(shape_env is graph_shape_env)
for shape_env in _node_shape_env_iter(symbolic_traced_graph.true_graph_0):
self.assertTrue(shape_env is graph_shape_env)
for shape_env in _node_shape_env_iter(symbolic_traced_graph.false_graph_0):
self.assertTrue(shape_env is graph_shape_env)
def test_cond_functionalized_nested(self):
def true_true_fn(x):
y = x.cos()
y.add_(4)
return x.sin().max() + y.sin().max()
def true_false_fn(x):
return x.cos().min()
def true_fn(x):
pred = x.shape[0] == 1
return cond(pred, true_true_fn, true_false_fn, [x])
def false_fn(x):
return x.sum()
def f(x):
pred = x.shape[0] == 1
return cond(pred, true_fn, false_fn, [x])
example_inputs = (torch.ones(4, 5),)
functional_f = torch.func.functionalize(f)
self.assertEqual(functional_f(*example_inputs), f(*example_inputs))
graph_module = make_fx(torch.func.functionalize(f), tracing_mode="symbolic")(
*example_inputs
)
self.assertEqual(graph_module(*example_inputs), f(*example_inputs))
gm_true_true_branch = graph_module.true_graph_0.true_graph_0
self.assertEqual(graph_module(*example_inputs), f(*example_inputs))
all_ops = []
for node in gm_true_true_branch.graph.nodes:
if node.op == "call_function":
all_ops.append(node.target)
self.assertFalse(any(op._schema.is_mutable for op in all_ops))
def test_cond_functionalized_data_dependent_pred(self):
def true_fn(x):
return x.sin().sum()
def false_fn(x):
return x.cos().sum()
def f(x):
pred = x.nonzero().shape[0] == 1
return cond(pred, true_fn, false_fn, [x])
example_inputs = (torch.ones(4, 5),)
functional_f = torch.func.functionalize(f)
self.assertEqual(functional_f(*example_inputs), f(*example_inputs))
graph_module = make_fx(torch.func.functionalize(f))(*example_inputs)
self.assertEqual(graph_module(*example_inputs), f(*example_inputs))
# https://github.com/pytorch/pytorch/issues/126988
def test_cond_functionalized_input_mutation_on_true_brancte(self):
def true_fn(x):
view_x = x.view(x.shape)
view_x.add_(1)
return view_x.sin().sum()
def false_fn(x):
return x.cos().sum()
def f(x):
pred = x.shape[0] == 4
return cond(pred, true_fn, false_fn, [x])
example_inputs = (torch.ones(4, 5),)
# torch.cond inlines into one of the branches because the predicate
# is a constant.
gm = make_fx(torch.func.functionalize(f))(*example_inputs)
self.assertExpectedInline(
gm.code.strip(),
"""\
def forward(self, x_1):
view = torch.ops.aten.view.default(x_1, [4, 5])
add = torch.ops.aten.add.Tensor(view, 1); view = None
view_1 = torch.ops.aten.view.default(add, [4, 5]); add = None
view_2 = torch.ops.aten.view.default(view_1, [4, 5])
sin = torch.ops.aten.sin.default(view_2); view_2 = None
sum_1 = torch.ops.aten.sum.default(sin); sin = None
copy_ = torch.ops.aten.copy_.default(x_1, view_1); x_1 = view_1 = copy_ = None
return sum_1""",
)
# torch.cond triggers the check of the branches because the predicate
# is a SymBool.
with self.assertRaisesRegex(
UnsupportedAliasMutationException, "One of torch.cond branch"
):
make_fx(torch.func.functionalize(f), tracing_mode="symbolic")(
*example_inputs
)
# https://github.com/pytorch/pytorch/issues/126988
def test_cond_functionalized_input_mutation_on_false_branch(self):
def true_fn(x):
return x.sin().sum()
def false_fn(x):
view_x = x.view(x.shape)
view_x.add_(1)
return view_x.cos().sum()
def f(x):
pred = x.shape[0] == 4
return cond(pred, true_fn, false_fn, [x])
example_inputs = (torch.ones(5, 5),)
gm = make_fx(torch.func.functionalize(f))(*example_inputs)
# torch.cond inlines into one of the branches because the predicate
# is a constant.
self.assertExpectedInline(
gm.code.strip(),
"""\
def forward(self, x_1):
view = torch.ops.aten.view.default(x_1, [5, 5])
add = torch.ops.aten.add.Tensor(view, 1); view = None
view_1 = torch.ops.aten.view.default(add, [5, 5]); add = None
view_2 = torch.ops.aten.view.default(view_1, [5, 5])
cos = torch.ops.aten.cos.default(view_2); view_2 = None
sum_1 = torch.ops.aten.sum.default(cos); cos = None
copy_ = torch.ops.aten.copy_.default(x_1, view_1); x_1 = view_1 = copy_ = None
return sum_1""",
)
# torch.cond triggers the check of the branches because the predicate
# is a SymBool.
with self.assertRaisesRegex(
UnsupportedAliasMutationException, "One of torch.cond branch"
):
make_fx(torch.func.functionalize(f), tracing_mode="symbolic")(
*example_inputs
)
# https://github.com/pytorch/pytorch/issues/126988
def test_cond_functionalized_output_alias_input(self):
def true_fn(x):
return x
def false_fn(x):
view_x = x.view(x.shape)
return view_x
def f(x):
pred = x.shape[0] == 4
return cond(pred, true_fn, false_fn, [x])
example_inputs = (torch.ones(5, 5),)
gm = make_fx(torch.func.functionalize(f))(*example_inputs)
# torch.cond inlines into one of the branches because the predicate
# is a constant.
self.assertExpectedInline(
gm.code.strip(),
"""\
def forward(self, x_1):
view = torch.ops.aten.view.default(x_1, [5, 5]); x_1 = None
return view""",
)
# torch.cond triggers the check of the branches because the predicate
# is a SymBool.
with self.assertRaisesRegex(
UnsupportedAliasMutationException, "One of torch.cond branch"
):
make_fx(torch.func.functionalize(f), tracing_mode="symbolic")(
*example_inputs
)
# https://github.com/pytorch/pytorch/issues/126988
def test_cond_functionalized_nested_input_mutation(self):
def true_true_fn(x):
x.add_(4)
return x.sin().max()
def true_false_fn(x):
return x.cos().min()
def true_fn(x):
pred = x.shape[0] == 1
return cond(pred, true_true_fn, true_false_fn, [x])
def false_fn(x):
return x.sum()
def f(x):
pred = x.shape[0] == 1
return cond(pred, true_fn, false_fn, [x])
example_inputs = (torch.ones(4, 5),)
with self.assertRaisesRegex(
UnsupportedAliasMutationException, "One of torch.cond branch"
):
make_fx(torch.func.functionalize(f), tracing_mode="symbolic")(
*example_inputs
)
# https://github.com/pytorch/pytorch/issues/126988
def test_cond_functionalized_nested_input_mutation_with_aot_func(self):
def true_true_fn(x):
x.add_(4)
return x.sin().max()
def true_false_fn(x):
return x.cos().min()
def true_fn(x):
pred = x.shape[0] == 1
return cond(pred, true_true_fn, true_false_fn, [x])
def false_fn(x):
return x.sum()
def f(x):
pred = x.shape[0] == 1
return cond(pred, true_fn, false_fn, [x])
example_input = torch.ones(4, 5)
try:
example_input_func = to_fun_old(example_input)
torch._enable_functionalization(reapply_views=False)
f(example_input_func)
with self.assertRaisesRegex(
UnsupportedAliasMutationException, "One of torch.cond branch"
):
make_fx(f, tracing_mode="symbolic")(example_input_func)
finally:
torch._disable_functionalization()
def f_wrapper(func):
@functools.wraps(func)
def wrapper(*args, **kwargs):
torch._enable_functionalization(reapply_views=False)
try:
return func(*args, **kwargs)
finally:
torch._disable_functionalization()
return wrapper
with self.assertRaisesRegex(
UnsupportedAliasMutationException, "One of torch.cond branch"
):
make_fx(f_wrapper(f), tracing_mode="symbolic")(example_input_func)
# https://github.com/pytorch/pytorch/issues/126988
@xfailIfTorchDynamo
def test_cond_functionalized_input_aliasing_with_aot_func(self):
def true_fn(x):
return x
def false_fn(x):
view_x = x.view(x.shape)
return view_x
def f(x):
pred = x.sum() > 0
return cond(pred, true_fn, false_fn, [x])
example_input = torch.ones(5, 5)
try:
example_input_func = to_fun_old(example_input)
torch._enable_functionalization(reapply_views=False)
with self.assertRaisesRegex(
UnsupportedAliasMutationException,
"One of torch.cond branch might be aliasing",
):
f(example_input_func)
finally:
torch._disable_functionalization()
def f_wrapper(func):
@functools.wraps(func)
def wrapper(*args, **kwargs):
torch._enable_functionalization(reapply_views=False)
try:
func_args = pytree.tree_map(
lambda x: torch._to_functional_tensor(x)
if isinstance(x, torch.Tensor)
else x,
args,
)
func_kwargs = pytree.tree_map(
lambda x: torch._to_functional_tensor(x)
if isinstance(x, torch.Tensor)
else x,
kwargs,
)
return func(*func_args, **func_kwargs)
finally:
torch._disable_functionalization()
return wrapper
with self.assertRaisesRegex(
UnsupportedAliasMutationException,
"One of torch.cond branch might be aliasing",
):
make_fx(f_wrapper(f), tracing_mode="symbolic")(example_input)
def test_cond_functionalized_aot_func_check_functional(self):
def true_fn(x):
return x.cos()
def false_fn(x):
y = x.sin()
y.add_(5)
return y
def f(x):
pred = x.shape[0] == 4
return cond(pred, true_fn, false_fn, [x])
example_input = torch.ones(5, 5)
def f_wrapper(func):
@functools.wraps(func)
def wrapper(*args, **kwargs):
torch._enable_functionalization(reapply_views=False)
try:
func_args = pytree.tree_map(
lambda x: to_fun_old(x) if isinstance(x, torch.Tensor) else x,
args,
)
func_kwargs = pytree.tree_map(
lambda x: to_fun_old(x) if isinstance(x, torch.Tensor) else x,
kwargs,
)
return pytree.tree_map(
from_fun_old, func(*func_args, **func_kwargs)
)
finally:
torch._disable_functionalization()
return wrapper
result_gm = make_fx(f_wrapper(f), tracing_mode="symbolic")(example_input)
for node in result_gm.true_graph_0.graph.nodes:
if node.op == "call_function":
self.assertTrue(not node.target._schema.is_mutable)
for node in result_gm.false_graph_0.graph.nodes:
if node.op == "call_function":
self.assertTrue(not node.target._schema.is_mutable)
self.assertEqual(result_gm(torch.ones(5, 5)), f(torch.ones(5, 5)))
def test_cond_nested_traced_other_inputs(self):
def true_nested(y):
return y * y
def false_nested(y):
return y + y
def true_fn(k, pred2):
z = cond(pred2, true_nested, false_nested, [k])
return torch.add(torch.tensor([0.25, 0.25]), z)
def false_fn(k, _):
return k.cos()
def f(k, pred, pred2):
return cond(pred, true_fn, false_fn, [k, pred2])
x = torch.tensor([0.5, 0.5])
graph = make_fx(f)(x, torch.tensor(False), torch.tensor(False))
a = torch.tensor([1.0, 1.0])
result_true_true = graph.forward(a, torch.tensor(True), torch.tensor(True))
self.assertEqual(result_true_true, (a * a) + torch.tensor([0.25, 0.25]))
b = torch.tensor([2.0, 2.0])
result_true_true = graph.forward(b, torch.tensor(True), torch.tensor(True))
self.assertEqual(result_true_true, (b * b) + torch.tensor([0.25, 0.25]))
def test_cond_nested_traced_multi(self):
def true_a(y):
return y * y
def false_a(y):
return y + y
def true_b(y, z):
return y + z
def false_b(y, z):
return y * z
def f(x, pred, pred2):
a_out = cond(pred, true_a, false_a, [x])
b_out = cond(pred2, true_b, false_b, [x, x])
return a_out + b_out
x = torch.randn(4)
graph = make_fx(f)(x, torch.tensor(False), torch.tensor(False))
self.assertExpectedInline(
graph.code.strip(),
"""\
def forward(self, x_1, pred_1, pred2_1):
true_graph_0 = self.true_graph_0
false_graph_0 = self.false_graph_0
cond = torch.ops.higher_order.cond(pred_1, true_graph_0, false_graph_0, [x_1]); pred_1 = true_graph_0 = false_graph_0 = None
getitem = cond[0]; cond = None
true_graph_1 = self.true_graph_1
false_graph_1 = self.false_graph_1
cond_1 = torch.ops.higher_order.cond(pred2_1, true_graph_1, false_graph_1, [x_1]); pred2_1 = true_graph_1 = false_graph_1 = x_1 = None
getitem_1 = cond_1[0]; cond_1 = None
add = torch.ops.aten.add.Tensor(getitem, getitem_1); getitem = getitem_1 = None
return add""", # noqa: B950
)
self.assertExpectedInline(
graph.true_graph_0.code.strip(),
"""\
def forward(self, arg0_1):
mul = torch.ops.aten.mul.Tensor(arg0_1, arg0_1); arg0_1 = None
return (mul,)""",
)
def test_raise_error_on_mismatch_type_size(self):
def true_fn(x):
return x.sin()
def false_fn(x):
return (x, x)
def f(x, y):
return cond(y, true_fn, false_fn, [x])
x = torch.randn(4)
with self.assertRaisesRegex(
torch._dynamo.exc.CondOpArgsMismatchError,
"Expected to return same number of outputs but got:",
):
make_fx(f)(x, torch.tensor(False))
def test_raise_error_on_mismatch_tensor_size(self):
def true_fn(x):
return x.sin()
def false_fn(x):
return torch.zeros([10, 10])
def f(x, y):
return cond(y, true_fn, false_fn, [x])
x = torch.randn(4)
with self.assertRaisesRegex(
torch._dynamo.exc.UncapturedHigherOrderOpError,
"Cond doesn't work unless it is captured completely with torch.compile",
):
make_fx(f)(x, torch.tensor(False))
def test_cond_traced_not_nested_fake_tensor(self):
def true_fn(x):
return x.sin()
def false_fn(x):
return x.cos()
def f(x, y):
return cond(y, true_fn, false_fn, [x])
x = torch.randn(4)
graph = make_fx(f, tracing_mode="fake")(x, torch.tensor(False))
result_true = graph.forward(x, torch.tensor(True))
result_false = graph.forward(x, torch.tensor(False))
self.assertFalse(torch.allclose(result_true, result_false))
self.assertEqual(result_true, torch.sin(x))
self.assertEqual(result_false, torch.cos(x))
def test_cond_nested_traced_fake_tensor(self):
def true_nested(y):
return y * y
def false_nested(y):
return y + y
def true_fn(x, pred2):
z = cond(pred2, true_nested, false_nested, [x])
return x + z
def false_fn(x, _):
return x.cos()
def f(x, pred, pred2):
return cond(pred, true_fn, false_fn, [x, pred2])
x = torch.randn(4)
graph = make_fx(f, tracing_mode="fake")(
x, torch.tensor(False), torch.tensor(False)
)
result_true_true = graph.forward(
x, torch.tensor(True), torch.tensor(True)
) # True + True -> x * x
result_true_false = graph.forward(
x, torch.tensor(True), torch.tensor(False)
) # True + True -> x + x
result_false_true = graph.forward(
x, torch.tensor(False), torch.tensor(True)
) # False + either -> cos
result_false_false = graph.forward(
x, torch.tensor(False), torch.tensor(False)
) # False + either -> cos
self.assertNotEqual(result_true_true, result_true_false)
self.assertFalse(torch.allclose(result_false_true, result_true_true))
self.assertEqual(result_false_true, result_false_false)
self.assertEqual(result_true_true, (x * x) + x)
self.assertEqual(result_true_false, x + x + x)
self.assertEqual(result_false_true, torch.cos(x))
def test_cond_nested_traced_other_inputs_fake_tensor(self):
def true_nested(y):
return y * y
def false_nested(y):
return y + y
def true_fn(k, pred2):
z = cond(pred2, true_nested, false_nested, [k])
return torch.add(torch.tensor([0.25, 0.25]), z)
def false_fn(k, _):
return k.cos()
def f(k, pred, pred2):
return cond(pred, true_fn, false_fn, [k, pred2])
x = torch.tensor([0.5, 0.5])
graph = make_fx(f, tracing_mode="fake")(
x, torch.tensor(False), torch.tensor(False)
)
a = torch.tensor([1.0, 1.0])
result_true_true = graph.forward(a, torch.tensor(True), torch.tensor(True))
self.assertEqual(result_true_true, (a * a) + torch.tensor([0.25, 0.25]))
b = torch.tensor([2.0, 2.0])
result_true_true = graph.forward(b, torch.tensor(True), torch.tensor(True))
self.assertEqual(result_true_true, (b * b) + torch.tensor([0.25, 0.25]))
def test_cond_nested_traced_multi_fake_tensor(self):
def true_a(y):
return y * y
def false_a(y):
return y + y
def true_b(y, z):
return y + z
def false_b(y, z):
return y * z
def f(x, pred, pred2):
a_out = cond(pred, true_a, false_a, [x])
b_out = cond(pred2, true_b, false_b, [x, x])
return a_out + b_out
x = torch.randn(4)
graph = make_fx(f, tracing_mode="fake")(
x, torch.tensor(False), torch.tensor(False)
)
self.assertExpectedInline(
graph.code.strip(),
"""\
def forward(self, x_1, pred_1, pred2_1):
true_graph_0 = self.true_graph_0
false_graph_0 = self.false_graph_0
cond = torch.ops.higher_order.cond(pred_1, true_graph_0, false_graph_0, [x_1]); pred_1 = true_graph_0 = false_graph_0 = None
getitem = cond[0]; cond = None
true_graph_1 = self.true_graph_1
false_graph_1 = self.false_graph_1
cond_1 = torch.ops.higher_order.cond(pred2_1, true_graph_1, false_graph_1, [x_1]); pred2_1 = true_graph_1 = false_graph_1 = x_1 = None
getitem_1 = cond_1[0]; cond_1 = None
add = torch.ops.aten.add.Tensor(getitem, getitem_1); getitem = getitem_1 = None
return add""", # noqa: B950
)
self.assertExpectedInline(
graph.true_graph_0.code.strip(),
"""\
def forward(self, arg0_1):
mul = torch.ops.aten.mul.Tensor(arg0_1, arg0_1); arg0_1 = None
return (mul,)""",
)
def test_raise_error_on_mismatch_type_size_fake_tensor(self):
def true_fn(x):
return x.sin()
def false_fn(x):
return (x, x)
def f(x, y):
return cond(y, true_fn, false_fn, [x])
x = torch.randn(4)
with self.assertRaisesRegex(
torch._dynamo.exc.CondOpArgsMismatchError,
"Expected to return same number of outputs but got:",
):
make_fx(f, tracing_mode="fake")(x, torch.tensor(False))
def test_raise_error_on_mismatch_tensor_size_fake_tensor(self):
def true_fn(x):
return x.sin()
def false_fn(x):
return torch.zeros([10, 10])
def f(x, y):
return cond(y, true_fn, false_fn, [x])
x = torch.randn(4)
with self.assertRaisesRegex(
torch._dynamo.exc.UncapturedHigherOrderOpError,
"Cond doesn't work unless it is captured completely with torch.compile",
):
make_fx(f, tracing_mode="fake")(x, torch.tensor(False))
def check_map_count(self, gm, op_count):
i = 0
for m in gm.modules():
for node in m.graph.nodes:
if (
node.op == "call_function"
and node.target == torch.ops.higher_order.map_impl
):
i += 1
self.assertEqual(i, op_count)
def test_tracing_map_real(self):
def f(x, y):
return x + y
def g(xs, y):
return control_flow.map(f, xs, y)
gm = make_fx(g, tracing_mode="real")(torch.ones(3, 2, 2), torch.ones(2))
x = torch.randn(3, 2, 2)
y = torch.randn(2)
res = gm(x, y)
self.assertEqual(res, g(x, y))
self.check_map_count(gm, 1)
def test_tracing_map_symbolic_simple(self):
def f(x, y):
return x + y
def g(xs, y):
return control_flow.map(f, xs, y)
gm = make_fx(g, tracing_mode="symbolic")(torch.ones(3, 2, 4), torch.ones(4))
x = torch.randn(3, 2, 2)
y = torch.randn(2)
res = gm(x, y)
self.assertEqual(res, g(x, y))
self.check_map_count(gm, 1)
def test_tracing_map_symbolic_list(self):
def f(x, y):
return [x[0][0] + y, x[1] * y]
def g(xs, y, z):
out = control_flow.map(f, xs, y)
return out[0] + z, out[1] * z
example_x = [[torch.ones(3, 4, 5)], torch.ones(3, 4, 5)]
gm = make_fx(g, tracing_mode="symbolic")(
example_x, torch.ones(5), torch.ones(5)
)
x = [[torch.randn(4, 5, 6)], torch.ones(4, 5, 6)]
y = torch.randn(6)
z = torch.ones(6)
res = gm(x, y, z)
self.assertEqual(res, g(x, y, z))
self.check_map_count(gm, 1)
def test_tracing_map_symbolic_dict(self):
def f(x, y):
return {"d": x["b"]["a"] + y, "e": x["c"] * y}
def g(xs, y, z):
out = control_flow.map(f, xs, y)
return {"f": out["d"] + z, "g": out["e"] * z}
example_x = {"b": {"a": torch.ones(3, 4, 5)}, "c": torch.ones(3, 4, 5)}
gm = make_fx(g, tracing_mode="symbolic")(
example_x, torch.ones(5), torch.ones(5)
)
x = {"b": {"a": torch.randn(4, 5, 6)}, "c": torch.ones(4, 5, 6)}
y = torch.randn(6)
z = torch.ones(6)
res = gm(x, y, z)
self.assertEqual(res, g(x, y, z))
self.check_map_count(gm, 1)
def test_tracing_map_autograd_symbolic_simple(self):
def f(x, y):
return x + y
def g(xs, y):
out = control_flow.map(f, xs, y)
return torch.autograd.grad(out, (xs, y), torch.ones_like(out))
gm = make_fx(g, tracing_mode="symbolic")(
torch.ones(3, 4, 5, requires_grad=True), torch.ones(5, requires_grad=True)
)
x = torch.randn(4, 5, 6, requires_grad=True)
y = torch.randn(6, requires_grad=True)
res = gm(x, y)
self.assertEqual(res, g(x, y))
self.check_map_count(gm, 2)
def test_tracing_map_autograd_symbolic_list(self):
import torch.utils._pytree as pytree
def f(x, y):
return [x[0].cos() + y.sin(), x[1].sin() * y.cos()]
def g(xs, y):
out = control_flow.map(f, xs, y)
flat_out = pytree.tree_leaves(out)
flat_inp = pytree.tree_leaves((xs, y))
requires_grad_inp = [inp for inp in flat_inp if inp.requires_grad]
return torch.autograd.grad(
flat_out, requires_grad_inp, [torch.ones_like(out) for out in flat_out]
)
gm = make_fx(g, tracing_mode="symbolic")(
[torch.ones(3, 4, 5), torch.ones(3, 4, 5, requires_grad=True)],
torch.ones(5, requires_grad=True),
)
x = [torch.randn(4, 5, 6), torch.ones(4, 5, 6, requires_grad=True)]
y = torch.randn(6, requires_grad=True)
res = gm(x, y)
self.assertEqual(res, g(x, y))
self.check_map_count(gm, 2)
def test_tracing_map_autograd_symbolic_dict(self):
def f(x, y):
return [x["a"] + y, x["b"] * y]
def g(xs, y):
out = control_flow.map(f, xs, y)
flat_out = pytree.tree_leaves(out)
flat_inp = pytree.tree_leaves((xs, y))
requires_grad_inp = [inp for inp in flat_inp if inp.requires_grad]
return torch.autograd.grad(
flat_out, requires_grad_inp, [torch.ones_like(out) for out in flat_out]
)
traced_x = {
"a": torch.ones(3, 4, 5, requires_grad=True),
"b": torch.ones(3, 4, 5, requires_grad=True),
}
gm = make_fx(g, tracing_mode="symbolic")(
traced_x, torch.ones(5, requires_grad=True)
)
x = {
"a": torch.randn(4, 5, 6, requires_grad=True),
"b": torch.ones(4, 5, 6, requires_grad=True),
}
y = torch.randn(6, requires_grad=True)
res = gm(x, y)
self.assertEqual(res, g(x, y))
self.check_map_count(gm, 2)
def test_tracing_map_autograd_aot_functionalized(self):
def inner(x, y):
z = x - 1
z.add_(1)
return z * y
def f(xs, y):
res = control_flow.map(inner, xs, y)
grads = torch.autograd.grad(res, (xs, y), torch.ones_like(res))
return grads
def f_wrapper(func):
@functools.wraps(func)
def wrapper(*args, **kwargs):
torch._enable_functionalization(reapply_views=False)
try:
return pytree.tree_map(from_fun_old, func(*args, **kwargs))
finally:
torch._disable_functionalization()
return wrapper
example_inputs = (
torch.ones(3, 2, 4, requires_grad=True),
torch.ones(2, 4, requires_grad=True),
)
gm = make_fx(f, tracing_mode="symbolic")(*example_inputs)
fgm = make_fx(f_wrapper(f), tracing_mode="symbolic")(*example_inputs)
xs = torch.ones(3, 4, 5, requires_grad=True)
y = torch.ones(4, 5, requires_grad=True)
self.assertEqual(gm(xs, y), f(xs, y))
def count_mutable(gm):
c = 0
for node in gm.graph.nodes:
if node.op == "call_function":
if node.target == torch.ops.higher_order.map_impl:
c += count_mutable(getattr(gm, str(node.args[0])))
elif schema := getattr(node.target, "_schema", None):
c += int(schema.is_mutable)
return c
self.assertEqual(count_mutable(fgm), 0)
# One for forward, one for recomputation logic in backward
self.assertEqual(count_mutable(gm), 2)
def test_map_functionalized(self):
def map_fn(x, y):
z = x + y
z.add_(4)
return z
def f(xs, y):
return control_flow.map(map_fn, xs, y)
example_inputs = (torch.ones(3, 2, 4), torch.ones(4))
functional_f = torch.func.functionalize(f)
self.assertEqual(functional_f(*example_inputs), f(*example_inputs))
gm = make_fx(torch.func.functionalize(f))(*example_inputs)
self.assertEqual(gm(*example_inputs), f(*example_inputs))
gm = make_fx(torch.func.functionalize(f), tracing_mode="symbolic")(
*example_inputs
)
self.assertEqual(gm(*example_inputs), f(*example_inputs))
for node in gm.body_graph_0.graph.nodes:
if node.op == "call_function":
self.assertTrue(not node.target._schema.is_mutable)
self.check_map_count(gm, 1)
def test_map_functionalized_aot_func(self):
def map_fn(x, y):
z = x + y
z.add_(4)
return z
def f(xs, y):
return control_flow.map(map_fn, xs, y)
def f_wrapper(func):
@functools.wraps(func)
def wrapper(*args, **kwargs):
torch._enable_functionalization(reapply_views=False)
try:
return pytree.tree_map(from_fun_old, func(*args, **kwargs))
finally:
torch._disable_functionalization()
return wrapper
example_inputs = (torch.ones(3, 2, 4), torch.ones(4))
gm = make_fx(f_wrapper(f))(*example_inputs)
for node in gm.body_graph_0.graph.nodes:
if node.op == "call_function":
self.assertTrue(not node.target._schema.is_mutable)
self.assertEqual(gm(*example_inputs), f(*example_inputs))
# https://github.com/pytorch/pytorch/issues/126988
@xfailIfTorchDynamo
def test_map_functionalized_arg_mutation(self):
def map_fn(x, y):
y.add_(4)
return x + y
def f(xs, y):
return control_flow.map(map_fn, xs, y)
example_inputs = (torch.ones(3, 2, 4), torch.ones(4))
functional_f = torch.func.functionalize(f)
with self.assertRaisesRegex(
UnsupportedAliasMutationException, "torch.map is mutating the input!"
):
functional_f(*example_inputs)
# https://github.com/pytorch/pytorch/issues/126988
@xfailIfTorchDynamo
def test_map_functionalized_elem_mutation(self):
def map_fn(x, y):
x.add_(4)
return x + y
def f(xs, y):
return control_flow.map(map_fn, xs, y)
example_inputs = (torch.ones(3, 2, 4), torch.ones(4))
functional_f = torch.func.functionalize(f)
with self.assertRaisesRegex(
UnsupportedAliasMutationException, "torch.map is mutating the input!"
):
functional_f(*example_inputs)
def test_cond_autograd_backward(self):
def true_fn(x):
return x.cos()
def false_fn(x):
return x.sin()
def f(x, y):
return control_flow.cond(x.shape[0] > 4, true_fn, false_fn, [y])
example_inputs = (
torch.ones(3, 2, 4, requires_grad=True),
torch.ones(4, requires_grad=True),
)
f(*example_inputs).sum().backward()
# Ensure no error is thrown when not running backward
res = f(*example_inputs)
# Ensure no error is thrown when not running backward
res_compiled = torch.compile(f)(*example_inputs)
self.assertEqual(res, res_compiled)
# https://github.com/pytorch/pytorch/issues/126988
@xfailIfTorchDynamo
def test_map_functionalized_elem_alias(self):
def map_fn(x):
x.view(x.shape)
return x
def f(xs):
return control_flow.map(map_fn, xs)
example_inputs = (torch.ones(3, 2, 4),)
functional_f = torch.func.functionalize(f)
with self.assertRaisesRegex(
UnsupportedAliasMutationException, "torch.map is aliasing the input!"
):
functional_f(*example_inputs)
def test_nested_map_cond_real(self):
def true_fn(x, y):
return x * y
def false_fn(x, y):
return x + y
def f(x, pred, y):
return cond(pred, true_fn, false_fn, [x, y])
def g(pred, xs, y):
return control_flow.map(f, xs, pred, y)
gm = make_fx(g, tracing_mode="real")(
torch.tensor(True), torch.ones(3, 2, 4), torch.ones(4)
)
pred = torch.tensor(False)
x = torch.randn(3, 2, 4)
y = torch.randn(4)
res = gm(pred, x, y)
self.assertEqual(res, g(pred, x, y))
self.check_map_count(gm, 1)
def test_nested_map_cond_symbolic(self):
def true_fn(x, y):
return x * y
def false_fn(x, y):
return x + y
def f(x, pred, y):
return cond(pred, true_fn, false_fn, [x, y])
def g(pred, xs, y):
return control_flow.map(f, xs, pred, y)
gm = make_fx(g, tracing_mode="symbolic")(
torch.tensor(True), torch.ones(3, 2, 4), torch.ones(4)
)
pred = torch.tensor(False)
x = torch.randn(3, 2, 2)
y = torch.randn(2)
res = gm(pred, x, y)
self.assertEqual(res, g(pred, x, y))
self.check_map_count(gm, 1)
def test_nested_cond_map_cond_symbolic(self):
def true_fn(x, y):
return x * y
def false_fn(x, y):
return x + y
def f(x, pred, y):
return cond(pred, true_fn, false_fn, [x, y])
def g(pred, xs, y):
return control_flow.map(f, xs, pred, y)
def main_true_fn(pred, xs, y):
return g(pred, xs, y) * 2
def main_false_fn(pred, xs, y):
return g(pred, xs, y) + 1
def main(p, pred, xs, y):
return cond(p, main_true_fn, main_false_fn, [pred, xs, y])
gm = make_fx(main, tracing_mode="symbolic")(
torch.tensor(True), torch.tensor(True), torch.ones(3, 2, 4), torch.ones(4)
)
p = torch.tensor(False)
pred = torch.tensor(False)
xs = torch.randn(3, 2, 2)
y = torch.randn(2)
res = gm(p, pred, xs, y)
self.assertEqual(res, main(p, pred, xs, y))
self.check_map_count(gm, 2)
def test_cond_with_sym_pred(self):
def true_fn(x):
return x + x
def false_fn(x):
return x * x
def foo(x):
return cond(x.shape[0] == 4, true_fn, false_fn, [x])
gm = make_fx(foo, tracing_mode="symbolic")(torch.ones(3, 2, 1))
# The symbols in make_fx's shape_env should not be specialized.
self.assertEqual(len(gm.shape_env.guards), 0)
self.assertExpectedInline(
gm.code.strip(),
"""\
def forward(self, x_1):
sym_size_int = torch.ops.aten.sym_size.int(x_1, 0)
eq = sym_size_int == 4; sym_size_int = None
true_graph_0 = self.true_graph_0
false_graph_0 = self.false_graph_0
cond = torch.ops.higher_order.cond(eq, true_graph_0, false_graph_0, [x_1]); eq = true_graph_0 = false_graph_0 = x_1 = None
getitem = cond[0]; cond = None
return getitem""", # noqa: B950
)
# We expect the traced graph module to work even if input size changes.
x = torch.ones(4, 3, 2)
self.assertEqual(gm(x), true_fn(x))
self.assertEqual(foo(x), true_fn(x))
def test_cond_with_unbacked_sym_pred(self):
def foo(x):
def true_fn(x):
return x + x
def false_fn(x):
return x * x
az = x.nonzero()
return cond(az.shape[0] > 3, true_fn, false_fn, (x,))
gm = make_fx(foo, tracing_mode="symbolic")(torch.randn(7))
self.assertExpectedInline(
gm.code.strip(),
"""\
def forward(self, x_1):
nonzero = torch.ops.aten.nonzero.default(x_1)
sym_size_int = torch.ops.aten.sym_size.int(nonzero, 0); nonzero = None
gt = sym_size_int > 3; sym_size_int = None
true_graph_0 = self.true_graph_0
false_graph_0 = self.false_graph_0
cond = torch.ops.higher_order.cond(gt, true_graph_0, false_graph_0, [x_1]); gt = true_graph_0 = false_graph_0 = x_1 = None
getitem = cond[0]; cond = None
return getitem""",
)
def _check_closure_correctly_lifted(self, f, *, args, exp_res, exp_arg_num):
assert isinstance(args, (tuple, list))
self.assertEqual(f(*args), exp_res)
gm = make_fx(f)(*args)
self.assertEqual(gm(*args), exp_res)
def cnt_placeholder(gm):
return len([node for node in gm.graph.nodes if node.op == "placeholder"])
placeholder_cnts = [cnt_placeholder(mod) for mod in gm.children()]
self.assertTrue(all(cnt == exp_arg_num for cnt in placeholder_cnts))
def _check_closure_correctly_lifted_with_mutation(
self, f, closures_to_be_mutated, *, args, exp_arg_num
):
exp_res = f(*args)
self._check_closure_correctly_lifted(
f, args=args, exp_res=exp_res, exp_arg_num=exp_arg_num
)
for closure in closures_to_be_mutated:
closure.add(-1)
new_exp_res = f(*args)
self._check_closure_correctly_lifted(
f, args=args, exp_res=new_exp_res, exp_arg_num=exp_arg_num
)
def test_cond_with_tensor_closure(self):
a = torch.ones(2, 3)
b = torch.ones(2, 3) + 1
def true_fn(x):
return x + a
def false_fn(x):
return x + b
def foo(x):
return cond(x.shape[0] == 4, true_fn, false_fn, [x])
# expected branches takes [x, a, b] as input
inp = torch.randn(2, 3)
self._check_closure_correctly_lifted_with_mutation(
foo, (a, b), args=(inp,), exp_arg_num=3
)
def test_cond_with_tensor_closure_graph_module(self):
a = torch.ones(2, 3)
b = torch.ones(2, 3) + 1
def true_fn(x):
return x + a
def false_fn(x):
return x + b
def foo(x):
return cond(x.shape[0] == 4, true_fn, false_fn, [x])
# expected branches takes [x, a, b] as input
inp = torch.randn(2, 3)
gm = make_fx(foo, tracing_mode="symbolic", _allow_non_fake_inputs=True)(inp)
self.assertExpectedInline(
gm.code.strip(),
"""\
def forward(self, x_1):
sym_size_int = torch.ops.aten.sym_size.int(x_1, 0)
eq = sym_size_int == 4; sym_size_int = None
true_graph_0 = self.true_graph_0
false_graph_0 = self.false_graph_0
_tensor_constant0 = self._tensor_constant0
_tensor_constant1 = self._tensor_constant1
cond = torch.ops.higher_order.cond(eq, true_graph_0, false_graph_0, [x_1, _tensor_constant0, _tensor_constant1]); eq = true_graph_0 = false_graph_0 = x_1 = _tensor_constant0 = _tensor_constant1 = None
getitem = cond[0]; cond = None
return getitem""", # noqa: B950
)
self.assertExpectedInline(
gm.true_graph_0.code.strip(),
"""\
def forward(self, arg0_1, arg1_1, arg2_1):
add = torch.ops.aten.add.Tensor(arg0_1, arg1_1); arg0_1 = arg1_1 = None
return (add,)""",
)
def test_cond_with_module_param_closure(self):
class Mod(torch.nn.Module):
def __init__(self) -> None:
super().__init__()
self.register_parameter(
"param", torch.nn.Parameter(torch.ones(2, 3), requires_grad=False)
)
self.buffer = torch.nn.Buffer(torch.ones(2, 3) + 1)
my_mode = Mod()
def true_fn(x):
return x + my_mode.param
def false_fn(x):
return x + my_mode.buffer
def foo(x):
return cond(x.shape[0] == 4, true_fn, false_fn, [x])
inp = torch.ones(2, 3)
# expected both branches takes (x, param, buffer)
self._check_closure_correctly_lifted_with_mutation(
foo, (my_mode.param, my_mode.buffer), args=(inp,), exp_arg_num=3
)
def test_cond_with_module_python_scalar_closure(self):
def foo(x):
a = torch.ones(1, 1)
b = 1
def true_fn(x):
return x + a
def false_fn(x):
return x + b
return cond(x.shape[0] == 4, true_fn, false_fn, [x])
inp = torch.ones(2, 3)
res = inp + 1
# python scalar b is not lifted as input, so both branches take (x, a)
self._check_closure_correctly_lifted(
foo, args=(inp,), exp_res=res, exp_arg_num=2
)
def test_cond_nested_with_closure(self):
a = torch.ones(1, 1)
b = torch.ones(1, 1) + 1
def inner_true_fn(x):
return x + a
def inner_false_fn(x):
return x + b
def foo(x):
def true_fn(x):
return cond(x.shape[0] == 2, inner_true_fn, inner_false_fn, [x])
def false_fn(x):
return cond(x.shape[0] > 4, inner_true_fn, inner_false_fn, [x])
return cond(x.shape[0] == 4, true_fn, false_fn, [x])
inp = torch.ones(2, 3)
# For top-level cond, it take 3 arguments (x, a, b). Dynamo should
# realize that the nonlocal variables are same for the true and false
# branches, so it should de-dupe them.
# For second-level conds, it takes (x, a, b)
self._check_closure_correctly_lifted_with_mutation(
foo, (a, b), args=(inp,), exp_arg_num=3
)
def test_cond_nested_with_closure_graph_module(self):
a = torch.ones(1, 1)
b = torch.ones(1, 1) + 1
def inner_true_fn(x):
return x + a
def inner_false_fn(x):
return x + b
def foo(x):
def true_fn(x):
return cond(x.shape[0] == 2, inner_true_fn, inner_false_fn, [x])
def false_fn(x):
return cond(x.shape[0] > 4, inner_true_fn, inner_false_fn, [x])
return cond(x.shape[0] == 4, true_fn, false_fn, [x])
def test_map_unfunc_boolean_tensor_for_nested_map_cond(self):
def map_fn(pred, x):
def fn(x, pred):
return control_flow.cond(pred, lambda x: x * 2, lambda x: x / 2, (x,))
return control_flow.map(fn, x, pred)
def f_wrapper(func):
@functools.wraps(func)
def wrapper(*args, **kwargs):
torch._enable_functionalization(reapply_views=False)
try:
func_args = pytree.tree_map(
lambda x: to_fun_old(x) if isinstance(x, torch.Tensor) else x,
args,
)
func_kwargs = pytree.tree_map(
lambda x: to_fun_old(x) if isinstance(x, torch.Tensor) else x,
kwargs,
)
return pytree.tree_map(
from_fun_old, func(*func_args, **func_kwargs)
)
finally:
torch._disable_functionalization()
return wrapper
gm = make_fx(f_wrapper(map_fn))(
torch.tensor(True), torch.ones([2, 3], requires_grad=False)
)
self.assertExpectedInline(
gm.code.strip(),
"""\
def forward(self, pred_1, x_1):
body_graph_0 = self.body_graph_0
map_impl = torch.ops.higher_order.map_impl(body_graph_0, [x_1], [pred_1]); body_graph_0 = x_1 = pred_1 = None
getitem = map_impl[0]; map_impl = None
return getitem""",
)
self.assertExpectedInline(
gm.body_graph_0.code.strip(),
"""\
def forward(self, arg0_1, arg1_1):
true_graph_0 = self.true_graph_0
false_graph_0 = self.false_graph_0
cond = torch.ops.higher_order.cond(arg1_1, true_graph_0, false_graph_0, [arg0_1]); arg1_1 = true_graph_0 = false_graph_0 = arg0_1 = None
getitem = cond[0]; cond = None
return [getitem]""", # noqa: B950
)
def test_cond_make_fx_preserve_stack_trace_for_nodes_in_subgraph(self):
def true_fn(x):
return x + x.cos()
def false_fn(x):
return x * x.sin()
def foo(x):
return cond(x.shape[0] == 4, true_fn, false_fn, (x,))
inp = torch.randn([4, 3])
gm, _ = torch._dynamo.export(foo)(inp)
def run_with_interpreter(*args):
with torch.fx.traceback.preserve_node_meta():
return torch.fx.Interpreter(gm).run(*args)
new_gm = make_fx(run_with_interpreter)(inp)
checked_ops = {"add", "mul", "sin", "cos"}
checked_meta = ["source_fn_stack", "stack_trace"]
all_source_fns = collect_meta_for_filtered_nodes(gm, checked_ops, checked_meta)
new_source_fns = collect_meta_for_filtered_nodes(
new_gm, checked_ops, checked_meta
)
self.assertEqual(all_source_fns, new_source_fns)
@unittest.skipIf(
TEST_WITH_TORCHDYNAMO,
"triggers cache limit for foo and changes unique_graphs count.",
)
def test_cond_no_dynamo_cache_limit(self):
torch._dynamo.reset()
counters = torch._dynamo.utils.counters
counters.clear()
def foo(x, true_fn, false_fn):
return cond(x.sum() < 0, true_fn, false_fn, (x,))
inp = torch.ones(3, 4)
exp_out = inp.sin()
iter_n = torch._dynamo.config.cache_size_limit + 1
# Need this because Dynamo checks lambda code ID not object itself.
def make_dummy_fn(op):
exec(f"temp = lambda x: x.{op}()")
return locals()["temp"]
for _ in range(iter_n):
# each lambda has a different object id thus fails the guard
self.assertEqual(
foo(inp, make_dummy_fn("cos"), make_dummy_fn("sin")), exp_out
)
# each iteration captures a cond and a getitem from the tuple output
self.assertEqual(counters["stats"]["calls_captured"], iter_n * 2)
self.assertEqual(counters["stats"]["unique_graphs"], iter_n)
def test_cond_with_consecutive_make_fx_symbolic(self):
def true_fn(x):
return x - x.cos()
def false_fn(x):
return x + x.sin()
def foo(x):
return cond(x.shape[0] == 4, true_fn, false_fn, [x])
inps = (torch.ones(3, 4), torch.ones(3, 5), torch.ones(5, 4), torch.ones(5, 3))
for inp in inps:
gm = make_fx(foo, tracing_mode="symbolic")(torch.ones(3, 4))
self.assertExpectedInline(
gm.code.strip(),
"""\
def forward(self, x_1):
sym_size_int = torch.ops.aten.sym_size.int(x_1, 0)
eq = sym_size_int == 4; sym_size_int = None
true_graph_0 = self.true_graph_0
false_graph_0 = self.false_graph_0
cond = torch.ops.higher_order.cond(eq, true_graph_0, false_graph_0, [x_1]); eq = true_graph_0 = false_graph_0 = x_1 = None
getitem = cond[0]; cond = None
return getitem""", # noqa: B950
)
self.assertExpectedInline(
gm.true_graph_0.code.strip(),
"""\
def forward(self, arg0_1):
cos = torch.ops.aten.cos.default(arg0_1)
sub = torch.ops.aten.sub.Tensor(arg0_1, cos); arg0_1 = cos = None
return (sub,)""",
)
self.assertExpectedInline(
gm.false_graph_0.code.strip(),
"""\
def forward(self, arg0_1):
sin = torch.ops.aten.sin.default(arg0_1)
add = torch.ops.aten.add.Tensor(arg0_1, sin); arg0_1 = sin = None
return (add,)""",
)
def _create_test_fns_for_cond(
self, pred, inner_most_fn, operands, closure_list, nested_level
):
if nested_level == 0:
if len(closure_list) > 0:
def true_fn(*operands):
return inner_most_fn(*operands) + inner_most_fn(*closure_list)
def false_fn(*operands):
return inner_most_fn(*operands) - inner_most_fn(*closure_list)
else:
def true_fn(*operands):
return inner_most_fn(*operands)
def false_fn(*operands):
return inner_most_fn(*operands)
def fn(*operands):
if len(operands) == 0 and len(closure_list) == 0:
return torch.zeros(1)
return cond(pred, true_fn, false_fn, operands)
return operands, fn
else:
args, inner_fn = self._create_test_fns_for_cond(
pred <= 0, inner_most_fn, operands, closure_list, nested_level - 1
)
def true_fn(*operands):
return inner_most_fn(*operands) + inner_fn(*args)
def false_fn(*operands):
return inner_most_fn(*operands) - inner_fn(*args)
def fn(*operands):
if len(operands) == 0 and len(closure_list) == 0:
return torch.ones(1)
return cond(pred, true_fn, false_fn, operands)
return operands, fn
def _init_predicate(self, pred_type):
if pred_type == "bool":
return True
elif pred_type == "intTensor":
return torch.tensor(1)
elif pred_type == "floatTensor":
return torch.tensor(1.0)
elif pred_type == "boolTensor":
return torch.tensor(False)
else:
raise NotImplementedError
def _init_fn(self, inner_fn_type):
if inner_fn_type == "function":
return reduce_func
elif inner_fn_type == "module":
return ReduceMod()
elif inner_fn_type == "object":
return ReduceObj()
else:
raise NotImplementedError
@parametrize("predType", ["bool", "intTensor", "floatTensor", "boolTensor"])
@parametrize("innerFnType", ["function", "module", "object"])
@parametrize("nOperands", [0, 1])
@parametrize("nClosure", [0, 1])
@parametrize("nesting", [0, 2])
def test_cond_tracing_with_valid_inputs(
self, predType, innerFnType, nOperands, nClosure, nesting
):
pred = self._init_predicate(predType)
inner_fn = self._init_fn(innerFnType)
operands = [torch.ones(2, 3) + i for i in range(nOperands)]
closure = [torch.ones(2, 3) - i for i in range(nClosure)]
args, fn = self._create_test_fns_for_cond(
pred, inner_fn, operands, closure, nesting
)
eager_res = fn(*args)
for tracing_mode in ["symbolic", "fake", "real"]:
# set _allow_non_fake_inputs = True to allow fake prop through closures
with self.subTest(tracing_mode=tracing_mode):
gm = make_fx(
fn, tracing_mode=tracing_mode, _allow_non_fake_inputs=True
)(*args)
self.assertEqual(gm(*args), eager_res)
@parametrize("predType", ["boolTensor"])
@parametrize("innerFnType", ["function", "module", "object"])
@parametrize("nOperands", [1, 2])
@parametrize("nClosure", [0, 1])
@parametrize("nesting", [0])
def test_cond_vmap(self, predType, innerFnType, nOperands, nClosure, nesting):
pred = self._init_predicate(predType)
inner_fn = self._init_fn(innerFnType)
operands = [torch.ones(2, 3) + i for i in range(nOperands)]
closure = [torch.ones(2, 3) - i for i in range(nClosure)]
args, fn = self._create_test_fns_for_cond(
pred, inner_fn, operands, closure, nesting
)
eager_res = fn(*args)
out = torch.vmap(fn)(*args)
if nClosure == 0:
self.assertEqual(eager_res, out)
else:
self.assertEqual(eager_res, out[0])
self.assertEqual(eager_res, out[1])
def test_cond_vmap_simple(self):
def fn(x):
return torch.cond(
pred=torch.tensor([True]),
true_fn=lambda x: x + 100,
false_fn=lambda x: x,
operands=(x,),
)
a = torch.arange(15).reshape((3, 5))
res = torch.vmap(fn, in_dims=(0,))(a)
self.assertEqual(res.shape, (3, 5))
self.assertEqual(res, a + 100)
def test_cond_vmap_multiple_inputs(self):
def fn(x, y):
return torch.cond(
pred=x.sum() < y.sum(),
true_fn=lambda x, y: x + 100,
false_fn=lambda x, y: y,
operands=(x, y),
)
a = torch.arange(15).reshape(3, 5)
b = torch.ones_like(a) + 3
res = torch.vmap(fn, in_dims=(0, 0))(a, b)
expected = torch.tensor(
[[100, 101, 102, 103, 104], [4, 4, 4, 4, 4], [4, 4, 4, 4, 4]]
)
self.assertEqual(res.shape, (3, 5))
self.assertEqual(expected, res)
def test_cond_vmap_single_input_with_closure(self):
a = torch.ones((3, 5)) + 3
c = torch.arange(5)
def fn(x):
return torch.cond(
pred=torch.tensor([True]),
true_fn=lambda x: x + c,
false_fn=lambda x: x - c,
operands=(x,),
)
res = torch.vmap(fn, in_dims=(0,))(
a,
)
with unittest.mock.patch("torch._dynamo.config.error_on_recompile", True):
res = torch.vmap(fn, in_dims=(0,))(
a,
)
self.assertEqual(a + c, res)
def test_cond_vmap_multiple_args_with_closure(self):
a = torch.ones((3, 5), dtype=torch.int64) + 3
b = torch.arange(15).reshape(3, 5)
c = torch.arange(5)
def fn(x, y):
return torch.cond(
pred=torch.tensor([False]),
true_fn=lambda x, y: x + c,
false_fn=lambda x, y: y - c,
operands=(x, y),
)
res = torch.vmap(fn)(a, b)
self.assertEqual(b - c, res)
@parametrize("nClosure", [0, 1])
def test_cond_vmap_multiple_outputs(self, nClosure):
if nClosure:
c = torch.ones(5, dtype=torch.int64) + 5
def fn(x):
return torch.cond(
pred=torch.tensor([True]),
true_fn=lambda x: (x + c, x - c),
false_fn=lambda x: (x, x),
operands=(x,),
)
else:
def fn(x):
return torch.cond(
pred=torch.tensor([True]),
true_fn=lambda x: (x + 1, x - 1),
false_fn=lambda x: (x, x),
operands=(x,),
)
a = torch.arange(15).reshape(3, 5)
res = torch.vmap(fn)(
a,
)
self.assertEqual(len(res), 2)
if nClosure:
self.assertEqual(res, (a + c, a - c))
else:
self.assertEqual(res, (a + 1, a - 1))
def test_vmap_vmap(self):
def fn(x):
return torch.cond(
pred=torch.tensor([True]),
true_fn=lambda x: x + 1,
false_fn=lambda x: x - 1,
operands=(x,),
)
def wrapper(x):
return torch.vmap(fn)(x)
a = torch.ones((3, 4, 5))
res = torch.vmap(wrapper)(a)
self.assertEqual(res, a + 1)
def test_cond_trace_set__and_mutate_input(self):
def f(a, tmp):
a_view = a.view(-1)
with torch.no_grad():
a.set_(tmp)
a_view.mul_(2)
return a + tmp
inp = torch.ones(3, 3, requires_grad=True)
tmp = torch.ones(3, 3, requires_grad=True)
# graph break: torch._dynamo.exc.Unsupported: call_function DelayGraphBreakVariable() [TensorVariable()] {}
# due to set_
with self.assertRaisesRegex(
torch._dynamo.exc.UncapturedHigherOrderOpError,
"Cond doesn't work unless it is captured completely with torch.compile",
):
torch.cond(inp.sum() > 0, f, f, (inp, tmp))
def test_cond_trace_set__and_mutate_intermediate(self):
def f(a, tmp):
a = a.clone()
a_view = a.view(-1)
tmp = tmp.clone()
with torch.no_grad():
a.set_(tmp)
a_view.mul_(2)
return a + tmp
inp = torch.ones(3, 3, requires_grad=True)
tmp = torch.ones(3, 3, requires_grad=True)
class Mod(torch.nn.Module):
def forward(self, inp: torch.Tensor, tmp: torch.Tensor) -> torch.Tensor:
return torch.cond(inp.sum() > 0, f, f, (inp, tmp))
with self.assertRaisesRegex(
RuntimeError, "cannot mutate tensors with frozen storage"
):
out = torch.compile(Mod(), backend="aot_eager")(inp, tmp)
with self.assertRaisesRegex(
RuntimeError, "cannot mutate tensors with frozen storage"
):
out = torch.compile(Mod(), backend="inductor")(inp, tmp)
from torch._dynamo.testing import EagerAndRecordGraphs
backend = EagerAndRecordGraphs()
out = torch.compile(Mod(), backend=backend)(inp, tmp)
self.assertExpectedInline(
backend.graphs[0].cond_true_0.code.strip("\n"),
"""\
def forward(self, l_inp_, l_tmp_):
l_inp__1 = l_inp_
l_tmp__1 = l_tmp_
a = l_inp__1.clone(); l_inp__1 = None
a_view = a.view(-1)
tmp = l_tmp__1.clone(); l_tmp__1 = None
_set_grad_enabled = torch._C._set_grad_enabled(False); _set_grad_enabled = None
set_ = a.set_(tmp); set_ = None
mul_ = a_view.mul_(2); a_view = mul_ = None
_set_grad_enabled_1 = torch._C._set_grad_enabled(True); _set_grad_enabled_1 = None
add = a + tmp; a = tmp = None
return (add,)
""",
)
self.assertEqual(out, f(inp, tmp))
def test_two_hops_not_sharing_code_obj(self):
pred, args = torch.tensor(True), (torch.ones(3, 3),)
def fn1(x):
return x + 1
def fn2(x):
return x - 1
from torch._dynamo.testing import CompileCounter
# Tests rely on automatic_dynamic = True
with torch._dynamo.config.patch(automatic_dynamic_shapes=True):
cnt = CompileCounter()
torch.compile(torch.cond, backend=cnt)(pred, fn1, fn2, args)
self.assertEqual(cnt.frame_count, 1)
args = (torch.randn(3, 3),)
# No recompilation
torch.compile(torch.cond, backend=cnt)(pred, fn1, fn2, args)
self.assertEqual(cnt.frame_count, 1)
def cond_fn(x):
return x.sum() > 0
args = (torch.randn(4, 4),)
torch.compile(torch.while_loop, backend=cnt)(cond_fn, fn2, args)
# recompilation
self.assertEqual(cnt.frame_count, 2)
args = (torch.randn(4, 4),)
torch.compile(torch.while_loop, backend=cnt)(cond_fn, fn2, args)
self.assertEqual(cnt.frame_count, 2)
# With recompilation due to automatic dynamic
# This also proves that while_loop doesn't share code obj with cond
torch.compile(torch.cond, backend=cnt)(pred, fn1, fn2, (torch.randn(4, 4),))
self.assertEqual(cnt.frame_count, 3)
def test_hop_raises_if_not_overriding_call(self):
class WrongHop(torch._ops.HigherOrderOperator):
pass
with self.assertRaisesRegex(TypeError, "WrongHop"):
wrong_hop = WrongHop("wrong_hop")
_hop_schema_test_schema_types = [
"bool",
"int",
"float",
"str",
"Tensor",
"SymInt",
"SymBool",
"GraphModule",
"ScriptObj",
]
@unittest.skipIf(IS_WINDOWS, "Windows not supported for this test")
class TestHopSchema(TestCase):
def _get_example_val(self, ty: str):
from torch.fx.experimental.sym_node import SymNode
from torch.fx.experimental.symbolic_shapes import ShapeEnv
def create_symtype(cls, pytype, shape_env, val):
from torch._dynamo.source import ConstantSource
symbol = shape_env.create_symbol(
val,
source=ConstantSource(
f"__testing_hop_schema{len(shape_env.var_to_val)}"
),
)
return cls(SymNode(symbol, shape_env, pytype, hint=val))
if ty == "bool":
return True
elif ty == "int":
return 1
elif ty == "float":
return 1.0
elif ty == "str":
return "foo"
elif ty == "Tensor":
return torch.tensor(1)
elif ty == "SymInt":
shape_env = ShapeEnv()
return create_symtype(torch.SymInt, int, shape_env, 1)
elif ty == "SymBool":
shape_env = ShapeEnv()
return create_symtype(torch.SymBool, bool, shape_env, True)
elif ty == "GraphModule":
def f(x):
return x.sin()
return make_fx(f)(torch.ones(1))
elif ty == "ScriptObj":
from torch.testing._internal.torchbind_impls import (
init_torchbind_implementations,
)
init_torchbind_implementations()
foo = torch.classes._TorchScriptTesting._Foo(3, 4)
return foo
else:
raise NotImplementedError(ty)
@parametrize("schema_type", _hop_schema_test_schema_types)
def test_type_gen(self, schema_type):
from torchgen.gen_schema_utils import TypeGen
example_val = self._get_example_val(schema_type)
ty = TypeGen.from_example(example_val)
# Test the generated type can be parsed
self.assertEqual(ty.parse(str(ty)), ty)
@parametrize("schema_type", _hop_schema_test_schema_types)
def test_list_gen(self, schema_type):
from torchgen.gen_schema_utils import TypeGen
example_val = self._get_example_val(schema_type)
li1 = [example_val]
li2 = [example_val, example_val]
ty1 = TypeGen.from_example(li1)
ty2 = TypeGen.from_example(li1)
self.assertEqual(ty1.parse(str(ty1)), ty1)
self.assertEqual(ty2.parse(str(ty2)), ty2)
def test_function_schema_gen(self):
from torchgen.gen_schema_utils import FunctionSchemaGen
inps = [
(schema_type + "_v", self._get_example_val(schema_type))
for schema_type in _hop_schema_test_schema_types
]
op_name = "test_op"
schema1 = FunctionSchemaGen.from_example("test_op1", inps, torch.ones(1))
schema2 = FunctionSchemaGen.from_example(
"test_op2",
inps,
[
torch.ones(1),
],
)
schema3 = FunctionSchemaGen.from_example(
"test_op3", inps, [torch.ones(1), torch.ones(1)]
)
self.assertExpectedInline(
str(schema1),
"""test_op1(bool bool_v, int int_v, float float_v, str str_v, Tensor Tensor_v, SymInt SymInt_v, SymBool SymBool_v, GraphModule GraphModule_v, __torch__.torch.classes._Foo ScriptObj_v) -> Tensor""", # noqa: B950
)
self.assertExpectedInline(
str(schema2),
"""test_op2(bool bool_v, int int_v, float float_v, str str_v, Tensor Tensor_v, SymInt SymInt_v, SymBool SymBool_v, GraphModule GraphModule_v, __torch__.torch.classes._Foo ScriptObj_v) -> Tensor""", # noqa: B950
)
self.assertExpectedInline(
str(schema3),
"""test_op3(bool bool_v, int int_v, float float_v, str str_v, Tensor Tensor_v, SymInt SymInt_v, SymBool SymBool_v, GraphModule GraphModule_v, __torch__.torch.classes._Foo ScriptObj_v) -> (Tensor, Tensor)""", # noqa: B950,
)
self.assertEqual(schema1.parse(str(schema1)), schema1)
self.assertEqual(schema2.parse(str(schema2)), schema2)
self.assertEqual(schema3.parse(str(schema3)), schema3)
def test_while_loop_schema_gen(self):
fn, inp = WHILE_LOOP_TESTS["simple_with_linear"]
graph = make_fx(fn)(*inp).graph
while_loop_node = next(
node
for node in graph.nodes
if node.op == "call_function"
and node.target is torch.ops.higher_order.while_loop
)
schema = torch._library.utils.hop_schema_from_fx_node(while_loop_node)
self.assertExpectedInline(
str(schema),
"""while_loop(GraphModule cond_fn, GraphModule body_fn, Tensor[2] carried_inputs, Tensor[3] additional_inputs) -> Tensor[2]""", # noqa: B950
)
self.assertEqual(schema.parse(str(schema)), schema)
instantiate_parametrized_tests(TestHopSchema)
instantiate_parametrized_tests(TestControlFlowTraced)
instantiate_parametrized_tests(TestControlFlow)
if __name__ == "__main__":
run_tests()