Add a tagged union type that replaces tensor in the interpreter. (#9368)

Summary:
IValue is short for interpreter value. It is used frequently so a short name is important.
This will allow us to implement more non-tensor types in an efficient way and remove
many hacks from the compiler.

This PR is limited. It only introduces IValue and changes interpreter to use it.
Follow up PRs will:
* Change the way aten_ops consume non-tensor types so that integer lists,
  are no longer represented as Tensors.
* Introduce TensorList as a fundamental type and remove all vararg handling in gen_jit_dispatch
* Change the compiler to implement math on primitive numbers rather than converting to tensors.

jamesr66a  apaszke
Pull Request resolved: https://github.com/pytorch/pytorch/pull/9368

Reviewed By: ezyang

Differential Revision: D8817598

Pulled By: zdevito

fbshipit-source-id: 29dce80611ce5f6384234de9d12a67861d2b112f
diff --git a/tools/jit/gen_jit_dispatch.py b/tools/jit/gen_jit_dispatch.py
index 6fc454c..d6458f9 100644
--- a/tools/jit/gen_jit_dispatch.py
+++ b/tools/jit/gen_jit_dispatch.py
@@ -88,9 +88,10 @@
 # map from aten 'simple_type' to the function that will turn a tensor into
 # that type
 FROM_TENSOR = {
-    'Device': 'tensor_as<IntList>',
+    'Device': 'tensor_as<std::vector<int64_t>>',
     'ScalarType': 'tensor_as<int64_t>',
     'Layout': 'tensor_as<int64_t>',
+    'IntList': 'tensor_as<std::vector<int64_t>>',
 }
 
 
@@ -107,7 +108,7 @@
 """)
 
 POS_ASSIGNMENT = CodeTemplate("""\
-auto ${name} = ${from_tensor}(std::move(peek(stack, ${i}, ${N})));\
+auto ${name} = ${from_tensor}(std::move(peek(stack, ${i}, ${N})).toTensor());\
 """)
 
 CALL_NAMESPACE = CodeTemplate("""\
@@ -261,12 +262,12 @@
                 # NOTE: don't advance real_inputs here. After this we are going
                 # to switch over to indexing from the end as if we only had
                 # the static arguments.
-                arguments.append('peekSlice(stack, {}, varargs_length - {}, varargs_length)'
+                arguments.append('toTensors(peekSlice(stack, {}, varargs_length - {}, varargs_length))'
                                  .format(real_inputs, static_inputs))
             elif arg['simple_type'] in default_only_types:
                 arguments.append(arg['default'])
             elif is_tensor_arg(arg):
-                arguments.append('std::move(peek(stack, {}, {}))'.format(real_inputs, view_length))
+                arguments.append('std::move(peek(stack, {}, {})).toTensor()'.format(real_inputs, view_length))
                 real_inputs += 1
             elif is_positional_arg[i]:
                 template_kwargs = dict(from_tensor=from_tensor(arg),
diff --git a/tools/jit/templates/register_aten_ops.cpp b/tools/jit/templates/register_aten_ops.cpp
index 4cb7fba..2f4d055 100644
--- a/tools/jit/templates/register_aten_ops.cpp
+++ b/tools/jit/templates/register_aten_ops.cpp
@@ -29,7 +29,6 @@
 using autograd::variable_list;
 using at::Scalar;
 using at::Tensor;
-using at::IntList;
 using at::TensorList;
 using at::TensorOptions;
 using at::DeviceGuard;
@@ -39,10 +38,16 @@
 int deviceForInputs(Stack & stack, size_t N) {
   if(N == 0)
     return -1;
-  auto & t = *(stack.end() - N);
+  auto t = (stack.end() - N)->toTensor();
   return t.type().is_cuda() ? (int) t.get_device() : -1;
 }
 
+std::vector<at::Tensor> toTensors(at::ArrayRef<IValue> ivalues) {
+  return fmap(ivalues, [](const IValue& v) {
+    return v.toTensor();
+  });
+}
+
 template<size_t N>
 std::array<bool, N> as_bool_array(const std::vector<int64_t>& vec) {
   std::array<bool, N> res;
diff --git a/torch/csrc/jit/autodiff.cpp b/torch/csrc/jit/autodiff.cpp
index da11677..fdeb0ef 100644
--- a/torch/csrc/jit/autodiff.cpp
+++ b/torch/csrc/jit/autodiff.cpp
@@ -17,7 +17,7 @@
 
 bool isDifferentiable(Node * n) {
   static std::unordered_set<Symbol> differentiable_kinds = {
-    aten::add, aten::sub, aten::mul, prim::Constant, prim::ReplaceIfUndef,
+    aten::add, aten::sub, aten::mul, prim::Constant,
     aten::sigmoid, aten::tanh, aten::mm, aten::chunk, aten::split, aten::t, aten::neg,
     aten::unsqueeze, aten::expand, aten::addmm, aten::gt, aten::lt, aten::eq, aten::ne, aten::ge, aten::le, aten::type_as,
     aten::relu, aten::exp, prim::AutogradAdd
@@ -99,8 +99,6 @@
         return {grads.at(0) * inputs.at(1), grads.at(0) * inputs.at(0)};
       case prim::Constant:
         return {};
-      case prim::ReplaceIfUndef:
-        return {grads.at(0), grads.at(0)};
       case aten::sigmoid:
         return {grads.at(0) * outputs.at(0) * (1 - outputs.at(0))};
       case aten::tanh:
diff --git a/torch/csrc/jit/graph_executor.cpp b/torch/csrc/jit/graph_executor.cpp
index 65c6a70..5ef60d9 100644
--- a/torch/csrc/jit/graph_executor.cpp
+++ b/torch/csrc/jit/graph_executor.cpp
@@ -21,6 +21,7 @@
 #include "torch/csrc/jit/passes/loop_unrolling.h"
 #include "torch/csrc/jit/passes/lower_grad_of.h"
 #include "torch/csrc/jit/symbolic_variable.h"
+#include "torch/csrc/jit/ivalue.h"
 
 #include "torch/csrc/autograd/edge.h"
 #include "torch/csrc/autograd/function.h"
@@ -72,6 +73,16 @@
 };
 
 
+// helper to run interpreter on variables until we switch
+// everything to IValue
+inline variable_tensor_list runOneStage(const Code & code, variable_tensor_list inputs) {
+  std::vector<IValue> stack(inputs.begin(), inputs.end());
+  InterpreterState(code).runOneStage(stack);
+  return variable_tensor_list(fmap(stack, [](IValue& v) {
+    return std::move(v).toTensor();
+  }));
+}
+
 // an optimized way of executing the subgraph computed directly on
 // tensors rather than Variables.
 // This will unwrap Variables, run the plan, and re-wrap them.
@@ -90,8 +101,7 @@
     if(grad) {
       return runWithGrad(std::move(stack));
     }
-    InterpreterState(f).runOneStage(stack);
-    return stack;
+    return runOneStage(f, std::move(stack));
   }
   std::shared_ptr<Graph> get_graph() const {
     return graph;
@@ -113,14 +123,15 @@
   }
 
 private:
-  // inplace to avoid allocations
-  variable_tensor_list unwrapVariables(variable_tensor_list && list) const {
-    for(auto & v : list) {
-      v = v.defined() ? autograd::as_variable_ref(v).detach() : at::Tensor();
-    }
-    return std::move(list);
+  // note: should be inplace to avoid allocations, but we have to switch from
+  // a list of tensor to a list of ivalues
+  std::vector<IValue> unwrapVariables(variable_tensor_list && list) const {
+    return fmap(list, [](const Variable& v) -> IValue {
+      return v.defined() ? autograd::as_variable_ref(v).detach() : at::Tensor();
+    });
   }
-  // inplace to avoid allocations
+  // note: should be inplace to avoid allocations, but we have to switch from
+  // a list of tensor to a list of ivalues
   variable_tensor_list wrapTensors(tensor_list && list) const {
     for(auto & v : list) {
       v = autograd::make_variable(v, /*requires_grad=*/false);
@@ -152,7 +163,8 @@
 
     auto stack = unwrapVariables(std::move(inputs));
     InterpreterState(f).runOneStage(stack);
-    variable_tensor_list outputs = std::move(stack);
+    variable_tensor_list outputs(
+        fmap(stack, [](IValue& v) { return std::move(v).toTensor(); }));
 
     // hookup the gradients for the output tensors that require gradients
     // to the inputs to our gradient function df
@@ -311,11 +323,7 @@
 
   variable_tensor_list runFallback(variable_tensor_list inputs) {
     auto & fb = getOrCreateAutogradFallback();
-    InterpreterState state(fb);
-    auto stack = std::move(inputs);
-    state.runOneStage(stack);
-    // note: we never unwrapped inputs, because we want autograd to record the trace
-    return stack;
+    return runOneStage(fb, std::move(inputs));
   }
 
   static bool calcMayIntroduceGradient(Block* b) {
diff --git a/torch/csrc/jit/interned_strings.h b/torch/csrc/jit/interned_strings.h
index a4a73eb..b61a49b 100644
--- a/torch/csrc/jit/interned_strings.h
+++ b/torch/csrc/jit/interned_strings.h
@@ -35,7 +35,6 @@
 _(prim, Placeholder) /* debug */ \
 _(prim, Print) \
 _(prim, PythonOp) \
-_(prim, ReplaceIfUndef) \
 _(prim, Reverse) \
 _(prim, Return) \
 _(prim, Store) \
diff --git a/torch/csrc/jit/interpreter.cpp b/torch/csrc/jit/interpreter.cpp
index 1fb82c9..1dd6ea6 100644
--- a/torch/csrc/jit/interpreter.cpp
+++ b/torch/csrc/jit/interpreter.cpp
@@ -9,6 +9,7 @@
 #include "torch/csrc/jit/graph_executor.h"
 #include "torch/csrc/jit/ir.h"
 #include "torch/csrc/jit/tensor_conversions.h"
+#include "torch/csrc/jit/ivalue.h"
 #include "torch/csrc/variable_tensor_functions.h"
 #include "torch/csrc/autograd/generated/variable_factories.h"
 
@@ -410,7 +411,7 @@
     JIT_ASSERT(inst.debug_name == prim::Placeholder);
     auto offset = relativeJump(from_inst, to_inst);
     inst.callback = [offset](Stack & stack) {
-      auto t = tensor_as<int64_t>(pop(stack));
+      auto t = tensor_as<int64_t>(pop(stack).toTensor());
       return (t == 0) ? offset : 0;
     };
     inst.debug_name = prim::JumpZ;
@@ -422,7 +423,7 @@
     JIT_ASSERT(inst.debug_name == prim::Placeholder);
     auto offset = relativeJump(from_inst, to_inst);
     inst.callback = [offset](Stack & stack) {
-      auto t = tensor_as<int64_t>(pop(stack));
+      auto t = tensor_as<int64_t>(pop(stack).toTensor());
       return (t != 0) ? offset : 0;
     };
     inst.debug_name = prim::JumpNZ;
@@ -629,7 +630,8 @@
     return [=](Stack& stack) mutable {
       autograd::profiler::RecordFunction record("GraphExecutor");
       auto inputs = last(stack, num_inputs);
-      variable_tensor_list tinputs(inputs.begin(), inputs.end());
+      variable_tensor_list tinputs(
+          fmap(inputs, [](const IValue& v) { return v.toTensor(); }));
       drop(stack, num_inputs);
       //TODO: has graph executor work from a stack as well
       variable_tensor_list toutputs = executor->run(variable_tensor_list(std::move(tinputs)));
@@ -774,7 +776,7 @@
   // in the case where it is true, then the interpreter and this array get copied
   // if this every becomes a bottleneck then we _should_ consider minimizing the
   // total number or register
-  std::vector<at::Tensor> registers;
+  std::vector<IValue> registers;
 
   // single buffer for input/output calls to ATen functions, so that we do not reallocate
   Stack stack;
@@ -799,7 +801,7 @@
 InterpreterState::~InterpreterState() {}
 
 void InterpreterState::runOneStage(Stack & stack) {
-    return pImpl->runOneStage(stack);
+  return pImpl->runOneStage(stack);
 }
 
 const TensorType & InterpreterState::tensorTypeForInput(size_t i) const {
diff --git a/torch/csrc/jit/interpreter.h b/torch/csrc/jit/interpreter.h
index ed086bd..b908552 100644
--- a/torch/csrc/jit/interpreter.h
+++ b/torch/csrc/jit/interpreter.h
@@ -19,6 +19,8 @@
 struct Graph;
 struct Node;
 struct TensorType;
+struct IValue;
+using Stack = std::vector<IValue>;
 
 struct Code {
   Code()
@@ -44,7 +46,7 @@
   // advance the interpreter state by running one stage. Returning the
   // outputs for that stage, suspending the computation.
   // Call this function again continues computation where it left off.
-  void runOneStage(std::vector<at::Tensor> & stack);
+  void runOneStage(Stack & stack);
   const TensorType & tensorTypeForInput(size_t i) const;
   ~InterpreterState();
   // create a copy of InterpreterState with its current state
diff --git a/torch/csrc/jit/ivalue.h b/torch/csrc/jit/ivalue.h
new file mode 100644
index 0000000..c31436d
--- /dev/null
+++ b/torch/csrc/jit/ivalue.h
@@ -0,0 +1,278 @@
+#pragma once
+#include <ATen/ATen.h>
+#include "torch/csrc/assertions.h"
+
+namespace torch { namespace jit {
+
+// smart pointer to hold onto at::Retainable objects in a generic way
+// this is close to the implementation of boost's intrusive_ptr
+template<typename PointerType>
+struct Shared {
+  Shared(): Shared(nullptr, false) {}
+  Shared(PointerType * self, bool retain)
+  : pImpl(self) {
+    if(retain && pImpl)
+      pImpl->retain();
+  }
+  Shared(const Shared & rhs)
+  : pImpl(rhs.pImpl) {
+    if (pImpl)
+      pImpl->retain();
+  }
+  Shared(Shared && rhs) noexcept
+  : pImpl(rhs.pImpl) {
+    rhs.pImpl = nullptr;
+  }
+  ~Shared() {
+    if (pImpl)
+      pImpl->release();
+  }
+  Shared & operator=(Shared && rhs) & {
+    rhs.swap(*this);
+    return *this;
+  }
+  Shared & operator=(Shared const & rhs) & {
+      //Shared ctor retains original rhs.pImpl
+      //then rhs.pImpl is swapped with this->pImpl
+      //finally Shared dtor releases rhs.pImpl, which was originally this->pImpl
+      Shared(rhs).swap(*this);
+      return *this;
+  }
+  void reset() {
+    Shared().swap(*this);
+  }
+  void reset(PointerType * rhs) {
+    Shared(rhs, true).swap(*this);
+  }
+  void reset(PointerType * rhs, bool retain) {
+    Shared(rhs, retain).swap(*this);
+  }
+  void swap(Shared & rhs) {
+    PointerType * tmp = pImpl;
+    pImpl = rhs.pImpl;
+    rhs.pImpl = tmp;
+  }
+  PointerType* get() const {
+    return pImpl;
+  }
+  PointerType* detach() {
+    PointerType * ret = pImpl;
+    pImpl = nullptr;
+    return ret;
+  }
+  PointerType& operator*() const {
+    return  *get();
+  }
+  PointerType* operator->() const {
+    return get();
+  }
+  operator bool() const {
+    return pImpl != nullptr;
+  }
+private:
+  PointerType * pImpl;
+};
+
+
+template<typename T>
+struct ConstantList;
+struct IValue;
+using Tuple = ConstantList<IValue>;
+using IntList = ConstantList<int64_t>;
+using DoubleList = ConstantList<double>;
+
+// IValue is the generic tagged union used by the interpreter to hold
+// all value types.
+// It is a 16-byte object with an 8-byte payload and an 8-byte tag.
+// The tag is currently 4 bytes to determine the type, and 1 byte
+// to mark whether that type is a subtype of at::Retainable and needs
+// retain/release calls.
+struct IValue {
+  IValue()
+  : payload(0)
+  , tag(Tag::None)
+  , retainable(false) {}
+  IValue(const IValue& rhs)
+      : payload(rhs.payload),
+        tag(rhs.tag),
+        retainable(rhs.retainable) {
+    if (retainable)
+      as_retainable->retain();
+  }
+  IValue(IValue&& rhs) noexcept : IValue() {
+    swap(rhs);
+  }
+  ~IValue() {
+    if (retainable) {
+      as_retainable->release();
+    }
+  }
+  IValue & operator=(IValue && rhs) & {
+    rhs.swap(*this);
+    return *this;
+  }
+  IValue & operator=(IValue const & rhs) & {
+      IValue(rhs).swap(*this);
+      return *this;
+  }
+  void swap(IValue & rhs) {
+    std::swap(payload, rhs.payload);
+    std::swap(retainable, rhs.retainable);
+    std::swap(tag, rhs.tag);
+  }
+  // Accessors for subtypes are arragned together below
+  // While some of these accessors could be generated through templates,
+  // we prefer to write them manually for clarity
+
+  // Tensor
+  IValue(at::Tensor t)
+  : tag(Tag::Tensor), retainable(t.defined())  {
+    // note: the undefined tensor is not refcounted, so while it
+    // is tagged as a tensor, retainable is set to false.
+    as_tensor_impl = t.at::detail::TensorBase::detach();
+  }
+  bool isTensor() const { return Tag::Tensor == tag; }
+  at::Tensor toTensor() && {
+    JIT_ASSERT(isTensor());
+    at::Tensor t(as_tensor_impl, /*retain=*/false);
+    clearToNone();
+    return t;
+  }
+  at::Tensor toTensor() const & {
+    JIT_ASSERT(isTensor());
+    return at::Tensor(as_tensor_impl, /*retain=*/true);
+  }
+
+  // Tuple
+  IValue(Shared<Tuple> v);
+  bool isTuple() const { return Tag::Tuple == tag; }
+  Shared<Tuple> toTuple() && {
+    JIT_ASSERT(isTuple());
+    return moveToRetainable<Tuple>();
+  }
+  Shared<Tuple> toTuple() const & {
+    JIT_ASSERT(isTuple());
+    return toRetainable<Tuple>();
+  }
+
+  // Double
+  IValue(double d)
+  : tag(Tag::Double), retainable(false) {
+    as_double = d;
+  }
+  bool isDouble() const { return Tag::Double == tag; }
+  double toDouble() const {
+    JIT_ASSERT(isDouble());
+    return as_double;
+  }
+
+  // Int
+  IValue(int64_t i)
+  : tag(Tag::Int), retainable(false) {
+    as_int = i;
+  }
+  // allow you to pass literals (3, 4) without ambiguity
+  IValue(int32_t i)
+  : IValue(static_cast<int64_t>(i)) {}
+
+  bool isInt() const { return Tag::Int == tag; }
+  int64_t toInt() const {
+    JIT_ASSERT(isInt());
+    return as_int;
+  }
+
+  // IntList
+  IValue(Shared<IntList> v);
+  bool isIntList() const { return Tag::IntList == tag; }
+  Shared<IntList> toIntList() && {
+    JIT_ASSERT(isIntList());
+    return moveToRetainable<IntList>();
+  }
+  Shared<IntList> toIntList() const & {
+    JIT_ASSERT(isIntList());
+    return toRetainable<IntList>();
+  }
+
+  // DoubleList
+  IValue(Shared<DoubleList> v);
+  bool isDoubleList() const { return Tag::DoubleList == tag; }
+  Shared<DoubleList> toDoubleList() && {
+    JIT_ASSERT(isDoubleList());
+    return moveToRetainable<DoubleList>();
+  }
+  Shared<DoubleList> toDoubleList() const & {
+    JIT_ASSERT(isDoubleList());
+    return toRetainable<DoubleList>();
+  }
+
+  bool isNone() {
+    return Tag::None == tag;
+  }
+
+private:
+  template<typename T>
+  Shared<T> moveToRetainable() {
+    Shared<T> t(static_cast<T*>(as_retainable), false);
+    clearToNone();
+    return t;
+  }
+  template<typename T>
+  Shared<T> toRetainable() const {
+    return Shared<T>(static_cast<T*>(as_retainable), true);
+  }
+  void clearToNone() {
+    payload = 0;
+    tag = Tag::None;
+    retainable = false;
+  }
+  enum class Tag : uint32_t {
+    None, Tensor, Double, Int, Tuple, IntList, DoubleList
+  };
+  union {
+    at::TensorImpl* as_tensor_impl;
+    at::Retainable* as_retainable;
+    double as_double;
+    int64_t as_int;
+    // this type should be as big as all the other types because it will
+    // be used to copy the union's value in certain cases
+    int64_t payload;
+  };
+  Tag tag;
+  bool retainable;
+};
+
+
+// non-mutable list
+template<typename Elem>
+struct ConstantList : at::Retainable {
+ private:
+  ConstantList(std::vector<Elem> elements_)
+  : elements_(std::move(elements_)) {}
+  std::vector<Elem> elements_;
+ public:
+  static Shared<ConstantList<Elem>> create(std::vector<Elem> elements_) {
+    return Shared<ConstantList<Elem>>(
+        new ConstantList<Elem>(std::move(elements_)), false);
+  }
+  at::ArrayRef<Elem> elements() const {
+    return elements_;
+  }
+};
+
+inline IValue::IValue(Shared<Tuple> v)
+: tag(Tag::Tuple), retainable(true) {
+  as_retainable = v.detach();
+}
+
+inline IValue::IValue(Shared<IntList> v)
+: tag(Tag::IntList), retainable(true) {
+  as_retainable = v.detach();
+}
+
+inline IValue::IValue(Shared<DoubleList> v)
+: tag(Tag::DoubleList), retainable(true) {
+  as_retainable = v.detach();
+}
+
+
+}}
diff --git a/torch/csrc/jit/passes/shape_analysis.cpp b/torch/csrc/jit/passes/shape_analysis.cpp
index 7e4b45e..f8239c5 100644
--- a/torch/csrc/jit/passes/shape_analysis.cpp
+++ b/torch/csrc/jit/passes/shape_analysis.cpp
@@ -88,7 +88,7 @@
 
 void PropagateShapeOnNodeByRunningIt(Node* node, const std::vector<TensorType*>& types) {
   auto op = getOperation(node);
-  std::vector<at::Tensor> stack;
+  Stack stack;
 
   for(auto & type : types) {
     stack.push_back(representativeTensor(type));
@@ -102,7 +102,7 @@
 
   JIT_ASSERT(stack.size() == node->outputs().size());
   for(size_t i = 0; i < stack.size(); ++i) {
-    node->outputs()[i]->inferTypeFrom(stack[i]);
+    node->outputs()[i]->inferTypeFrom(stack[i].toTensor());
   }
 }
 
@@ -322,14 +322,6 @@
         node->output()->setType(ten->withSizes(sizes));
       }
     } break;
-    case prim::ReplaceIfUndef: {
-      // If types[0] has a type, then it is not defined, and the type will
-      // get set to types[0] because that will be the value propagated.
-      // If its type is not defined, then unification is an undefined type.
-      SHAPE_ASSERT(types.size() == 1);
-      node->output()->setType(types.at(0)->shared_from_this());
-      handled = true;
-    } break;
     case prim::Constant: {
       node->output()->inferTypeFrom(node->t(attr::value));
       handled = true;
diff --git a/torch/csrc/jit/python_interpreter.cpp b/torch/csrc/jit/python_interpreter.cpp
index c0668b7..5af53c4 100644
--- a/torch/csrc/jit/python_interpreter.cpp
+++ b/torch/csrc/jit/python_interpreter.cpp
@@ -44,7 +44,7 @@
         py_inputs[i] = py::reinterpret_borrow<py::object>(
             op->scalar_args[next_scalar++].get());
       } else if (arg_type == 't') {
-        auto var = peek(stack, next_tensor, num_inputs);
+        auto var = std::move(peek(stack, next_tensor, num_inputs)).toTensor();
         py_inputs[i] =
             py::reinterpret_steal<py::object>(THPVariable_Wrap(var));
         next_tensor++;
diff --git a/torch/csrc/jit/register_prim_ops.cpp b/torch/csrc/jit/register_prim_ops.cpp
index 0d084ed..3a2ae20 100644
--- a/torch/csrc/jit/register_prim_ops.cpp
+++ b/torch/csrc/jit/register_prim_ops.cpp
@@ -42,7 +42,10 @@
             autograd::profiler::RecordFunction record("FusionGroup");
             std::vector<at::Tensor> toutputs;
             // TODO: have fusion_fn work off of a stack as well
-            fusion_fn->launch(last(stack, num_inputs), toutputs);
+            auto tinputs = fmap(last(stack, num_inputs), [](const IValue& v) {
+              return v.toTensor();
+            });
+            fusion_fn->launch(tinputs, toutputs);
             drop(stack, num_inputs);
             stack.insert(stack.end(), toutputs.begin(), toutputs.end());
             return 0;
@@ -70,27 +73,13 @@
           };
         }),
     Operator(
-        prim::ReplaceIfUndef,
-        [](Node* n) {
-          return [](Stack& stack) {
-            auto alternate = pop(stack);
-            auto result = pop(stack);
-            if (result.defined()) {
-              stack.push_back(std::move(result));
-            } else {
-              stack.push_back(std::move(alternate));
-            }
-            return 0;
-          };
-        }),
-
-    Operator(
         prim::Print,
         [](Node* node) {
           size_t num_inputs = node->inputs().size();
           return [num_inputs](Stack& stack) {
             bool first = true;
-            for (at::Tensor i : last(stack, num_inputs)) {
+            for (const IValue& i_ : last(stack, num_inputs)) {
+              auto i = i_.toTensor();
               if (!first)
                 std::cout << " ";
               first = false;
@@ -114,7 +103,7 @@
     // and inst.outputs
     Operator(prim::Load, noop),
     // x, y = Store
-    // stores values from stack into registers, the actual callback does
+    // stores vales from stack into registers, the actual callback does
     // nothing since the stack manipulation is already encoded in inst.inputs
     // and inst.outputs
     Operator(prim::Store, noop),
@@ -132,8 +121,8 @@
         onnx::Reshape,
         [](Node* node) {
           return [=](Stack& stack) {
-            auto shape = pop(stack).contiguous();
-            auto input = pop(stack);
+            auto shape = pop(stack).toTensor().contiguous();
+            auto input = pop(stack).toTensor();
             JIT_ASSERT(shape.ndimension() == 1);
             at::IntList shape_list(shape.data<int64_t>(), shape.size(0));
             stack.push_back(input.reshape(shape_list));
@@ -144,7 +133,7 @@
         onnx::Shape,
         [](Node* node) {
           return [=](Stack& stack) {
-            auto t = pop(stack);
+            auto t = pop(stack).toTensor();
             at::IntList sizes = t.sizes();
             auto sizes_tensor = torch::empty(
                 {static_cast<int64_t>(sizes.size())}, at::dtype(at::kLong));
@@ -165,8 +154,8 @@
           auto false_ = at::full({}, 0, at::kLong);
           return [=](Stack& stack) {
             bool result = false;
-            for (const at::Tensor& t : last(stack, num_inputs)) {
-              if (t.defined()) {
+            for (const IValue& t : last(stack, num_inputs)) {
+              if (std::move(t).toTensor().defined()) {
                 result = true;
                 break;
               }
@@ -181,8 +170,8 @@
         prim::AutogradAdd,
         [](Node* node) {
           return [=](Stack& stack) {
-            auto a = pop(stack);
-            auto b = pop(stack);
+            auto a = pop(stack).toTensor();
+            auto b = pop(stack).toTensor();
             if (!a.defined())
               stack.push_back(b);
             else if (!b.defined())
diff --git a/torch/csrc/jit/script/compiler.cpp b/torch/csrc/jit/script/compiler.cpp
index 8d1d3f7..df3ff81 100644
--- a/torch/csrc/jit/script/compiler.cpp
+++ b/torch/csrc/jit/script/compiler.cpp
@@ -297,9 +297,9 @@
 }
 
 at::optional<std::vector<int64_t>> getIntListAttribute(at::optional<int32_t> N, Value* input) {
-  auto list = constant_as<at::IntList>(input);
+  auto list = constant_as<std::vector<int64_t>>(input);
   if(list)
-    return std::vector<int64_t>(*list);
+    return list;
   // broadcast IntList[3] with value 4 -> {4, 4, 4}
   if(!N)
     return at::nullopt;
diff --git a/torch/csrc/jit/stack.h b/torch/csrc/jit/stack.h
index 5037253..e4d1d18 100644
--- a/torch/csrc/jit/stack.h
+++ b/torch/csrc/jit/stack.h
@@ -1,10 +1,11 @@
 #pragma once
 #include "ATen/ATen.h"
 #include "torch/csrc/jit/tensor_conversions.h"
+#include "torch/csrc/jit/ivalue.h"
 
 namespace torch { namespace jit {
 
-using Stack = std::vector<at::Tensor>;
+using Stack = std::vector<IValue>;
 using Operation = std::function<int(Stack&)>;
 
 // An operation with N inputs and M outputs pops the last N inputs off
@@ -21,21 +22,21 @@
 
 // treat the last N elements of the stack as a list, looking up
 // element i
-static inline at::Tensor & peek(Stack & stack, size_t i, size_t N) {
+static inline IValue & peek(Stack & stack, size_t i, size_t N) {
   return *(stack.end() - N + i);
 }
 // treat the last N elements of the stack as a list, looking up the
 // slice starting at index i and having length len
-static inline at::ArrayRef<at::Tensor> peekSlice(Stack & stack, size_t i, size_t len, size_t N) {
-  return at::ArrayRef<at::Tensor>(stack).slice(stack.size() - N + i, len);
+static inline at::ArrayRef<IValue> peekSlice(Stack & stack, size_t i, size_t len, size_t N) {
+  return at::ArrayRef<IValue>(stack).slice(stack.size() - N + i, len);
 }
-static inline at::ArrayRef<at::Tensor> last(Stack & stack, size_t N) {
+static inline at::ArrayRef<IValue> last(Stack & stack, size_t N) {
   return peekSlice(stack, 0, N, N);
 }
 static inline void drop(Stack & stack, size_t n) {
   stack.erase(stack.end() - n, stack.end());
 }
-static inline at::Tensor pop(Stack & stack) {
+static inline IValue pop(Stack & stack) {
   auto r = std::move(stack.back());
   stack.pop_back();
   return r;
@@ -47,22 +48,22 @@
 // pack takes the return values of aten functions pushes them onto the stack
 template<typename T>
 inline void pack(Stack & stack, T&& v) {
-  stack.push_back(as_variable(std::move(v)));
+  stack.push_back(IValue(as_variable(std::move(v))));
 }
 template<>
 inline void pack(Stack & stack, at::Tensor&& v) {
-  stack.push_back(std::move(v));
+  stack.push_back(IValue(std::move(v)));
 }
 
 template<>
 inline void pack(Stack & stack, autograd::Variable&& v) {
-  stack.push_back(std::move(v));
+  stack.push_back(IValue(std::move(v)));
 }
 
 template<>
 inline void pack(Stack & stack, std::vector<at::Tensor>&& ts) {
   for(auto& t : ts) {
-    stack.push_back(std::move(t));
+    stack.push_back(IValue(std::move(t)));
   }
 }
 
diff --git a/torch/csrc/jit/tensor_conversions.h b/torch/csrc/jit/tensor_conversions.h
index 276db96..84162a4 100644
--- a/torch/csrc/jit/tensor_conversions.h
+++ b/torch/csrc/jit/tensor_conversions.h
@@ -57,15 +57,15 @@
 };
 
 template<>
-struct tensor_as_impl<at::IntList> {
-  at::IntList operator()(at::Tensor&& t) {
+struct tensor_as_impl<std::vector<int64_t>> {
+  std::vector<int64_t> operator()(at::Tensor&& t) {
     if (t.type().scalarType() != at::ScalarType::Long)
       throw tensor_conversion_error("Expected a LongTensor");
     if (t.dim() != 1)
       throw tensor_conversion_error("Expected a 1D LongTensor");
     if (!t.is_contiguous())
       throw tensor_conversion_error("Expected a contiguous LongTensor");
-    return at::IntList{t.data<int64_t>(), static_cast<size_t>(t.numel())};
+    return std::vector<int64_t>(t.data<int64_t>(), t.data<int64_t>() + t.numel());
   }
 };
 
diff --git a/torch/csrc/jit/test_jit.cpp b/torch/csrc/jit/test_jit.cpp
index 54e99f9..7b784f0 100644
--- a/torch/csrc/jit/test_jit.cpp
+++ b/torch/csrc/jit/test_jit.cpp
@@ -36,6 +36,8 @@
 #include "torch/csrc/jit/graph_executor.h"
 #include "torch/csrc/jit/script/compiler.h"
 #include "torch/csrc/jit/script/module.h"
+#include "torch/csrc/jit/ivalue.h"
+
 #include "onnx/onnx_pb.h"
 
 
@@ -439,8 +441,12 @@
 }
 
 void runOneStage(InterpreterState & interp, const std::vector<at::Tensor> & inputs, std::vector<at::Tensor> & outputs) {
-  outputs = inputs;
-  interp.runOneStage(outputs);
+  std::vector<IValue> stack(inputs.begin(), inputs.end());
+  interp.runOneStage(stack);
+  outputs.clear();
+  for(auto & ivalue : stack) {
+    outputs.push_back(std::move(ivalue).toTensor());
+  }
 }
 
 void interpTest() {
@@ -878,7 +884,7 @@
 void testControlFlow() {
   script::Module cu;
   script::defineMethodsInModule(cu, cf_examples, torch::jit::script::Resolver(), nullptr);
-  auto run = [&](const std::string & name, std::vector<at::Tensor> stack) {
+  auto run = [&](const std::string & name, std::vector<IValue> stack) {
     auto graph = cu.get_method(name).graph();
     Code code(graph);
     InterpreterState interp(code);
@@ -886,8 +892,8 @@
     return stack;
   };
 
-  auto L = [](int64_t l) { return autograd::make_variable(at::Scalar(l).toTensor()); };
-  auto V = [](at::Tensor t) { return at::Scalar(t).toLong(); };
+  auto L = [](int64_t l) { return IValue(autograd::make_variable(at::Scalar(l).toTensor())); };
+  auto V = [](IValue t) { return at::Scalar(std::move(t).toTensor()).toLong(); };
   auto run_binary = [&](const std::string & name, int64_t a, int64_t b) {
     return V(run(name, {L(a), L(b)})[0]);
   };
@@ -898,6 +904,50 @@
   REQUIRE(256 == run_binary("while_test",2,0));
 }
 
+void testIValue() {
+  Shared<IntList> foo = IntList::create({3, 4, 5});
+  JIT_ASSERT(foo->use_count() == 1);
+  IValue bar(foo);
+  JIT_ASSERT(foo->use_count() == 2);
+  auto baz = bar;
+  JIT_ASSERT(foo->use_count() == 3);
+  auto foo2 = std::move(bar);
+  JIT_ASSERT(foo->use_count() == 3);
+  JIT_ASSERT(foo2.isIntList());
+  JIT_ASSERT(bar.isNone());
+  foo2 = IValue(4.0);
+  JIT_ASSERT(foo2.isDouble());
+  JIT_ASSERT(foo2.toDouble() == 4.0);
+  JIT_ASSERT(foo->use_count() == 2);
+  JIT_ASSERT(baz.toIntList()->elements().equals({3,4,5}));
+
+  auto move_it = std::move(baz).toIntList();
+  JIT_ASSERT(foo->use_count() == 2);
+  JIT_ASSERT(baz.isNone());
+  IValue i(4);
+  JIT_ASSERT(i.isInt() && i.toInt() == 4);
+  IValue dlist(DoubleList::create({3.5}));
+  JIT_ASSERT(
+      dlist.isDoubleList() &&
+      std::move(dlist).toDoubleList()->elements().equals({3.5}));
+  JIT_ASSERT(dlist.isNone());
+  dlist = IValue(DoubleList::create({3.4}));
+  JIT_ASSERT(dlist.toDoubleList()->elements().equals({3.4}));
+  IValue the_list(Tuple::create({IValue(3.4), IValue(4), IValue(foo)}));
+  JIT_ASSERT(foo->use_count() == 3);
+  JIT_ASSERT(the_list.isTuple());
+  auto first = std::move(the_list).toTuple()->elements().at(1);
+  JIT_ASSERT(first.toInt() == 4);
+  at::Tensor tv = at::rand({3,4});
+  IValue ten(tv);
+  JIT_ASSERT(tv.get()->use_count() == 2);
+  auto ten2 = ten;
+  JIT_ASSERT(tv.get()->use_count() == 3);
+  JIT_ASSERT(ten2.toTensor().equal(ten.toTensor()));
+  std::move(ten2).toTensor();
+  JIT_ASSERT(tv.get()->use_count() == 2);
+}
+
 void testProto() {
   ::ONNX_NAMESPACE::ModelProto proto;
   proto.set_producer_name("foo");
@@ -905,6 +955,7 @@
 
 std::string runJITCPPTests() {
   std::stringstream out;
+  testIValue();
   testControlFlow();
   testGraphExecutor();
   testBlocks(out);
diff --git a/torch/csrc/utils/functional.h b/torch/csrc/utils/functional.h
index 3f81228..af5099e 100644
--- a/torch/csrc/utils/functional.h
+++ b/torch/csrc/utils/functional.h
@@ -23,6 +23,15 @@
   return r;
 }
 
+template<typename F, typename T>
+inline auto fmap(T& inputs, const F& fn) -> std::vector<decltype(fn(*inputs.begin()))> {
+  std::vector<decltype(fn(*inputs.begin()))> r;
+  r.reserve(inputs.size());
+  for(auto & input : inputs)
+    r.push_back(fn(input));
+  return r;
+}
+
 // C++ forbids taking an address of a constructor, so here's a workaround...
 // Overload for constructor (R) application
 template<typename R, typename T>