|  | #include <torch/csrc/jit/passes/inliner.h> | 
|  |  | 
|  | #include <ATen/core/interned_strings.h> | 
|  | #include <torch/csrc/jit/api/function_impl.h> | 
|  | #include <torch/csrc/jit/api/module.h> | 
|  | #include <torch/csrc/jit/frontend/error_report.h> | 
|  | #include <torch/csrc/jit/jit_log.h> | 
|  |  | 
|  | namespace torch { | 
|  | namespace jit { | 
|  |  | 
|  | namespace prim { | 
|  | using namespace ::c10::prim; | 
|  | } | 
|  |  | 
|  | GraphFunction* tryToGraphFunction(Node* n) { | 
|  | if (n->kind() == prim::CallFunction) { | 
|  | AT_ASSERT(n->input(0)->node()->kind() == prim::Constant); | 
|  | auto function_constant = n->input(0)->node(); | 
|  | auto fun_type = function_constant->output()->type()->expect<FunctionType>(); | 
|  | return tryToGraphFunction(*fun_type->function()); | 
|  | } | 
|  | if (n->kind() == prim::CallMethod) { | 
|  | const std::string& name = n->s(attr::name); | 
|  | if (auto class_type = n->input(0)->type()->cast<ClassType>()) { | 
|  | Function& function = class_type->getMethod(name); | 
|  | return tryToGraphFunction(function); | 
|  | } | 
|  | } | 
|  | return nullptr; | 
|  | } | 
|  |  | 
|  | static void inlineCalls(Block* block) { | 
|  | for (auto it = block->nodes().begin(), end = block->nodes().end(); | 
|  | it != end;) { | 
|  | Node* cur = *it++; | 
|  | switch (cur->kind()) { | 
|  | case prim::CallFunction: { | 
|  | if (auto graphFunction = tryToGraphFunction(cur)) { | 
|  | auto function_constant = cur->input(0)->node(); | 
|  | auto fun_type = | 
|  | function_constant->output()->type()->expect<FunctionType>(); | 
|  |  | 
|  | cur->removeInput(0); | 
|  | GRAPH_UPDATE( | 
|  | "Inlining function '", | 
|  | fun_type->function()->name(), | 
|  | "' to ", | 
|  | *cur); | 
|  |  | 
|  | std::shared_ptr<Graph> g = nullptr; | 
|  | // inline optimized graph for debugging/testing purposes. | 
|  | // we only insert fallback functions in JIT optimized graphs for | 
|  | // execution, not on the Graph that is used for serialization | 
|  | bool fallback = | 
|  | function_constant->hasAttribute(Symbol::attr("fallback")); | 
|  | if (fallback && graphFunction->get_executor().isOptimized()) { | 
|  | auto exec_plans = | 
|  | graphFunction->get_executor().getDebugState().execution_plans; | 
|  | if (!exec_plans.empty()) { | 
|  | g = exec_plans.begin()->second.graph; | 
|  | // optimized_graph() calls Inline, so we only need to explicitly | 
|  | // invoke inlining on the jit optimized graph with recursive | 
|  | // fallback function calls | 
|  | Inline(*g.get()); | 
|  | } | 
|  | } | 
|  | if (g == nullptr) { | 
|  | g = graphFunction->optimized_graph(); | 
|  | } | 
|  |  | 
|  | GRAPH_UPDATE("Function body: ", g); | 
|  | inlineCallTo(cur, graphFunction, g.get()); | 
|  | } | 
|  | } break; | 
|  | case prim::CallMethod: { | 
|  | if (auto graphFunction = tryToGraphFunction(cur)) { | 
|  | GRAPH_UPDATE("Inlining method '", cur->s(attr::name), "' to ", *cur); | 
|  | GRAPH_UPDATE("Function body: ", graphFunction->optimized_graph()); | 
|  | inlineCallTo(cur, graphFunction); | 
|  | } | 
|  | } break; | 
|  | default: { | 
|  | for (auto b : cur->blocks()) { | 
|  | inlineCalls(b); | 
|  | } | 
|  | } break; | 
|  | } | 
|  | } | 
|  | } | 
|  |  | 
|  | void Inline(Graph& graph) { | 
|  | GRAPH_DUMP("Before Inlining: ", &graph); | 
|  | inlineCalls(graph.block()); | 
|  | GRAPH_DUMP("After Inlining: ", &graph); | 
|  | } | 
|  |  | 
|  | } // namespace jit | 
|  | } // namespace torch |