enable back 2 tests for simple exec
Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/29661
Differential Revision: D18456143
Pulled By: Krovatkin
fbshipit-source-id: 9e4ae3ae681e3c9a81ada1e8b39da1e1342ce394
diff --git a/test/test_jit.py b/test/test_jit.py
index 5e9688e..2b5bdb4 100644
--- a/test/test_jit.py
+++ b/test/test_jit.py
@@ -4841,7 +4841,6 @@
self.checkScript(func5, (x, y))
@unittest.skipIf(not RUN_CUDA, "device tests require CUDA")
- @unittest.skipIf(GRAPH_EXECUTOR == ProfilingMode.SIMPLE, "Simple executor doesn't support backward")
def test_pow_scalar_backward_cuda(self):
# see that scalar exponent works with cuda base (#19253)
with enable_profiling_mode():
@@ -5568,7 +5567,7 @@
m()
- @unittest.skipIf(GRAPH_EXECUTOR == ProfilingMode.SIMPLE, "NYI: fuser support for Sandcastle")
+ @unittest.skipIf(GRAPH_EXECUTOR == ProfilingMode.SIMPLE, "Simple Executor doesn't use requires_grad information")
def test_requires_grad_loop(self):
@torch.jit.script
def test(x, y, z):
@@ -11112,8 +11111,6 @@
self.run_pass('erase_number_types', graph)
FileCheck().check_not("int = prim::Constant").check_not("aten::add_").run(str(graph))
-
- @unittest.skipIf(GRAPH_EXECUTOR == ProfilingMode.SIMPLE, "Simple executor doesn't support gradients")
def test_mm_batching(self):
with enable_profiling_mode():
@@ -11132,8 +11129,8 @@
slstm(*inputs, profile_and_replay=True).sum().backward()
fw_graph = slstm.graph_for(*inputs)
- bw_graph = backward_graph(slstm, diff_graph_idx=0)
if GRAPH_EXECUTOR == ProfilingMode.LEGACY:
+ bw_graph = backward_graph(slstm, diff_graph_idx=0)
self.assertTrue('prim::MMBatchSide' in str(fw_graph))
self.assertTrue('prim::MMTreeReduce' in str(bw_graph))