[Static Runtime] Check unsupported up when enabling static runtime (#61613)
Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/61613
Reviewed By: ajyu, movefast1990
Differential Revision: D29663466
fbshipit-source-id: d819903b7227f534c0a4fffa5eeea2b5c0c04750
diff --git a/benchmarks/static_runtime/test_scripts.h b/benchmarks/static_runtime/test_scripts.h
index f24060a..14bc0e5 100644
--- a/benchmarks/static_runtime/test_scripts.h
+++ b/benchmarks/static_runtime/test_scripts.h
@@ -493,3 +493,17 @@
def forward(self, inp: Tensor, mat1: Tensor, mat2: Tensor, beta: float, alpha: float):
return torch.addmm(inp, mat1, mat2, alpha=alpha, beta=beta).clone()
)JIT";
+
+const auto if_script = R"JIT(
+ def forward(self, a: Tensor, b: Tensor, x: bool):
+ c = (a + b).relu().half().float()
+ d = b * c
+ if x:
+ e = a.flatten().half() * b.flatten().half()
+ else:
+ e = a.flatten().half() + b.flatten().half()
+ f = e.float().relu()
+ g = {"d": d, "b": b}
+ h = {"e": e, "f": f}
+ return [g, h]
+)JIT";
diff --git a/benchmarks/static_runtime/test_static_runtime.cc b/benchmarks/static_runtime/test_static_runtime.cc
index b4e27e9..4f50060 100644
--- a/benchmarks/static_runtime/test_static_runtime.cc
+++ b/benchmarks/static_runtime/test_static_runtime.cc
@@ -178,6 +178,16 @@
return nullptr;
}
+bool testCanEnableStaticRuntime(const std::string& jit_script) {
+ script::Module module("module");
+ module.define(jit_script);
+
+ Method method = module.get_method("forward");
+ auto graph = module.get_method("forward").graph();
+
+ // here we do not freeze graph
+ return torch::jit::canEnableStaticRuntime(graph);
+}
} // namespace
TEST(StaticRuntime, InPlace) {
@@ -186,6 +196,11 @@
EXPECT_FALSE(testHasInplaceOp(sigmoid_out_script));
}
+TEST(StaticRuntime, CanEnableStaticRuntime) {
+ EXPECT_TRUE(testCanEnableStaticRuntime(reshape_inplace_script));
+ EXPECT_FALSE(testCanEnableStaticRuntime(if_script));
+}
+
TEST(StaticRuntime, NestedOutput) {
auto run_test = [](std::vector<int64_t> shapes) {
auto a = at::randn(shapes);
diff --git a/tools/build_variables.bzl b/tools/build_variables.bzl
index cd7d017..0dcfaf3 100644
--- a/tools/build_variables.bzl
+++ b/tools/build_variables.bzl
@@ -313,6 +313,7 @@
core_sources_full = core_sources_full_mobile + [
"torch/csrc/jit/runtime/static/fusion.cpp",
"torch/csrc/jit/runtime/static/impl.cpp",
+ "torch/csrc/jit/runtime/static/native_ops.cpp",
"torch/csrc/jit/runtime/static/ops.cpp",
"torch/csrc/jit/runtime/static/passes.cpp",
"torch/csrc/jit/tensorexpr/external_functions.cpp",
diff --git a/torch/csrc/jit/runtime/static/impl.cpp b/torch/csrc/jit/runtime/static/impl.cpp
index 3b1468c..2e329b2 100644
--- a/torch/csrc/jit/runtime/static/impl.cpp
+++ b/torch/csrc/jit/runtime/static/impl.cpp
@@ -21,6 +21,36 @@
namespace torch {
namespace jit {
+// graph must be frozen or canEnableStaticRuntime would return false if there's
+// any prim::CallMethod op left in the graph
+bool canEnableStaticRuntime(const std::shared_ptr<torch::jit::Graph>& graph) {
+ // check for sub-blocks
+ bool can_support = true;
+ bool has_blocks = false;
+ for (auto* node : graph->block()->nodes()) {
+ if (node->blocks().size() > 0) {
+ has_blocks = true;
+ VLOG(1) << "Found nested sub-blocks in graph at node: "
+ << PrintNode(node);
+ }
+ if (node->kind() == prim::Constant) {
+ continue;
+ }
+ // check if can get op from Node
+ const Operator* op = node->maybeOperator();
+ if (!op && !nativeOpIsRegistered(node->kind())) {
+ can_support = false;
+ LOG(WARNING) << "Found unsupported op: " << node->kind().toQualString();
+ }
+ }
+ if (has_blocks) {
+ LOG(WARNING)
+ << "Found nested sub-block in graph. Static Runtime doesn't support nested sub-blocks.";
+ can_support = false;
+ }
+ return can_support;
+}
+
namespace {
void OptimizeGraph(
@@ -46,20 +76,6 @@
ConstantPropagation(graph);
}
-bool CheckGraphEligibility(const std::shared_ptr<torch::jit::Graph>& graph) {
- // check for sub-blocks
- bool can_support = true;
- for (auto* node : graph->block()->nodes()) {
- for (Block* sub_block : node->blocks()) {
- VLOG(1) << "Found nested sub-blocks in graph at node: "
- << PrintNode(node);
- can_support = false;
- }
- }
-
- return can_support;
-}
-
// remove unused input 0 from graph
bool RemoveSelfFromGraphInput(std::shared_ptr<torch::jit::Graph>& graph) {
if (graph->inputs().at(0)->type()->is_module()) {
@@ -464,8 +480,7 @@
void PrepareGraphForStaticModule(
std::shared_ptr<torch::jit::Graph> graph,
const StaticModuleOptions& opts) {
- // TODO: call CheckGraphEligibility before trying to enable static runtime
- TORCH_CHECK(CheckGraphEligibility(graph));
+ TORCH_CHECK(canEnableStaticRuntime(graph));
OptimizeGraph(graph, opts);
}
diff --git a/torch/csrc/jit/runtime/static/impl.h b/torch/csrc/jit/runtime/static/impl.h
index e28dcc3..d4bc5cb 100644
--- a/torch/csrc/jit/runtime/static/impl.h
+++ b/torch/csrc/jit/runtime/static/impl.h
@@ -12,6 +12,8 @@
namespace torch {
namespace jit {
+bool canEnableStaticRuntime(const std::shared_ptr<torch::jit::Graph>& graph);
+
struct TORCH_API StaticModuleOptions {
// to batch allocate (deallocate) tensor storage for all non-escaping
// temporary tensors
diff --git a/torch/csrc/jit/runtime/static/native_ops.cpp b/torch/csrc/jit/runtime/static/native_ops.cpp
new file mode 100644
index 0000000..e38d269
--- /dev/null
+++ b/torch/csrc/jit/runtime/static/native_ops.cpp
@@ -0,0 +1,340 @@
+#include <torch/csrc/jit/runtime/static/ops.h>
+
+#include <ATen/CPUFunctions.h>
+#include <ATen/NativeFunctions.h>
+#include <ATen/ScalarOps.h>
+#include <ATen/TensorUtils.h>
+#include <ATen/native/IndexingUtils.h>
+#include <ATen/native/Resize.h>
+#include <ATen/native/TensorAdvancedIndexing.h>
+#include <c10/util/irange.h>
+#include <torch/csrc/jit/ir/ir.h>
+#include <torch/csrc/jit/runtime/vararg_functions.h>
+
+namespace torch {
+namespace jit {
+
+// NOLINTNEXTLINE(cppcoreguidelines-avoid-non-const-global-variables)
+C10_DEFINE_REGISTRY(SRNativeOperatorRegistry, SROperatorFunctor);
+
+bool nativeOpIsRegistered(const c10::Symbol& op_name) {
+ const std::string name(op_name.toQualString());
+ return SRNativeOperatorRegistry()->Has(name);
+}
+
+std::function<void(ProcessedNode*)> getNativeOperation(Node* n) {
+ auto op_name = n->kind().toQualString();
+ if (SRNativeOperatorRegistry()->Has(op_name)) {
+ return SRNativeOperatorRegistry()->Create(op_name)->Generate(n);
+ }
+ return nullptr;
+}
+
+// NOLINTNEXTLINE(cppcoreguidelines-avoid-non-const-global-variables)
+REGISTER_NATIVE_OPERATOR_FUNCTOR(
+ prim::TupleConstruct,
+ prim_TupleConstruct,
+ [](Node* n) -> SROperator {
+ return [](ProcessedNode* p_node) {
+ // prepare inputs
+ std::vector<IValue> stack;
+ const size_t size = p_node->inputs().size();
+ stack.reserve(size);
+ for (const auto i : c10::irange(size)) {
+ stack.emplace_back(p_node->Input(i));
+ }
+ // run op
+ auto* node = p_node->node();
+ const auto& type = node->output()->type()->expect<TupleType>();
+ if (type->name().has_value()) {
+ namedTupleConstruct(stack, type, node->inputs().size());
+ } else {
+ tupleConstruct(stack, node->inputs().size());
+ }
+ // put output back
+ p_node->Output(0) = std::move(stack[0]);
+ };
+ });
+
+// NOLINTNEXTLINE(cppcoreguidelines-avoid-non-const-global-variables)
+REGISTER_NATIVE_OPERATOR_FUNCTOR(
+ prim::DictConstruct,
+ prim_DictConstruct,
+ [](Node* n) -> SROperator {
+ return [](ProcessedNode* p_node) {
+ // prepare inputs
+ std::vector<IValue> stack;
+ const size_t size = p_node->inputs().size();
+ stack.reserve(size);
+ for (const auto i : c10::irange(size)) {
+ stack.emplace_back(p_node->Input(i));
+ }
+ // run op
+ auto* node = p_node->node();
+ dictConstruct(
+ stack,
+ node->output()->type()->expectRef<DictType>(),
+ node->inputs().size());
+ // put output back
+ p_node->Output(0) = std::move(stack[0]);
+ };
+ });
+
+// NOLINTNEXTLINE(cppcoreguidelines-avoid-non-const-global-variables)
+REGISTER_NATIVE_OPERATOR_FUNCTOR(
+ aten::__getitem__,
+ aten_getitem,
+ [](Node* n) -> SROperator {
+ if (n->inputs().size() != 2) {
+ return nullptr;
+ }
+ // TODO: make __getitem__ work for other container types
+ if (n->input(0)->type()->castRaw<DictType>() == nullptr) {
+ return nullptr;
+ }
+ return [](ProcessedNode* p_node) {
+ auto dict = p_node->Input(0).toGenericDict();
+ auto key = p_node->Input(1);
+ auto value = dict.find(key);
+ TORCH_CHECK(value != dict.end(), "Key not in dict: ", key);
+ p_node->Output(0) = value->value();
+ };
+ });
+
+// NOLINTNEXTLINE(cppcoreguidelines-avoid-non-const-global-variables)
+REGISTER_NATIVE_OPERATOR_FUNCTOR(
+ prim::ListConstruct,
+ prim_ListConstruct,
+ [](Node* n) -> SROperator {
+ return [](ProcessedNode* p_node) {
+ // prepare inputs
+ std::vector<IValue> stack;
+ const size_t size = p_node->inputs().size();
+ stack.reserve(size);
+ for (const auto i : c10::irange(size)) {
+ stack.emplace_back(p_node->Input(i));
+ }
+ // run op
+ listConstruct(
+ stack,
+ p_node->node()->output()->type()->expectRef<ListType>(),
+ p_node->inputs().size());
+ // put output back
+ p_node->Output(0) = std::move(stack[0]);
+ };
+ });
+
+// NOLINTNEXTLINE(cppcoreguidelines-avoid-non-const-global-variables)
+REGISTER_NATIVE_OPERATOR_FUNCTOR(
+ prim::ListUnpack,
+ prim_ListUnpack,
+ [](Node* n) -> SROperator {
+ return [](ProcessedNode* p_node) {
+ // prepare inputs
+ std::vector<IValue> stack;
+ const size_t size = p_node->inputs().size();
+ stack.reserve(size);
+ for (const auto i : c10::irange(size)) {
+ stack.emplace_back(p_node->Input(i));
+ }
+ // run op
+ size_t num_outputs = p_node->outputs().size();
+ listUnpack(stack, num_outputs);
+ // put output back
+ DCHECK_EQ(stack.size(), num_outputs);
+ // NOLINTNEXTLINE(clang-diagnostic-sign-compare)
+ for (auto i = 0; i < num_outputs; i++) {
+ p_node->Output(i) = std::move(stack[i]);
+ }
+ };
+ });
+
+// NOLINTNEXTLINE(cppcoreguidelines-avoid-non-const-global-variables)
+REGISTER_NATIVE_OPERATOR_FUNCTOR(
+ prim::GetAttr,
+ prim_GetAttr,
+ [](Node* n) -> SROperator {
+ return [](ProcessedNode* p_node) {
+ auto module = p_node->Input(0).toObject();
+ Node* node = p_node->node();
+ const auto type = node->input()->type()->expect<ClassType>();
+ const auto& field = node->s(attr::name);
+ const auto slot = type->getAttributeSlot(field);
+ p_node->Output(0) = module->getSlot(slot);
+ };
+ });
+
+// NOLINTNEXTLINE(cppcoreguidelines-avoid-non-const-global-variables)
+REGISTER_NATIVE_OPERATOR_FUNCTOR(
+ prim::SetAttr,
+ prim_SetAttr,
+ [](Node* n) -> SROperator {
+ return [](ProcessedNode* p_node) {
+ auto module = p_node->Input(0).toObject();
+ Node* node = p_node->node();
+ const auto type = node->inputs()[0]->type()->expect<ClassType>();
+ const auto& field = node->s(attr::name);
+ const auto slot = type->getAttributeSlot(field);
+ module->setSlot(slot, p_node->Input(1));
+ };
+ });
+
+// NOLINTNEXTLINE(cppcoreguidelines-avoid-non-const-global-variables)
+REGISTER_NATIVE_OPERATOR_FUNCTOR(
+ aten::transpose,
+ aten_transpose,
+ [](Node* n) -> SROperator {
+ if (!n->matches(torch::schema(
+ "aten::transpose.int(Tensor(a) self, int dim0, int dim1) -> Tensor(a)"))) {
+ LogAndDumpSchema(n);
+ return nullptr;
+ }
+ return [](ProcessedNode* p_node) {
+ const auto& in0_t = p_node->Input(0).toTensor();
+ const auto in1_i = p_node->Input(1).toInt();
+ const auto in2_i = p_node->Input(2).toInt();
+ p_node->Output(0) = at::native::transpose(in0_t, in1_i, in2_i);
+ };
+ });
+
+// NOLINTNEXTLINE(cppcoreguidelines-avoid-non-const-global-variables)
+REGISTER_NATIVE_OPERATOR_FUNCTOR(aten::flatten, aten_flatten, [](Node* n) -> SROperator {
+ if (!n->matches(torch::schema(
+ "aten::flatten.using_ints(Tensor(a) self, int start_dim=0, int end_dim=-1) -> Tensor(a)"))) {
+ LogAndDumpSchema(n);
+ return nullptr;
+ }
+ return [](ProcessedNode* p_node) {
+ const auto& in0_t = p_node->Input(0).toTensor();
+ const auto in1_i = p_node->Input(1).toInt();
+ const auto in2_i = p_node->Input(2).toInt();
+ p_node->Output(0) = at::native::flatten(in0_t, in1_i, in2_i);
+ };
+});
+
+// NOLINTNEXTLINE(cppcoreguidelines-avoid-non-const-global-variables)
+REGISTER_NATIVE_OPERATOR_FUNCTOR(
+ aten::permute,
+ aten_permute,
+ [](Node* n) -> SROperator {
+ if (!n->matches(torch::schema(
+ "aten::permute(Tensor(a) self, int[] dims) -> Tensor(a)"))) {
+ LogAndDumpSchema(n);
+ return nullptr;
+ }
+ return [](ProcessedNode* p_node) {
+ const auto& in0_t = p_node->Input(0).toTensor();
+ const auto in1_iv = p_node->Input(1).toIntVector();
+ p_node->Output(0) = at::native::permute(in0_t, in1_iv);
+ };
+ });
+
+// NOLINTNEXTLINE(cppcoreguidelines-avoid-non-const-global-variables)
+REGISTER_NATIVE_OPERATOR_FUNCTOR(
+ aten::reshape,
+ aten_reshape,
+ [](Node* n) -> SROperator {
+ if (!n->matches(torch::schema(
+ "aten::reshape(Tensor(a) self, int[] shape) -> Tensor(a)"))) {
+ LogAndDumpSchema(n);
+ return nullptr;
+ }
+ return [](ProcessedNode* p_node) {
+ const auto& in0_t = p_node->Input(0).toTensor();
+ const auto in1_iv = p_node->Input(1).toIntVector();
+ p_node->Output(0) = at::native::reshape(in0_t, in1_iv);
+ };
+ });
+
+// NOLINTNEXTLINE(cppcoreguidelines-avoid-non-const-global-variables)
+REGISTER_NATIVE_OPERATOR_FUNCTOR(aten::slice, aten_slice, [](Node* n) -> SROperator {
+ if (!n->matches(torch::schema(
+ "aten::slice.Tensor(Tensor(a) self, int dim=0, int? start=0, int? end=9223372036854775807, int step=1) -> Tensor(a)"))) {
+ LogAndDumpSchema(n);
+ return nullptr;
+ }
+ return [](ProcessedNode* p_node) {
+ const auto& in0_t = p_node->Input(0).toTensor();
+ const auto in1_i = p_node->Input(1).toInt();
+ const auto in2_i = p_node->Input(2).toInt();
+ const auto in3_i = p_node->Input(3).toInt();
+ const auto in4_i = p_node->Input(4).toInt();
+ p_node->Output(0) = at::native::slice(in0_t, in1_i, in2_i, in3_i, in4_i);
+ };
+});
+
+// NOLINTNEXTLINE(cppcoreguidelines-avoid-non-const-global-variables)
+REGISTER_NATIVE_OPERATOR_FUNCTOR(aten::narrow, aten_narrow, [](Node* n) -> SROperator {
+ if (!n->matches(torch::schema(
+ "aten::narrow(Tensor(a) self, int dim, int start, int length) -> Tensor(a)")) &&
+ !n->matches(torch::schema(
+ "aten::narrow.Tensor(Tensor(a) self, int dim, Tensor start, int length) -> Tensor(a)"))) {
+ LogAndDumpSchema(n);
+ return nullptr;
+ }
+ return [](ProcessedNode* p_node) {
+ const auto& self = p_node->Input(0).toTensor(); // self
+ const auto dim = p_node->Input(1).toInt(); // dim
+ int64_t start = 0;
+ if (p_node->Input(2).isScalar()) {
+ start = p_node->Input(2).toInt();
+ } else {
+ auto& t = p_node->Input(2).toTensor();
+ start = t.item<int64_t>();
+ }
+ const auto length = p_node->Input(3).toInt(); // length
+ TORCH_CHECK(
+ self.dim() > 0, "narrow() cannot be applied to a 0-dim tensor.");
+ auto cur_size = self.sizes()[dim];
+ if (start != cur_size && start < 0) { // start being the end is valid, but
+ // not a valid dim specification.
+ start = at::maybe_wrap_dim(start, cur_size);
+ }
+ TORCH_CHECK(
+ length >= 0 && start <= cur_size - length,
+ "start (",
+ start,
+ ") + length (",
+ length,
+ ") exceeds dimension size (",
+ cur_size,
+ ").");
+ p_node->Output(0) = at::native::slice(self, dim, start, start + length, 1);
+ };
+});
+
+// NOLINTNEXTLINE(cppcoreguidelines-avoid-non-const-global-variables)
+REGISTER_NATIVE_OPERATOR_FUNCTOR(aten::to, aten_to, [](Node* n) -> SROperator {
+ if (!n->matches(torch::schema(
+ "aten::to.other(Tensor(a) self, Tensor other, bool non_blocking=False, bool copy=False, MemoryFormat? memory_format=None) -> Tensor(a)")) &&
+ !n->matches(torch::schema(
+ "aten::to.dtype(Tensor(a) self, ScalarType dtype, bool non_blocking=False, bool copy=False, MemoryFormat? memory_format=None) -> Tensor(a)"))) {
+ LogAndDumpSchema(n);
+ return nullptr;
+ }
+ return [](ProcessedNode* p_node) {
+ const auto& in0_t = p_node->Input(0).toTensor();
+ const auto in2_i = p_node->Input(2).toBool();
+ const auto in3_i = p_node->Input(3).toBool();
+ const auto in4_o = p_node->Input(4).toOptional<at::MemoryFormat>();
+ if (p_node->Input(1).isTensor()) {
+ // to.other(Tensor(a) self, Tensor other, bool non_blocking=False, bool
+ // copy=False, MemoryFormat? memory_format=None) -> Tensor(a)
+ const auto in1_t = p_node->Input(1).toTensor();
+ p_node->Output(0) = at::native::to(in0_t, in1_t, in2_i, in3_i, in4_o);
+ } else {
+ // to.dtype(Tensor(a) self, ScalarType dtype, bool non_blocking=False,
+ // bool copy=False, MemoryFormat? memory_format=None) -> Tensor(a)
+ const auto in1_i = p_node->Input(1).toScalarType();
+ p_node->Output(0) = at::native::to(in0_t, in1_i, in2_i, in3_i, in4_o);
+ }
+ // in case that Output(0) is an alias of in0_t, copy the tensor.
+ if (p_node->Output(0).toTensor().unsafeGetTensorImpl() ==
+ in0_t.unsafeGetTensorImpl()) {
+ p_node->Output(0) = in0_t.clone();
+ }
+ };
+});
+
+} // namespace jit
+} // namespace torch
diff --git a/torch/csrc/jit/runtime/static/ops.cpp b/torch/csrc/jit/runtime/static/ops.cpp
index 8cc6cae..236ac43 100644
--- a/torch/csrc/jit/runtime/static/ops.cpp
+++ b/torch/csrc/jit/runtime/static/ops.cpp
@@ -1062,246 +1062,6 @@
};
});
-std::function<void(ProcessedNode*)> getNativeOperation(Node* n) {
- if (n->kind() == c10::Symbol::fromQualString("aten::transpose")) {
- if (!n->matches(torch::schema(
- "aten::transpose.int(Tensor(a) self, int dim0, int dim1) -> Tensor(a)"))) {
- LogAndDumpSchema(n);
- return nullptr;
- }
- return [](ProcessedNode* p_node) {
- const auto& in0_t = p_node->Input(0).toTensor();
- const auto in1_i = p_node->Input(1).toInt();
- const auto in2_i = p_node->Input(2).toInt();
- p_node->Output(0) = at::native::transpose(in0_t, in1_i, in2_i);
- };
- } else if (n->kind() == c10::Symbol::fromQualString("aten::flatten")) {
- if (!n->matches(torch::schema(
- "aten::flatten.using_ints(Tensor(a) self, int start_dim=0, int end_dim=-1) -> Tensor(a)"))) {
- LogAndDumpSchema(n);
- return nullptr;
- }
- return [](ProcessedNode* p_node) {
- const auto& in0_t = p_node->Input(0).toTensor();
- const auto in1_i = p_node->Input(1).toInt();
- const auto in2_i = p_node->Input(2).toInt();
- p_node->Output(0) = at::native::flatten(in0_t, in1_i, in2_i);
- };
- } else if (n->kind() == prim::TupleConstruct) {
- return [](ProcessedNode* p_node) {
- // prepare inputs
- std::vector<IValue> stack;
- const size_t size = p_node->inputs().size();
- stack.reserve(size);
- for (const auto i : c10::irange(size)) {
- stack.emplace_back(p_node->Input(i));
- }
- // run op
- auto* node = p_node->node();
- const auto& type = node->output()->type()->expect<TupleType>();
- if (type->name().has_value()) {
- namedTupleConstruct(stack, type, node->inputs().size());
- } else {
- tupleConstruct(stack, node->inputs().size());
- }
- // put output back
- p_node->Output(0) = std::move(stack[0]);
- };
- } else if (n->kind() == prim::DictConstruct) {
- return [](ProcessedNode* p_node) {
- // prepare inputs
- std::vector<IValue> stack;
- const size_t size = p_node->inputs().size();
- stack.reserve(size);
- for (const auto i : c10::irange(size)) {
- stack.emplace_back(p_node->Input(i));
- }
- // run op
- auto* node = p_node->node();
- dictConstruct(
- stack,
- node->output()->type()->expectRef<DictType>(),
- node->inputs().size());
- // put output back
- p_node->Output(0) = std::move(stack[0]);
- };
- } else if (n->kind() == c10::Symbol::fromQualString("aten::__getitem__")) {
- if (n->inputs().size() != 2) {
- return nullptr;
- }
- // TODO: make __getitem__ work for other container types
- if (n->input(0)->type()->castRaw<DictType>() == nullptr) {
- return nullptr;
- }
- return [](ProcessedNode* p_node) {
- auto dict = p_node->Input(0).toGenericDict();
- auto key = p_node->Input(1);
- auto value = dict.find(key);
- TORCH_CHECK(value != dict.end(), "Key not in dict: ", key);
- p_node->Output(0) = value->value();
- };
- } else if (n->kind() == prim::ListConstruct) {
- return [](ProcessedNode* p_node) {
- // prepare inputs
- std::vector<IValue> stack;
- const size_t size = p_node->inputs().size();
- stack.reserve(size);
- for (const auto i : c10::irange(size)) {
- stack.emplace_back(p_node->Input(i));
- }
- // run op
- listConstruct(
- stack,
- p_node->node()->output()->type()->expectRef<ListType>(),
- p_node->inputs().size());
- // put output back
- p_node->Output(0) = std::move(stack[0]);
- };
- } else if (n->kind() == prim::ListUnpack) {
- return [](ProcessedNode* p_node) {
- // prepare inputs
- std::vector<IValue> stack;
- const size_t size = p_node->inputs().size();
- stack.reserve(size);
- for (const auto i : c10::irange(size)) {
- stack.emplace_back(p_node->Input(i));
- }
- // run op
- size_t num_outputs = p_node->outputs().size();
- listUnpack(stack, num_outputs);
- // put output back
- DCHECK_EQ(stack.size(), num_outputs);
- // NOLINTNEXTLINE(clang-diagnostic-sign-compare)
- for (auto i = 0; i < num_outputs; i++) {
- p_node->Output(i) = std::move(stack[i]);
- }
- };
- } else if (n->kind() == c10::Symbol::fromQualString("aten::permute")) {
- if (!n->matches(torch::schema(
- "aten::permute(Tensor(a) self, int[] dims) -> Tensor(a)"))) {
- LogAndDumpSchema(n);
- return nullptr;
- }
- return [](ProcessedNode* p_node) {
- const auto& in0_t = p_node->Input(0).toTensor();
- const auto in1_iv = p_node->Input(1).toIntVector();
- p_node->Output(0) = at::native::permute(in0_t, in1_iv);
- };
- } else if (n->kind() == c10::Symbol::fromQualString("aten::reshape")) {
- if (!n->matches(torch::schema(
- "aten::reshape(Tensor(a) self, int[] shape) -> Tensor(a)"))) {
- LogAndDumpSchema(n);
- return nullptr;
- }
- return [](ProcessedNode* p_node) {
- const auto& in0_t = p_node->Input(0).toTensor();
- const auto in1_iv = p_node->Input(1).toIntVector();
- p_node->Output(0) = at::native::reshape(in0_t, in1_iv);
- };
- } else if (n->kind() == c10::Symbol::fromQualString("aten::slice")) {
- if (!n->matches(torch::schema(
- "aten::slice.Tensor(Tensor(a) self, int dim=0, int? start=0, int? end=9223372036854775807, int step=1) -> Tensor(a)"))) {
- LogAndDumpSchema(n);
- return nullptr;
- }
- return [](ProcessedNode* p_node) {
- const auto& in0_t = p_node->Input(0).toTensor();
- const auto in1_i = p_node->Input(1).toInt();
- const auto in2_i = p_node->Input(2).toInt();
- const auto in3_i = p_node->Input(3).toInt();
- const auto in4_i = p_node->Input(4).toInt();
- p_node->Output(0) = at::native::slice(in0_t, in1_i, in2_i, in3_i, in4_i);
- };
- } else if (n->kind() == c10::Symbol::fromQualString("aten::narrow")) {
- if (!n->matches(torch::schema(
- "aten::narrow(Tensor(a) self, int dim, int start, int length) -> Tensor(a)")) &&
- !n->matches(torch::schema(
- "aten::narrow.Tensor(Tensor(a) self, int dim, Tensor start, int length) -> Tensor(a)"))) {
- LogAndDumpSchema(n);
- return nullptr;
- }
- return [](ProcessedNode* p_node) {
- const auto& self = p_node->Input(0).toTensor(); // self
- const auto dim = p_node->Input(1).toInt(); // dim
- int64_t start = 0;
- if (p_node->Input(2).isScalar()) {
- start = p_node->Input(2).toInt();
- } else {
- auto& t = p_node->Input(2).toTensor();
- start = t.item<int64_t>();
- }
- const auto length = p_node->Input(3).toInt(); // length
- TORCH_CHECK(
- self.dim() > 0, "narrow() cannot be applied to a 0-dim tensor.");
- auto cur_size = self.sizes()[dim];
- if (start != cur_size && start < 0) { // start being the end is valid, but
- // not a valid dim specification.
- start = at::maybe_wrap_dim(start, cur_size);
- }
- TORCH_CHECK(
- length >= 0 && start <= cur_size - length,
- "start (",
- start,
- ") + length (",
- length,
- ") exceeds dimension size (",
- cur_size,
- ").");
- p_node->Output(0) =
- at::native::slice(self, dim, start, start + length, 1);
- };
- } else if (n->kind() == c10::Symbol::fromQualString("aten::to")) {
- if (!n->matches(torch::schema(
- "aten::to.other(Tensor(a) self, Tensor other, bool non_blocking=False, bool copy=False, MemoryFormat? memory_format=None) -> Tensor(a)")) &&
- !n->matches(torch::schema(
- "aten::to.dtype(Tensor(a) self, ScalarType dtype, bool non_blocking=False, bool copy=False, MemoryFormat? memory_format=None) -> Tensor(a)"))) {
- LogAndDumpSchema(n);
- return nullptr;
- }
- return [](ProcessedNode* p_node) {
- const auto& in0_t = p_node->Input(0).toTensor();
- const auto in2_i = p_node->Input(2).toBool();
- const auto in3_i = p_node->Input(3).toBool();
- const auto in4_o = p_node->Input(4).toOptional<at::MemoryFormat>();
- if (p_node->Input(1).isTensor()) {
- // to.other(Tensor(a) self, Tensor other, bool non_blocking=False, bool
- // copy=False, MemoryFormat? memory_format=None) -> Tensor(a)
- const auto in1_t = p_node->Input(1).toTensor();
- p_node->Output(0) = at::native::to(in0_t, in1_t, in2_i, in3_i, in4_o);
- } else {
- // to.dtype(Tensor(a) self, ScalarType dtype, bool non_blocking=False,
- // bool copy=False, MemoryFormat? memory_format=None) -> Tensor(a)
- const auto in1_i = p_node->Input(1).toScalarType();
- p_node->Output(0) = at::native::to(in0_t, in1_i, in2_i, in3_i, in4_o);
- }
- // in case that Output(0) is an alias of in0_t, copy the tensor.
- if (p_node->Output(0).toTensor().unsafeGetTensorImpl() ==
- in0_t.unsafeGetTensorImpl()) {
- p_node->Output(0) = in0_t.clone();
- }
- };
- } else if (n->kind() == prim::GetAttr) {
- return [](ProcessedNode* p_node) {
- auto module = p_node->Input(0).toObject();
- Node* node = p_node->node();
- const auto type = node->input()->type()->expect<ClassType>();
- const auto& field = node->s(attr::name);
- const auto slot = type->getAttributeSlot(field);
- p_node->Output(0) = module->getSlot(slot);
- };
- } else if (n->kind() == prim::SetAttr) {
- return [](ProcessedNode* p_node) {
- auto module = p_node->Input(0).toObject();
- Node* node = p_node->node();
- const auto type = node->inputs()[0]->type()->expect<ClassType>();
- const auto& field = node->s(attr::name);
- const auto slot = type->getAttributeSlot(field);
- module->setSlot(slot, p_node->Input(1));
- };
- }
- return nullptr;
-}
-
// NOLINTNEXTLINE(cppcoreguidelines-avoid-non-const-global-variables)
REGISTER_OPERATOR_FUNCTOR(aten::embedding_bag, aten_embedding_bag, [](Node* n) -> SROperator {
// TODO: Support only 9 args once the old signature has been removed.
diff --git a/torch/csrc/jit/runtime/static/ops.h b/torch/csrc/jit/runtime/static/ops.h
index 021cd21..5d6bbf8 100644
--- a/torch/csrc/jit/runtime/static/ops.h
+++ b/torch/csrc/jit/runtime/static/ops.h
@@ -29,8 +29,6 @@
C10_DECLARE_REGISTRY(SROperatorRegistry, SROperatorFunctor);
-// TODO: reuse_inp reuse_out can be deprecated with further analysis
-// try to avoid this API.
#define REGISTER_OPERATOR_FUNCTOR(name, id, ...) \
struct SROperatorFunctor_##id : public SROperatorFunctor { \
const SROpFunctor fn = __VA_ARGS__; \
@@ -40,6 +38,17 @@
}; \
C10_REGISTER_CLASS(SROperatorRegistry, name, SROperatorFunctor_##id);
+C10_DECLARE_REGISTRY(SRNativeOperatorRegistry, SROperatorFunctor);
+#define REGISTER_NATIVE_OPERATOR_FUNCTOR(name, id, ...) \
+ struct SRNativeOperatorFunctor_##id : public SROperatorFunctor { \
+ const SROpFunctor fn = __VA_ARGS__; \
+ SROperator Generate(Node* n) override { \
+ return fn(n); \
+ } \
+ }; \
+ C10_REGISTER_CLASS( \
+ SRNativeOperatorRegistry, name, SRNativeOperatorFunctor_##id);
+
inline at::Tensor create_empty_from(const at::Tensor& t) {
return at::detail::empty_cpu(
{0},
@@ -117,13 +126,17 @@
TORCH_INTERNAL_ASSERT_DEBUG_ONLY(checkResizedDataPtr(t));
}
+// check if an op has an out variant registered in Static Runtime
bool opIsRegistered(const c10::Symbol& op_name);
+// check if Static Runtime can run an op natively.
+// prim ops that are implemented directly in the jit interpreter are implemented
+// as native ops in Static Runtime
+bool nativeOpIsRegistered(const c10::Symbol& op_name);
bool canReuseInputsOutputs(Node* n);
bool isOptimizableContainerType(Node* n);
std::function<void(ProcessedNode*)> getOutOfPlaceOperation(Node* n);
-
std::function<void(ProcessedNode*)> getNativeOperation(Node* n);
inline std::string PrintNode(const Node* node) {