Revert D26333953: [StaticRuntime] Clean up output references and remove dead code
Test Plan: revert-hammer
Differential Revision:
D26333953 (https://github.com/pytorch/pytorch/commit/0c9d72b5e11e9d3e706eaf1c07406a0cebfe27c7)
Original commit changeset: cadc0595ad6a
fbshipit-source-id: 75d0b33099342653cd8867b129139325789aee6c
diff --git a/benchmarks/static_runtime/test_static_runtime.cc b/benchmarks/static_runtime/test_static_runtime.cc
index a220a17..c8f2d5a 100644
--- a/benchmarks/static_runtime/test_static_runtime.cc
+++ b/benchmarks/static_runtime/test_static_runtime.cc
@@ -220,16 +220,8 @@
at::Tensor output_1 = getTensor(module.forward(inputs));
// run static runtime
- c10::IValue output_ivalue = runtime.run(inputs, {});
-
- at::Tensor output_2 = getTensor(output_ivalue);
+ at::Tensor output_2 = getTensor(runtime.run(inputs, {}));
EXPECT_TRUE(output_1.equal(output_2));
-
- // check for output aliasing
- EXPECT_EQ(output_ivalue.use_count(), 1);
- output_ivalue = IValue();
-
- EXPECT_EQ(output_2.getIntrusivePtr().use_count(), 1);
}
// check for input aliasing (deep & wide does not have ops
@@ -264,16 +256,8 @@
{"wide", wide}});
// run static runtime
- c10::IValue output_ivalue = runtime.run({}, kwargs);
-
- at::Tensor output_2 = getTensor(output_ivalue);
+ at::Tensor output_2 = getTensor(runtime.run({}, kwargs));
EXPECT_TRUE(output_1.equal(output_2));
-
- // check for output aliasing
- EXPECT_EQ(output_ivalue.use_count(), 1);
- output_ivalue = IValue();
-
- EXPECT_EQ(output_2.getIntrusivePtr().use_count(), 1);
}
EXPECT_EQ(ad_emb_packed.getIntrusivePtr().use_count(), 1);
diff --git a/torch/csrc/jit/runtime/static/fusion.cpp b/torch/csrc/jit/runtime/static/fusion.cpp
index 6134753..b4ef21e 100644
--- a/torch/csrc/jit/runtime/static/fusion.cpp
+++ b/torch/csrc/jit/runtime/static/fusion.cpp
@@ -23,7 +23,7 @@
Operation createStaticSubgraphRuntime(const Node* node) {
auto g = torch::jit::PrepareForStaticRuntime(node->g(attr::Subgraph));
auto runtime = std::make_shared<torch::jit::StaticRuntime>(g);
- auto num_inputs = runtime->num_inputs();
+ auto num_inputs = runtime->get_inference_module()->input_regs.size();
return [runtime, num_inputs](Stack* stack) {
RECORD_FUNCTION("Static Runtime", std::vector<c10::IValue>());
auto inps = torch::jit::last(stack, num_inputs);
diff --git a/torch/csrc/jit/runtime/static/impl.cpp b/torch/csrc/jit/runtime/static/impl.cpp
index 87526d4..56c9fa0 100644
--- a/torch/csrc/jit/runtime/static/impl.cpp
+++ b/torch/csrc/jit/runtime/static/impl.cpp
@@ -198,12 +198,143 @@
}
return can_reuse;
}
+
+size_t AssignRegisters(
+ const std::shared_ptr<torch::jit::Graph>& graph,
+ std::unordered_map<Value*, size_t>& value_to_reg,
+ std::vector<Value*>& values,
+ std::vector<size_t>& input_regs,
+ std::vector<size_t>& output_regs,
+ bool optimize_memory) {
+ auto lm = LivenessMap(graph);
+ auto optimizable_values = GetOptimizableValues(graph);
+
+ size_t num_regs = 0;
+ size_t reused_regs = 0;
+ std::unordered_map<size_t, std::set<Value*>> reg_to_val;
+ auto getReg = [&](Value* v) -> size_t {
+ if (!optimize_memory) {
+ return num_regs++;
+ }
+ TORCH_CHECK(!value_to_reg.count(v));
+ auto iter = lm.first.find(v);
+ if (iter == lm.first.end()) {
+ return num_regs++;
+ }
+ if (!optimizable_values.count(v)) {
+ return num_regs++;
+ }
+ if (lm.second.count(v)) {
+ return num_regs++;
+ }
+ const auto& live_values = iter->second;
+ // iterate through all the allocated registers
+ // and check for potential re-use, greedily
+ for (const auto& v2r : value_to_reg) {
+ auto candidate_v = v2r.first;
+
+ if (!optimizable_values.count(candidate_v)) {
+ continue;
+ }
+ if (lm.second.count(candidate_v)) {
+ continue;
+ }
+
+ // Only re-use float* tensors
+ auto t = candidate_v->type()->cast<TensorType>();
+ if (!t) {
+ continue;
+ }
+ // TODO audit this assumption (passes tests, but is scary)
+ if (t->scalarType() && *(t->scalarType()) != at::kFloat) {
+ continue;
+ }
+ // TODO
+ // if (*(t->scalarType()) != at::kFloat) {
+ // continue;
+ //}
+ if (!live_values.count(candidate_v)) {
+ bool already_used = false;
+ for (auto use : reg_to_val.at(v2r.second)) {
+ if (live_values.count(use)) {
+ already_used = true;
+ }
+ }
+ if (already_used) {
+ continue;
+ }
+ reused_regs++;
+ return v2r.second;
+ }
+ }
+ return num_regs++;
+ };
+
+ // assign register to Value*
+ for (Value* input : graph->inputs()) {
+ TORCH_CHECK(value_to_reg.count(input) == 0);
+ auto reg = getReg(input);
+ value_to_reg[input] = reg;
+ reg_to_val[reg].insert(input);
+ input_regs.push_back(reg);
+ }
+ for (Node* node : graph->nodes()) {
+ for (Value* input : node->inputs()) {
+ TORCH_CHECK(value_to_reg.count(input) > 0);
+ }
+ for (Value* output : node->outputs()) {
+ TORCH_CHECK(
+ value_to_reg.count(output) == 0, "the graph needs to be in SSA form");
+ auto reg = getReg(output);
+ value_to_reg[output] = reg;
+ reg_to_val[reg].insert(output);
+ }
+ }
+ TORCH_CHECK(graph->outputs().size() > 0);
+ for (Value* output : graph->outputs()) {
+ TORCH_CHECK(value_to_reg.count(output) > 0);
+ output_regs.push_back(value_to_reg[output]);
+ }
+
+ values.resize(value_to_reg.size());
+ for (const auto& p : value_to_reg) {
+ values[p.second] = p.first;
+ }
+ return reused_regs;
+}
+
+// Internal values are discarded after run if
+// opts_.cleanup_activations is true.
+void DeduceInternalValues(
+ const std::shared_ptr<torch::jit::Graph>& graph,
+ const std::unordered_map<Value*, size_t>& value_to_reg,
+ std::vector<size_t>& internals) {
+ std::unordered_set<Value*> outputs{
+ graph->outputs().begin(), graph->outputs().end()};
+ for (Node* node : graph->nodes()) {
+ if (node->kind() != prim::Constant) {
+ for (Value* output : node->outputs()) {
+ if (outputs.count(output) == 0) {
+ internals.push_back(value_to_reg.at(output));
+ }
+ }
+ }
+ }
+}
} // namespace
void InferenceModule::init() {
OptimizeGraph(graph);
CheckGraphEligibility(graph);
RemoveSelfFromGraphInput(graph);
+ reused_regs = AssignRegisters(
+ graph,
+ value_to_reg,
+ values,
+ input_regs,
+ output_regs,
+ opts.optimize_memory);
+ DeduceInternalValues(graph, value_to_reg, internals);
}
InferenceModule::InferenceModule(
@@ -298,6 +429,10 @@
}
}
+size_t StaticRuntime::num_outputs() const {
+ return module_->output_regs.size();
+}
+
std::vector<at::Tensor> StaticRuntime::run(
const std::vector<at::Tensor>& inps) {
std::vector<c10::IValue> stack;
@@ -375,13 +510,11 @@
std::vector<c10::IValue> outputs;
outputs.reserve(num_outputs());
for (auto i = 0; i < num_outputs(); ++i) {
- // use move here. Otherwise, clean up outputs_[i] explicitly
- outputs.emplace_back(std::move(*outputs_[i]));
+ outputs.emplace_back(Output(i));
}
- return c10::ivalue::Tuple::create(std::move(outputs));
+ return c10::ivalue::Tuple::create(outputs);
}
- // use move here. Otherwise, clean up outputs_[0] explicitly
- return std::move(*outputs_[0]);
+ return Output(0);
}
void StaticRuntime::benchmark(
@@ -428,8 +561,8 @@
<< " bytes" << std::endl;
}
if (module_->opts.optimize_memory) {
- // std::cout << "Total number of reused registers: " << module_->reused_regs
- // << std::endl;
+ std::cout << "Total number of reused registers: " << module_->reused_regs
+ << std::endl;
}
}
@@ -537,7 +670,6 @@
// collect register indices of outputs of ops with out variant
std::unordered_set<Value*> managed_values;
std::unordered_set<IValue*> unmanaged_value_set;
- std::unordered_map<Value*, IValue*> values_map;
for (ProcessedNode& pnode : runtime->get_nodes()) {
bool should_manage = pnode.has_out_variant();
if (should_manage && isViewOp(pnode.get_node())) {
@@ -553,23 +685,31 @@
}
if (should_manage) {
// Types are stored in the underlying TorchScript IR
- for (size_t i = 0; i < pnode.outputs().size(); i++) {
- Value* out = pnode.get_node()->output(i);
+ for (Value* out : pnode.get_node()->outputs()) {
if (out->type()->cast<TensorType>()) {
managed_values.insert(out);
- values_map[out] = &pnode.Output(i);
}
}
} else {
for (auto i = 0; i < pnode.outputs().size(); ++i) {
unmanaged_value_set.insert(&pnode.Output(i));
- values_map[pnode.get_node()->output(i)] = &pnode.Output(i);
}
}
}
const InferenceModule* module = runtime->get_inference_module();
+ // remove model outputs from managed_values
+ for (Value* output : module->graph->outputs()) {
+ managed_values.erase(output);
+ }
+ for (IValue* output : runtime->outputs()) {
+ unmanaged_value_set.erase(output);
+ }
+ for (IValue* out : unmanaged_value_set) {
+ unmanaged_values_.emplace_back(out);
+ }
+
// remove tensors in output List/Tuple from managed_values
for (Value* output : module->graph->outputs()) {
Node* output_node = output->node();
@@ -577,28 +717,10 @@
output_node->kind() == prim::ListConstruct) {
for (Value* input : output_node->inputs()) {
managed_values.erase(input);
- // Elements in Tuples and Lists are refcounted. MemoryPlanner should not
- // hold refs of elements in output Tuples/Lists
- if (graph_input_values.count(input) == 0) {
- unmanaged_value_set.insert(values_map[input]);
- }
}
}
}
- // remove model outputs from managed_values and unmanaged_value_set
- for (Value* output : module->graph->outputs()) {
- managed_values.erase(output);
- }
- for (IValue* output : runtime->outputs()) {
- unmanaged_value_set.erase(output);
- }
-
- // unmanaged_value_set => unmanaged_values_
- for (IValue* out : unmanaged_value_set) {
- unmanaged_values_.emplace_back(out);
- }
-
// some Values should share storage, this map will
// keep track of the index into managed_storage_
std::unordered_map<Value*, size_t> shared;
diff --git a/torch/csrc/jit/runtime/static/impl.h b/torch/csrc/jit/runtime/static/impl.h
index 4385c3f..a8e2cb3 100644
--- a/torch/csrc/jit/runtime/static/impl.h
+++ b/torch/csrc/jit/runtime/static/impl.h
@@ -71,6 +71,12 @@
std::shared_ptr<torch::jit::Graph> graph;
std::unique_ptr<c10::FunctionSchema> schema;
+ std::unordered_map<Value*, size_t> value_to_reg;
+ std::vector<Value*> values; // useful for debugging
+ std::vector<size_t> input_regs; // inputs to the graph
+ std::vector<size_t> output_regs; // outputs of the graph
+ std::vector<size_t> internals;
+ size_t reused_regs = 0;
InferenceModuleOptions opts;
private:
@@ -153,13 +159,11 @@
return nodes_;
}
- [[nodiscard]] size_t num_inputs() const {
- return inputs_.size();
+ const std::vector<IValue>& get_registers() {
+ return reg_;
}
- [[nodiscard]] size_t num_outputs() const {
- return outputs_.size();
- }
+ size_t num_outputs() const;
inline const std::vector<IValue*>& outputs() const {
return outputs_;
@@ -170,6 +174,7 @@
std::shared_ptr<InferenceModule> module_;
StaticRuntimeOptions opts_;
// IValue table (including inputs, outputs, intermediates, and weights)
+ std::vector<IValue> reg_;
std::vector<IValue> constants_;
std::vector<IValue> inputs_;
std::vector<IValue*> outputs_;