Revert D27902824: static runtime support for fb::equally_split
Test Plan: revert-hammer
Differential Revision:
D27902824 (https://github.com/pytorch/pytorch/commit/a4e47ea152d9942f85e5c8718ba206da95b53f9e)
Original commit changeset: 7855047c3bd4
fbshipit-source-id: a46834418ce98826871cd604d1a01f0ff8f23d7f
diff --git a/torch/csrc/jit/runtime/static/impl.cpp b/torch/csrc/jit/runtime/static/impl.cpp
index c2ff7ef..0ac8fbf 100644
--- a/torch/csrc/jit/runtime/static/impl.cpp
+++ b/torch/csrc/jit/runtime/static/impl.cpp
@@ -38,7 +38,7 @@
#ifdef FBCODE_CAFFE2
if (opts.enable_out_variant) {
ReplaceWithCopy(graph);
- FuseListUnpack(graph);
+ FuseSigridTransformsListUnpack(graph);
}
#endif
ConstantPropagation(graph);
diff --git a/torch/csrc/jit/runtime/static/passes.cpp b/torch/csrc/jit/runtime/static/passes.cpp
index 4ccdbdc..6f4c204 100644
--- a/torch/csrc/jit/runtime/static/passes.cpp
+++ b/torch/csrc/jit/runtime/static/passes.cpp
@@ -477,15 +477,14 @@
}
}
-void FuseListUnpack(std::shared_ptr<torch::jit::Graph>& graph) {
+void FuseSigridTransformsListUnpack(std::shared_ptr<torch::jit::Graph>& graph) {
auto nodes = graph->nodes();
for (auto it = nodes.begin(); it != nodes.end(); ++it) {
- Node* node = *it;
- const char* node_qual_string = node->kind().toQualString();
- if (strcmp(node_qual_string, "fb::sigrid_transforms") == 0 ||
- strcmp(node_qual_string, "fb::sigrid_transforms_torch_bind") == 0 ||
- strcmp(node_qual_string, "fb::equally_split") == 0) {
- const Value* sigrid_out = node->outputs()[0];
+ Node* sigrid_node = *it;
+ auto kind = sigrid_node->kind();
+ if (strcmp(kind.toQualString(), "fb::sigrid_transforms") == 0 ||
+ strcmp(kind.toQualString(), "fb::sigrid_transforms_torch_bind") == 0) {
+ const Value* sigrid_out = sigrid_node->outputs()[0];
if (sigrid_out->uses().size() > 1) {
continue;
}
@@ -502,7 +501,7 @@
// handle outputs
for (Value* out : list_unpack_outputs) {
- Value* new_out = node->addOutput();
+ Value* new_out = sigrid_node->addOutput();
new_out->copyMetadata(out);
out->replaceAllUsesWith(new_out);
}
@@ -511,9 +510,10 @@
++it_next; // it_next points to list_unpack
it_next.destroyCurrent(); // remove list_unpack
- node->eraseOutput(0);
+ sigrid_node->eraseOutput(0);
}
}
}
+
} // namespace jit
} // namespace torch
diff --git a/torch/csrc/jit/runtime/static/passes.h b/torch/csrc/jit/runtime/static/passes.h
index 2becd86..9f4d519 100644
--- a/torch/csrc/jit/runtime/static/passes.h
+++ b/torch/csrc/jit/runtime/static/passes.h
@@ -5,7 +5,8 @@
TORCH_API void FuseInferenceOpsForSparseNN(
std::shared_ptr<torch::jit::Graph>& graph);
-TORCH_API void FuseListUnpack(std::shared_ptr<torch::jit::Graph>& graph);
+TORCH_API void FuseSigridTransformsListUnpack(
+ std::shared_ptr<torch::jit::Graph>& graph);
TORCH_API void ReplaceWithCopy(std::shared_ptr<torch::jit::Graph>& graph);