blob: a84fb2f1aa3df2c39049720328e651ad6562c742 [file] [log] [blame]
#include "torch/csrc/jit/passes/canonicalize.h"
namespace torch { namespace jit {
// Canonicalize a graph, renumbering it so that all structurally equivalent
// graphs have same numbers.
std::shared_ptr<Graph> Canonicalize(const std::shared_ptr<Graph>& graph) {
auto r = std::make_shared<Graph>(graph->current_scope());
std::unordered_map<Value*, Value*> rn_env;
auto rn_fn = [&](Value* v) { return rn_env.at(v); };
for (auto* input : graph->inputs()) {
auto* r_input = r->addInput();
r_input->copyMetadata(input);
rn_env[input] = r_input;
}
for (auto* node : graph->nodes()) {
auto* r_node = r->createClone(node, rn_fn);
r->appendNode(r_node);
auto outputs = node->outputs();
auto r_outputs = r_node->outputs();
for (size_t i = 0; i < outputs.size(); i++) {
rn_env[outputs.at(i)] = r_outputs.at(i);
}
if (node->hasAttribute(attr::Subgraph)) {
r_node->g_(attr::Subgraph, Canonicalize(node->g(attr::Subgraph)));
}
}
for (auto* output : graph->outputs()) {
r->registerOutput(rn_fn(output));
}
return r;
}
}}