blob: 6e5ad377de8be65c1295524d347dfa886f629d1d [file] [log] [blame]
# Owner(s): ["oncall: pt2"]
import random
import unittest
from math import prod
import torch
import torch._functorch.config as config
from torch.testing._internal.common_utils import run_tests, TEST_WITH_ROCM, TestCase
from torch.testing._internal.inductor_utils import HAS_CUDA
from torch.utils._triton import has_triton
from torch.utils.flop_counter import FlopCounterMode, register_flop_formula
if has_triton():
# note: if we only import triton in the test, the test fails:
# def relu_kernel_(inp_ptr, out_ptr, sz, BLOCK_SIZE: tl.constexpr):
# NameError('tl is not defined')
import triton
import triton.language as tl
def compile_with_ac(f, memory_budget):
return torch.compile(f, backend="aot_eager_decomp_partition")
def get_act_mem(f):
out = f()
out.backward()
start_mem = torch.cuda.memory_stats()["requested_bytes.all.current"]
out = f()
cur_mem = torch.cuda.memory_stats()["requested_bytes.all.current"]
act_mem = (cur_mem - start_mem) / (1024 * 1024)
out.backward()
return act_mem
def get_bw_flops(f):
# Normalized so that a 512 square matmul returns 1
f().backward()
out = f()
with FlopCounterMode(display=False) as mode:
out.backward()
return mode.get_total_flops() / (512**3 * 2)
def create_pair(B_I, O):
# results in B_I * O memory, requires B_I * B_I * O flops
# arithmetic intensity of B_I
x = torch.randn(B_I * 512, B_I * 512, requires_grad=True)
w = torch.randn(B_I * 512, O * 512, requires_grad=True)
return x, w
def get_mem_and_flops(f, memory_budget=None):
# Returns megabytes rounded to 1 decimal point and FLOPs
# Note that each value of size (512, 512, torch.float32) is 1 MiB
torch._dynamo.reset()
with config.patch(activation_memory_budget=memory_budget):
if memory_budget is not None:
f = torch.compile(f, backend="aot_eager_decomp_partition")
# We round this to nearest 10th of a megabyte.
return round(get_act_mem(f), 1), get_bw_flops(f)
class MemoryBudgetTest(TestCase):
def setUp(self):
super().setUp()
torch.set_default_device("cuda")
def test_rematerializes_cheap(self):
def f(x, w):
x = x.cos()
x = torch.mm(x, w)
return x.sum()
x = torch.randn(512, 512, requires_grad=True)
w = torch.randn(512, 512, requires_grad=True)
def call():
return f(x, w)
eager_mem, eager_flops = get_mem_and_flops(call)
self.assertEqual(eager_mem, 1.0)
mem_10, flops_10 = get_mem_and_flops(call, memory_budget=1.0)
# Recomputing `.cos()` is not free here.
self.assertEqual(mem_10, 1.0)
self.assertEqual(eager_flops, flops_10)
mem_5, flops_5 = get_mem_and_flops(call, memory_budget=0.5)
# We can just recompute `x.cos()` here to only depend on the inputs
self.assertEqual(mem_5, 0.0)
self.assertEqual(flops_5, eager_flops)
def test_matmul_even_chain(self):
def f(x, ws):
x = x.cos()
for w in ws:
x = torch.mm(x, w).cos()
return x.sum()
x = torch.randn(512, 512, requires_grad=True)
ws = [torch.randn(512, 512, requires_grad=True) for _ in range(5)]
def call():
return f(x, ws)
eager_mem, eager_flops = get_mem_and_flops(call)
for budget in range(0, 11):
mem, flops = get_mem_and_flops(call, memory_budget=budget / 10)
if budget <= 5:
# We start saving the matmuls
self.assertEqual(mem, budget)
self.assertEqual(flops, eager_flops + (5 - budget))
elif budget < 10:
# We're only recomputing the `cos` operations
self.assertEqual(mem, 5.0)
self.assertEqual(flops, eager_flops)
elif budget == 10:
self.assertEqual(mem, 10.0)
self.assertEqual(flops, eager_flops)
def test_matmul_uneven_chain(self):
# This function is constructed so that we are saving one input of size
# [512, in_dim] for each w
# In addition, every matmul has a same ratio of compute to "memory
# saved", so this test is essentially testing our knapsack solving
def f(x, ws):
xs = [torch.mm(x, w).cos() for w in ws]
return sum(x.sum() for x in xs)
x = torch.randn(512, 512, requires_grad=True)
def make_weights(w_shapes):
ws = []
for idx, dim in enumerate(w_shapes):
ws.append(torch.randn(512, dim * 512, requires_grad=True))
return ws
def make_weights_chain(w_shapes):
ws = []
for idx, _ in enumerate(w_shapes):
old_dim = 512 if idx == 0 else w_shapes[idx - 1] * 512
new_dim = w_shapes[idx] * 512
ws.append(torch.randn(old_dim, new_dim, requires_grad=True))
return ws
weight_configs = [
(
[11, 3, 4, 2],
[
18, # 11 + 4 + 3
17, # 11 + 4 + 2
16, # 11 + 3 + 2
15, # 11 + 4
14, # 11 + 3
13, # 11 + 2
11, # 11 + 2
7, # 4 + 3
6, # 4 + 2
5, # 3 + 2
],
),
(
[3, 5, 11, 17, 14],
[
42, # 17 + 14 + 9
30, # 11 + 15 + 5
19, # 11 + 5 + 3
8, # 5 + 3
3, # 3
],
),
]
random.seed(0)
random_arr = [random.randint(0, 50) for _ in range(10)]
exact_sums = []
for i in range(10):
random.shuffle(random_arr)
exact_sums.append(sum(random_arr[:i]))
weight_configs.append((random_arr, exact_sums))
for weight_shapes, exact_solves in weight_configs:
ws = make_weights(weight_shapes)
def call():
return f(x, ws)
eager_mem, eager_flops = get_mem_and_flops(call)
total_mem = sum(weight_shapes)
self.assertEqual(eager_mem, sum(weight_shapes))
for mem_achieved in exact_solves:
mem, _ = get_mem_and_flops(call, memory_budget=mem_achieved / total_mem)
self.assertEqual(mem, mem_achieved)
# needs CUDA, but this test file all needs CUDA.
@unittest.skipIf(not has_triton(), "test needs triton")
def test_custom_triton_kernel(self):
@triton.jit
def relu_kernel_(inp_ptr, out_ptr, sz, BLOCK_SIZE: tl.constexpr):
pid = tl.program_id(0)
block = tl.arange(0, BLOCK_SIZE) + pid * BLOCK_SIZE
msk = block < sz
inp = tl.load(inp_ptr + block, mask=msk)
relu = tl.where(inp < 0, 0, inp)
tl.store(out_ptr + block, relu, mask=msk)
@torch._library.triton_op("testac::triton_relu", mutates_args=())
def triton_relu(x: torch.Tensor) -> torch.Tensor:
y = torch.empty_like(x)
sz = y.numel()
BLOCK_SIZE = 256
grid = (triton.cdiv(sz, BLOCK_SIZE),)
torch._library.capture_triton(relu_kernel_)[grid](x, y, sz, BLOCK_SIZE)
return y
@torch._library.triton_op("testac::triton_relu_backward", mutates_args=())
def triton_relu_backward(grad_out: torch.Tensor) -> torch.Tensor:
grad_x = torch.empty_like(grad_out)
sz = grad_out.numel()
BLOCK_SIZE = 256
grid = (triton.cdiv(sz, BLOCK_SIZE),)
# I know this is wrong, but whatever..
torch._library.capture_triton(relu_kernel_)[grid](
grad_out, grad_x, sz, BLOCK_SIZE
)
return grad_x
def _triton_relu_backward(ctx, grad_out: torch.Tensor) -> torch.Tensor:
return triton_relu_backward(grad_out)
def _triton_relu_setup_context(ctx, inputs, output):
pass
triton_relu.register_autograd(
_triton_relu_backward,
setup_context=_triton_relu_setup_context,
)
@register_flop_formula(
[torch.ops.testac.triton_relu, torch.ops.testac.triton_relu_backward]
)
def triton_relu_flops(inp_shape, *args, **kwargs):
return prod(inp_shape)
def f(x, ws):
x = torch.ops.testac.triton_relu(x)
for w in ws:
x = torch.ops.testac.triton_relu(torch.mm(x, w))
return x.sum()
x = torch.randn(512, 512, requires_grad=True, device="cuda")
ws = [
torch.randn(512, 512, requires_grad=True, device="cuda") for _ in range(5)
]
def call():
return f(x, ws)
expected = call()
for budget in range(0, 11):
memory_budget = budget / 10
torch._dynamo.reset()
with config.patch(activation_memory_budget=memory_budget):
if memory_budget is not None:
f_compile = torch.compile(
call, backend="aot_eager_decomp_partition"
)
self.assertEqual(expected, f_compile())
def test_prioritize_cheaper_matmul(self):
def f(xs, ws):
xs = [torch.mm(x, w).cos() for x, w in zip(xs, ws)]
return sum(x.sum() for x in xs)
x1, w1 = create_pair(1, 4)
x2, w2 = create_pair(2, 2)
def call():
return f([x1, x2], [w1, w2])
eager_mem, eager_flops = get_mem_and_flops(call)
self.assertEqual(eager_mem, 8)
self.assertEqual(eager_flops, 24)
comp_mem, comp_flops = get_mem_and_flops(call, memory_budget=0.5)
self.assertEqual(comp_mem, 4)
# We are recomputing x1 @ w1 here!
self.assertEqual(comp_flops, eager_flops + 4)
@config.patch(activation_memory_budget_runtime_estimator="profile")
def test_profile(self):
def f(x, ws):
x = x.cos()
for w in ws:
x = torch.mm(x, w).cos()
return x.sum()
x = torch.randn(512, 512, requires_grad=True)
ws = [torch.randn(512, 512, requires_grad=True) for _ in range(5)]
def call():
return f(x, ws)
eager_mem, eager_flops = get_mem_and_flops(call)
mem, flops = get_mem_and_flops(call, memory_budget=0.2)
# We start saving the matmuls
self.assertEqual(mem, 2)
self.assertEqual(flops, eager_flops + 3)
def test_prioritize_cheaper_matmul2(self):
def f(xs, ws):
xs = [torch.mm(x, w).cos() for x, w in zip(xs, ws)]
return sum(x.sum() for x in xs)
data = [(4, 4), (6, 2), (2, 6)]
xs, ws = zip(*[create_pair(a, b) for a, b in data])
def call():
return f(xs, ws)
eager_mem, eager_flops = get_mem_and_flops(call)
self.assertEqual(eager_mem, 40)
self.assertEqual(eager_flops, 320)
mem, flops = get_mem_and_flops(call, memory_budget=28 / eager_mem)
# Save w1 and w2
self.assertEqual(mem, 28)
# We're recomputing w3 (the cheap one!)
self.assertEqual(flops - eager_flops, 2 * 2 * 6)
mem, flops = get_mem_and_flops(call, memory_budget=16 / eager_mem)
# Save w2. Note that even though saving w1 gets us closer to our memory
# limit, w2 is actually *more* FLOPs than w1!
self.assertEqual(mem, 12)
self.assertEqual(flops - eager_flops, 2 * 2 * 6 + 4 * 4 * 4)
def test_attention_vs_linear(self):
def f(x, w):
orig_shape = x.shape
x = x.reshape(1, 1, x.shape[0], x.shape[1])
# I know this isn't technically right lol
x = torch.nn.functional.scaled_dot_product_attention(
x, x, x, is_causal=False
).reshape(*orig_shape)
x = torch.mm(x, w)
x = x.cos()
return x.sum()
def try_seq_length(S, D, expected_recompute):
x = torch.randn(S * 512, D * 512, requires_grad=True)
w = torch.randn(D * 512, D * 512, requires_grad=True)
def call():
return f(x, w)
with FlopCounterMode(display=False) as mode:
call()
mm_flops = mode.get_flop_counts()["Global"][torch.ops.aten.mm]
attn_flops = mode.get_total_flops() - mm_flops
mm_flops /= 512**3 * 2
attn_flops /= 512**3 * 2
eager_mem, eager_flops = get_mem_and_flops(call)
self.assertEqual(eager_mem, S * D * 2)
mem, flops = get_mem_and_flops(
call, memory_budget=0.6
) # Force it to recompute one of mm or attn
self.assertEqual(mem, S * D)
if expected_recompute == "attn":
expected_flops = attn_flops
else:
expected_flops = mm_flops
self.assertEqual(flops - eager_flops, expected_flops)
# General behind this test is that if sequence length * 2 > D, then
# attention is more expensive than the linear.
try_seq_length(1, 1, "mm")
try_seq_length(1, 3, "attn")
try_seq_length(2, 2, "mm")
try_seq_length(2, 1, "mm")
try_seq_length(2, 5, "attn")
try_seq_length(4, 7, "mm")
try_seq_length(4, 9, "attn")
if __name__ == "__main__":
# I'm using the cuda memory allocator to verify memory allocations
if HAS_CUDA and not TEST_WITH_ROCM:
run_tests()