| # Owner(s): ["module: inductor"] |
| |
| import copy |
| import os |
| import unittest |
| |
| import torch |
| from torch import nn |
| from torch._dynamo.utils import counters, same |
| from torch._inductor import metrics |
| from torch._inductor.runtime.runtime_utils import do_bench_gpu as do_bench |
| from torch._inductor.test_case import TestCase |
| from torch.testing._internal.inductor_utils import GPU_TYPE, HAS_GPU |
| |
| DO_PERF_TEST = os.environ.get("DO_PERF_TEST") == "1" |
| |
| |
| class TestScatterOpt(TestCase): |
| def setUp(self): |
| super().setUp() |
| metrics.reset() |
| counters.clear() |
| |
| def check_metric(self, val=1): |
| self.assertEqual(val, metrics.num_matches_for_scatter_upon_const_tensor) |
| |
| def do_acc_test(self, f, *args): |
| expect = f(*args) |
| actual = torch.compile(f)(*args) |
| self.assertTrue(same(expect, actual, tol=1e-3), f"{expect=}\n{actual=}\n") |
| |
| def test_3d_tensor(self): |
| L, M, N = 2, 1024, 2048 |
| |
| def f(x): |
| y = torch.full([L, M, N], 3.14, dtype=torch.float) |
| y.scatter_(2, x.unsqueeze(2), 2.718) |
| return y |
| |
| x = torch.randint(0, N, (L, M), dtype=torch.int64) |
| self.do_acc_test(f, x) |
| expected_num_bytes = ( |
| L * M * N * torch.float.itemsize + L * M * torch.int64.itemsize |
| ) |
| self.assertEqual(metrics.num_bytes_accessed, expected_num_bytes) |
| |
| def test_non_last_dim(self): |
| """ |
| Test the case that the scatter dimension is not the last one. |
| """ |
| M, N = 1024, 2048 |
| |
| def f(x): |
| y = torch.full([M, N], 3.14, dtype=torch.float) |
| y.scatter_(0, x.unsqueeze(0), 2.718) |
| return y |
| |
| x = torch.randint(0, M, (N,), dtype=torch.int64) |
| self.do_acc_test(f, x) |
| expected_num_bytes = M * N * torch.float.itemsize + N * torch.int64.itemsize |
| self.assertEqual(metrics.num_bytes_accessed, expected_num_bytes) |
| |
| def test_neg_scatter_dim(self): |
| M, N = 1024, 2048 |
| |
| def f(x): |
| y = torch.full([M, N], 3.14, dtype=torch.float) |
| y.scatter_(-1, x.unsqueeze(1), 2.718) |
| return y |
| |
| x = torch.randint(0, N, (M,), dtype=torch.int64) |
| self.do_acc_test(f, x) |
| expected_num_bytes = M * N * torch.float.itemsize + M * torch.int64.itemsize |
| self.assertEqual(metrics.num_bytes_accessed, expected_num_bytes) |
| |
| def test_shorter_index_tensor(self): |
| M, N = 1024, 2048 |
| |
| def f(x): |
| y = torch.full([M, N], 3.14, dtype=torch.float) |
| y.scatter_(1, x.unsqueeze(1), 2.718) |
| return y |
| |
| x = torch.randint(0, N, (M // 2,), dtype=torch.int64) |
| self.do_acc_test(f, x) |
| |
| # no match since the index tensor is shorter. May support it in future. |
| self.assertEqual(0, counters["inductor"]["pattern_matcher_count"]) |
| |
| def test_nonzero_const_tensor(self): |
| M, N = 1024, 2048 |
| |
| def f(x): |
| y = torch.full([M, N], 3.14, dtype=torch.float) |
| y.scatter_(1, x.unsqueeze(1), 2.718) |
| return y |
| |
| x = torch.randint(0, N, (M,), dtype=torch.int64) |
| self.do_acc_test(f, x) |
| expected_num_bytes = M * N * torch.float.itemsize + M * torch.int64.itemsize |
| self.assertEqual(metrics.num_bytes_accessed, expected_num_bytes) |
| |
| def test_can_not_optimize_due_to_dense(self): |
| M, N = 1024, 2048 |
| |
| def f(x): |
| y = torch.full([M, N], 0, dtype=torch.float) |
| y.scatter_(1, x, 0.618) |
| return y |
| |
| x = torch.randint(0, N, (M, N // 2), dtype=torch.int64) |
| self.do_acc_test(f, x) |
| expected_num_bytes = M * N * torch.float.itemsize + M * (N // 2) * ( |
| torch.int64.itemsize + torch.float.itemsize |
| ) |
| # Use assertGreaterEqual rather than assertEqual due to the issue related |
| # to StarDep mentioned here: https://github.com/pytorch/pytorch/pull/129043#discussion_r1651699706 |
| self.assertGreaterEqual(metrics.num_bytes_accessed, expected_num_bytes) |
| |
| def test_can_not_optimize_due_to_non_const(self): |
| M, N = 1024, 2048 |
| |
| def f(x, y): |
| y.scatter_(1, x, 0.618) |
| return y |
| |
| x = torch.randint(0, N, (M, 1), dtype=torch.int64) |
| y = torch.randn([M, N]) |
| self.do_acc_test(f, x, y) |
| |
| # The generated code is quite in-efficient. |
| # There are 3 kernels |
| # 1. copy from arg to buf |
| # 2. scatter upon buf |
| # 3. copy buf back to arg |
| # Link to the wrapper: https://gist.github.com/shunting314/d43b74e680b3e5b514f7c28160c39f40 |
| expected_num_bytes = 4 * M * N * torch.float.itemsize + M * ( |
| torch.int64.itemsize + torch.float.itemsize |
| ) |
| self.assertGreaterEqual(metrics.num_bytes_accessed, expected_num_bytes) |
| |
| # the second kernel and third kernel are both mutation kernel. So we |
| # overestimated the memory accessed |
| # Update the test once the overestimiation is fixed. |
| over_estimate = M * torch.float.itemsize + M * N * torch.float.itemsize |
| self.assertEqual(metrics.num_bytes_accessed, expected_num_bytes + over_estimate) |
| |
| def test_cross_entropy_loss(self): |
| """ |
| Match full+scatter in CEL and replaces it with a pointwise. |
| |
| Perf data on an A100 GPU: |
| Without the scatter optimization: |
| ms=47.340, peak_mem=10.524 GB |
| With the scatter optimization: |
| ms=42.768, peak_mem=7.227 GB |
| """ |
| B, T, D, V = 32, 1024, 768, 50257 |
| if not DO_PERF_TEST: |
| # use a smaller V if not doing perf test to avoid OOM |
| # in CI |
| V = V // 100 |
| ref_model = nn.Linear(D, V).to(torch.bfloat16) |
| opt_model = copy.deepcopy(ref_model) |
| ce = nn.CrossEntropyLoss() |
| |
| def f(m, x, label): |
| ce(m(x).view(-1, V), label.view(-1)).backward() |
| |
| opt_f = torch.compile(f) |
| |
| x = torch.randn(B, T, D).to(torch.bfloat16) |
| label = torch.randint(0, V, (B, T)).to(torch.int64) |
| |
| f(ref_model, x, label) |
| ref_grad = ref_model.weight.grad |
| opt_f(opt_model, x, label) |
| act_grad = opt_model.weight.grad |
| assert torch.allclose( |
| ref_grad, act_grad, atol=1e-3, rtol=1e-3 |
| ), f"{ref_grad=}\n{act_grad=}" |
| |
| self.check_metric() |
| |
| if DO_PERF_TEST: |
| if GPU_TYPE == "xpu": |
| raise unittest.SkipTest( |
| "torch.xpu.reset_peak_memory_stats not implemented." |
| ) |
| torch.cuda.reset_peak_memory_stats() |
| for _ in range(3): |
| opt_f(opt_model, x, label) |
| ms = do_bench(lambda: opt_f(opt_model, x, label)) |
| peak_mem = torch.cuda.max_memory_allocated() / 10**9 |
| print(f"{ms=:.3f}, {peak_mem=:.3f} GB") |
| |
| |
| if HAS_GPU: |
| torch.set_default_device(GPU_TYPE) |
| |
| if __name__ == "__main__": |
| from torch._inductor.test_case import run_tests |
| |
| if HAS_GPU: |
| run_tests() |