[Static Runtime] test case for staticRuntime::runAsync() API (#80407)

Summary:
- Python interface to call StaticRuntime::runAsync API
- creates a custom executor with execution on inter-op thread pool
- test cases for different async graph scenarios like multiple forks, nested forks, exception handling

Test Plan:
- local tests
    buck test mode/opt caffe2/test:static_runtime
    buck test mode/opt caffe2/benchmarks/static_runtime/fb:test_fb_operators
    buck test mode/opt caffe2/benchmarks/static_runtime:static_runtime_cpptest
- OSS CI tests

Differential Revision: D37471859

Pull Request resolved: https://github.com/pytorch/pytorch/pull/80407
Approved by: https://github.com/tenpercent
diff --git a/test/test_static_runtime.py b/test/test_static_runtime.py
index a501d4b..b3087ee 100644
--- a/test/test_static_runtime.py
+++ b/test/test_static_runtime.py
@@ -23,6 +23,9 @@
     def benchmark(self, args, kwargs, warmup_runs, main_runs):
         self.static_module.benchmark(args, kwargs, warmup_runs, main_runs)
 
+    def runAsync(self, args, kwargs):
+        return self.static_module.runAsync(args, kwargs)
+
     def benchmark_individual_ops(self, args, kwargs, warmup_runs, main_runs):
         return self.static_module.benchmark_individual_ops(
             args, kwargs, warmup_runs, main_runs
@@ -223,6 +226,20 @@
         torch.testing.assert_close(output_test, output_ref)
 
     """
+    Test Case: To test simple fork/wait operation with
+    StaticRuntime runAsync API returning future
+    """
+    def test_fork_wait_1_async(self):
+        inp1 = torch.ones(5, 5)
+        inp2 = torch.randn(5, 5)
+        torch_graph = torch.jit.script(fork_wait_graph1)
+        output_ref = torch_graph(inp1, inp2)
+        static_runtime_module = StaticModule(torch_graph)
+        output_test = static_runtime_module.runAsync((inp1, inp2), {})
+        output_test.wait()
+        torch.testing.assert_close(output_test.value(), output_ref)
+
+    """
     Test Case: To test fork/wait operation in a graph on
     a loop subgraph performing mix of operations
     """
@@ -236,6 +253,20 @@
         torch.testing.assert_close(output_test, output_ref)
 
     """
+    Test Case: To test fork/wait operation on a loop
+    subgraph with StaticRuntime runAsync API returning future
+    """
+    def test_fork_wait_2_async(self):
+        inp1 = torch.randn(5, 5)
+        inp2 = torch.randn(5, 5)
+        torch_graph = torch.jit.script(fork_wait_graph2)
+        output_ref = torch_graph(inp1, inp2)
+        static_runtime_module = StaticModule(torch_graph)
+        output_test = static_runtime_module.runAsync((inp1, inp2), {})
+        output_test.wait()
+        torch.testing.assert_close(output_test.value(), output_ref)
+
+    """
     Test Case: To test fork/wait operation in a graph on
     having multiple fork/wait operations
     """
@@ -247,6 +278,21 @@
         static_runtime_module = StaticModule(torch_graph)
         output_test = static_runtime_module(input, num_forks)
         torch.testing.assert_close(output_test, output_ref)
+
+    """
+    Test Case: To test fork/wait operation in a graph with
+    multiple fork/wait operations on runAsync API returning future
+    """
+    def test_fork_wait_3_async(self):
+        input = torch.ones(3, 3)
+        num_forks = 10
+        torch_graph = torch.jit.script(fork_wait_graph3)
+        output_ref = torch_graph(input, num_forks)
+        static_runtime_module = StaticModule(torch_graph)
+        output_test = static_runtime_module.runAsync((input, num_forks), {})
+        output_test.wait()
+        torch.testing.assert_close(output_test.value(), output_ref)
+
     """
     Test Case: To test fork/wait operation in a graph on
     multiple nested fork/wait operations
@@ -262,6 +308,22 @@
         torch.testing.assert_close(output_test, output_ref)
 
     """
+    Test Case: To test fork/wait operation in a graph with multiple
+    nested fork/wait operations on runAsync API returning future
+    """
+    def test_fork_wait_4_async(self):
+        input = torch.ones(3, 3)
+        num_forks = 10
+        num_child_forks = 10
+        torch_graph = torch.jit.script(fork_wait_graph4)
+        static_runtime_module = StaticModule(torch_graph)
+        output_ref = torch_graph(input, num_forks, num_child_forks)
+        output_test = static_runtime_module.runAsync(
+            (input, num_forks, num_child_forks), {})
+        output_test.wait()
+        torch.testing.assert_close(output_test.value(), output_ref)
+
+    """
     Test Case: To test exception handling in fork/wait
     operation. Add.Tensor op is called for tensors with
     non-matching dims on the forked subgraph and the
@@ -290,6 +352,36 @@
                     f"not contain expected substring: \"{expected_error_msg}\""
                 ) from error
 
+    """
+    Test Case: To test exception handling in fork/wait
+    operation with runAsync API. Add.Tensor op is called for
+    tensors with non-matching dims on the forked subgraph
+    and the exception raised by subgraph is set on future returned
+    by prim::fork to parent graph. Returned exception is
+    checked for substring expected_error_msg as declared below
+    """
+    def test_fork_wait_exception_async(self):
+        # incompatible tensors for add due to shape mismatch
+        input1 = torch.randn(4, 7)
+        input2 = torch.randn(4, 5)
+        torch_graph = torch.jit.script(fork_wait_graph_exception)
+        try:
+            static_runtime_module = StaticModule(torch_graph)
+            output_test = static_runtime_module.runAsync(
+                (input1, input2), {})
+        except Exception as error:
+            expected_error_msg = (
+                "The size of tensor a (7) must match the size "
+                "of tensor b (5) at non-singleton dimension 1"
+            )
+            # test fails if error does not contain expected substr
+            if str(error).find(expected_error_msg) == -1:
+                raise RuntimeError(
+                    "Tried execution of add.Tensors with incompatible shape. "
+                    "Exception raised by forked runtime execution does "
+                    f"not contain expected substring: \"{expected_error_msg}\""
+                ) from error
+
     def test_multihead_attention_layer(self):
         HID_DIM = 256
         QUERY_LEN = 8
diff --git a/torch/csrc/jit/runtime/static/init.cpp b/torch/csrc/jit/runtime/static/init.cpp
index 778d6e2..36b8ca8 100644
--- a/torch/csrc/jit/runtime/static/init.cpp
+++ b/torch/csrc/jit/runtime/static/init.cpp
@@ -89,6 +89,28 @@
                 kwargs.begin(), kwargs.end()};
             return self.runtime().benchmark_individual_ops(
                 {arg_ivalues}, {kwarg_ivalues}, warmup_runs, main_runs);
+          })
+      .def(
+          "runAsync",
+          [](StaticModule& self,
+             const py::tuple& args,
+             const py::dict& kwargs) {
+            std::vector<c10::IValue> arg_ivalues;
+            for (const auto& elem : args) {
+              arg_ivalues.push_back(
+                  torch::jit::toIValue(elem, c10::AnyType::get()));
+            }
+            std::unordered_map<std::string, c10::IValue> kwarg_ivalues;
+            for (const auto& kv : kwargs) {
+              kwarg_ivalues[py::cast<std::string>(kv.first)] =
+                  torch::jit::toIValue(kv.second, c10::AnyType::get());
+            }
+            // custom executor for async op execution
+            auto task_launcher = [](const std::function<void()>& f) {
+              at::launch(f);
+            };
+            return toPyObject(self.runtime().runAsync(
+                arg_ivalues, kwarg_ivalues, task_launcher));
           });
   m.def(
        "_jit_to_static_module",