python clone, more asserts, better names.
diff --git a/torch/csrc/jit/graph_fuser.cpp b/torch/csrc/jit/graph_fuser.cpp
index 8a60131..77cd09c 100644
--- a/torch/csrc/jit/graph_fuser.cpp
+++ b/torch/csrc/jit/graph_fuser.cpp
@@ -15,10 +15,13 @@
std::unique_ptr<Graph> graph;
// used to order nodes so we alway consider producer-consumer fusions
// in reverse topological order
- std::unordered_map<Node*,size_t> original_topological_position;
+ std::unordered_map<Node*,size_t> topological_index;
GraphFuser(std::unique_ptr<Graph> graph)
: graph(std::move(graph)) {}
+
+ // the tracer should handle this conversion, and then this code should
+ // be deleted
void replacePythonOps() {
auto nodes = graph->nodes();
for(auto it = nodes.begin(), end = nodes.end(); it != end; ++it) {
@@ -30,6 +33,7 @@
new_op->insertAfter(p);
JIT_ASSERT(1 == p->uses().size());
auto single_select = p->uses()[0].user;
+ JIT_ASSERT(single_select->kind() == NodeKind::Select);
single_select->replaceAllUsesWith(new_op);
single_select->eraseFromParent();
//erasing p directly would invalidate iterator
@@ -51,7 +55,7 @@
// In this case, producer becomes an output of the fusion group.
bool allUsersAreThisConsumerOrOccurAfterIt(Node * consumer, Node * producer) {
for(auto u : producer->uses()) {
- if(u.user != consumer && original_topological_position[consumer] > original_topological_position[u.user])
+ if(u.user != consumer && topological_index[consumer] > topological_index[u.user])
return false;
}
return true;
@@ -75,6 +79,7 @@
// group's subgraph that correspond to them
std::unordered_map<Node*,Node*> inputs_map;
size_t i = 0;
+ JIT_ASSERT(group->inputs().size() == subgraph.inputs().size());
for(auto input : group->inputs()) {
inputs_map[input] = subgraph.inputs()[i++];
}
@@ -108,7 +113,7 @@
auto group = graph->create<FusionGroup>();
// propogate position information for the new node so we can always
// have a valid mapping
- original_topological_position[group] = original_topological_position[n];
+ topological_index[group] = topological_index[n];
group->insertBefore(n);
Node * mergedNode = mergeNodeIntoGroup(group,n);
group->subgraph().registerOutput(mergedNode);
@@ -148,10 +153,10 @@
// the f-a fusion before the f-(a+b) fusion first.
node_list inputs = consumer->inputs();
for(auto i : inputs) {
- JIT_ASSERT(original_topological_position.count(i) > 0);
+ JIT_ASSERT(topological_index.count(i) > 0);
}
std::sort(inputs.begin(),inputs.end(),[&](Node * a, Node * b) {
- return original_topological_position[a] > original_topological_position[b];
+ return topological_index[a] > topological_index[b];
});
for(auto producer : inputs) {
if(shouldFuse(consumer, producer)) {
@@ -169,10 +174,10 @@
replacePythonOps();
size_t i = 0;
for(auto p : graph->inputs()) {
- original_topological_position[p] = i++;
+ topological_index[p] = i++;
}
for(auto consumer : graph->nodes()) {
- original_topological_position[consumer] = i++;
+ topological_index[consumer] = i++;
}
auto reversed = graph->nodes().reverse();
for(auto it = reversed.begin(), end = reversed.end(); it != end;) {
diff --git a/torch/csrc/jit/ir.h b/torch/csrc/jit/ir.h
index 187e320..ae9c244 100644
--- a/torch/csrc/jit/ir.h
+++ b/torch/csrc/jit/ir.h
@@ -143,7 +143,7 @@
// for example, if you are adding nodes to the end of the topsort, it's
// impossible for them to refer to inputs that are not in the topsort.
// If it is not obvious, please comment accordingly.
-
+
// Add 'node' as an input to 'this' at the end of existing
// arguments. Returns the added node for ease of chaining.
//
@@ -191,7 +191,7 @@
const use_list & uses() {
return uses_;
}
-
+
// Replaces all uses of this node with 'newValue'.
//
// Given: %3 = f(%1, %2)
@@ -220,10 +220,10 @@
// %5 = h(%1)
// %4 = g(%3)
void insertBefore(Node * n) {
- JIT_ASSERT(n->inGraphList());
+ JIT_ASSERT(n->inGraphList()&& !this->inGraphList());
insertAfter(n->prev());
}
-
+
// Insert unattached 'this' node after 'n' in the topological order.
//
// Given: %3 = f(%1, %2)
@@ -249,13 +249,13 @@
// Execute: %2.moveAfter(%3)
// Result: %3 = g(%1)
// %2 = f(%1)
- //
+ //
void moveAfter(Node * n) {
JIT_ASSERT(inGraphList());
removeFromList();
insertAfter(n);
}
-
+
// Move a node 'n' (already in the graph) before 'this' in the topological order.
//
// Given: %2 = f(%1)
@@ -268,7 +268,7 @@
removeFromList();
insertBefore(n);
}
-
+
// Remove the input at 'i' from this node.
//
// WARNING: This is O(n) in the number of inputs, so avoid repeatedly calling
@@ -287,7 +287,7 @@
}
inputs_.erase(inputs_.begin() + i);
}
-
+
// Remove all inputs from a node.
//
// Given: %3 = f(%1, %2)
@@ -351,7 +351,8 @@
return old_node;
}
bool inGraphList() {
- return next() && prev();
+ JIT_ASSERT(next() != nullptr || prev() == nullptr);
+ return next() != nullptr;
}
void removeFromList() {
JIT_ASSERT(inGraphList());
@@ -693,7 +694,14 @@
this->is_legacy = is_legacy;
}
virtual void cloneFrom(PythonOp * other) override {
- throw std::runtime_error("cannot clone PythonOp because of THPObjectPtr");
+ this->cconv = cconv;
+ this->is_legacy = is_legacy;
+ Py_INCREF(other->pyobj.get());
+ this->pyobj = THPObjectPtr(other->pyobj.get());
+ for(auto & sa : other->scalar_args) {
+ Py_INCREF(sa.get());
+ this->scalar_args.emplace_back(sa.get());
+ }
}
};