| # Owner(s): ["module: nvfuser"] | 
 |  | 
 | import unittest | 
 | import warnings | 
 | from functools import partial | 
 |  | 
 | import torch | 
 | import torch._dynamo as torchdynamo | 
 | from torch.testing import make_tensor | 
 | from torch.testing._internal.common_utils import ( | 
 |     IS_WINDOWS, | 
 |     run_tests, | 
 |     skipIfTorchDynamo, | 
 |     TEST_WITH_ROCM, | 
 |     TestCase, | 
 | ) | 
 | from torch.testing._internal.jit_utils import RUN_CUDA | 
 |  | 
 | RUN_NVFUSER = RUN_CUDA and not TEST_WITH_ROCM | 
 |  | 
 |  | 
 | def is_pre_volta(): | 
 |     if not RUN_NVFUSER: | 
 |         return False | 
 |     prop = torch.cuda.get_device_properties(torch.cuda.current_device()) | 
 |     return prop.major < 7 | 
 |  | 
 |  | 
 | def is_networkx_available(): | 
 |     try: | 
 |         import networkx  # noqa: F401 | 
 |  | 
 |         return True | 
 |     except ImportError: | 
 |         return False | 
 |  | 
 |  | 
 | @skipIfTorchDynamo("Not a suitable test for TorchDynamo") | 
 | @unittest.skipIf(IS_WINDOWS, "TorchDynamo is not supported on Windows") | 
 | @unittest.skipIf(not RUN_NVFUSER, "requires CUDA") | 
 | @unittest.skipIf(is_pre_volta(), "Only supported on Volta and newer devices.") | 
 | class TestNvFuserDynamo(TestCase): | 
 |     def test_basic(self): | 
 |         input1 = make_tensor((2, 4, 8), device="cuda", dtype=torch.float32) | 
 |         input2 = make_tensor((2, 4, 8), device="cuda", dtype=torch.float32) | 
 |  | 
 |         @torchdynamo.optimize("nvprims_nvfuser") | 
 |         def func(a, b): | 
 |             return a.sin() + b.cos() | 
 |  | 
 |         # No warnings and no errors | 
 |         with warnings.catch_warnings(record=True) as w: | 
 |             nvfuser_result = func(input1, input2) | 
 |             self.assertEqual(len(w), 0) | 
 |         eager_result = func.__wrapped__(input1, input2) | 
 |         self.assertEqual(eager_result, nvfuser_result) | 
 |  | 
 |     @unittest.skipIf(not is_networkx_available(), "networkx not available") | 
 |     def test_min_cut(self): | 
 |         from functorch.compile import default_partition | 
 |         from torch._dynamo.optimizations.training import nvprims_fw_bw_partition_fn | 
 |  | 
 |         def get_fw_bw_graph(f, inps, partitioner): | 
 |             from functorch.compile import aot_function | 
 |  | 
 |             # Helper functions are taken from functorch/test_aotdispatch.py | 
 |             def extract_graph(fx_g, _, graph_cell): | 
 |                 graph_cell[0] = fx_g | 
 |                 return fx_g | 
 |  | 
 |             fw_graph_cell = [None] | 
 |             bw_graph_cell = [None] | 
 |             aot_function( | 
 |                 f, | 
 |                 fw_compiler=partial(extract_graph, graph_cell=fw_graph_cell), | 
 |                 bw_compiler=partial(extract_graph, graph_cell=bw_graph_cell), | 
 |                 partition_fn=partitioner, | 
 |             )(*inps).sum().backward() | 
 |             return (fw_graph_cell[0], bw_graph_cell[0]) | 
 |  | 
 |         def get_ins_outs(fx_g): | 
 |             ins = [] | 
 |             outs = [] | 
 |             for n in fx_g.graph.nodes: | 
 |                 if n.op == "placeholder": | 
 |                     ins.append(n) | 
 |                 elif n.op == "output": | 
 |                     outs = tuple(n.args[0]) | 
 |             return ins, outs | 
 |  | 
 |         def get_num_ins_outs(fx_g): | 
 |             return tuple(len(i) for i in get_ins_outs(fx_g)) | 
 |  | 
 |         def func(x): | 
 |             return x * x * x | 
 |  | 
 |         input1 = make_tensor( | 
 |             (3,), device="cpu", dtype=torch.float32, requires_grad=True | 
 |         ) | 
 |         fw_graph, bw_graph = get_fw_bw_graph(func, [input1], default_partition) | 
 |         self.assertEqual(get_num_ins_outs(fw_graph), (1, 3)) | 
 |         self.assertEqual(get_num_ins_outs(bw_graph), (3, 1)) | 
 |  | 
 |         input1 = make_tensor( | 
 |             (3,), device="cpu", dtype=torch.float32, requires_grad=True | 
 |         ) | 
 |         fw_graph, bw_graph = get_fw_bw_graph(func, [input1], nvprims_fw_bw_partition_fn) | 
 |         self.assertEqual(get_num_ins_outs(fw_graph), (1, 2)) | 
 |         self.assertEqual(get_num_ins_outs(bw_graph), (2, 1)) | 
 |  | 
 |     def test_batch_norm_implicit_dtype_promotion(self): | 
 |         input1 = make_tensor((2, 3, 4, 5), device="cuda", dtype=torch.float32) | 
 |         input2 = make_tensor((5, 5), device="cuda", dtype=torch.float32) | 
 |         w = make_tensor((3), device="cuda", dtype=torch.float32) | 
 |         b = make_tensor((3), device="cuda", dtype=torch.float32) | 
 |  | 
 |         @torchdynamo.optimize("nvprims_nvfuser") | 
 |         def func(mat1, mat2, w, b): | 
 |             o = torch.matmul(mat1, mat2) | 
 |             return torch.batch_norm(o, w, b, None, None, True, 1e-2, 1e-5, True) | 
 |  | 
 |         # No warnings and no errors | 
 |         with torch.cuda.amp.autocast(): | 
 |             with warnings.catch_warnings(record=True) as warning: | 
 |                 nvfuser_result = func(input1, input2, w, b) | 
 |                 self.assertEqual(len(warning), 0) | 
 |             eager_result = func.__wrapped__(input1, input2, w, b) | 
 |             self.assertEqual(eager_result, nvfuser_result) | 
 |  | 
 |     def test_dtype_correctness(self): | 
 |         input1 = make_tensor((2, 4, 8), device="cuda", dtype=torch.float16) | 
 |  | 
 |         @torchdynamo.optimize("nvprims_nvfuser") | 
 |         def func(a): | 
 |             tmp = a + 1.0 | 
 |             # nvfuser would promote output to fp32 in math, FusionDefinition should cast output dtype back | 
 |             return torch.where(tmp > 0, tmp, 0.0) | 
 |  | 
 |         # No warnings and no errors | 
 |         with warnings.catch_warnings(record=True) as w: | 
 |             nvfuser_result = func(input1) | 
 |             self.assertEqual(len(w), 0) | 
 |         eager_result = func.__wrapped__(input1) | 
 |         self.assertEqual(eager_result, nvfuser_result) | 
 |  | 
 |  | 
 | if __name__ == "__main__": | 
 |     run_tests() |