| # Owner(s): ["module: inductor"] |
| import itertools |
| |
| import torch |
| |
| from torch.testing._internal.common_utils import ( |
| instantiate_parametrized_tests, |
| parametrize, |
| TestCase, |
| ) |
| from torch.testing._internal.inductor_utils import HAS_CPU, HAS_CUDA |
| from torch.testing._internal.triton_utils import requires_cuda |
| |
| |
| def prepend_predicates(inputs, num_predicates=1): |
| result = [] |
| device = inputs[0].device |
| # iterate over the cartesian product of predicate values |
| for p_values in itertools.product(*([[False, True]] * num_predicates)): |
| predicates = [torch.tensor(v, device=device) for v in p_values] |
| result.append((*predicates, *inputs)) |
| return result |
| |
| |
| class CondModels: |
| class Simple(torch.nn.Module): |
| def forward(self, p, a, b): |
| def true_fn(x, y): |
| return x + y |
| |
| def false_fn(x, y): |
| return x - y |
| |
| return torch.cond(p, true_fn, false_fn, [a, b]) |
| |
| class Nested(torch.nn.Module): |
| def forward(self, p0, p1, p2, a, b, c): |
| def true_fn(x0, y0, z0): |
| def true_true_fn(x1, y1, z1): |
| return (x1 - y1 * z1) * 3.14 |
| |
| def true_false_fn(x1, y1, z1): |
| def true_false_true_fn(x2, y2, z2): |
| return (x2 * y2 * z2) / 2.71 |
| |
| def true_false_false_fn(x2, y2, z2): |
| return (x2 + y2 + z2) * 1.23 |
| |
| return torch.cond( |
| p2, true_false_true_fn, true_false_false_fn, [x1, y1, z1] |
| ) |
| |
| return torch.cond(p1, true_true_fn, true_false_fn, [x0, y0, z0]) |
| |
| def false_fn(x0, y0, z0): |
| def false_true_fn(x1, y1, z1): |
| def false_true_true_fn(x2, y2, z2): |
| return (x2 - y2 - z2) + 1.23 |
| |
| def false_true_false_fn(x2, y2, z2): |
| return (x2 / y2 / z2) - 3.14 |
| |
| return torch.cond( |
| p2, false_true_true_fn, false_true_false_fn, [x1, y1, z1] |
| ) |
| |
| def false_false_fn(x1, y1, z1): |
| return (x1 - y1 * z1) / 2.71 |
| |
| return torch.cond(p1, false_true_fn, false_false_fn, [x0, y0, z0]) |
| |
| return torch.cond(p0, true_fn, false_fn, [a, b, c]) |
| |
| class Parameters(torch.nn.Module): |
| class InnerModel1(torch.nn.Module): |
| def __init__(self, device): |
| super().__init__() |
| self.layer = torch.nn.Linear(20, 30, device=device) |
| |
| def forward(self, x): |
| return self.layer(x + 1) * 3.14 |
| |
| class InnerModel2(torch.nn.Module): |
| def __init__(self, device): |
| super().__init__() |
| self.layer1 = torch.nn.Linear(20, 10, device=device) |
| self.layer2 = torch.nn.Linear(10, 30, device=device) |
| |
| def forward(self, x): |
| return self.layer2(self.layer1(x - 2)) * 3.14 |
| |
| def __init__(self, device): |
| super().__init__() |
| self.true_fn = self.InnerModel1(device) |
| self.false_fn = self.InnerModel2(device) |
| |
| def forward(self, p, a): |
| return torch.cond(p, self.true_fn, self.false_fn, [a]) |
| |
| class ReinterpretView(torch.nn.Module): |
| def forward(self, p, a, b): |
| def true_fn(x, y): |
| z1 = x + y |
| z2 = x - y |
| return z1[2:], z2[:, 4:] |
| |
| def false_fn(x, y): |
| z1 = x - y |
| z2 = x + y |
| return z1[2:], z2[:, 4:] |
| |
| return torch.cond(p, true_fn, false_fn, [a[:-1], b[:-1]]) |
| |
| class MultipleOutputs(torch.nn.Module): |
| def forward(self, p, a, b, c): |
| def true_fn(x, y, z): |
| return x * y, z / 2.71, (y - x).sum(dim=1) |
| |
| def false_fn(x, y, z): |
| return y / x, z * 3.14, (x + y).mean(dim=1) |
| |
| return torch.cond(p, true_fn, false_fn, [a, b, c]) |
| |
| class OuterCode(torch.nn.Module): |
| def forward(self, p, a, b): |
| c = a * b + 3.14 |
| d = a / b - 2.71 |
| |
| def true_fn(x, y): |
| return x + y |
| |
| def false_fn(x, y): |
| return x - y |
| |
| e = torch.cond(p, true_fn, false_fn, [c, d]) |
| |
| return e * e / 1.41 |
| |
| class OuterBuffers(torch.nn.Module): |
| def forward(self, p, a, b, c): |
| d = a * 2 |
| e = b / 2 |
| |
| def true_fn(x): |
| return x + d |
| |
| def false_fn(x): |
| return x - e |
| |
| return torch.cond(p, true_fn, false_fn, [c]) |
| |
| |
| class CondTests(TestCase): |
| def _run_test( |
| self, |
| model, |
| inputs, |
| device, |
| dynamic=False, |
| num_predicates=1, |
| ): |
| cnt = torch._dynamo.testing.CompileCounterWithBackend("inductor") |
| compiled_model = torch.compile(backend=cnt, fullgraph=True)(model) |
| |
| inputs = [inp.to(device=device) for inp in inputs] |
| input_sets = [inputs] |
| if dynamic: |
| larger_inputs = [] |
| for inp in inputs: |
| # tile every first dim 5x |
| tiling = [5] + [1] * (inp.ndim - 1) |
| larger_inputs.append(torch.tile(inp, tiling)) |
| input_sets.append(larger_inputs) |
| for inputs in input_sets: |
| for inp in inputs: |
| # mark every first dim as dynamic |
| torch._dynamo.mark_dynamic(inp, 0) |
| |
| for inputs in input_sets: |
| for inputs_with_predicates in prepend_predicates(inputs, num_predicates): |
| result = model(*inputs_with_predicates) |
| result_compiled = compiled_model(*inputs_with_predicates) |
| self.assertEqual(result, result_compiled) |
| |
| self.assertEqual(cnt.frame_count, 1, "only one compilation expected") |
| |
| @requires_cuda |
| @parametrize("device", ["cpu", "cuda"]) |
| @parametrize("dynamic", [False, True]) |
| def test_simple_control_flow(self, device, dynamic): |
| # cond control flow without nesting |
| self._run_test( |
| model=CondModels.Simple(), |
| inputs=( |
| torch.randn(10, 20), |
| torch.randn(10, 20), |
| ), |
| device=device, |
| dynamic=dynamic, |
| ) |
| |
| @requires_cuda |
| @parametrize("device", ["cpu", "cuda"]) |
| @parametrize("dynamic", [False, True]) |
| def test_nested_control_flow(self, device, dynamic): |
| # cond control flow with nesting |
| self._run_test( |
| model=CondModels.Nested(), |
| inputs=( |
| torch.randn(10, 20), |
| torch.randn(10, 20), |
| torch.randn(10, 20), |
| ), |
| device=device, |
| dynamic=dynamic, |
| num_predicates=3, |
| ) |
| |
| @requires_cuda |
| @parametrize("device", ["cpu", "cuda"]) |
| @parametrize("dynamic", [False, True]) |
| def test_outer_code_before_after(self, device, dynamic): |
| # some code before and after the conditional |
| self._run_test( |
| model=CondModels.OuterCode(), |
| inputs=( |
| torch.randn(10, 20), |
| torch.randn(10, 20), |
| ), |
| device=device, |
| dynamic=dynamic, |
| ) |
| |
| @requires_cuda |
| @parametrize("device", ["cpu", "cuda"]) |
| @parametrize("dynamic", [False, True]) |
| def test_multiple_outputs(self, device, dynamic): |
| # multiple outputs with different shapes |
| self._run_test( |
| model=CondModels.MultipleOutputs(), |
| inputs=( |
| torch.randn(10, 20), |
| torch.randn(10, 20), |
| torch.randn(30, 40), |
| ), |
| device=device, |
| dynamic=dynamic, |
| ) |
| |
| @requires_cuda |
| @parametrize("device", ["cpu", "cuda"]) |
| def test_advanced_dynamic_shapes(self, device): |
| # subgraphs input shapes include symbolic expressions |
| class Model(torch.nn.Module): |
| def forward(self, p, a, b): |
| def true_fn(x, y): |
| return torch.cat([x - 3, y * 3], dim=1) |
| |
| def false_fn(x, y): |
| return torch.cat([x / 3, y - 3], dim=1) |
| |
| c = torch.cat([a, b], dim=0) |
| d = c * 2 |
| e = c / 2 |
| |
| return torch.cond(p, true_fn, false_fn, [d, e]) |
| |
| self._run_test( |
| model=Model(), |
| inputs=( |
| torch.randn(2, 3, 3), |
| torch.randn(4, 3, 3), |
| ), |
| device=device, |
| dynamic=True, |
| ) |
| |
| @requires_cuda |
| def test_use_buffers_from_outer_scope(self): |
| # subgraphs input shapes include symbolic expressions |
| self._run_test( |
| model=CondModels.OuterBuffers(), |
| inputs=( |
| torch.randn(10, 20), |
| torch.randn(10, 20), |
| torch.randn(10, 20), |
| ), |
| device="cuda", |
| dynamic=False, |
| ) |
| |
| @requires_cuda |
| def test_reintepret_view_inputs_outputs(self): |
| # ReinterpretView in inputs and outputs of the subgraphs |
| self._run_test( |
| model=CondModels.ReinterpretView(), |
| inputs=( |
| torch.randn(10, 20), |
| torch.randn(10, 20), |
| ), |
| device="cuda", |
| dynamic=True, |
| ) |
| |
| @requires_cuda |
| @parametrize("device", ["cpu", "cuda"]) |
| @parametrize("dynamic", [False, True]) |
| def test_subgraphs_with_parameters(self, device, dynamic): |
| # nested Modules with parameters |
| self._run_test( |
| model=CondModels.Parameters(device), |
| inputs=(torch.randn(10, 20),), |
| device=device, |
| dynamic=dynamic, |
| ) |
| |
| @requires_cuda |
| def test_aliasing_outputs(self): |
| # output aliasing in subgraphs: not supported |
| class Model(torch.nn.Module): |
| def forward(self, p, a, b): |
| def true_fn(x, y): |
| z = x + y |
| return z, z[1:] |
| |
| def false_fn(x, y): |
| z = x - y |
| return z, z[1:] |
| |
| return torch.cond(p, true_fn, false_fn, [a, b]) |
| |
| # AssertionError: Output aliasing is currently not supported... |
| with self.assertRaises(torch._dynamo.exc.BackendCompilerFailed): |
| torch.compile(Model())( |
| torch.tensor(True), |
| torch.randn(10, 20), |
| torch.randn(10, 20), |
| ) |
| |
| @requires_cuda |
| @parametrize("device", ["cpu", "cuda"]) |
| def test_cond_decompose_ops_in_subgraph(self, device): |
| class Model(torch.nn.Module): |
| def forward(self, p, a): |
| def true_fn(x): |
| return torch.zeros_like(x) |
| |
| def false_fn(x): |
| return torch.ones_like(x) |
| |
| b = torch.ones_like(a) |
| c = torch.cond(p, true_fn, false_fn, [b]) |
| return c |
| |
| self._run_test( |
| model=Model(), |
| inputs=(torch.rand(10, 20),), |
| device=device, |
| ) |
| |
| @requires_cuda |
| @parametrize("device", ["cpu", "cuda"]) |
| def test_cond_decompose_ops_in_subgraph_recursive(self, device): |
| def inner_fn1(x): |
| return torch.zeros_like(x) |
| |
| def inner_fn2(x): |
| return torch.ones_like(x) |
| |
| class Model(torch.nn.Module): |
| def forward(self, p, a): |
| def true_fn(x): |
| return torch.cond(p, inner_fn2, inner_fn1, [x]) |
| |
| def false_fn(x): |
| return torch.cond(p, inner_fn1, inner_fn2, [x]) |
| |
| b = torch.ones_like(a) |
| c = torch.cond(p, true_fn, false_fn, [b]) |
| return c |
| |
| self._run_test( |
| model=Model(), |
| inputs=(torch.rand(10, 20),), |
| device=device, |
| ) |
| |
| @requires_cuda |
| def test_inductor_fx_passes_recursively_applied(self): |
| counters = {"pre_grad": 0, "post_grad": 0} |
| |
| def pre_grad_pass_counter(gm): |
| counters["pre_grad"] += 1 |
| |
| def post_grad_pass_counter(gm): |
| counters["post_grad"] += 1 |
| |
| with torch._inductor.config.patch( |
| { |
| "pre_grad_custom_pass": pre_grad_pass_counter, |
| "post_grad_custom_pre_pass": post_grad_pass_counter, |
| } |
| ): |
| self._run_test( |
| model=CondModels.Nested(), |
| inputs=( |
| torch.randn(10, 20), |
| torch.randn(10, 20), |
| torch.randn(10, 20), |
| ), |
| device="cuda", |
| dynamic=True, |
| num_predicates=3, |
| ) |
| |
| self.assertEqual(counters["pre_grad"], 11) |
| self.assertEqual(counters["post_grad"], 11) |
| |
| |
| instantiate_parametrized_tests(CondTests) |
| |
| |
| if __name__ == "__main__": |
| from torch._dynamo.test_case import run_tests |
| |
| if HAS_CPU or HAS_CUDA: |
| run_tests(needs="filelock") |