[Static Runtime] Check unsupported up when enabling static runtime (#61613)

Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/61613

Reviewed By: ajyu, movefast1990

Differential Revision: D29663466

fbshipit-source-id: d819903b7227f534c0a4fffa5eeea2b5c0c04750
diff --git a/benchmarks/static_runtime/test_scripts.h b/benchmarks/static_runtime/test_scripts.h
index f24060a..14bc0e5 100644
--- a/benchmarks/static_runtime/test_scripts.h
+++ b/benchmarks/static_runtime/test_scripts.h
@@ -493,3 +493,17 @@
   def forward(self, inp: Tensor, mat1: Tensor, mat2: Tensor, beta: float, alpha: float):
    return torch.addmm(inp, mat1, mat2, alpha=alpha, beta=beta).clone()
 )JIT";
+
+const auto if_script = R"JIT(
+  def forward(self, a: Tensor, b: Tensor, x: bool):
+    c = (a + b).relu().half().float()
+    d = b * c
+    if x:
+      e = a.flatten().half() * b.flatten().half()
+    else:
+      e = a.flatten().half() + b.flatten().half()
+    f = e.float().relu()
+    g = {"d": d, "b": b}
+    h = {"e": e, "f": f}
+    return [g, h]
+)JIT";
diff --git a/benchmarks/static_runtime/test_static_runtime.cc b/benchmarks/static_runtime/test_static_runtime.cc
index b4e27e9..4f50060 100644
--- a/benchmarks/static_runtime/test_static_runtime.cc
+++ b/benchmarks/static_runtime/test_static_runtime.cc
@@ -178,6 +178,16 @@
   return nullptr;
 }
 
+bool testCanEnableStaticRuntime(const std::string& jit_script) {
+  script::Module module("module");
+  module.define(jit_script);
+
+  Method method = module.get_method("forward");
+  auto graph = module.get_method("forward").graph();
+
+  // here we do not freeze graph
+  return torch::jit::canEnableStaticRuntime(graph);
+}
 } // namespace
 
 TEST(StaticRuntime, InPlace) {
@@ -186,6 +196,11 @@
   EXPECT_FALSE(testHasInplaceOp(sigmoid_out_script));
 }
 
+TEST(StaticRuntime, CanEnableStaticRuntime) {
+  EXPECT_TRUE(testCanEnableStaticRuntime(reshape_inplace_script));
+  EXPECT_FALSE(testCanEnableStaticRuntime(if_script));
+}
+
 TEST(StaticRuntime, NestedOutput) {
   auto run_test = [](std::vector<int64_t> shapes) {
     auto a = at::randn(shapes);
diff --git a/tools/build_variables.bzl b/tools/build_variables.bzl
index cd7d017..0dcfaf3 100644
--- a/tools/build_variables.bzl
+++ b/tools/build_variables.bzl
@@ -313,6 +313,7 @@
 core_sources_full = core_sources_full_mobile + [
     "torch/csrc/jit/runtime/static/fusion.cpp",
     "torch/csrc/jit/runtime/static/impl.cpp",
+    "torch/csrc/jit/runtime/static/native_ops.cpp",
     "torch/csrc/jit/runtime/static/ops.cpp",
     "torch/csrc/jit/runtime/static/passes.cpp",
     "torch/csrc/jit/tensorexpr/external_functions.cpp",
diff --git a/torch/csrc/jit/runtime/static/impl.cpp b/torch/csrc/jit/runtime/static/impl.cpp
index 3b1468c..2e329b2 100644
--- a/torch/csrc/jit/runtime/static/impl.cpp
+++ b/torch/csrc/jit/runtime/static/impl.cpp
@@ -21,6 +21,36 @@
 namespace torch {
 namespace jit {
 
+// graph must be frozen or canEnableStaticRuntime would return false if there's
+// any prim::CallMethod op left in the graph
+bool canEnableStaticRuntime(const std::shared_ptr<torch::jit::Graph>& graph) {
+  // check for sub-blocks
+  bool can_support = true;
+  bool has_blocks = false;
+  for (auto* node : graph->block()->nodes()) {
+    if (node->blocks().size() > 0) {
+      has_blocks = true;
+      VLOG(1) << "Found nested sub-blocks in graph at node: "
+              << PrintNode(node);
+    }
+    if (node->kind() == prim::Constant) {
+      continue;
+    }
+    // check if can get op from Node
+    const Operator* op = node->maybeOperator();
+    if (!op && !nativeOpIsRegistered(node->kind())) {
+      can_support = false;
+      LOG(WARNING) << "Found unsupported op: " << node->kind().toQualString();
+    }
+  }
+  if (has_blocks) {
+    LOG(WARNING)
+        << "Found nested sub-block in graph. Static Runtime doesn't support nested sub-blocks.";
+    can_support = false;
+  }
+  return can_support;
+}
+
 namespace {
 
 void OptimizeGraph(
@@ -46,20 +76,6 @@
   ConstantPropagation(graph);
 }
 
-bool CheckGraphEligibility(const std::shared_ptr<torch::jit::Graph>& graph) {
-  // check for sub-blocks
-  bool can_support = true;
-  for (auto* node : graph->block()->nodes()) {
-    for (Block* sub_block : node->blocks()) {
-      VLOG(1) << "Found nested sub-blocks in graph at node: "
-              << PrintNode(node);
-      can_support = false;
-    }
-  }
-
-  return can_support;
-}
-
 // remove unused input 0 from graph
 bool RemoveSelfFromGraphInput(std::shared_ptr<torch::jit::Graph>& graph) {
   if (graph->inputs().at(0)->type()->is_module()) {
@@ -464,8 +480,7 @@
 void PrepareGraphForStaticModule(
     std::shared_ptr<torch::jit::Graph> graph,
     const StaticModuleOptions& opts) {
-  // TODO: call CheckGraphEligibility before trying to enable static runtime
-  TORCH_CHECK(CheckGraphEligibility(graph));
+  TORCH_CHECK(canEnableStaticRuntime(graph));
   OptimizeGraph(graph, opts);
 }
 
diff --git a/torch/csrc/jit/runtime/static/impl.h b/torch/csrc/jit/runtime/static/impl.h
index e28dcc3..d4bc5cb 100644
--- a/torch/csrc/jit/runtime/static/impl.h
+++ b/torch/csrc/jit/runtime/static/impl.h
@@ -12,6 +12,8 @@
 namespace torch {
 namespace jit {
 
+bool canEnableStaticRuntime(const std::shared_ptr<torch::jit::Graph>& graph);
+
 struct TORCH_API StaticModuleOptions {
   // to batch allocate (deallocate) tensor storage for all non-escaping
   // temporary tensors
diff --git a/torch/csrc/jit/runtime/static/native_ops.cpp b/torch/csrc/jit/runtime/static/native_ops.cpp
new file mode 100644
index 0000000..e38d269
--- /dev/null
+++ b/torch/csrc/jit/runtime/static/native_ops.cpp
@@ -0,0 +1,340 @@
+#include <torch/csrc/jit/runtime/static/ops.h>
+
+#include <ATen/CPUFunctions.h>
+#include <ATen/NativeFunctions.h>
+#include <ATen/ScalarOps.h>
+#include <ATen/TensorUtils.h>
+#include <ATen/native/IndexingUtils.h>
+#include <ATen/native/Resize.h>
+#include <ATen/native/TensorAdvancedIndexing.h>
+#include <c10/util/irange.h>
+#include <torch/csrc/jit/ir/ir.h>
+#include <torch/csrc/jit/runtime/vararg_functions.h>
+
+namespace torch {
+namespace jit {
+
+// NOLINTNEXTLINE(cppcoreguidelines-avoid-non-const-global-variables)
+C10_DEFINE_REGISTRY(SRNativeOperatorRegistry, SROperatorFunctor);
+
+bool nativeOpIsRegistered(const c10::Symbol& op_name) {
+  const std::string name(op_name.toQualString());
+  return SRNativeOperatorRegistry()->Has(name);
+}
+
+std::function<void(ProcessedNode*)> getNativeOperation(Node* n) {
+  auto op_name = n->kind().toQualString();
+  if (SRNativeOperatorRegistry()->Has(op_name)) {
+    return SRNativeOperatorRegistry()->Create(op_name)->Generate(n);
+  }
+  return nullptr;
+}
+
+// NOLINTNEXTLINE(cppcoreguidelines-avoid-non-const-global-variables)
+REGISTER_NATIVE_OPERATOR_FUNCTOR(
+    prim::TupleConstruct,
+    prim_TupleConstruct,
+    [](Node* n) -> SROperator {
+      return [](ProcessedNode* p_node) {
+        // prepare inputs
+        std::vector<IValue> stack;
+        const size_t size = p_node->inputs().size();
+        stack.reserve(size);
+        for (const auto i : c10::irange(size)) {
+          stack.emplace_back(p_node->Input(i));
+        }
+        // run op
+        auto* node = p_node->node();
+        const auto& type = node->output()->type()->expect<TupleType>();
+        if (type->name().has_value()) {
+          namedTupleConstruct(stack, type, node->inputs().size());
+        } else {
+          tupleConstruct(stack, node->inputs().size());
+        }
+        // put output back
+        p_node->Output(0) = std::move(stack[0]);
+      };
+    });
+
+// NOLINTNEXTLINE(cppcoreguidelines-avoid-non-const-global-variables)
+REGISTER_NATIVE_OPERATOR_FUNCTOR(
+    prim::DictConstruct,
+    prim_DictConstruct,
+    [](Node* n) -> SROperator {
+      return [](ProcessedNode* p_node) {
+        // prepare inputs
+        std::vector<IValue> stack;
+        const size_t size = p_node->inputs().size();
+        stack.reserve(size);
+        for (const auto i : c10::irange(size)) {
+          stack.emplace_back(p_node->Input(i));
+        }
+        // run op
+        auto* node = p_node->node();
+        dictConstruct(
+            stack,
+            node->output()->type()->expectRef<DictType>(),
+            node->inputs().size());
+        // put output back
+        p_node->Output(0) = std::move(stack[0]);
+      };
+    });
+
+// NOLINTNEXTLINE(cppcoreguidelines-avoid-non-const-global-variables)
+REGISTER_NATIVE_OPERATOR_FUNCTOR(
+    aten::__getitem__,
+    aten_getitem,
+    [](Node* n) -> SROperator {
+      if (n->inputs().size() != 2) {
+        return nullptr;
+      }
+      // TODO: make __getitem__ work for other container types
+      if (n->input(0)->type()->castRaw<DictType>() == nullptr) {
+        return nullptr;
+      }
+      return [](ProcessedNode* p_node) {
+        auto dict = p_node->Input(0).toGenericDict();
+        auto key = p_node->Input(1);
+        auto value = dict.find(key);
+        TORCH_CHECK(value != dict.end(), "Key not in dict: ", key);
+        p_node->Output(0) = value->value();
+      };
+    });
+
+// NOLINTNEXTLINE(cppcoreguidelines-avoid-non-const-global-variables)
+REGISTER_NATIVE_OPERATOR_FUNCTOR(
+    prim::ListConstruct,
+    prim_ListConstruct,
+    [](Node* n) -> SROperator {
+      return [](ProcessedNode* p_node) {
+        // prepare inputs
+        std::vector<IValue> stack;
+        const size_t size = p_node->inputs().size();
+        stack.reserve(size);
+        for (const auto i : c10::irange(size)) {
+          stack.emplace_back(p_node->Input(i));
+        }
+        // run op
+        listConstruct(
+            stack,
+            p_node->node()->output()->type()->expectRef<ListType>(),
+            p_node->inputs().size());
+        // put output back
+        p_node->Output(0) = std::move(stack[0]);
+      };
+    });
+
+// NOLINTNEXTLINE(cppcoreguidelines-avoid-non-const-global-variables)
+REGISTER_NATIVE_OPERATOR_FUNCTOR(
+    prim::ListUnpack,
+    prim_ListUnpack,
+    [](Node* n) -> SROperator {
+      return [](ProcessedNode* p_node) {
+        // prepare inputs
+        std::vector<IValue> stack;
+        const size_t size = p_node->inputs().size();
+        stack.reserve(size);
+        for (const auto i : c10::irange(size)) {
+          stack.emplace_back(p_node->Input(i));
+        }
+        // run op
+        size_t num_outputs = p_node->outputs().size();
+        listUnpack(stack, num_outputs);
+        // put output back
+        DCHECK_EQ(stack.size(), num_outputs);
+        // NOLINTNEXTLINE(clang-diagnostic-sign-compare)
+        for (auto i = 0; i < num_outputs; i++) {
+          p_node->Output(i) = std::move(stack[i]);
+        }
+      };
+    });
+
+// NOLINTNEXTLINE(cppcoreguidelines-avoid-non-const-global-variables)
+REGISTER_NATIVE_OPERATOR_FUNCTOR(
+    prim::GetAttr,
+    prim_GetAttr,
+    [](Node* n) -> SROperator {
+      return [](ProcessedNode* p_node) {
+        auto module = p_node->Input(0).toObject();
+        Node* node = p_node->node();
+        const auto type = node->input()->type()->expect<ClassType>();
+        const auto& field = node->s(attr::name);
+        const auto slot = type->getAttributeSlot(field);
+        p_node->Output(0) = module->getSlot(slot);
+      };
+    });
+
+// NOLINTNEXTLINE(cppcoreguidelines-avoid-non-const-global-variables)
+REGISTER_NATIVE_OPERATOR_FUNCTOR(
+    prim::SetAttr,
+    prim_SetAttr,
+    [](Node* n) -> SROperator {
+      return [](ProcessedNode* p_node) {
+        auto module = p_node->Input(0).toObject();
+        Node* node = p_node->node();
+        const auto type = node->inputs()[0]->type()->expect<ClassType>();
+        const auto& field = node->s(attr::name);
+        const auto slot = type->getAttributeSlot(field);
+        module->setSlot(slot, p_node->Input(1));
+      };
+    });
+
+// NOLINTNEXTLINE(cppcoreguidelines-avoid-non-const-global-variables)
+REGISTER_NATIVE_OPERATOR_FUNCTOR(
+    aten::transpose,
+    aten_transpose,
+    [](Node* n) -> SROperator {
+      if (!n->matches(torch::schema(
+              "aten::transpose.int(Tensor(a) self, int dim0, int dim1) -> Tensor(a)"))) {
+        LogAndDumpSchema(n);
+        return nullptr;
+      }
+      return [](ProcessedNode* p_node) {
+        const auto& in0_t = p_node->Input(0).toTensor();
+        const auto in1_i = p_node->Input(1).toInt();
+        const auto in2_i = p_node->Input(2).toInt();
+        p_node->Output(0) = at::native::transpose(in0_t, in1_i, in2_i);
+      };
+    });
+
+// NOLINTNEXTLINE(cppcoreguidelines-avoid-non-const-global-variables)
+REGISTER_NATIVE_OPERATOR_FUNCTOR(aten::flatten, aten_flatten, [](Node* n) -> SROperator {
+  if (!n->matches(torch::schema(
+          "aten::flatten.using_ints(Tensor(a) self, int start_dim=0, int end_dim=-1) -> Tensor(a)"))) {
+    LogAndDumpSchema(n);
+    return nullptr;
+  }
+  return [](ProcessedNode* p_node) {
+    const auto& in0_t = p_node->Input(0).toTensor();
+    const auto in1_i = p_node->Input(1).toInt();
+    const auto in2_i = p_node->Input(2).toInt();
+    p_node->Output(0) = at::native::flatten(in0_t, in1_i, in2_i);
+  };
+});
+
+// NOLINTNEXTLINE(cppcoreguidelines-avoid-non-const-global-variables)
+REGISTER_NATIVE_OPERATOR_FUNCTOR(
+    aten::permute,
+    aten_permute,
+    [](Node* n) -> SROperator {
+      if (!n->matches(torch::schema(
+              "aten::permute(Tensor(a) self, int[] dims) -> Tensor(a)"))) {
+        LogAndDumpSchema(n);
+        return nullptr;
+      }
+      return [](ProcessedNode* p_node) {
+        const auto& in0_t = p_node->Input(0).toTensor();
+        const auto in1_iv = p_node->Input(1).toIntVector();
+        p_node->Output(0) = at::native::permute(in0_t, in1_iv);
+      };
+    });
+
+// NOLINTNEXTLINE(cppcoreguidelines-avoid-non-const-global-variables)
+REGISTER_NATIVE_OPERATOR_FUNCTOR(
+    aten::reshape,
+    aten_reshape,
+    [](Node* n) -> SROperator {
+      if (!n->matches(torch::schema(
+              "aten::reshape(Tensor(a) self, int[] shape) -> Tensor(a)"))) {
+        LogAndDumpSchema(n);
+        return nullptr;
+      }
+      return [](ProcessedNode* p_node) {
+        const auto& in0_t = p_node->Input(0).toTensor();
+        const auto in1_iv = p_node->Input(1).toIntVector();
+        p_node->Output(0) = at::native::reshape(in0_t, in1_iv);
+      };
+    });
+
+// NOLINTNEXTLINE(cppcoreguidelines-avoid-non-const-global-variables)
+REGISTER_NATIVE_OPERATOR_FUNCTOR(aten::slice, aten_slice, [](Node* n) -> SROperator {
+  if (!n->matches(torch::schema(
+          "aten::slice.Tensor(Tensor(a) self, int dim=0, int? start=0, int? end=9223372036854775807, int step=1) -> Tensor(a)"))) {
+    LogAndDumpSchema(n);
+    return nullptr;
+  }
+  return [](ProcessedNode* p_node) {
+    const auto& in0_t = p_node->Input(0).toTensor();
+    const auto in1_i = p_node->Input(1).toInt();
+    const auto in2_i = p_node->Input(2).toInt();
+    const auto in3_i = p_node->Input(3).toInt();
+    const auto in4_i = p_node->Input(4).toInt();
+    p_node->Output(0) = at::native::slice(in0_t, in1_i, in2_i, in3_i, in4_i);
+  };
+});
+
+// NOLINTNEXTLINE(cppcoreguidelines-avoid-non-const-global-variables)
+REGISTER_NATIVE_OPERATOR_FUNCTOR(aten::narrow, aten_narrow, [](Node* n) -> SROperator {
+  if (!n->matches(torch::schema(
+          "aten::narrow(Tensor(a) self, int dim, int start, int length) -> Tensor(a)")) &&
+      !n->matches(torch::schema(
+          "aten::narrow.Tensor(Tensor(a) self, int dim, Tensor start, int length) -> Tensor(a)"))) {
+    LogAndDumpSchema(n);
+    return nullptr;
+  }
+  return [](ProcessedNode* p_node) {
+    const auto& self = p_node->Input(0).toTensor(); // self
+    const auto dim = p_node->Input(1).toInt(); // dim
+    int64_t start = 0;
+    if (p_node->Input(2).isScalar()) {
+      start = p_node->Input(2).toInt();
+    } else {
+      auto& t = p_node->Input(2).toTensor();
+      start = t.item<int64_t>();
+    }
+    const auto length = p_node->Input(3).toInt(); // length
+    TORCH_CHECK(
+        self.dim() > 0, "narrow() cannot be applied to a 0-dim tensor.");
+    auto cur_size = self.sizes()[dim];
+    if (start != cur_size && start < 0) { // start being the end is valid, but
+                                          // not a valid dim specification.
+      start = at::maybe_wrap_dim(start, cur_size);
+    }
+    TORCH_CHECK(
+        length >= 0 && start <= cur_size - length,
+        "start (",
+        start,
+        ") + length (",
+        length,
+        ") exceeds dimension size (",
+        cur_size,
+        ").");
+    p_node->Output(0) = at::native::slice(self, dim, start, start + length, 1);
+  };
+});
+
+// NOLINTNEXTLINE(cppcoreguidelines-avoid-non-const-global-variables)
+REGISTER_NATIVE_OPERATOR_FUNCTOR(aten::to, aten_to, [](Node* n) -> SROperator {
+  if (!n->matches(torch::schema(
+          "aten::to.other(Tensor(a) self, Tensor other, bool non_blocking=False, bool copy=False, MemoryFormat? memory_format=None) -> Tensor(a)")) &&
+      !n->matches(torch::schema(
+          "aten::to.dtype(Tensor(a) self, ScalarType dtype, bool non_blocking=False, bool copy=False, MemoryFormat? memory_format=None) -> Tensor(a)"))) {
+    LogAndDumpSchema(n);
+    return nullptr;
+  }
+  return [](ProcessedNode* p_node) {
+    const auto& in0_t = p_node->Input(0).toTensor();
+    const auto in2_i = p_node->Input(2).toBool();
+    const auto in3_i = p_node->Input(3).toBool();
+    const auto in4_o = p_node->Input(4).toOptional<at::MemoryFormat>();
+    if (p_node->Input(1).isTensor()) {
+      // to.other(Tensor(a) self, Tensor other, bool non_blocking=False, bool
+      // copy=False, MemoryFormat? memory_format=None) -> Tensor(a)
+      const auto in1_t = p_node->Input(1).toTensor();
+      p_node->Output(0) = at::native::to(in0_t, in1_t, in2_i, in3_i, in4_o);
+    } else {
+      // to.dtype(Tensor(a) self, ScalarType dtype, bool non_blocking=False,
+      // bool copy=False, MemoryFormat? memory_format=None) -> Tensor(a)
+      const auto in1_i = p_node->Input(1).toScalarType();
+      p_node->Output(0) = at::native::to(in0_t, in1_i, in2_i, in3_i, in4_o);
+    }
+    // in case that Output(0) is an alias of in0_t, copy the tensor.
+    if (p_node->Output(0).toTensor().unsafeGetTensorImpl() ==
+        in0_t.unsafeGetTensorImpl()) {
+      p_node->Output(0) = in0_t.clone();
+    }
+  };
+});
+
+} // namespace jit
+} // namespace torch
diff --git a/torch/csrc/jit/runtime/static/ops.cpp b/torch/csrc/jit/runtime/static/ops.cpp
index 8cc6cae..236ac43 100644
--- a/torch/csrc/jit/runtime/static/ops.cpp
+++ b/torch/csrc/jit/runtime/static/ops.cpp
@@ -1062,246 +1062,6 @@
   };
 });
 
-std::function<void(ProcessedNode*)> getNativeOperation(Node* n) {
-  if (n->kind() == c10::Symbol::fromQualString("aten::transpose")) {
-    if (!n->matches(torch::schema(
-            "aten::transpose.int(Tensor(a) self, int dim0, int dim1) -> Tensor(a)"))) {
-      LogAndDumpSchema(n);
-      return nullptr;
-    }
-    return [](ProcessedNode* p_node) {
-      const auto& in0_t = p_node->Input(0).toTensor();
-      const auto in1_i = p_node->Input(1).toInt();
-      const auto in2_i = p_node->Input(2).toInt();
-      p_node->Output(0) = at::native::transpose(in0_t, in1_i, in2_i);
-    };
-  } else if (n->kind() == c10::Symbol::fromQualString("aten::flatten")) {
-    if (!n->matches(torch::schema(
-            "aten::flatten.using_ints(Tensor(a) self, int start_dim=0, int end_dim=-1) -> Tensor(a)"))) {
-      LogAndDumpSchema(n);
-      return nullptr;
-    }
-    return [](ProcessedNode* p_node) {
-      const auto& in0_t = p_node->Input(0).toTensor();
-      const auto in1_i = p_node->Input(1).toInt();
-      const auto in2_i = p_node->Input(2).toInt();
-      p_node->Output(0) = at::native::flatten(in0_t, in1_i, in2_i);
-    };
-  } else if (n->kind() == prim::TupleConstruct) {
-    return [](ProcessedNode* p_node) {
-      // prepare inputs
-      std::vector<IValue> stack;
-      const size_t size = p_node->inputs().size();
-      stack.reserve(size);
-      for (const auto i : c10::irange(size)) {
-        stack.emplace_back(p_node->Input(i));
-      }
-      // run op
-      auto* node = p_node->node();
-      const auto& type = node->output()->type()->expect<TupleType>();
-      if (type->name().has_value()) {
-        namedTupleConstruct(stack, type, node->inputs().size());
-      } else {
-        tupleConstruct(stack, node->inputs().size());
-      }
-      // put output back
-      p_node->Output(0) = std::move(stack[0]);
-    };
-  } else if (n->kind() == prim::DictConstruct) {
-    return [](ProcessedNode* p_node) {
-      // prepare inputs
-      std::vector<IValue> stack;
-      const size_t size = p_node->inputs().size();
-      stack.reserve(size);
-      for (const auto i : c10::irange(size)) {
-        stack.emplace_back(p_node->Input(i));
-      }
-      // run op
-      auto* node = p_node->node();
-      dictConstruct(
-          stack,
-          node->output()->type()->expectRef<DictType>(),
-          node->inputs().size());
-      // put output back
-      p_node->Output(0) = std::move(stack[0]);
-    };
-  } else if (n->kind() == c10::Symbol::fromQualString("aten::__getitem__")) {
-    if (n->inputs().size() != 2) {
-      return nullptr;
-    }
-    // TODO: make __getitem__ work for other container types
-    if (n->input(0)->type()->castRaw<DictType>() == nullptr) {
-      return nullptr;
-    }
-    return [](ProcessedNode* p_node) {
-      auto dict = p_node->Input(0).toGenericDict();
-      auto key = p_node->Input(1);
-      auto value = dict.find(key);
-      TORCH_CHECK(value != dict.end(), "Key not in dict: ", key);
-      p_node->Output(0) = value->value();
-    };
-  } else if (n->kind() == prim::ListConstruct) {
-    return [](ProcessedNode* p_node) {
-      // prepare inputs
-      std::vector<IValue> stack;
-      const size_t size = p_node->inputs().size();
-      stack.reserve(size);
-      for (const auto i : c10::irange(size)) {
-        stack.emplace_back(p_node->Input(i));
-      }
-      // run op
-      listConstruct(
-          stack,
-          p_node->node()->output()->type()->expectRef<ListType>(),
-          p_node->inputs().size());
-      // put output back
-      p_node->Output(0) = std::move(stack[0]);
-    };
-  } else if (n->kind() == prim::ListUnpack) {
-    return [](ProcessedNode* p_node) {
-      // prepare inputs
-      std::vector<IValue> stack;
-      const size_t size = p_node->inputs().size();
-      stack.reserve(size);
-      for (const auto i : c10::irange(size)) {
-        stack.emplace_back(p_node->Input(i));
-      }
-      // run op
-      size_t num_outputs = p_node->outputs().size();
-      listUnpack(stack, num_outputs);
-      // put output back
-      DCHECK_EQ(stack.size(), num_outputs);
-      // NOLINTNEXTLINE(clang-diagnostic-sign-compare)
-      for (auto i = 0; i < num_outputs; i++) {
-        p_node->Output(i) = std::move(stack[i]);
-      }
-    };
-  } else if (n->kind() == c10::Symbol::fromQualString("aten::permute")) {
-    if (!n->matches(torch::schema(
-            "aten::permute(Tensor(a) self, int[] dims) -> Tensor(a)"))) {
-      LogAndDumpSchema(n);
-      return nullptr;
-    }
-    return [](ProcessedNode* p_node) {
-      const auto& in0_t = p_node->Input(0).toTensor();
-      const auto in1_iv = p_node->Input(1).toIntVector();
-      p_node->Output(0) = at::native::permute(in0_t, in1_iv);
-    };
-  } else if (n->kind() == c10::Symbol::fromQualString("aten::reshape")) {
-    if (!n->matches(torch::schema(
-            "aten::reshape(Tensor(a) self, int[] shape) -> Tensor(a)"))) {
-      LogAndDumpSchema(n);
-      return nullptr;
-    }
-    return [](ProcessedNode* p_node) {
-      const auto& in0_t = p_node->Input(0).toTensor();
-      const auto in1_iv = p_node->Input(1).toIntVector();
-      p_node->Output(0) = at::native::reshape(in0_t, in1_iv);
-    };
-  } else if (n->kind() == c10::Symbol::fromQualString("aten::slice")) {
-    if (!n->matches(torch::schema(
-            "aten::slice.Tensor(Tensor(a) self, int dim=0, int? start=0, int? end=9223372036854775807, int step=1) -> Tensor(a)"))) {
-      LogAndDumpSchema(n);
-      return nullptr;
-    }
-    return [](ProcessedNode* p_node) {
-      const auto& in0_t = p_node->Input(0).toTensor();
-      const auto in1_i = p_node->Input(1).toInt();
-      const auto in2_i = p_node->Input(2).toInt();
-      const auto in3_i = p_node->Input(3).toInt();
-      const auto in4_i = p_node->Input(4).toInt();
-      p_node->Output(0) = at::native::slice(in0_t, in1_i, in2_i, in3_i, in4_i);
-    };
-  } else if (n->kind() == c10::Symbol::fromQualString("aten::narrow")) {
-    if (!n->matches(torch::schema(
-            "aten::narrow(Tensor(a) self, int dim, int start, int length) -> Tensor(a)")) &&
-        !n->matches(torch::schema(
-            "aten::narrow.Tensor(Tensor(a) self, int dim, Tensor start, int length) -> Tensor(a)"))) {
-      LogAndDumpSchema(n);
-      return nullptr;
-    }
-    return [](ProcessedNode* p_node) {
-      const auto& self = p_node->Input(0).toTensor(); // self
-      const auto dim = p_node->Input(1).toInt(); // dim
-      int64_t start = 0;
-      if (p_node->Input(2).isScalar()) {
-        start = p_node->Input(2).toInt();
-      } else {
-        auto& t = p_node->Input(2).toTensor();
-        start = t.item<int64_t>();
-      }
-      const auto length = p_node->Input(3).toInt(); // length
-      TORCH_CHECK(
-          self.dim() > 0, "narrow() cannot be applied to a 0-dim tensor.");
-      auto cur_size = self.sizes()[dim];
-      if (start != cur_size && start < 0) { // start being the end is valid, but
-                                            // not a valid dim specification.
-        start = at::maybe_wrap_dim(start, cur_size);
-      }
-      TORCH_CHECK(
-          length >= 0 && start <= cur_size - length,
-          "start (",
-          start,
-          ") + length (",
-          length,
-          ") exceeds dimension size (",
-          cur_size,
-          ").");
-      p_node->Output(0) =
-          at::native::slice(self, dim, start, start + length, 1);
-    };
-  } else if (n->kind() == c10::Symbol::fromQualString("aten::to")) {
-    if (!n->matches(torch::schema(
-            "aten::to.other(Tensor(a) self, Tensor other, bool non_blocking=False, bool copy=False, MemoryFormat? memory_format=None) -> Tensor(a)")) &&
-        !n->matches(torch::schema(
-            "aten::to.dtype(Tensor(a) self, ScalarType dtype, bool non_blocking=False, bool copy=False, MemoryFormat? memory_format=None) -> Tensor(a)"))) {
-      LogAndDumpSchema(n);
-      return nullptr;
-    }
-    return [](ProcessedNode* p_node) {
-      const auto& in0_t = p_node->Input(0).toTensor();
-      const auto in2_i = p_node->Input(2).toBool();
-      const auto in3_i = p_node->Input(3).toBool();
-      const auto in4_o = p_node->Input(4).toOptional<at::MemoryFormat>();
-      if (p_node->Input(1).isTensor()) {
-        // to.other(Tensor(a) self, Tensor other, bool non_blocking=False, bool
-        // copy=False, MemoryFormat? memory_format=None) -> Tensor(a)
-        const auto in1_t = p_node->Input(1).toTensor();
-        p_node->Output(0) = at::native::to(in0_t, in1_t, in2_i, in3_i, in4_o);
-      } else {
-        // to.dtype(Tensor(a) self, ScalarType dtype, bool non_blocking=False,
-        // bool copy=False, MemoryFormat? memory_format=None) -> Tensor(a)
-        const auto in1_i = p_node->Input(1).toScalarType();
-        p_node->Output(0) = at::native::to(in0_t, in1_i, in2_i, in3_i, in4_o);
-      }
-      // in case that Output(0) is an alias of in0_t, copy the tensor.
-      if (p_node->Output(0).toTensor().unsafeGetTensorImpl() ==
-          in0_t.unsafeGetTensorImpl()) {
-        p_node->Output(0) = in0_t.clone();
-      }
-    };
-  } else if (n->kind() == prim::GetAttr) {
-    return [](ProcessedNode* p_node) {
-      auto module = p_node->Input(0).toObject();
-      Node* node = p_node->node();
-      const auto type = node->input()->type()->expect<ClassType>();
-      const auto& field = node->s(attr::name);
-      const auto slot = type->getAttributeSlot(field);
-      p_node->Output(0) = module->getSlot(slot);
-    };
-  } else if (n->kind() == prim::SetAttr) {
-    return [](ProcessedNode* p_node) {
-      auto module = p_node->Input(0).toObject();
-      Node* node = p_node->node();
-      const auto type = node->inputs()[0]->type()->expect<ClassType>();
-      const auto& field = node->s(attr::name);
-      const auto slot = type->getAttributeSlot(field);
-      module->setSlot(slot, p_node->Input(1));
-    };
-  }
-  return nullptr;
-}
-
 // NOLINTNEXTLINE(cppcoreguidelines-avoid-non-const-global-variables)
 REGISTER_OPERATOR_FUNCTOR(aten::embedding_bag, aten_embedding_bag, [](Node* n) -> SROperator {
   // TODO: Support only 9 args once the old signature has been removed.
diff --git a/torch/csrc/jit/runtime/static/ops.h b/torch/csrc/jit/runtime/static/ops.h
index 021cd21..5d6bbf8 100644
--- a/torch/csrc/jit/runtime/static/ops.h
+++ b/torch/csrc/jit/runtime/static/ops.h
@@ -29,8 +29,6 @@
 
 C10_DECLARE_REGISTRY(SROperatorRegistry, SROperatorFunctor);
 
-// TODO: reuse_inp reuse_out can be deprecated with further analysis
-// try to avoid this API.
 #define REGISTER_OPERATOR_FUNCTOR(name, id, ...)             \
   struct SROperatorFunctor_##id : public SROperatorFunctor { \
     const SROpFunctor fn = __VA_ARGS__;                      \
@@ -40,6 +38,17 @@
   };                                                         \
   C10_REGISTER_CLASS(SROperatorRegistry, name, SROperatorFunctor_##id);
 
+C10_DECLARE_REGISTRY(SRNativeOperatorRegistry, SROperatorFunctor);
+#define REGISTER_NATIVE_OPERATOR_FUNCTOR(name, id, ...)            \
+  struct SRNativeOperatorFunctor_##id : public SROperatorFunctor { \
+    const SROpFunctor fn = __VA_ARGS__;                            \
+    SROperator Generate(Node* n) override {                        \
+      return fn(n);                                                \
+    }                                                              \
+  };                                                               \
+  C10_REGISTER_CLASS(                                              \
+      SRNativeOperatorRegistry, name, SRNativeOperatorFunctor_##id);
+
 inline at::Tensor create_empty_from(const at::Tensor& t) {
   return at::detail::empty_cpu(
       {0},
@@ -117,13 +126,17 @@
   TORCH_INTERNAL_ASSERT_DEBUG_ONLY(checkResizedDataPtr(t));
 }
 
+// check if an op has an out variant registered in Static Runtime
 bool opIsRegistered(const c10::Symbol& op_name);
+// check if Static Runtime can run an op natively.
+// prim ops that are implemented directly in the jit interpreter are implemented
+// as native ops in Static Runtime
+bool nativeOpIsRegistered(const c10::Symbol& op_name);
 
 bool canReuseInputsOutputs(Node* n);
 bool isOptimizableContainerType(Node* n);
 
 std::function<void(ProcessedNode*)> getOutOfPlaceOperation(Node* n);
-
 std::function<void(ProcessedNode*)> getNativeOperation(Node* n);
 
 inline std::string PrintNode(const Node* node) {