| # Owner(s): ["oncall: pt2"] |
| import random |
| import unittest |
| from math import prod |
| |
| import torch |
| import torch._functorch.config as config |
| from torch.testing._internal.common_utils import run_tests, TEST_WITH_ROCM, TestCase |
| from torch.testing._internal.inductor_utils import HAS_CUDA |
| from torch.utils._triton import has_triton |
| from torch.utils.flop_counter import FlopCounterMode, register_flop_formula |
| |
| |
| if has_triton(): |
| # note: if we only import triton in the test, the test fails: |
| # def relu_kernel_(inp_ptr, out_ptr, sz, BLOCK_SIZE: tl.constexpr): |
| # NameError('tl is not defined') |
| import triton |
| import triton.language as tl |
| |
| |
| def compile_with_ac(f, memory_budget): |
| return torch.compile(f, backend="aot_eager_decomp_partition") |
| |
| |
| def get_act_mem(f): |
| out = f() |
| out.backward() |
| start_mem = torch.cuda.memory_stats()["requested_bytes.all.current"] |
| out = f() |
| cur_mem = torch.cuda.memory_stats()["requested_bytes.all.current"] |
| act_mem = (cur_mem - start_mem) / (1024 * 1024) |
| out.backward() |
| return act_mem |
| |
| |
| def get_bw_flops(f): |
| # Normalized so that a 512 square matmul returns 1 |
| f().backward() |
| out = f() |
| with FlopCounterMode(display=False) as mode: |
| out.backward() |
| return mode.get_total_flops() / (512**3 * 2) |
| |
| |
| def create_pair(B_I, O): |
| # results in B_I * O memory, requires B_I * B_I * O flops |
| # arithmetic intensity of B_I |
| x = torch.randn(B_I * 512, B_I * 512, requires_grad=True) |
| w = torch.randn(B_I * 512, O * 512, requires_grad=True) |
| return x, w |
| |
| |
| def get_mem_and_flops(f, memory_budget=None): |
| # Returns megabytes rounded to 1 decimal point and FLOPs |
| # Note that each value of size (512, 512, torch.float32) is 1 MiB |
| torch._dynamo.reset() |
| with config.patch(activation_memory_budget=memory_budget): |
| if memory_budget is not None: |
| f = torch.compile(f, backend="aot_eager_decomp_partition") |
| |
| # We round this to nearest 10th of a megabyte. |
| return round(get_act_mem(f), 1), get_bw_flops(f) |
| |
| |
| class MemoryBudgetTest(TestCase): |
| def setUp(self): |
| super().setUp() |
| torch.set_default_device("cuda") |
| |
| def test_rematerializes_cheap(self): |
| def f(x, w): |
| x = x.cos() |
| x = torch.mm(x, w) |
| return x.sum() |
| |
| x = torch.randn(512, 512, requires_grad=True) |
| w = torch.randn(512, 512, requires_grad=True) |
| |
| def call(): |
| return f(x, w) |
| |
| eager_mem, eager_flops = get_mem_and_flops(call) |
| self.assertEqual(eager_mem, 1.0) |
| mem_10, flops_10 = get_mem_and_flops(call, memory_budget=1.0) |
| # Recomputing `.cos()` is not free here. |
| self.assertEqual(mem_10, 1.0) |
| self.assertEqual(eager_flops, flops_10) |
| mem_5, flops_5 = get_mem_and_flops(call, memory_budget=0.5) |
| # We can just recompute `x.cos()` here to only depend on the inputs |
| self.assertEqual(mem_5, 0.0) |
| self.assertEqual(flops_5, eager_flops) |
| |
| def test_matmul_even_chain(self): |
| def f(x, ws): |
| x = x.cos() |
| for w in ws: |
| x = torch.mm(x, w).cos() |
| return x.sum() |
| |
| x = torch.randn(512, 512, requires_grad=True) |
| ws = [torch.randn(512, 512, requires_grad=True) for _ in range(5)] |
| |
| def call(): |
| return f(x, ws) |
| |
| eager_mem, eager_flops = get_mem_and_flops(call) |
| for budget in range(0, 11): |
| mem, flops = get_mem_and_flops(call, memory_budget=budget / 10) |
| if budget <= 5: |
| # We start saving the matmuls |
| self.assertEqual(mem, budget) |
| self.assertEqual(flops, eager_flops + (5 - budget)) |
| elif budget < 10: |
| # We're only recomputing the `cos` operations |
| self.assertEqual(mem, 5.0) |
| self.assertEqual(flops, eager_flops) |
| elif budget == 10: |
| self.assertEqual(mem, 10.0) |
| self.assertEqual(flops, eager_flops) |
| |
| def test_matmul_uneven_chain(self): |
| # This function is constructed so that we are saving one input of size |
| # [512, in_dim] for each w |
| # In addition, every matmul has a same ratio of compute to "memory |
| # saved", so this test is essentially testing our knapsack solving |
| |
| def f(x, ws): |
| xs = [torch.mm(x, w).cos() for w in ws] |
| return sum(x.sum() for x in xs) |
| |
| x = torch.randn(512, 512, requires_grad=True) |
| |
| def make_weights(w_shapes): |
| ws = [] |
| for idx, dim in enumerate(w_shapes): |
| ws.append(torch.randn(512, dim * 512, requires_grad=True)) |
| return ws |
| |
| def make_weights_chain(w_shapes): |
| ws = [] |
| for idx, _ in enumerate(w_shapes): |
| old_dim = 512 if idx == 0 else w_shapes[idx - 1] * 512 |
| new_dim = w_shapes[idx] * 512 |
| ws.append(torch.randn(old_dim, new_dim, requires_grad=True)) |
| return ws |
| |
| weight_configs = [ |
| ( |
| [11, 3, 4, 2], |
| [ |
| 18, # 11 + 4 + 3 |
| 17, # 11 + 4 + 2 |
| 16, # 11 + 3 + 2 |
| 15, # 11 + 4 |
| 14, # 11 + 3 |
| 13, # 11 + 2 |
| 11, # 11 + 2 |
| 7, # 4 + 3 |
| 6, # 4 + 2 |
| 5, # 3 + 2 |
| ], |
| ), |
| ( |
| [3, 5, 11, 17, 14], |
| [ |
| 42, # 17 + 14 + 9 |
| 30, # 11 + 15 + 5 |
| 19, # 11 + 5 + 3 |
| 8, # 5 + 3 |
| 3, # 3 |
| ], |
| ), |
| ] |
| random.seed(0) |
| random_arr = [random.randint(0, 50) for _ in range(10)] |
| exact_sums = [] |
| for i in range(10): |
| random.shuffle(random_arr) |
| exact_sums.append(sum(random_arr[:i])) |
| weight_configs.append((random_arr, exact_sums)) |
| |
| for weight_shapes, exact_solves in weight_configs: |
| ws = make_weights(weight_shapes) |
| |
| def call(): |
| return f(x, ws) |
| |
| eager_mem, eager_flops = get_mem_and_flops(call) |
| total_mem = sum(weight_shapes) |
| self.assertEqual(eager_mem, sum(weight_shapes)) |
| for mem_achieved in exact_solves: |
| mem, _ = get_mem_and_flops(call, memory_budget=mem_achieved / total_mem) |
| self.assertEqual(mem, mem_achieved) |
| |
| # needs CUDA, but this test file all needs CUDA. |
| @unittest.skipIf(not has_triton(), "test needs triton") |
| def test_custom_triton_kernel(self): |
| @triton.jit |
| def relu_kernel_(inp_ptr, out_ptr, sz, BLOCK_SIZE: tl.constexpr): |
| pid = tl.program_id(0) |
| block = tl.arange(0, BLOCK_SIZE) + pid * BLOCK_SIZE |
| msk = block < sz |
| inp = tl.load(inp_ptr + block, mask=msk) |
| relu = tl.where(inp < 0, 0, inp) |
| tl.store(out_ptr + block, relu, mask=msk) |
| |
| @torch._library.triton_op("testac::triton_relu", mutates_args=()) |
| def triton_relu(x: torch.Tensor) -> torch.Tensor: |
| y = torch.empty_like(x) |
| sz = y.numel() |
| BLOCK_SIZE = 256 |
| grid = (triton.cdiv(sz, BLOCK_SIZE),) |
| torch._library.capture_triton(relu_kernel_)[grid](x, y, sz, BLOCK_SIZE) |
| return y |
| |
| @torch._library.triton_op("testac::triton_relu_backward", mutates_args=()) |
| def triton_relu_backward(grad_out: torch.Tensor) -> torch.Tensor: |
| grad_x = torch.empty_like(grad_out) |
| sz = grad_out.numel() |
| BLOCK_SIZE = 256 |
| grid = (triton.cdiv(sz, BLOCK_SIZE),) |
| # I know this is wrong, but whatever.. |
| torch._library.capture_triton(relu_kernel_)[grid]( |
| grad_out, grad_x, sz, BLOCK_SIZE |
| ) |
| return grad_x |
| |
| def _triton_relu_backward(ctx, grad_out: torch.Tensor) -> torch.Tensor: |
| return triton_relu_backward(grad_out) |
| |
| def _triton_relu_setup_context(ctx, inputs, output): |
| pass |
| |
| triton_relu.register_autograd( |
| _triton_relu_backward, |
| setup_context=_triton_relu_setup_context, |
| ) |
| |
| @register_flop_formula( |
| [torch.ops.testac.triton_relu, torch.ops.testac.triton_relu_backward] |
| ) |
| def triton_relu_flops(inp_shape, *args, **kwargs): |
| return prod(inp_shape) |
| |
| def f(x, ws): |
| x = torch.ops.testac.triton_relu(x) |
| for w in ws: |
| x = torch.ops.testac.triton_relu(torch.mm(x, w)) |
| return x.sum() |
| |
| x = torch.randn(512, 512, requires_grad=True, device="cuda") |
| ws = [ |
| torch.randn(512, 512, requires_grad=True, device="cuda") for _ in range(5) |
| ] |
| |
| def call(): |
| return f(x, ws) |
| |
| expected = call() |
| for budget in range(0, 11): |
| memory_budget = budget / 10 |
| torch._dynamo.reset() |
| with config.patch(activation_memory_budget=memory_budget): |
| if memory_budget is not None: |
| f_compile = torch.compile( |
| call, backend="aot_eager_decomp_partition" |
| ) |
| |
| self.assertEqual(expected, f_compile()) |
| |
| def test_prioritize_cheaper_matmul(self): |
| def f(xs, ws): |
| xs = [torch.mm(x, w).cos() for x, w in zip(xs, ws)] |
| return sum(x.sum() for x in xs) |
| |
| x1, w1 = create_pair(1, 4) |
| x2, w2 = create_pair(2, 2) |
| |
| def call(): |
| return f([x1, x2], [w1, w2]) |
| |
| eager_mem, eager_flops = get_mem_and_flops(call) |
| self.assertEqual(eager_mem, 8) |
| self.assertEqual(eager_flops, 24) |
| comp_mem, comp_flops = get_mem_and_flops(call, memory_budget=0.5) |
| self.assertEqual(comp_mem, 4) |
| # We are recomputing x1 @ w1 here! |
| self.assertEqual(comp_flops, eager_flops + 4) |
| |
| @config.patch(activation_memory_budget_runtime_estimator="profile") |
| def test_profile(self): |
| def f(x, ws): |
| x = x.cos() |
| for w in ws: |
| x = torch.mm(x, w).cos() |
| return x.sum() |
| |
| x = torch.randn(512, 512, requires_grad=True) |
| ws = [torch.randn(512, 512, requires_grad=True) for _ in range(5)] |
| |
| def call(): |
| return f(x, ws) |
| |
| eager_mem, eager_flops = get_mem_and_flops(call) |
| mem, flops = get_mem_and_flops(call, memory_budget=0.2) |
| # We start saving the matmuls |
| self.assertEqual(mem, 2) |
| self.assertEqual(flops, eager_flops + 3) |
| |
| def test_prioritize_cheaper_matmul2(self): |
| def f(xs, ws): |
| xs = [torch.mm(x, w).cos() for x, w in zip(xs, ws)] |
| return sum(x.sum() for x in xs) |
| |
| data = [(4, 4), (6, 2), (2, 6)] |
| xs, ws = zip(*[create_pair(a, b) for a, b in data]) |
| |
| def call(): |
| return f(xs, ws) |
| |
| eager_mem, eager_flops = get_mem_and_flops(call) |
| self.assertEqual(eager_mem, 40) |
| self.assertEqual(eager_flops, 320) |
| mem, flops = get_mem_and_flops(call, memory_budget=28 / eager_mem) |
| # Save w1 and w2 |
| self.assertEqual(mem, 28) |
| # We're recomputing w3 (the cheap one!) |
| self.assertEqual(flops - eager_flops, 2 * 2 * 6) |
| mem, flops = get_mem_and_flops(call, memory_budget=16 / eager_mem) |
| # Save w2. Note that even though saving w1 gets us closer to our memory |
| # limit, w2 is actually *more* FLOPs than w1! |
| self.assertEqual(mem, 12) |
| self.assertEqual(flops - eager_flops, 2 * 2 * 6 + 4 * 4 * 4) |
| |
| def test_attention_vs_linear(self): |
| def f(x, w): |
| orig_shape = x.shape |
| x = x.reshape(1, 1, x.shape[0], x.shape[1]) |
| # I know this isn't technically right lol |
| x = torch.nn.functional.scaled_dot_product_attention( |
| x, x, x, is_causal=False |
| ).reshape(*orig_shape) |
| x = torch.mm(x, w) |
| x = x.cos() |
| return x.sum() |
| |
| def try_seq_length(S, D, expected_recompute): |
| x = torch.randn(S * 512, D * 512, requires_grad=True) |
| w = torch.randn(D * 512, D * 512, requires_grad=True) |
| |
| def call(): |
| return f(x, w) |
| |
| with FlopCounterMode(display=False) as mode: |
| call() |
| mm_flops = mode.get_flop_counts()["Global"][torch.ops.aten.mm] |
| attn_flops = mode.get_total_flops() - mm_flops |
| mm_flops /= 512**3 * 2 |
| attn_flops /= 512**3 * 2 |
| |
| eager_mem, eager_flops = get_mem_and_flops(call) |
| self.assertEqual(eager_mem, S * D * 2) |
| |
| mem, flops = get_mem_and_flops( |
| call, memory_budget=0.6 |
| ) # Force it to recompute one of mm or attn |
| self.assertEqual(mem, S * D) |
| if expected_recompute == "attn": |
| expected_flops = attn_flops |
| else: |
| expected_flops = mm_flops |
| self.assertEqual(flops - eager_flops, expected_flops) |
| |
| # General behind this test is that if sequence length * 2 > D, then |
| # attention is more expensive than the linear. |
| try_seq_length(1, 1, "mm") |
| try_seq_length(1, 3, "attn") |
| try_seq_length(2, 2, "mm") |
| try_seq_length(2, 1, "mm") |
| try_seq_length(2, 5, "attn") |
| try_seq_length(4, 7, "mm") |
| try_seq_length(4, 9, "attn") |
| |
| |
| if __name__ == "__main__": |
| # I'm using the cuda memory allocator to verify memory allocations |
| if HAS_CUDA and not TEST_WITH_ROCM: |
| run_tests() |