blob: 2ab7303130c16834cf0fc43286facb4da05856ae [file] [log] [blame]
#include "torch/csrc/jit/graph_fuser.h"
#include <unordered_map>
namespace torch { namespace jit {
std::unordered_set<NodeKind> simple_mappable = {
kSigmoid,
kTanh,
kMul,
kAdd,
kNegate
};
bool isSimpleMap(Node *node) {
return simple_mappable.count(node->kind());
}
struct GraphFuser {
std::shared_ptr<Graph>& graph;
// Used to order nodes so we alway consider producer-consumer fusions
// in reverse topological order.
// If topological_index[a] > topological_index[b] then a occurs after b.
// Because nodes can be added to this graph during optimization, this mapping is not bijective.
// Newly generated nodes will copy the location where they are inserted.
std::unordered_map<Node*,size_t> topological_index;
GraphFuser(std::shared_ptr<Graph>& graph)
: graph(graph) {}
bool isFusable(Node * node) {
return isSimpleMap(node) || node->kind() == kFusionGroup;
}
// necessary condition for fusion. If all of the uses of producer are consumer
// then it is safe to merge producer into consumer, because it doesn't have any other uses
// If there are other uses, but they occur _after_ consumer, then we can still merge in producer
// with consumer, by rewriting those later uses to use the version of producer generated by the fused blob
// In this case, producer becomes an output of the fusion group.
bool allUsersAreThisConsumerOrOccurAfterIt(Node * consumer, Node * producer) {
for(auto u : producer->uses()) {
if(u.user != consumer && topological_index[consumer] > topological_index[u.user])
return false;
}
return true;
}
bool allUsersAreThisConsumer(Node * consumer, Node * producer) {
for(auto u : producer->uses()) {
if(u.user != consumer)
return false;
}
return true;
}
bool shouldFuse(Node * consumer, Node * producer) {
// this handles cases where producer can be moved _into_ the fusion group of consumer.
// TODO: extend to fusion of consumer into _producer's_ fusion blob
// if the consumer allInputsAreThisProducer(consumer,producer)
// we can move the consumer up into the producer.
// but this requires better handling of merging fusion groups so it is not done now
return isFusable(producer) && allUsersAreThisConsumerOrOccurAfterIt(consumer, producer);
}
// insert a producer node into a consuming fusion group.
// DOES NOT WORK if n is a consumer of an output of the fusion group
// returns the node _inside_ the group that represents the node
Graph & getSubgraph(Node * n) {
JIT_ASSERT(n->kind() == kFusionGroup);
return *n->g(kSubgraph);
}
Node * mergeNodeIntoGroup(Node* group, Node * n) {
auto & subgraph = getSubgraph(group);
auto & inputs = group->inputs();
// map from nodes in the surrounding graph to parameters in the fusion
// group's subgraph that correspond to them
std::unordered_map<Node*,Node*> inputs_map;
size_t i = 0;
JIT_ASSERT(group->inputs().size() == subgraph.inputs().size());
for(auto input : group->inputs()) {
inputs_map[input] = subgraph.inputs()[i++];
}
// add n's inputs to the fusion group's input list if we don't already have them
for(auto input : n->inputs()) {
if(inputs_map.count(input) == 0) {
auto in_group = subgraph.addInput();
in_group->setType(input->typeOption());
inputs_map[input] = in_group;
group->addInput(input);
}
}
// copy n into the graph, remapping its inputs to internal nodes
Node * in_graph = subgraph.createClone(n,[&](Node * k)-> Node* {
return inputs_map[k];
});
// if n is already an input to the fusion group,
// we need to remove it because n is now inside the fusion group
// remapping nodes that used the input to the newly-merged node
// n is not an input when the fusion group is empty
auto it = std::find(inputs.begin(), inputs.end(), n);
if(it != inputs.end()) {
size_t p = it - inputs.begin();
group->removeInput(p);
subgraph.inputs()[p]->replaceAllUsesWith(in_graph);
subgraph.eraseInput(p);
}
return subgraph.prependNode(in_graph);
}
// turn consumer node n into a fusion group with just n inside
// to prepare for fusion and replace uses of n with the new group
Node * createSingletonFusionGroup(Node * n) {
auto group = graph->createFusionGroup();
// propogate position information for the new node so we can always
// have a valid mapping
topological_index[group] = topological_index[n];
group->insertBefore(n);
Node * mergedNode = mergeNodeIntoGroup(group,n);
getSubgraph(group).registerOutput(mergedNode);
auto sel = graph->createSelect(group,0);
sel->setType(n->typeOption());
sel->insertAfter(group);
n->replaceAllUsesWith(sel);
n->destroy();
return group;
}
void insertAfter(Node * n, Node * after) {
n->insertAfter(after);
topological_index[n] = topological_index[after];
}
Node * fuse(Node * consumer, Node * producer) {
auto group = consumer;
if(group->kind() != kFusionGroup) {
group = createSingletonFusionGroup(consumer);
}
Node * merged = mergeNodeIntoGroup(group, producer);
// remaining uses of this producer can occur because we allow
// fusion in cases where uses remain after the consumer
// if these exist, re-route them to the version of producer
// created in FusionGroup
if(producer->uses().size() != 0) {
size_t offset = getSubgraph(group).registerOutput(merged);
Node * new_producer = graph->createSelect(group,offset);
new_producer->setType(producer->typeOption());
insertAfter(new_producer, group);
producer->replaceAllUsesWith(new_producer);
}
producer->destroy();
return group;
}
// in places where op can be fused into a consumer but chunk is in the way
// distribute chunk to op's operands:
// replace a,b = chunk(op(x,y,z)) with:
// x0,x1 = chunk(x) (x0 has a's type, x1 has b's type)
// y0,y1 = chunk(y) (y0 has a's type, y1 has b's type)
// z0,z1 = chunk(z) (z0 has a's type, z1 has b's type)
// a = op(x0,y0,z0) (a,b have their same size but are now contiguous)
// b = op(x1,y1,x1)
bool tryToMoveChunk(Node * consumer, Node * producer) {
// if we are fusing a select,
if(producer->kind() != kSelect)
return false;
// and the select refers to a chunk,
auto * chunk = producer->input();
if(chunk->kind() != kChunk)
return false;
// and the thing being chunked is fusable into the consumer
Node * producer_for_chunk = chunk->input();
if(!isFusable(producer_for_chunk) || !allUsersAreThisConsumer(chunk,producer_for_chunk))
return false;
// and all uses of the chunk are in this consumer
for(auto s : chunk->uses()) {
for(auto u : s.user->uses()) {
if(u.user != consumer)
return false;
}
}
std::vector<Node*> chunks;
for(auto input : producer_for_chunk->inputs()) {
Node * c = graph->createChunk(input,chunk->i(kNumChunks),chunk->i(kDim));
insertAfter(c,chunk);
chunks.push_back(c);
}
//as we replace/remove the selects the use list changes, so copy it first
use_list copy_uses = chunk->uses();
for(auto s : copy_uses) {
Node* sel = s.user;
size_t i = sel->offset();
size_t j = 0;
Node * new_output = graph->createClone(producer_for_chunk,[&](Node * n) {
auto & c = chunks[j++];
Node * ns = graph->createSelect(c,i);
ns->setType(sel->typeOption());
insertAfter(ns,c);
return ns;
});
if(sel->hasType()) {
new_output->setType(sel->type()->cast<TensorType>()->contiguous());
}
insertAfter(new_output,s.user);
s.user->replaceAllUsesWith(new_output);
s.user->destroy();
}
chunk->destroy();
producer_for_chunk->destroy();
return true;
}
// returns where to continue scanning
graph_node_list::iterator scanNode(Node * consumer) {
if(isFusable(consumer)) {
// handle inputs in reverse topological order as well...
// otherwise in f(a,a+b) it will appear a is used twice if we consider
// the f-a fusion before the f-(a+b) fusion first.
node_list inputs = consumer->inputs();
for(auto i : inputs) {
JIT_ASSERT(topological_index.count(i) > 0);
}
std::sort(inputs.begin(), inputs.end(), [&](Node * a, Node * b) {
return topological_index[a] > topological_index[b];
});
for(auto producer : inputs) {
if(tryToMoveChunk(consumer,producer)) {
// the chunk before this consumer was re-arranged to allow fusion,
// we scan this consumer again to perform the fusion
return consumer->reverseIterator();
}
if(shouldFuse(consumer, producer)) {
auto fusion_group = fuse(consumer,producer);
// after fusion, consumer moves into a FusionGroup, so inputs is no longer valid
// so we rescan the new FusionGroup for more fusions...
return fusion_group->reverseIterator();
}
}
}
return ++consumer->reverseIterator();
}
void run() {
size_t i = 0;
for(auto p : graph->inputs()) {
topological_index[p] = i++;
}
auto nodes = graph->nodes();
for(auto consumer : nodes) {
topological_index[consumer] = i++;
}
topological_index[graph->return_node()] = i++;
for(auto it = nodes.rbegin(); it != nodes.rend();) {
it = scanNode(*it);
}
}
};
void FuseGraph(std::shared_ptr<Graph>& graph) {
GraphFuser(graph).run();
}
}}