Add Node type to JIT IR
Rewrite Type as a class hierarchy
PR comments + rebase fixes
diff --git a/torch/csrc/autograd/functions/basic_ops.cpp b/torch/csrc/autograd/functions/basic_ops.cpp
index 2588856..562d2ea 100644
--- a/torch/csrc/autograd/functions/basic_ops.cpp
+++ b/torch/csrc/autograd/functions/basic_ops.cpp
@@ -47,10 +47,9 @@
check_input_variables("Mul", inputs, 2);
auto& input1 = inputs[0]->data;
auto& input2 = inputs[1]->data;
- AutoGPU guard(input1->getDevice());
+ AutoGPU guard(input1.type().isCuda() ? input1.get_device() : -1);
- auto output = input1->newTensor();
- output->cmul(*input1, *input2);
+ auto output = input1 * input2;
return wrap_outputs(inputs, as_tensor_list(std::move(output)), [&](FunctionFlags f) {
return std::make_shared<MulBackward>(std::move(f), inputs[0]->save(this), inputs[1]->save(this));
diff --git a/torch/csrc/autograd/python_function.cpp b/torch/csrc/autograd/python_function.cpp
index 192b6d4..59e82ff 100644
--- a/torch/csrc/autograd/python_function.cpp
+++ b/torch/csrc/autograd/python_function.cpp
@@ -453,6 +453,7 @@
// output, but Python nodes can't be optimized away, so we simplify the
// code here.
Node* sel = GlobalTracingState.current().appendNewNode<Select>(this_expr, i);
+ sel->inferTypeFrom(output_var->cdata->data);
GlobalTracingState.setValueTrace(output_var->cdata.get(), sel);
}
output_var->cdata->output_nr = i;
diff --git a/torch/csrc/jit/ir.h b/torch/csrc/jit/ir.h
index 01fcf05..2987081 100644
--- a/torch/csrc/jit/ir.h
+++ b/torch/csrc/jit/ir.h
@@ -11,6 +11,9 @@
#include <algorithm>
#include <unordered_set>
#include <list>
+#include <cstdint>
+
+#include <ATen/ATen.h>
#include "torch/csrc/utils/object_ptr.h"
@@ -30,7 +33,109 @@
// and dependencies on a list of values. The "prim-ops", so to speak.
struct Node;
-struct Type {}; // we will need a type, but for now it does nothing...
+#define TH_FORALL_TYPES(_) \
+_(Multi) \
+_(Single)
+
+enum class TypeKind {
+#define DEFINE_TYPE(T) T,
+ TH_FORALL_TYPES(DEFINE_TYPE)
+#undef DEFINE_TYPE
+};
+
+struct Type {
+private:
+ TypeKind kind_;
+
+protected:
+ Type(TypeKind kind)
+ : kind_(kind) {}
+
+public:
+ TypeKind kind() const {
+ return kind_;
+ }
+
+ // Dynamically cast this object to the subclass indicated by the
+ // template variable, returning nullptr if the cast is invalid..
+ template<typename T>
+ T* cast() {
+ if (T::Kind == kind())
+ return static_cast<T*>(this);
+ return nullptr;
+ }
+
+ std::unique_ptr<Type> clone();
+
+ static std::unique_ptr<Type> newWithKind(TypeKind kind);
+};
+
+// Type of nodes with a single output
+struct TypeSingle : public Type {
+private:
+ friend struct Type;
+
+ std::vector<std::int64_t> sizes_;
+ std::vector<std::int64_t> strides_;
+
+ TypeSingle()
+ : Type(TypeKind::Single) {}
+
+public:
+ static const TypeKind Kind = TypeKind::Single;
+
+ void inferFrom(const at::Tensor& tensor) {
+ auto ndim = tensor.dim();
+ sizes_.resize(ndim);
+ strides_.resize(ndim);
+ // NOTE: This is not memcpy! These are assignments.
+ std::copy(tensor.sizes().begin(), tensor.sizes().end(), sizes_.begin());
+ std::copy(tensor.strides().begin(), tensor.strides().end(), strides_.begin());
+ }
+
+ const std::vector<std::int64_t>& sizes() {
+ return sizes_;
+ }
+ const std::vector<std::int64_t>& strides() {
+ return strides_;
+ }
+};
+
+// Type of multireturn nodes. Note that it doesn't mean that they must always
+// have multiple outputs.
+struct TypeMulti : public Type {
+private:
+ friend struct Type;
+
+ TypeMulti()
+ : Type(TypeKind::Multi) {}
+
+public:
+ static const TypeKind Kind = TypeKind::Multi;
+};
+
+inline std::unique_ptr<Type> Type::newWithKind(TypeKind kind) {
+ switch (kind) {
+#define HANDLE_KIND(KIND) \
+ case TypeKind::KIND: \
+ return std::unique_ptr<Type>(static_cast<Type*>(new Type##KIND()));
+ TH_FORALL_TYPES(HANDLE_KIND)
+ }
+#undef HANDLE_KIND
+ __builtin_unreachable();
+}
+
+inline std::unique_ptr<Type> Type::clone() {
+#define HANDLE_KIND(KIND) \
+ case TypeKind::KIND: \
+ return std::unique_ptr<Type>(static_cast<Type*>( \
+ new Type##KIND(*(static_cast<Type##KIND*>(this)))));
+ switch (kind_) {
+ TH_FORALL_TYPES(HANDLE_KIND)
+ }
+#undef HANDLE_KIND
+ __builtin_unreachable();
+}
// Each use is represented by this type, see Node::uses()
// 'user' is the consumer of the node, offset is the index into
@@ -89,21 +194,26 @@
graph_node_list::iterator prev() { return std::prev(nodes_iter_); }
const NodeKind kind_;
- Type * type_;
std::vector<Node*> inputs_;
use_list uses_;
Graph* graph_;
size_t unique_ = 0; // unique id
size_t stage_ = 0; // 0-forward, 1-backward, 2-double-backward,...
protected:
- Node(NodeKind kind_)
- : kind_(kind_), type_(nullptr) {}
+ std::unique_ptr<Type> type_;
+ Node(NodeKind kind_, TypeKind type_kind)
+ : kind_(kind_), type_(Type::newWithKind(type_kind)) {}
public:
NodeKind kind() {
return kind_;
}
- Type * type() {
- return type_;
+ const Type* type() {
+ return type_.get();
+ }
+ void inferTypeFrom(const at::Tensor& output) {
+ auto single_type = type_->cast<TypeSingle>();
+ JIT_ASSERT(single_type);
+ single_type->inferFrom(output);
}
Graph * owningGraph() {
return graph_;
@@ -358,12 +468,12 @@
// using CRTP so that we can alloc new clones and dispatch to custom clone code
// without significant boilerplate
-template<typename Self, NodeKind K>
+template<typename Self, NodeKind K, TypeKind T>
struct NodeWithKind : public Node {
friend struct Graph; /* so it can access allocClone() */
static const NodeKind Kind = K;
NodeWithKind()
- : Node(K) {}
+ : Node(K, T) {}
// virtual so we can easily define a default here
// defined using CRTP so cloneFrom doesn't need casts.
// called from allocClone
@@ -380,7 +490,7 @@
// helper to define simple primitive Ops.
template<typename Self, NodeKind K>
-struct Primitive : public NodeWithKind<Self,K> {
+struct Primitive : public NodeWithKind<Self, K, TypeKind::Single> {
void init() {}
void init(ArrayRef<Node*> inputs) {
for(auto i : inputs)
@@ -388,20 +498,19 @@
}
};
-
// the outputs of the Graph are represented as an Node so that its inputs
// can be tracked as Uses.
-struct Return : public Primitive<Return,NodeKind::Return> {};
+struct Return : public Primitive<Return, NodeKind::Return> {};
// an input tensor to the graph
-struct Param : public NodeWithKind<Param,NodeKind::Param> {
+struct Param : public NodeWithKind<Param, NodeKind::Param, TypeKind::Single> {
void init() {}
};
struct Graph {
TH_DISALLOW_COPY_AND_ASSIGN(Graph);
friend struct Node;
-template<typename Self, NodeKind K>
+template<typename Self, NodeKind K, TypeKind T>
friend struct NodeWithKind;
private:
param_list inputs_;
@@ -551,9 +660,10 @@
graph_->freeNode(this);
}
-template<typename Self, NodeKind K>
-Node * NodeWithKind<Self,K>::allocClone(Graph * in_graph) {
+template<typename Self, NodeKind K, TypeKind T>
+Node * NodeWithKind<Self,K,T>::allocClone(Graph * in_graph) {
auto s = new Self();
+ s->type_ = this->type_->clone();
in_graph->initNewNodeForGraph(s);
// static cast is needed because the compiler doesn't know NodeWithKind is a CRTP.
s->cloneFrom(static_cast<Self*>(this));
@@ -611,7 +721,7 @@
/************* All nodes not required to be defined before Graph **************/
// execute a Python function, used for Ops we can't optimize but that we want to optimize around
-struct PythonOp : public NodeWithKind<PythonOp,NodeKind::PythonOp> {
+struct PythonOp : public NodeWithKind<PythonOp,NodeKind::PythonOp,TypeKind::Multi> {
//TODO: make this non-autograd specific
//remove is_legacy, avoid THPObjectPtr to avoid big PyTorch dependency
@@ -653,7 +763,7 @@
// in this case
// number_of_outputs = op.uses().size()
// this will change if Tuples ever become first class.
-struct Select : public NodeWithKind<Select,NodeKind::Select> {
+struct Select : public NodeWithKind<Select,NodeKind::Select,TypeKind::Single> {
void init(Node * node, size_t offset) {
addInput(node);
this->offset_ = offset;
@@ -682,7 +792,7 @@
struct Sigmoid : public Primitive<Sigmoid,NodeKind::Sigmoid> {};
struct Tanh : public Primitive<Tanh,NodeKind::Tanh> {};
-struct FusionGroup : public NodeWithKind<FusionGroup,NodeKind::FusionGroup> {
+struct FusionGroup : public NodeWithKind<FusionGroup,NodeKind::FusionGroup,TypeKind::Multi> {
void init() {
subgraph_ = std::make_shared<Graph>();
}