blob: 03463af842a4dc6cc85fc8d0d0010d3714b39e63 [file] [log] [blame]
#include "torch/csrc/jit/init_pass.h"
#include <unordered_map>
namespace torch { namespace jit {
namespace {
using constructor_type = std::function<Node*(Graph*, PythonOp*)>;
Node * trivial_ctor(Graph *graph, PythonOp *p) {
return graph->create(stringToSymbol(p->name()), p->inputs());
}
Node * chunk_ctor(Graph * graph, PythonOp * p) {
auto num_chunks = PyLong_AsLong(p->scalar_args[0]);
auto dim = PyLong_AsLong(p->scalar_args[1]);
return graph->createChunk(p->input(),num_chunks,dim);
}
std::unordered_map<std::string, constructor_type> constructors = {
{"Add", trivial_ctor},
{"Mul", trivial_ctor},
{"Sigmoid", trivial_ctor},
{"Tanh", trivial_ctor},
{"Chunk", chunk_ctor},
{"Negate", trivial_ctor},
};
} // anonymous namespace
void MatchJITOps(std::shared_ptr<Graph>& graph) {
auto nodes = graph->nodes();
for(auto it = nodes.begin(); it != nodes.end(); ++it) {
PythonOp *p = (*it)->cast<PythonOp>();
if (!p) {
continue;
}
auto ctor_it = constructors.find(p->name());
if (ctor_it == constructors.end()) {
continue;
}
auto& constructor = ctor_it->second;
// Set up the Node that will replace p
auto new_op = constructor(graph.get(), p);
new_op->insertAfter(p);
if (new_op->hasMultipleOutputs()) {
auto uses = p->uses();
for (auto & use : uses) {
// Invariant: Node replacements never make use of opaque handles,
// so we drop them before we do the node replacement. E.g., if
// an op previously returned %1 : Tensor, %2 : Handle, after
// replacement it will only return %1 : Tensor.
// NB: This code implies that init_pass doesn't work with backwards.
if (use.user->type()->kind() != TypeKind::HandleType) continue;
JIT_ASSERT(use.user->uses().size() == 0);
use.user->destroy();
}
p->replaceAllUsesWith(new_op);
} else {
// PythonOps are always multireturn.
// We need to replace the tensor Select node and remove the handle Select.
JIT_ASSERT(p->uses().size() == 2);
auto tensor_select = p->uses()[0].user;
auto handle_select = p->uses()[1].user;
JIT_ASSERT(tensor_select->type()->kind() == TypeKind::TensorType);
JIT_ASSERT(handle_select->type()->kind() == TypeKind::HandleType);
JIT_ASSERT(tensor_select->kind() == kSelect);
JIT_ASSERT(handle_select->kind() == kSelect);
JIT_ASSERT(handle_select->uses().size() == 0);
new_op->setType(tensor_select->type());
tensor_select->replaceAllUsesWith(new_op);
tensor_select->destroy();
handle_select->destroy();
}
it.destroyCurrent();
}
}
}}