[Static Runtime] test case for staticRuntime::runAsync() API (#80407)
Summary:
- Python interface to call StaticRuntime::runAsync API
- creates a custom executor with execution on inter-op thread pool
- test cases for different async graph scenarios like multiple forks, nested forks, exception handling
Test Plan:
- local tests
buck test mode/opt caffe2/test:static_runtime
buck test mode/opt caffe2/benchmarks/static_runtime/fb:test_fb_operators
buck test mode/opt caffe2/benchmarks/static_runtime:static_runtime_cpptest
- OSS CI tests
Differential Revision: D37471859
Pull Request resolved: https://github.com/pytorch/pytorch/pull/80407
Approved by: https://github.com/tenpercent
diff --git a/test/test_static_runtime.py b/test/test_static_runtime.py
index a501d4b..b3087ee 100644
--- a/test/test_static_runtime.py
+++ b/test/test_static_runtime.py
@@ -23,6 +23,9 @@
def benchmark(self, args, kwargs, warmup_runs, main_runs):
self.static_module.benchmark(args, kwargs, warmup_runs, main_runs)
+ def runAsync(self, args, kwargs):
+ return self.static_module.runAsync(args, kwargs)
+
def benchmark_individual_ops(self, args, kwargs, warmup_runs, main_runs):
return self.static_module.benchmark_individual_ops(
args, kwargs, warmup_runs, main_runs
@@ -223,6 +226,20 @@
torch.testing.assert_close(output_test, output_ref)
"""
+ Test Case: To test simple fork/wait operation with
+ StaticRuntime runAsync API returning future
+ """
+ def test_fork_wait_1_async(self):
+ inp1 = torch.ones(5, 5)
+ inp2 = torch.randn(5, 5)
+ torch_graph = torch.jit.script(fork_wait_graph1)
+ output_ref = torch_graph(inp1, inp2)
+ static_runtime_module = StaticModule(torch_graph)
+ output_test = static_runtime_module.runAsync((inp1, inp2), {})
+ output_test.wait()
+ torch.testing.assert_close(output_test.value(), output_ref)
+
+ """
Test Case: To test fork/wait operation in a graph on
a loop subgraph performing mix of operations
"""
@@ -236,6 +253,20 @@
torch.testing.assert_close(output_test, output_ref)
"""
+ Test Case: To test fork/wait operation on a loop
+ subgraph with StaticRuntime runAsync API returning future
+ """
+ def test_fork_wait_2_async(self):
+ inp1 = torch.randn(5, 5)
+ inp2 = torch.randn(5, 5)
+ torch_graph = torch.jit.script(fork_wait_graph2)
+ output_ref = torch_graph(inp1, inp2)
+ static_runtime_module = StaticModule(torch_graph)
+ output_test = static_runtime_module.runAsync((inp1, inp2), {})
+ output_test.wait()
+ torch.testing.assert_close(output_test.value(), output_ref)
+
+ """
Test Case: To test fork/wait operation in a graph on
having multiple fork/wait operations
"""
@@ -247,6 +278,21 @@
static_runtime_module = StaticModule(torch_graph)
output_test = static_runtime_module(input, num_forks)
torch.testing.assert_close(output_test, output_ref)
+
+ """
+ Test Case: To test fork/wait operation in a graph with
+ multiple fork/wait operations on runAsync API returning future
+ """
+ def test_fork_wait_3_async(self):
+ input = torch.ones(3, 3)
+ num_forks = 10
+ torch_graph = torch.jit.script(fork_wait_graph3)
+ output_ref = torch_graph(input, num_forks)
+ static_runtime_module = StaticModule(torch_graph)
+ output_test = static_runtime_module.runAsync((input, num_forks), {})
+ output_test.wait()
+ torch.testing.assert_close(output_test.value(), output_ref)
+
"""
Test Case: To test fork/wait operation in a graph on
multiple nested fork/wait operations
@@ -262,6 +308,22 @@
torch.testing.assert_close(output_test, output_ref)
"""
+ Test Case: To test fork/wait operation in a graph with multiple
+ nested fork/wait operations on runAsync API returning future
+ """
+ def test_fork_wait_4_async(self):
+ input = torch.ones(3, 3)
+ num_forks = 10
+ num_child_forks = 10
+ torch_graph = torch.jit.script(fork_wait_graph4)
+ static_runtime_module = StaticModule(torch_graph)
+ output_ref = torch_graph(input, num_forks, num_child_forks)
+ output_test = static_runtime_module.runAsync(
+ (input, num_forks, num_child_forks), {})
+ output_test.wait()
+ torch.testing.assert_close(output_test.value(), output_ref)
+
+ """
Test Case: To test exception handling in fork/wait
operation. Add.Tensor op is called for tensors with
non-matching dims on the forked subgraph and the
@@ -290,6 +352,36 @@
f"not contain expected substring: \"{expected_error_msg}\""
) from error
+ """
+ Test Case: To test exception handling in fork/wait
+ operation with runAsync API. Add.Tensor op is called for
+ tensors with non-matching dims on the forked subgraph
+ and the exception raised by subgraph is set on future returned
+ by prim::fork to parent graph. Returned exception is
+ checked for substring expected_error_msg as declared below
+ """
+ def test_fork_wait_exception_async(self):
+ # incompatible tensors for add due to shape mismatch
+ input1 = torch.randn(4, 7)
+ input2 = torch.randn(4, 5)
+ torch_graph = torch.jit.script(fork_wait_graph_exception)
+ try:
+ static_runtime_module = StaticModule(torch_graph)
+ output_test = static_runtime_module.runAsync(
+ (input1, input2), {})
+ except Exception as error:
+ expected_error_msg = (
+ "The size of tensor a (7) must match the size "
+ "of tensor b (5) at non-singleton dimension 1"
+ )
+ # test fails if error does not contain expected substr
+ if str(error).find(expected_error_msg) == -1:
+ raise RuntimeError(
+ "Tried execution of add.Tensors with incompatible shape. "
+ "Exception raised by forked runtime execution does "
+ f"not contain expected substring: \"{expected_error_msg}\""
+ ) from error
+
def test_multihead_attention_layer(self):
HID_DIM = 256
QUERY_LEN = 8
diff --git a/torch/csrc/jit/runtime/static/init.cpp b/torch/csrc/jit/runtime/static/init.cpp
index 778d6e2..36b8ca8 100644
--- a/torch/csrc/jit/runtime/static/init.cpp
+++ b/torch/csrc/jit/runtime/static/init.cpp
@@ -89,6 +89,28 @@
kwargs.begin(), kwargs.end()};
return self.runtime().benchmark_individual_ops(
{arg_ivalues}, {kwarg_ivalues}, warmup_runs, main_runs);
+ })
+ .def(
+ "runAsync",
+ [](StaticModule& self,
+ const py::tuple& args,
+ const py::dict& kwargs) {
+ std::vector<c10::IValue> arg_ivalues;
+ for (const auto& elem : args) {
+ arg_ivalues.push_back(
+ torch::jit::toIValue(elem, c10::AnyType::get()));
+ }
+ std::unordered_map<std::string, c10::IValue> kwarg_ivalues;
+ for (const auto& kv : kwargs) {
+ kwarg_ivalues[py::cast<std::string>(kv.first)] =
+ torch::jit::toIValue(kv.second, c10::AnyType::get());
+ }
+ // custom executor for async op execution
+ auto task_launcher = [](const std::function<void()>& f) {
+ at::launch(f);
+ };
+ return toPyObject(self.runtime().runAsync(
+ arg_ivalues, kwarg_ivalues, task_launcher));
});
m.def(
"_jit_to_static_module",