blob: c836c4079e0b9a637c9e1cdcfd3e02e8ad142b17 [file] [log] [blame]
#include "torch/csrc/utils/pybind.h"
#include "torch/csrc/jit/passes/onnx.h"
#include "torch/csrc/autograd/function.h"
#include "torch/csrc/autograd/symbolic.h"
#include "torch/csrc/utils/functional.h"
#include <unordered_map>
#include <sstream>
namespace torch { namespace jit {
namespace {
bool hasHandleOutput(Node *node) {
auto last_output = node->outputs().back();
return last_output->typeOption() && last_output->typeOption()->kind() == TypeKind::HandleType;
}
bool hasUsedHandle(Node *node) {
if (!hasHandleOutput(node)) return false;
return node->outputs().back()->uses().size() > 0;
}
} // anonymous namespace
// Transform PythonOps and Cpp Ops into Node's that match ONNX semantics.
// Argument aten indicates whether we should export ops as "ATen" ONNX ops if possible.
void ToONNX(std::shared_ptr<tracer::TracingState>& state, bool aten) {
// Check that the tracing state is live (it should be, because
// you were supposed to request zero derivatives.)
if (state->is_expired()) {
throw std::logic_error("ToONNX: tracing state is expired");
}
auto new_graph = std::make_shared<Graph>(state->graph->scope_root());
std::unordered_map<void*, Value*> new_buffer_map;
torch::autograd::SymbolicContext ctx;
ctx.graph = new_graph.get();
std::unordered_map<Value*, Value*> env;
py::object onnx = py::module::import("torch.onnx");
py::object onnx_symbolic = py::module::import("torch.onnx.symbolic");
// Returns a node that n maps to in the new graph
auto envFn = [&env](Value * n) -> Value* {
auto it = env.find(n);
JIT_ASSERTM(it != env.end(), "Dangling node reference");
JIT_ASSERTM(it->second, "Unused node was subsequently used");
return it->second;
};
// Initialize context and environment
for (auto input : state->graph->inputs()) {
auto n = ctx.graph->addInput()->copyMetadata(input);
n->setStage(input->stage());
env[input] = n;
}
for (auto kv : state->buffer_map) {
new_buffer_map[kv.first] = envFn(kv.second);
}
// Put the new outputs in our environment map, and copy the type from the
// input graph if they were not set by the symbolic. This is called only
// with results of symbolic call (not for nodes that are just cloned).
auto setOutputs = [&](const std::string& op_name, Node * node, const value_list & outputs) {
auto old_outputs = node->outputs();
// Count all outputs, excluding Handles
bool has_handle = hasHandleOutput(node);
auto num_old_outputs = old_outputs.size() - (has_handle ? 1 : 0);
if (outputs.size() != num_old_outputs) {
std::ostringstream ss;
ss << "symbolic for " << op_name << " produced an incorrect number of outputs (expected ";
ss << num_old_outputs << ", but got " << outputs.size() << ")";
throw std::runtime_error(ss.str());
}
for (std::size_t i = 0; i < num_old_outputs; ++i) {
auto old = old_outputs[i];
if (outputs[i]) {
// Allow symbolic() to skip specifying the type of the return node.
// Unfortunately, they are on the hook for all internal nodes
// (though in practice, the types are not computed.)
if (!outputs[i]->hasType()) {
outputs[i]->setType(old->typeOption());
}
// Copy over source location information to all nodes created by
// the symbolic
outputs[i]->node()->setSourceLocation(node->getSourceLocation());
env[old] = outputs[i];
} else {
// Null output means that the ONNX op doesn't have outputs corresponding
// to certain PyTorch outputs
env[old] = nullptr;
if (!old->uses().empty()) {
std::ostringstream ss;
ss << "symbolic for " << op_name << " returned None for the output " << i;
ss << " (indicating conversion for that particular output is not supported), ";
ss << "but the network uses this output later";
// TODO: Say what actually used it
throw std::runtime_error(ss.str());
}
}
}
if (has_handle) {
JIT_ASSERT(old_outputs.back()->uses().empty());
env[old_outputs.back()] = nullptr;
}
};
// Clone the node and add it to the new graph
auto cloneNode = [&](Node * node) {
auto n_ = ctx.graph->appendNode(ctx.graph->createClone(node, envFn));
for(size_t i = 0; i < node->outputs().size(); i++) {
env[node->outputs()[i]] = n_->outputs()[i];
}
};
// Cast output of symbolic() python implementation
auto processSymbolicOutput = [&](const std::string& op_name, Node* n, const py::object& raw_output) {
if (raw_output.ptr() == Py_None) {
cloneNode(n);
return;
}
// Cast the outputs back to C++ and put them in the new graph
std::vector<Value*> outputs;
try {
if (py::isinstance<Value>(raw_output)) {
outputs = value_list{py::cast<Value*>(raw_output)};
} else {
outputs = py::cast<std::vector<Value*>>(raw_output);
}
} catch (const std::exception& ex) {
std::ostringstream ss;
ss << "Error casting results of symbolic for " << op_name
<< ": expected to return list of op nodes, instead received type ''"
<< py::str(raw_output.get_type()) << "': " << py::str(raw_output);
throw std::runtime_error(ss.str());
}
setOutputs(op_name, n, outputs);
};
auto callPySymbolicFunction = [&](Node* n) {
// The idea is delegate as much of the actual argument massaging to
// Python as possible
py::tuple py_inputs(n->inputs().size());
Py_ssize_t input_nr = 0;
for (auto* input : n->inputs()) {
py_inputs[input_nr++] = py::cast(envFn(input));
}
auto scope_guard = ctx.graph->set_current_scope_temporary(n->scope());
py::object raw_output = onnx.attr("_run_symbolic_function")(ctx.graph, n, py_inputs, aten);
processSymbolicOutput(symbolToString(n->kind()), n, raw_output);
};
auto callPySymbolicMethod = [&](PythonOp* op) {
// Test if there is a symbolic function; bail if there is not
auto pyobj = py::handle(op->pyobj.get());
if (!py::hasattr(pyobj, "symbolic")) {
cloneNode(op);
return;
}
// Prepare args for Python. First one is the graph, and is followed
// by regular args, with Variables replaced by corresponding nodes.
Py_ssize_t input_nr = 0;
py::tuple py_symbolic_args(1 + op->cconv.size());
py_symbolic_args[input_nr++] = py::cast(ctx.graph);
auto inputs = op->inputs();
auto node_it = inputs.begin();
auto scalar_it = op->scalar_args.begin();
for (auto arg_type : op->cconv) {
py::object obj;
if (arg_type == 's') {
JIT_ASSERTM(scalar_it != op->scalar_args.end(), "expected too many scalar args");
obj = py::reinterpret_borrow<py::object>(py::handle((scalar_it++)->get()));
} else if (arg_type == 't') {
JIT_ASSERTM(node_it != inputs.end(), "expected too many inputs");
obj = py::cast(envFn(*node_it++));
} else {
throw std::runtime_error("unexpected calling convention");
}
py_symbolic_args[input_nr++] = obj;
}
// Call the symbolic function
// Use a little trampoline function so we can give good error messages
// upon argument mismatch
py::object raw_output = onnx.attr("_run_symbolic_method")(op->name(), pyobj.attr("symbolic"), py_symbolic_args);
processSymbolicOutput(op->name(), op, raw_output);
};
// Finally, visit all nodes in the graph
for (auto node : state->graph->nodes()) {
if (hasUsedHandle(node)) {
// Nothing we can do here. The handle is used, so we'll need to capture the
// original state and can't do anything with this op (we don't know what the
// backward is).
cloneNode(node);
continue;
}
// Needed so that symbolic calls create nodes with correct stages.
auto stage_guard = new_graph->setStageTemporary(node->stage());
IR_IFM(node, CppOp)
cloneNode(node);
IR_ELSEIFM(PythonOp)
callPySymbolicMethod(value);
IR_ELSE()
if (node->kind() == kUndefined) {
// Undefined nodes get passed into Convolution, but then they are
// removed. We'll test for leftover Undefined in export.cpp
cloneNode(node);
} else {
callPySymbolicFunction(node);
}
IR_END()
}
for (auto output : state->graph->outputs()) {
ctx.graph->registerOutput(env.at(output));
}
// Copy stage from original graph
new_graph->setStage(state->graph->stage());
state->graph = std::move(new_graph);
state->buffer_map = std::move(new_buffer_map);
}
}}