[AOTI][refactor] Update some test cases (#123093)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/123093
Approved by: https://github.com/Skylion007, https://github.com/chenyang78
diff --git a/test/inductor/test_cpu_cpp_wrapper.py b/test/inductor/test_cpu_cpp_wrapper.py
index 53b8748..0789a02 100644
--- a/test/inductor/test_cpu_cpp_wrapper.py
+++ b/test/inductor/test_cpu_cpp_wrapper.py
@@ -103,7 +103,6 @@
f"{test_name}_dynamic_shapes"
] = test_torchinductor.TestFailure(("cpp_wrapper",), is_skip=False)
skip_list = [
- "test_linear1_cpu", # segfault from double free
"test_multihead_attention_cpu",
]
for test_name in skip_list:
diff --git a/test/inductor/test_cuda_cpp_wrapper.py b/test/inductor/test_cuda_cpp_wrapper.py
index 7b730d5..6e80c32 100644
--- a/test/inductor/test_cuda_cpp_wrapper.py
+++ b/test/inductor/test_cuda_cpp_wrapper.py
@@ -109,7 +109,6 @@
] = test_torchinductor.TestFailure(("cuda_wrapper",), is_skip=False)
skip_list = [
"test_multi_device_cuda",
- "test_linear1_cuda", # segfault from double free
]
for test_name in skip_list:
test_failures_cuda_wrapper[test_name] = test_torchinductor.TestFailure(
diff --git a/test/inductor/test_torchinductor.py b/test/inductor/test_torchinductor.py
index 392fe29..ae1f7f7 100644
--- a/test/inductor/test_torchinductor.py
+++ b/test/inductor/test_torchinductor.py
@@ -7194,8 +7194,6 @@
@torch._dynamo.config.patch(assume_static_by_default=False)
def test_dtype_sympy_expr(self):
- torch._inductor.metrics.disable_cpp_wrapper = 0
-
@torch._dynamo.optimize_assert("inductor")
def fn(a):
y = a[..., :-1, :].contiguous()
@@ -7204,11 +7202,6 @@
result = fn(torch.randn([1, 2, 16, 4]).requires_grad_())
result.sum().backward()
- expected_disable_cpp_wrapper = 0
- self.assertEqual(
- torch._inductor.metrics.disable_cpp_wrapper, expected_disable_cpp_wrapper
- )
-
def test_dropout2(self):
n = 100000
weight = torch.ones(
diff --git a/torch/_inductor/metrics.py b/torch/_inductor/metrics.py
index c9a2fcc..ae2c77c 100644
--- a/torch/_inductor/metrics.py
+++ b/torch/_inductor/metrics.py
@@ -39,9 +39,6 @@
# counters for tracking to_dtype inserted
cpp_to_dtype_count = 0
-# counters for tracking cpp_wrapper disabled
-disable_cpp_wrapper = 0
-
# reset all counters
def reset():
@@ -50,7 +47,6 @@
global num_bytes_accessed, nodes_num_elem
global ir_nodes_pre_fusion
global cpp_to_dtype_count
- global disable_cpp_wrapper
generated_kernel_count = 0
generated_cpp_vec_kernel_count = 0
@@ -59,7 +55,6 @@
node_runtimes.clear()
ir_nodes_pre_fusion = 0
cpp_to_dtype_count = 0
- disable_cpp_wrapper = 0
@dataclass