blob: 4b97c520ad008e772244c8fa5c9d851887c6ff82 [file] [log] [blame]
# Owner(s): ["module: inductor"]
import contextlib
import functools
import gc
import importlib
import sys
import unittest
import warnings
import torch
import torch._dynamo
import torch.nn as nn
from torch._inductor import config
from torch._inductor.cudagraph_trees import cudagraphify_impl as tree_cudagraphify_impl
from torch.testing import FileCheck
from torch.testing._internal.common_utils import (
IS_CI,
IS_LINUX,
IS_WINDOWS,
TEST_WITH_ASAN,
TEST_WITH_ROCM,
TestCase as TorchTestCase,
)
if IS_WINDOWS and IS_CI:
sys.stderr.write(
"Windows CI does not have necessary dependencies for test_torchinductor yet\n"
)
if __name__ == "__main__":
sys.exit(0)
raise unittest.SkipTest("requires sympy/functorch/filelock")
importlib.import_module("functorch")
importlib.import_module("filelock")
from torch.testing._internal.inductor_utils import HAS_CPU, HAS_CUDA
HAS_MULTIGPU = HAS_CUDA and torch.cuda.device_count() >= 2
aten = torch.ops.aten
requires_cuda = functools.partial(unittest.skipIf, not HAS_CUDA, "requires cuda")
requires_multigpu = functools.partial(
unittest.skipIf, not HAS_MULTIGPU, "requires multiple cuda devices"
)
def cdata(t):
return t.untyped_storage()._cdata
class TestCase(TorchTestCase):
@classmethod
def setUpClass(cls):
super().setUpClass()
cls._stack = contextlib.ExitStack()
cls._stack.enter_context(
config.patch(
{
"debug": True,
"cpp.min_chunk_size": 1,
"triton.autotune_pointwise": False, # too slow
"implicit_fallbacks": False,
}
)
)
@classmethod
def tearDownClass(cls):
cls._stack.close()
super().tearDownClass()
def setUp(self):
torch._dynamo.reset()
super().setUp()
def tearDown(self):
super().tearDown()
torch._dynamo.reset()
if HAS_CUDA and not TEST_WITH_ASAN:
def get_all_cudagraph_segments():
segments = torch.cuda.memory_snapshot()
return [segment for segment in segments if segment["segment_pool_id"] != (0, 0)]
def all_live_blocks():
blocks_addrs = []
for segment in get_all_cudagraph_segments():
addr = segment["address"]
for block in segment["blocks"]:
if block["state"] == "active_allocated":
blocks_addrs.append(addr)
addr += block["size"]
return blocks_addrs
def all_live_block_count():
return len(all_live_blocks())
class CudaGraphTreeTests(TestCase):
def setUp(self):
super().setUp()
self.graph_stack = contextlib.ExitStack()
self.graph_stack.enter_context(
config.patch(
{
"triton.cudagraphs": True,
"triton.cudagraph_trees": True,
"triton.fast_path_cudagraph_asserts": True, # too slow
"triton.slow_path_cudagraph_asserts": True,
}
)
)
self.device_idx = torch.rand([0], device="cuda").device.index
warnings.filterwarnings("ignore")
def tearDown(self):
super().tearDown()
torch._dynamo.reset()
gc.collect()
torch.cuda.empty_cache()
self.graph_stack.close()
self.assertIsNone(self.get_manager())
self.assertEqual(all_live_block_count(), 0)
self.assertEqual(len(get_all_cudagraph_segments()), 0)
warnings.resetwarnings()
def get_manager(self, device_index=None):
return torch._inductor.cudagraph_trees.get_container(
(self.device_idx if not device_index else device_index)
).tree_manager
def get_roots(self):
return self.get_manager().get_roots()
def curr_node(self):
return self.get_manager().current_node
def get_root_children(self):
return [root.num_descendants() for root in self.get_roots()]
def cudagraphify_impl(
self, *args, is_inference=True, is_backward=False, **kwargs
):
return tree_cudagraphify_impl(
*args,
**kwargs,
device_index=self.device_idx,
is_inference=is_inference,
is_backward=is_backward,
)
@staticmethod
def run_twc(fn, *args, **kwargs):
fn(*args, **kwargs)
return fn(*args, **kwargs)
def num_checkpoints(self):
return self.get_manager().debug_checkpointing_counter
def test_run_simple(self):
def foo(x):
return x * x * x
foo_opt = torch._dynamo.optimize()(foo)
ones = torch.ones([4, 4], device="cuda")
zeros = torch.zeros([5, 5], device="cuda")
self.run_twc(foo_opt, ones)
self.run_twc(foo_opt, zeros)
self.assertEqual(self.get_root_children(), [0, 0])
def check_rng(self):
@torch.compile(mode="reduce-overhead")
def foo():
return torch.rand([20])
torch.manual_seed(0)
out = foo()
out2 = foo()
out3 = foo()
torch.manual_seed(0)
self.assertEqual(out, foo())
self.assertEqual(out2, foo())
self.assertEqual(out3, foo())
@torch._inductor.config.patch("fallback_random", True)
def test_rng_trees(self):
self.check_rng()
@torch._inductor.config.patch("triton.cudagraph_trees", False)
@torch._inductor.config.patch("fallback_random", True)
def test_rng_non_trees(self):
self.check_rng()
def test_function_compiled_multiple_times(self):
def foo(x):
y = foo2(x)
y2 = foo2(y)
return y + y2
def foo2(x):
torch._dynamo.graph_break()
return x * x * x
foo_opt = torch._dynamo.optimize()(foo)
ones = torch.ones([4, 4], device="cuda")
foo(ones)
foo_opt(ones)
foo_opt(ones)
self.assertEqual(foo_opt(ones), foo(ones))
# paths
children = self.get_root_children()
# one root with two children
self.assertEqual(children, [2])
def test_end_recording_early(self):
def foo(x):
y = x * x * x
torch._dynamo.graph_break()
z = x + y
return z
@torch._dynamo.optimize()
def foo2(x):
return x + 4
foo_opt = torch._dynamo.optimize()(foo)
for _ in range(3):
out = foo_opt(torch.ones([4, 4], device="cuda"))
del out
# when I tried inducing separate recordings via graph break,
# the frame kept interferring by keeping outputs alive
# this isnt great by simulates the logic.
from torch._dynamo.mutation_guard import GenerationTracker
GenerationTracker.generation -= 1
out = foo2(torch.ones([4, 4], device="cuda"))
del out
foo_opt(torch.ones([4, 4], device="cuda"))
# Two separate traces - one has a child, one doesnt
self.assertEqual(self.get_root_children(), [1, 0])
def test_execution_into_recording(self):
def foo(x):
y = x + x
if y.sum() > 0:
return y + 10
else:
return y - 10
foo_opt = torch._dynamo.optimize()(foo)
inp = torch.zeros([4, 4], dtype=torch.float, device="cuda")
self.assertEqual(foo_opt(inp), foo(inp))
self.assertEqual(foo_opt(inp), foo(inp))
inp.add_(1)
out_eager = foo(inp)
out_warmup = foo_opt(inp)
self.assertEqual(out_warmup, out_eager)
# warmup should be have storage deallocator hooked on
self.assertEqual(all_live_block_count(), 1)
out_live = foo_opt(inp)
self.assertEqual(out_live, out_eager)
# should be in recording mode, with storage deallocator hooked on
self.assertEqual(all_live_block_count(), 1)
# warmup should have been freed
del out_warmup
# should be in recording mode, with storage deallocator hooked on
self.assertEqual(all_live_block_count(), 1)
del out_live
self.assertEqual(all_live_block_count(), 0)
out = foo_opt(inp)
self.assertEqual(foo(inp), out)
# should be in execution mode
self.assertEqual(all_live_block_count(), 0)
def test_forward_with_skipped_cudagraphed_backward(self):
@torch.compile(mode="reduce-overhead")
def foo(x):
return x * x * x
for _ in range(3):
inp = torch.rand([20, 20], device="cuda", requires_grad=True)
out = foo(inp)
def complex_memory_overlap_new(t):
return True
try:
prev = torch._inductor.compile_fx.complex_memory_overlap
torch._inductor.compile_fx.complex_memory_overlap = (
complex_memory_overlap_new
)
back_inp = torch.empty_strided([20, 20], [0, 1], device="cuda")
out.backward(back_inp)
finally:
torch._inductor.compile_fx.complex_memory_overlap = prev
# we should not have cudagraph'd the backwards
new_id = self.get_manager().new_graph_id().id
self.assertEqual(new_id, 1)
self.assertFalse(self.get_manager().running_forwards_with_pending_backwards)
def test_forward_backward_not_called(self):
@torch.compile(mode="reduce-overhead")
def foo(x, y):
x_out = x * x * x
torch._dynamo.graph_break()
y_out = y * y * y
return x_out, y_out
for _ in range(3):
inps = [
torch.rand([20, 20], requires_grad=True, device="cuda")
for _ in range(2)
]
x_out, y_out = foo(inps[0], inps[1])
x_out.sum().backward()
self.assertFalse(self.get_manager().running_forwards_with_pending_backwards)
# we should not have cudagraph'd the y backward
new_id = self.get_manager().new_graph_id().id
self.assertEqual(new_id, 3)
def test_accumulate_multiple_recordings(self):
def foo(x):
y = x + x + x
torch._dynamo.graph_break()
if y.sum() <= 0:
return y
else:
return y * 10
foo_opt = torch._dynamo.optimize()(foo)
# two separate compilations & recordings
out1 = self.run_twc(foo_opt, torch.zeros([5], device="cuda"))
# out1 gets manually freed
out2 = self.run_twc(foo_opt, torch.zeros([6], device="cuda"))
self.assertEqual(all_live_block_count(), 1)
out3 = self.run_twc(foo_opt, torch.ones([5], device="cuda"))
self.assertEqual(out3, foo(torch.ones([5], device="cuda")))
self.assertEqual(all_live_block_count(), 1)
del out1, out2
self.assertEqual(all_live_block_count(), 1)
del out3
gc.collect()
self.assertEqual(all_live_block_count(), 0)
def test_live_outputs_multiple_graphs(self):
def foo(x):
x = x + x + x
y = x + 1
torch._dynamo.graph_break()
z = x * x
if z.sum() > 0:
return y + 1
else:
return y
foo_opt = torch._dynamo.optimize()(foo)
self.run_twc(foo_opt, torch.zeros([5], device="cuda"))
self.assertEqual(self.num_checkpoints(), 0)
out = self.run_twc(foo_opt, torch.ones([5], device="cuda"))
self.assertEqual(all_live_block_count(), 1)
del out
self.assertEqual(all_live_block_count(), 0)
# we need to checkpoint from function to warmup y + 1,
# and then again to record it
self.assertEqual(self.num_checkpoints(), 2)
def test_expanded_inputs(self):
x = torch.rand(1, 512, device="cuda").expand(4, 512)
def foo(x):
return x + 4 + torch.ones([4, 512], device="cuda")
foo_opt = torch.compile()(foo)
for _ in range(3):
self.assertEqual(foo_opt(x), foo(x))
self.assertFalse(self.get_manager().new_graph_id().id == 0)
@torch._inductor.config.patch("triton.skip_cudagraph_warmup", True)
def test_tensor_dies_between_checkpoint(self):
def foo(args):
x = args[0]
args.clear()
return x + 1, x + 2
inp = torch.rand([4], device="cuda")
inp_list = [inp]
foo_cg = self.cudagraphify_impl(foo, inp_list, ())
foo_cg(inp_list)
foo_cg([inp])
out1, out2 = foo_cg([inp])
inp = [out1]
del out1, out2
def foo2(args):
x = args[0]
args.clear()
return [x * x * x]
self.assertEqual(self.num_checkpoints(), 0)
foo2_cg = self.cudagraphify_impl(foo2, inp, ())
x = foo2_cg(inp)[0]
self.assertEqual(self.num_checkpoints(), 1)
# out2 dies between the previous recording and the new one,
# need to be manually deallocated after the checkpoint
self.assertEqual(all_live_block_count(), 1)
del x
self.assertEqual(all_live_block_count(), 0)
def test_aliased_storage_single_weakref(self):
@torch.compile(mode="reduce-overhead")
def foo(x):
x = x * 20
x_alias = x[0]
y = x * 10
y_alias = y[0]
torch._dynamo.graph_break()
ind = torch.tensor(4, device="cuda")
x_alias2 = x[ind:]
y_alias2 = y[ind:]
return x, x_alias, x_alias2, y_alias, y_alias2
for _ in range(4):
outs = foo(torch.rand([20, 20], device="cuda"))
ptr_to_ref = {
out.untyped_storage().data_ptr(): out.untyped_storage()._cdata
for out in outs
}
self.assertEqual(len(ptr_to_ref), 2)
for out in outs:
self.assertEqual(
ptr_to_ref[out.untyped_storage().data_ptr()],
out.untyped_storage()._cdata,
)
del outs
del out
node = self.get_manager().current_node
self.assertEqual(len(list(node.path_live_weakrefs())), 0)
self.assertFalse(self.get_manager().new_graph_id().id == 0)
def test_aliasing_static_ref(self):
class Mod(torch.nn.Linear):
def forward(self, x):
return self.weight.T @ x, self.weight.T, self.weight[0:4]
m = Mod(10, 10).cuda()
@torch.compile(mode="reduce-overhead")
def foo(mod, x):
return mod(x)
@torch.compile(mode="reduce-overhead")
def foo2(x):
return x[2:]
x = torch.rand([10, 10], device="cuda", requires_grad=True)
param_c = cdata(m.weight)
for _ in range(3):
out1, alias_1, alias_2 = foo(m, x)
self.assertEqual(len({param_c, cdata(alias_1), cdata(alias_2)}), 1)
out2 = foo2(out1)
out2.sum().backward()
self.assertEqual(cdata(out1), cdata(out2))
node = self.curr_node()
first_node = next(node._path_from_root)
self.assertFalse(first_node.unaliased_in_all_paths[0])
self.assertTrue(first_node.cached_tensor_outputs[0] is None)
def test_checkpointing_resets_persistent_refs(self):
@torch.compile(mode="reduce-overhead")
def foo(x):
return x @ x
def inp():
return torch.rand([20, 20], device="cuda", requires_grad=False)
for _ in range(3):
foo(inp())
self.assertEqual(self.num_checkpoints(), 0)
out = foo(inp())
out_id = id(out)
del out
self.assertEqual(id(foo(inp())), out_id)
@torch.compile(mode="reduce-overhead")
def foo2(x):
return x[0], x @ x
for i in range(2):
out = foo(inp())
from torch._dynamo.mutation_guard import GenerationTracker
GenerationTracker.generation -= 1
out_alias, out2 = foo2(out)
del out_alias
self.assertEqual(all_live_block_count(), 2)
del out
self.assertEqual(all_live_block_count(), 1)
del out2
self.assertEqual(all_live_block_count(), 0)
self.assertEqual(self.num_checkpoints(), i + 1)
new_out = foo(inp())
curr_node = self.curr_node()
self.assertFalse(curr_node.unaliased_in_all_paths[0])
self.assertFalse(out_id == id(new_out))
def test_aliased_static_parameter(self):
inp = torch.rand([20, 20], device="cuda")
def foo(args):
x = args[0]
args.clear()
return (x[0],)
foo_cg = self.cudagraphify_impl(foo, [inp], (0,))
for _ in range(3):
out = foo_cg([inp])[0]
self.assertEqual(cdata(inp), cdata(out))
node = self.curr_node()
self.assertEqual(node.cached_tensor_outputs, [None])
self.assertEqual(node.unaliased_in_all_paths, [False])
def test_unaligned_static_parameter(self):
def gen_inp():
inp = torch.ones([20], device="cuda")
return [inp[1:]]
def foo(args):
x = args[0]
args.clear()
return (x + x,)
foo_cg = self.cudagraphify_impl(foo, gen_inp(), (0,))
for _ in range(3):
out = foo_cg(gen_inp())
self.assertEqual(out, foo(gen_inp()))
del out
node = self.curr_node()
self.assertEqual(node.static_input_data_ptrs, [None])
def test_amp_cache_disabled(self):
@torch.compile()
def foo(x):
return x + x
for _ in range(3):
out = foo(torch.rand([4, 4], device="cuda", requires_grad=True))
# amp cache for cudagraph outputs should be disabled
t2 = torch.rand([4, 4], device="cuda")
with torch.cuda.amp.autocast():
run_once = out @ t2
out.detach().zero_()
run_twice = out @ t2
self.assertNotEqual(run_once, run_twice)
def test_multiple_insert_removal_caching(self):
torch._C._set_cached_tensors_enabled(True)
try:
x = torch.rand([4], device="cuda")
torch._C._add_cached_tensor(x)
self.assertTrue(torch._C._is_cached_tensor(x))
torch._C._add_cached_tensor(x)
torch._C._remove_cached_tensor(x)
self.assertFalse(torch._C._is_cached_tensor(x))
finally:
torch._C._set_cached_tensors_enabled(False)
def test_accumulate_grad(self):
# cudagraph trees shouldnt interfere with accumulation logic
def compute_grad(grad_output, create_graph):
x = torch.randn(5, 5, requires_grad=True, device="cuda")
@torch.compile()
def foo(x):
return x + 2
y = foo(x)
y.backward(grad_output, retain_graph=True)
x_grad = x.grad
x_grad_clone = x.grad.clone()
y.backward(grad_output, create_graph=create_graph)
return x_grad, x_grad_clone
for _ in range(3):
grad_output = torch.ones(5, 5, device="cuda")
# Accumulate in-place when create_graph is False
x_grad, x_grad_clone = compute_grad(grad_output, create_graph=False)
self.assertEqual(x_grad, x_grad_clone * 2)
# Accumulate out-of-place when create_graph is False
x_grad, x_grad_clone = compute_grad(grad_output, create_graph=True)
self.assertEqual(x_grad, x_grad_clone)
def test_frozen_fn(self):
@torch.compile()
def foo(x):
return x @ x
for _ in range(3):
out = foo(torch.rand([10, 10], device="cuda"))
self.assertTrue(self.get_manager().new_graph_id().id == 1)
frozen = torch._dynamo.run(foo)
for _ in range(3):
out = frozen(torch.rand([10, 10], device="cuda"))
# didnt do additional recordings
self.assertTrue(self.get_manager().new_graph_id().id == 2)
def test_output_alias(self):
inp = torch.rand([20, 20], device="cuda")
def foo(args):
x = args[0]
args.clear()
out = x + x
return (x, x[0])
foo_cg = self.cudagraphify_impl(foo, [inp], ())
for _ in range(3):
out_1, out_2 = foo_cg([inp])
self.assertEqual(cdata(out_1), cdata(out_2))
del out_1, out_2
self.assertEqual(len(list(self.curr_node().path_live_weakrefs())), 0)
self.assertEqual(self.curr_node().cached_tensor_outputs, [None, None])
@torch._inductor.config.patch("triton.skip_cudagraph_warmup", True)
def test_aliased_output_checkpoint(self):
def foo(args):
x = args[0]
args.clear()
y = x + 2
return x + 1, y, y[0]
inp = torch.rand([4, 4], device="cuda")
foo_cg = self.cudagraphify_impl(foo, [inp], ())
foo_cg([inp])
foo_cg([inp])
out1, out2, out3 = foo_cg([inp])
inp = [out1]
del out1, out2, out3
def foo2(args):
x = args[0]
args.clear()
return [x * x * x]
self.assertEqual(self.num_checkpoints(), 0)
foo2_cg = self.cudagraphify_impl(foo2, inp, ())
x = foo2_cg(inp)[0]
self.assertEqual(self.num_checkpoints(), 1)
# out2 and out3 dies between the previous recording and the new one,
# need to be manually deallocated after the checkpoint
self.assertEqual(all_live_block_count(), 1)
del x
self.assertEqual(all_live_block_count(), 0)
@unittest.skipIf(not IS_LINUX, "cpp contexts are linux only")
@torch._inductor.config.patch("triton.cudagraph_trees_history_recording", True)
def test_workspace_allocation_error(self):
torch._C._cuda_clearCublasWorkspaces()
prev = torch._inductor.cudagraph_trees.clear_cublas_manager
try:
torch._inductor.cudagraph_trees.clear_cublas_manager = (
contextlib.nullcontext
)
@torch.compile()
def foo(x, y):
return x @ x
inps = [torch.rand([400, 400], device="cuda") for _ in range(2)]
thrown = False
try:
foo(*inps)
except Exception as e:
thrown = True
FileCheck().check("at::cuda::getNewWorkspace").check(
"at::cuda::blas::gemm<float>"
).run(str(e))
self.assertTrue(thrown)
finally:
torch._C._cuda_clearCublasWorkspaces()
torch._inductor.cudagraph_trees.clear_cublas_manager = prev
torch._inductor.cudagraph_trees.get_container(
self.device_idx
).tree_manager = None
def test_peristed_output_livenes(self):
@torch.compile
def foo(x):
return x + x
for _ in range(3):
foo(torch.rand([2, 2], device="cuda"))
node = self.get_manager().current_node
self.assertEqual(len(list(node.path_live_weakrefs())), 0)
out = foo(torch.rand([2, 2], device="cuda"))
self.assertTrue(out is node.cached_tensor_outputs[0])
self.assertEqual(len(list(node.path_live_weakrefs())), 1)
out_ref = out[0:]
del out
self.assertEqual(len(list(node.path_live_weakrefs())), 1)
del out_ref
self.assertEqual(len(list(node.path_live_weakrefs())), 0)
@torch._inductor.config.patch("triton.skip_cudagraph_warmup", True)
def test_tensor_no_longer_in_pool(self):
def foo(args):
x = args[0]
args.clear()
return x + 1, x + 2
inp = torch.rand([4], device="cuda")
inp_list = [inp]
foo_cg = self.cudagraphify_impl(foo, inp_list, ())
x1, x2 = foo_cg(inp_list)
def foo2(args):
x = args[0]
args.clear()
return [x * x * x]
inp_list = [x1]
foo2_cg = self.cudagraphify_impl(foo2, inp_list, ())
foo2_cg(inp_list)
del x1, x2
# TODO make configurable
x1, x2 = foo_cg([inp])
self.assertEqual(self.num_checkpoints(), 0)
# input location has changed, should force recompile and checkpointing
foo2_cg([torch.zeros_like(x1)])
self.assertEqual(self.num_checkpoints(), 1)
self.assertEqual(self.get_root_children(), [2])
@torch._inductor.config.patch("triton.skip_cudagraph_warmup", True)
def test_checkpoint_shared_output_storage_deallocation(self):
def foo(args):
x = args[0]
args.clear()
x_tmp = x + 1
return x[0], x[1]
inp = torch.rand([2, 2], device="cuda")
inp_list = [inp]
foo_cg = self.cudagraphify_impl(foo, inp_list, ())
foo_cg(inp_list)
foo_cg([inp])
x1, x2 = foo_cg([inp])
inp = [x1]
def foo2(args):
x = args[0]
args.clear()
y = x * x
return y[0], y[1]
foo2_cg = self.cudagraphify_impl(foo2, inp, ())
foo2_cg(inp)
self.assertEqual(self.num_checkpoints(), 1)
self.assertEqual(
x1.untyped_storage().data_ptr(), x2.untyped_storage().data_ptr()
)
self.assertEqual(all_live_block_count(), 1)
del x1
self.assertEqual(all_live_block_count(), 1)
del x2
self.assertEqual(all_live_block_count(), 0)
@torch._inductor.config.patch("triton.skip_cudagraph_warmup", True)
def test_cleanup(self):
def test_closure():
@torch._dynamo.optimize()
def foo(x):
return x + 1 + 2, x * 10
foo(torch.rand([4], device="cuda"))
return foo(torch.rand([4], device="cuda"))
out1, out2 = test_closure()
torch._dynamo.reset()
# TODO - deallocate on tensor deallocation
# self.assertTrue(self.get_manager() is not None)
# del out1
# self.assertTrue(self.get_manager() is not None)
# del out2
self.assertTrue(self.get_manager() is None)
@torch._inductor.config.patch("triton.skip_cudagraph_warmup", True)
def test_forward_backward(self):
@torch._dynamo.optimize()
def foo(x):
y = x * 2
return torch.sin(y) * torch.nn.functional.dropout(x, p=0.4)
inp = torch.rand([4, 4], requires_grad=True, device="cuda")
out = foo(inp)
out.sum().backward()
self.assertEqual(self.get_root_children(), [1])
# the three saved tensors should die in the backward
# we kept alive the output
self.assertEqual(self.curr_node().expected_dead_indices_before_graph, [])
self.assertEqual(
self.curr_node().expected_dead_indices_after_graph,
[(0, 1), (0, 2), (0, 3)],
)
self.assertFalse(self.get_manager().new_graph_id().id == 0)
def test_separate_recordings(self):
def foo_unopt(x, y):
return (x + 1) @ y
foo = torch._dynamo.optimize()(foo_unopt)
foo_unopt(
torch.ones([20, 20], device="cuda"), torch.ones([20, 20], device="cuda")
)
inps = [
torch.ones([20, 20], device="cuda", requires_grad=False)
for _ in range(2)
]
out = foo(*inps)
torch.cuda.synchronize()
foo(*inps)
torch.cuda.synchronize()
foo(*inps)
torch.cuda.synchronize()
foo_unopt(
torch.ones([20, 20], device="cuda"), torch.ones([20, 20], device="cuda")
)
inps2 = [
torch.rand([40, 40], device="cuda", requires_grad=False)
for _ in range(2)
]
foo(*inps2)
foo(*inps2)
foo(*inps2)
# two separate roots
self.assertEqual(self.get_root_children(), [0, 0])
def test_alias_of_parameter(self):
class AliasMod(nn.Module):
def __init__(self):
super().__init__()
self.param = torch.nn.Parameter(torch.rand([20, 20], device="cuda"))
def forward(self, x):
return self.param[0], self.param, self.param + x
@torch.compile(mode="reduce-overhead")
def foo(mod, inp):
return mod(inp)
inp = torch.rand([20, 20], device="cuda")
mod = AliasMod()
storage_ref = torch.multiprocessing.reductions.StorageWeakRef(
mod.param.untyped_storage()
)
for _ in range(3):
outs = foo(mod, inp)
self.assertEqual(mod(inp), outs)
self.assertFalse(storage_ref.expired())
node = self.get_manager().current_node
self.assertEqual(len(list(node.path_live_weakrefs())), 1)
@requires_multigpu()
def test_manager_per_device(self):
def test():
def foo(args):
x = args[0]
args.clear()
return (x + 3,)
inp = torch.rand([20, 20], device="cuda:1")
inp_list = [inp]
foo_cg = tree_cudagraphify_impl(
foo,
inp_list,
(),
device_index=1,
is_backward=False,
is_inference=True,
)
for _ in range(3):
self.assertEqual(foo_cg([inp]), foo([inp]))
self.assertTrue(self.get_manager(device_index=0) is None)
self.assertFalse(self.get_manager(device_index=1) is None)
test()
self.assertTrue(self.get_manager(device_index=1) is None)
def test_warnings_on_dealloc(self):
@torch.compile()
def foo(x):
return x * x * x
inp = torch.rand([4], device="cuda")
out = foo(inp)
warnings.resetwarnings()
with warnings.catch_warnings(record=True) as w:
foo(inp)
self.assertTrue(len(w) == 1)
self.assertTrue("x * x * x" in str(w[0]))
def test_single_stream_use(self):
@torch.compile()
def foo(x):
return (x * x * x).relu()
inp = torch.rand([4], device="cuda", requires_grad=True)
streams = set()
streams_init = {seg["stream"] for seg in get_all_cudagraph_segments()}
for _ in range(4):
foo(inp).sum().backward()
streams = {
seg["stream"] for seg in get_all_cudagraph_segments()
} - streams_init
self.assertEqual(len(streams), 1)
self.assertFalse(self.get_manager().new_graph_id().id == 0)
def test_forward_generation(self):
def foo(x):
return x * x * x
def foo2(x):
return x * 12
foo_opt = torch._dynamo.optimize()(foo)
foo2_opt = torch._dynamo.optimize()(foo2)
ones = torch.ones([4, 4], device="cuda", requires_grad=True)
out = foo_opt(ones)
out2 = foo2_opt(out)
self.assertEqual(all_live_block_count(), 2)
self.assertTrue(self.get_manager().running_forwards_with_pending_backwards)
out2.sum().backward()
self.assertFalse(self.get_manager().running_forwards_with_pending_backwards)
del out
del out2
foo2_opt(foo_opt(ones)).sum().backward()
out = foo_opt(ones.detach())
self.assertFalse(self.get_manager().running_forwards_with_pending_backwards)
self.assertFalse(self.get_manager().new_graph_id().id == 0)
if __name__ == "__main__":
from torch._dynamo.test_case import run_tests
if (HAS_CPU or HAS_CUDA) and not TEST_WITH_ROCM:
run_tests(needs="filelock")