Export modules in ir with google protobuf

Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/9746

Differential Revision: D9110006

Pulled By: li-roy

fbshipit-source-id: 8b9744c042f822fdfe959a7a7fef3d0baff4f639
diff --git a/test/test_jit.py b/test/test_jit.py
index c05ed16..79a0a68 100644
--- a/test/test_jit.py
+++ b/test/test_jit.py
@@ -3603,6 +3603,139 @@
         self.assertEqual(1, foo3(a))
         self.assertEqual(2, foo3(b))
 
+    def test_script_module_export_submodule(self):
+        class M1(torch.jit.ScriptModule):
+            def __init__(self):
+                super(M1, self).__init__(False)
+                self.weight = nn.Parameter(torch.randn(2))
+
+            @torch.jit.script_method
+            def forward(self, thing):
+                return self.weight + thing
+
+        class M2(torch.jit.ScriptModule):
+            def __init__(self):
+                super(M2, self).__init__(False)
+                # test submodule
+                self.sub = M1()
+                self.weight = nn.Parameter(torch.randn(2, 3))
+                self.bias = nn.Parameter(torch.randn(2))
+                self.define("""
+                    def hi(self, a):
+                        return self.weight.mm(a)
+                """)
+
+            @torch.jit.script_method
+            def doit(self, input):
+                return self.weight.mm(input)
+
+            @torch.jit.script_method
+            def doit2(self, input):
+                return self.weight.mm(input)
+
+            @torch.jit.script_method
+            def doit3(self, input):
+                return input + torch.ones([1], dtype=torch.double)
+
+            @torch.jit.script_method
+            def forward(self, input):
+                a = self.doit(input)
+                b = self.doit2(input)
+                c = self.hi(input)
+                return a + b + self.bias + c
+
+        m_orig = M2()
+        m_import = torch.jit.ScriptModule()
+        m_export, storage_map = m_orig.export()
+        torch._C._jit_import_module(m_import, m_export, storage_map)
+
+        input = torch.randn(3, 2)
+        self.assertEqual(m_orig.doit(input), m_import.doit(input))
+        self.assertEqual(m_orig.hi(input), m_import.hi(input))
+        self.assertEqual(m_orig.doit3(input), m_import.doit3(input))
+        self.assertEqual(m_orig.forward(input), m_import.forward(input))
+
+    @skipIfNoTorchVision
+    def test_script_module_export_resnet18(self):
+        x = torch.ones(1, 3, 224, 224)
+        m_orig = torch.jit.trace(torch.ones(1, 3, 224, 224))(torchvision.models.resnet18())
+        m_import = torch.jit.ScriptModule()
+        m_export, storage_map = m_orig.export()
+        torch._C._jit_import_module(m_import, m_export, storage_map)
+
+        input = torch.randn(1, 3, 224, 224, requires_grad=True)
+        output_orig = m_orig(input)
+        output_orig.sum().backward()
+        grad_orig = input.grad.clone()
+        input.grad.zero_()
+
+        output_import = m_import(input)
+        output_import.sum().backward()
+        grad_import = input.grad.clone()
+
+        self.assertEqual(output_orig, output_import)
+        self.assertEqual(grad_orig, grad_import)
+
+    def test_script_module_export_tensor_type(self):
+        class M(torch.jit.ScriptModule):
+
+            def __init__(self, type):
+                super(M, self).__init__(False)
+                self.param = torch.nn.Parameter(torch.zeros((5, 5), dtype=type).random_())
+
+            @torch.jit.script_method
+            def foo(self):
+                return self.param
+
+        for type in [torch.float, torch.double]:
+            m_orig = M(type)
+            m_import = torch.jit.ScriptModule()
+            m_export, storage_map = m_orig.export()
+            torch._C._jit_import_module(m_import, m_export, storage_map)
+            self.assertEqual(m_orig.foo(), m_import.foo())
+            self.assertTrue(m_orig.foo().dtype == m_import.foo().dtype)
+
+    @unittest.skipIf(not RUN_CUDA, "testing cuda tensors require CUDA")
+    def test_script_module_export_tensor_cuda(self):
+        class M(torch.jit.ScriptModule):
+
+            def __init__(self):
+                super(M, self).__init__(False)
+                self.param = torch.nn.Parameter(torch.zeros((5, 5), device='cuda').random_())
+
+            @torch.jit.script_method
+            def foo(self):
+                return self.param
+
+        m_orig = M()
+        m_import = torch.jit.ScriptModule()
+        m_export, storage_map = m_orig.export()
+        torch._C._jit_import_module(m_import, m_export, storage_map)
+        self.assertTrue(m_import.foo().device == torch.device('cpu'))
+        self.assertEqual(m_orig.foo(), m_import.foo())
+        self.assertTrue(m_orig.foo().dtype == m_import.foo().dtype)
+
+    def test_script_module_export_shared_storage(self):
+        class M(torch.jit.ScriptModule):
+
+            def __init__(self):
+                super(M, self).__init__(False)
+                self.param1 = torch.nn.Parameter(torch.rand(5, 5))
+                self.param2 = torch.nn.Parameter(self.param1[3])
+                self.param3 = torch.nn.Parameter(torch.rand(5, 5))
+
+            @torch.jit.script_method
+            def foo(self):
+                return self.param1 + self.param2 + self.param3
+
+        m_orig = M()
+        m_import = torch.jit.ScriptModule()
+        m_export, storage_map = m_orig.export()
+        torch._C._jit_import_module(m_import, m_export, storage_map)
+        self.assertEqual(m_orig.foo(), m_import.foo())
+        self.assertTrue(m_import.param1.storage().data_ptr() == m_import.param2.storage().data_ptr())
+        self.assertTrue(m_import.param1.storage().data_ptr() != m_import.param3.storage().data_ptr())
+
     def test_onnx_export_script_module(self):
         class ModuleToExport(torch.jit.ScriptModule):
             def __init__(self):
diff --git a/torch/csrc/jit/export.cpp b/torch/csrc/jit/export.cpp
index 20208af..7174fd0 100644
--- a/torch/csrc/jit/export.cpp
+++ b/torch/csrc/jit/export.cpp
@@ -22,317 +22,7 @@
 namespace onnx_torch = ::torch::onnx;
 namespace onnx = ::ONNX_NAMESPACE;
 
-std::string value_name(Value* n) {
-  return n->uniqueName();
-}
-
-struct ExportContext {
-  size_t num_blocks = 0;
-  onnx_torch::OperatorExportTypes operator_export_type;
-};
-
-void encodeGraph(onnx::GraphProto * p_g, const std::shared_ptr<Graph> & g,
-                 const std::vector<at::Tensor> & initializers,
-                 ExportContext *ctx, RawDataExportMap* raw_data_export_map=nullptr);
-
-void encodeBlock(onnx::GraphProto * p_g, Block *b,
-                const std::vector<at::Tensor> & initializers,
-                ExportContext *ctx, RawDataExportMap* raw_data_export_map);
-
-void encodeTensor(onnx::TensorProto * p, const at::Tensor & tensor,
-                  at::optional<std::string> external_ref={},
-                  RawDataExportMap* raw_data_export_map = nullptr) {
-  for(auto d : tensor.sizes()) {
-    p->add_dims(d);
-  }
-  onnx::TensorProto_DataType onnx_type;
-  // Most integral types and float16 need to be serialized as int32
-  at::ScalarType cast_type = tensor.type().scalarType();
-  switch(tensor.type().scalarType()) {
-    case at::kDouble:
-      onnx_type = onnx::TensorProto_DataType_DOUBLE;
-      break;
-    case at::kFloat:
-      onnx_type = onnx::TensorProto_DataType_FLOAT;
-      break;
-    case at::kHalf:
-      onnx_type = onnx::TensorProto_DataType_FLOAT16;
-      cast_type = at::kInt;
-      break;
-    case at::kByte:
-      onnx_type = onnx::TensorProto_DataType_UINT8;
-      cast_type = at::kInt;
-      break;
-    case at::kChar:
-      onnx_type = onnx::TensorProto_DataType_INT8;
-      cast_type = at::kInt;
-      break;
-    case at::kShort:
-      onnx_type = onnx::TensorProto_DataType_INT16;
-      cast_type = at::kInt;
-      break;
-    case at::kInt:
-      onnx_type = onnx::TensorProto_DataType_INT32;
-      break;
-    case at::kLong:
-      onnx_type = onnx::TensorProto_DataType_INT64;
-      break;
-    default:
-      AT_ERROR("unexpected tensor scalar type");
-      break;
-  }
-  p->set_data_type(onnx_type);
-  // CPU's HalfTensor doesn't have contiguous(), so first calling contiguous()
-  auto t = tensor.contiguous().toBackend(at::kCPU).toType(cast_type);
-  // Add a buffer to the raw_data_export_map for the caller to dump into an
-  // external data store. If external_ref is not specified, we instead dump
-  // the contiguous data into the protobuf itself
-  if (external_ref) {
-    // For now, we use the name of the tensor as the external lookup name to
-    // avoid ONNX protobuf changes.
-    JIT_ASSERT(external_ref.value() == p->name());
-    JIT_ASSERT(raw_data_export_map != nullptr);
-    JIT_ASSERT(raw_data_export_map->count(external_ref.value()) == 0);
-    (*raw_data_export_map)[external_ref.value()] = t;
-    p->set_raw_data("__EXTERNAL");
-  } else {
-    JIT_ASSERT(t.is_contiguous());
-    p->set_raw_data(std::string(static_cast<char*>(t.data_ptr()),  t.type().elementSizeInBytes() * t.numel()));
-  }
-}
-
-void addAttribute(onnx::NodeProto * n_p, jit::Node * n, jit::Symbol name, ExportContext *ctx) {
-  auto attr = n_p->add_attribute();
-  JIT_ASSERT(name.is_attr());
-  attr->set_name(name.toUnqualString());
-  switch(n->kindOf(name)) {
-    case AttributeKind::f:
-      attr->set_f(n->f(name));
-      attr->set_type(onnx::AttributeProto_AttributeType_FLOAT);
-      break;
-    case AttributeKind::fs:
-      attr->set_type(onnx::AttributeProto_AttributeType_FLOATS);
-      for(auto & v : n->fs(name))
-        attr->add_floats(v);
-      break;
-    case AttributeKind::i:
-      attr->set_type(onnx::AttributeProto_AttributeType_INT);
-      attr->set_i(n->i(name));
-      break;
-    case AttributeKind::is:
-      attr->set_type(onnx::AttributeProto_AttributeType_INTS);
-      for(auto & v : n->is(name))
-        attr->add_ints(v);
-      break;
-    case AttributeKind::s:
-      attr->set_type(onnx::AttributeProto_AttributeType_STRING);
-      attr->set_s(n->s(name));
-      break;
-    case AttributeKind::ss:
-      attr->set_type(onnx::AttributeProto_AttributeType_STRINGS);
-      for(auto & v : n->ss(name))
-        attr->add_strings(v);
-      break;
-    case AttributeKind::t: {
-      attr->set_type(onnx::AttributeProto_AttributeType_TENSOR);
-      auto t = attr->mutable_t();
-      encodeTensor(t, n->t(name));
-    } break;
-    case AttributeKind::ts:
-      attr->set_type(onnx::AttributeProto_AttributeType_TENSORS);
-      for(auto & v : n->ts(name)) {
-        auto t = attr->add_tensors();
-        encodeTensor(t, v);
-      }
-      break;
-    case AttributeKind::g: {
-      attr->set_type(onnx::AttributeProto_AttributeType_GRAPH);
-      auto g = attr->mutable_g();
-      encodeGraph(g, n->g(name), {}, ctx, nullptr);
-    } break;
-    case AttributeKind::gs:
-      attr->set_type(onnx::AttributeProto_AttributeType_GRAPHS);
-      for(auto & v : n->gs(name)) {
-        auto g = attr->add_graphs();
-        encodeGraph(g, v, {}, ctx, nullptr);
-      }
-      break;
-  }
-}
-
-void encodeTypeProtoTensorType(onnx::TypeProto_Tensor* tensor_type, Value* n) {
-  onnx::TensorShapeProto* shape = tensor_type->mutable_shape();
-  if (TensorTypePtr node_type = n->type()->cast<TensorType>()) {
-    const std::vector<std::int64_t>& sizes = node_type->sizes();
-    for (size_t i = 0; i < sizes.size(); i++) {
-      shape->add_dim();
-      shape->mutable_dim(i)->set_dim_value(sizes[i]);
-    }
-    onnx::TensorProto_DataType onnx_type;
-    switch(node_type->scalarType()) {
-      case at::kDouble:
-        onnx_type = onnx::TensorProto_DataType_DOUBLE;
-        break;
-      case at::kFloat:
-        onnx_type = onnx::TensorProto_DataType_FLOAT;
-        break;
-      case at::kHalf:
-        onnx_type = onnx::TensorProto_DataType_FLOAT16;
-        break;
-      case at::kByte:
-        onnx_type = onnx::TensorProto_DataType_UINT8;
-        break;
-      case at::kChar:
-        onnx_type = onnx::TensorProto_DataType_INT8;
-        break;
-      case at::kShort:
-        onnx_type = onnx::TensorProto_DataType_INT16;
-        break;
-      case at::kInt:
-        onnx_type = onnx::TensorProto_DataType_INT32;
-        break;
-      case at::kLong:
-        onnx_type = onnx::TensorProto_DataType_INT64;
-        break;
-      default:
-        AT_ERROR("unexpected tensor scalar type");
-        break;
-    }
-    tensor_type->set_elem_type(onnx_type);
-  }
-}
-
-void encodeValueInfo(onnx::ValueInfoProto* v, Value* n) {
-  v->set_name(value_name(n));
-  onnx::TypeProto* t = v->mutable_type();
-  onnx::TypeProto_Tensor* tensor_type = t->mutable_tensor_type();
-  encodeTypeProtoTensorType(tensor_type, n);
-}
-
-void encodeGraph(onnx::GraphProto * p_g, const std::shared_ptr<Graph>& g,
-                 const std::vector<at::Tensor> & initializers,
-                 ExportContext *ctx, RawDataExportMap* raw_data_export_map) {
-  encodeBlock(p_g, g->block(), initializers, ctx, raw_data_export_map);
-}
-
-void encodeBlock(onnx::GraphProto * p_g, Block *b,
-                 const std::vector<at::Tensor> & initializers,
-                 ExportContext *ctx, RawDataExportMap* raw_data_export_map) {
-  JIT_ASSERT(p_g != nullptr);
-  std::string block_name = "torch-jit-export";
-  if (ctx->num_blocks) {
-    block_name += std::to_string(ctx->num_blocks);
-  }
-  ctx->num_blocks++;
-  p_g->set_name(block_name);
-
-  for (auto input : b->inputs()) {
-    onnx::ValueInfoProto* v = p_g->add_input();
-    encodeValueInfo(v, input);
-  }
-  for (auto output : b->outputs()) {
-    onnx::ValueInfoProto* v = p_g->add_output();
-    encodeValueInfo(v, output);
-  }
-  for (auto node : b->nodes()) {
-    bool is_raw_export = ctx->operator_export_type == onnx_torch::OperatorExportTypes::RAW;
-    if (node->kind() == prim::Undefined && !is_raw_export) {
-      // Undefined nodes are used to implement optional inputs. One
-      // way to "not provide" an optional input is to create an
-      // Undefined node, and pass its output as that input.
-      continue;
-    }
-    auto p_n = p_g->add_node();
-    if (node->getSourceLocation()) {
-      std::stringstream ss;
-      node->getSourceLocation()->highlight(ss);
-      p_n->set_doc_string(ss.str());
-    }
-    for(auto input : node->inputs()) {
-      if (input->node()->kind() == prim::Undefined && !is_raw_export) {
-        p_n->add_input("");
-      } else {
-        p_n->add_input(value_name(input));
-      }
-    }
-    for(auto output : node->outputs()) {
-      p_n->add_output(value_name(output));
-    }
-    if (is_raw_export) {
-      JIT_ASSERT(!node->kind().is_onnx());
-      p_n->set_domain(node->kind().domainString());
-    }
-    else if (ctx->operator_export_type != onnx_torch::OperatorExportTypes::ONNX_ATEN_FALLBACK) {
-      JIT_ASSERT(node->kind().is_onnx());
-    }
-    p_n->set_op_type(node->kind().toUnqualString());
-    for(auto attr_name : node->attributeNames()) {
-      addAttribute(p_n, node, attr_name, ctx);
-    }
-    if (is_raw_export && node->blocks().size() > 0) {
-      auto blocks = p_n->add_attribute();
-      blocks->set_name("_blocks");
-      blocks->set_type(onnx::AttributeProto_AttributeType_GRAPHS);
-      for (auto block : node->blocks()) {
-        auto graph = blocks->add_graphs();
-        encodeBlock(graph, block, initializers, ctx, raw_data_export_map);
-      }
-    }
-    if (node->kind() == torch::jit::onnx::Loop) {
-      JIT_ASSERT(node->blocks().size() == 1);
-
-      auto body = p_n->add_attribute();
-      body->set_name("body");
-      body->set_type(onnx::AttributeProto_AttributeType_GRAPH);
-      auto g = body->mutable_g();
-      encodeBlock(g, node->blocks()[0], {}, ctx, raw_data_export_map);
-    }
-    if (node->kind() == torch::jit::onnx::If) {
-      JIT_ASSERT(node->blocks().size() == 2);
-
-      auto true_branch = p_n->add_attribute();
-      true_branch->set_name("then_branch");
-      true_branch->set_type(onnx::AttributeProto_AttributeType_GRAPH);
-      auto true_g = true_branch->mutable_g();
-      encodeBlock(true_g, node->blocks()[0], {}, ctx, raw_data_export_map);
-
-      auto false_branch = p_n->add_attribute();
-      false_branch->set_name("else_branch");
-      false_branch->set_type(onnx::AttributeProto_AttributeType_GRAPH);
-      auto false_g = false_branch->mutable_g();
-      encodeBlock(false_g, node->blocks()[1], {}, ctx, raw_data_export_map);
-    }
-  }
-  auto num_initializers = initializers.size();
-  JIT_ASSERT(b->inputs().size() >= num_initializers);
-  size_t inputs_count = b->inputs().size() - num_initializers;
-  for (auto & tensor : initializers) {
-    // TODO: stop using positions to determine which initializers
-    // match to which inputs
-    std::string name = p_g->input(inputs_count++).name();
-    auto p = p_g->add_initializer();
-    p->set_name(name);
-    if (raw_data_export_map) {
-      encodeTensor(p, tensor, name, raw_data_export_map);
-    } else {
-      encodeTensor(p, tensor, {});
-    }
-  }
-}
-
-void encodeModel(onnx::ModelProto* p_m, const std::shared_ptr<Graph>& g,
-                 const std::vector<at::Tensor>& initializers,
-                 RawDataExportMap* raw_data_export_map = nullptr,
-                 onnx_torch::OperatorExportTypes operator_export_type
-                   = onnx_torch::OperatorExportTypes::ONNX) {
-  onnx::GraphProto* p_g = p_m->mutable_graph();
-  ExportContext ctx;
-  ctx.operator_export_type = operator_export_type;
-  encodeGraph(p_g, g, initializers, &ctx, raw_data_export_map);
-}
-
-namespace {
-std::string getNodeStackTraceString(Node* n) {
+std::string getNodeStackTraceString(const Node* n) {
   std::stringstream ss;
   if (n->getSourceLocation()) {
     n->getSourceLocation()->highlight(ss);
@@ -341,7 +31,6 @@
   }
   return ss.str();
 }
-} // namespace
 
 void validateGraph(const std::shared_ptr<Graph>& graph, onnx_torch::OperatorExportTypes operator_export_type) {
   for (auto node : graph->nodes()) {
@@ -376,6 +65,591 @@
   }
 }
 
+class EncoderBase {
+ public:
+  EncoderBase(onnx::ModelProto *model_proto,
+             onnx_torch::OperatorExportTypes operator_export_type,
+             bool defer_weight_export = false);
+
+  RawDataExportMap get_raw_data_export_map() {
+    return raw_data_export_map_;
+  }
+
+ protected:
+  void EncodeGraph(onnx::GraphProto *graph_proto,
+                   const std::shared_ptr<Graph> &graph,
+                   const std::vector<at::Tensor> &initializers = {});
+
+  void EncodeBlock(onnx::GraphProto *graph_proto,
+                   const Block *block,
+                   const std::vector<at::Tensor> &initializers = {});
+
+  virtual void EncodeTensor(onnx::TensorProto *tensor_proto,
+                            const at::Tensor &tensor,
+                            const at::optional<std::string> external_ref = {});
+
+  virtual void EncodeIntermediateValueInfo(onnx::GraphProto *graph_proto,
+                                           const Value* n) {};
+
+  virtual void EncodeValueInfo(onnx::GraphProto *graph_proto,
+                               onnx::ValueInfoProto* v,
+                               const Value* n);
+
+  void AddAttribute(onnx::NodeProto *node_proto, const jit::Node *node, const jit::Symbol name);
+
+  size_t num_blocks_;
+  bool defer_weight_export_;
+  onnx_torch::OperatorExportTypes operator_export_type_;
+  RawDataExportMap raw_data_export_map_;
+};
+
+onnx::TensorProto_DataType ATenTypeToOnnxType(at::ScalarType at_type) {
+  switch(at_type) {
+    case at::kDouble:
+      return onnx::TensorProto_DataType_DOUBLE;
+    case at::kFloat:
+      return onnx::TensorProto_DataType_FLOAT;
+    case at::kHalf:
+      return onnx::TensorProto_DataType_FLOAT16;
+    case at::kByte:
+      return onnx::TensorProto_DataType_UINT8;
+    case at::kChar:
+      return onnx::TensorProto_DataType_INT8;
+    case at::kShort:
+      return onnx::TensorProto_DataType_INT16;
+    case at::kInt:
+      return onnx::TensorProto_DataType_INT32;
+    case at::kLong:
+      return onnx::TensorProto_DataType_INT64;
+    default:
+      AT_ERROR("unexpected tensor scalar type");
+  }
+}
+
+EncoderBase::EncoderBase(
+    onnx::ModelProto *model_proto,
+    onnx_torch::OperatorExportTypes operator_export_type,
+    bool defer_weight_export)
+    : num_blocks_(0),
+      defer_weight_export_(defer_weight_export),
+      operator_export_type_(operator_export_type) {
+  model_proto->set_producer_name("pytorch");
+  model_proto->set_ir_version(onnx::IR_VERSION);
+  model_proto->set_producer_version("0.3");
+}
+
+void EncoderBase::EncodeValueInfo(
+    onnx::GraphProto *graph_proto,
+    onnx::ValueInfoProto* v,
+    const Value* n) {
+  v->set_name(n->uniqueName());
+  onnx::TypeProto* t = v->mutable_type();
+  onnx::TypeProto_Tensor* tensor_type = t->mutable_tensor_type();
+
+  onnx::TensorShapeProto* shape = tensor_type->mutable_shape();
+  if (TensorTypePtr node_type = n->type()->cast<TensorType>()) {
+    const std::vector<std::int64_t>& sizes = node_type->sizes();
+    for (size_t i = 0; i < sizes.size(); i++) {
+      shape->add_dim();
+      shape->mutable_dim(i)->set_dim_value(sizes[i]);
+    }
+    tensor_type->set_elem_type(ATenTypeToOnnxType(node_type->scalarType()));
+  }
+}
+
+void EncoderBase::EncodeGraph(
+    onnx::GraphProto *graph_proto,
+    const std::shared_ptr<Graph> &graph,
+    const std::vector<at::Tensor> &initializers) {
+  EncodeBlock(graph_proto, graph->block(), initializers);
+}
+
+void EncoderBase::EncodeBlock(
+    onnx::GraphProto *graph_proto, const Block *block,
+    const std::vector<at::Tensor> &initializers) {
+  JIT_ASSERT(graph_proto != nullptr);
+  std::string block_name = "torch-jit-export";
+  if (num_blocks_) {
+    block_name += std::to_string(num_blocks_);
+  }
+  num_blocks_++;
+  graph_proto->set_name(block_name);
+
+  for (auto input : block->inputs()) {
+    onnx::ValueInfoProto* v = graph_proto->add_input();
+    EncodeValueInfo(graph_proto, v, input);
+  }
+  for (auto output : block->outputs()) {
+    onnx::ValueInfoProto* v = graph_proto->add_output();
+    EncodeValueInfo(graph_proto, v, output);
+  }
+  for (auto node : block->nodes()) {
+    bool is_raw_export = operator_export_type_ == onnx_torch::OperatorExportTypes::RAW;
+    if (node->kind() == prim::Undefined && !is_raw_export) {
+      // Undefined nodes are used to implement optional inputs. One
+      // way to "not provide" an optional input is to create an
+      // Undefined node, and pass its output as that input.
+      continue;
+    }
+    auto p_n = graph_proto->add_node();
+    if (node->getSourceLocation()) {
+      std::stringstream ss;
+      node->getSourceLocation()->highlight(ss);
+      p_n->set_doc_string(ss.str());
+    }
+    for(auto input : node->inputs()) {
+      if (input->node()->kind() == prim::Undefined && !is_raw_export) {
+        p_n->add_input("");
+      } else {
+        p_n->add_input(input->uniqueName());
+      }
+    }
+    for(auto output : node->outputs()) {
+      p_n->add_output(output->uniqueName());
+      EncodeIntermediateValueInfo(graph_proto, output);
+    }
+    if (is_raw_export) {
+      JIT_ASSERT(!node->kind().is_onnx());
+      p_n->set_domain(node->kind().domainString());
+    }
+    else if (operator_export_type_ == onnx_torch::OperatorExportTypes::ONNX) {
+      JIT_ASSERT(node->kind().is_onnx());
+    }
+    p_n->set_op_type(node->kind().toUnqualString());
+    for(auto attr_name : node->attributeNames()) {
+      AddAttribute(p_n, node, attr_name);
+    }
+    if (is_raw_export && node->blocks().size() > 0) {
+      auto blocks = p_n->add_attribute();
+      blocks->set_name("_blocks");
+      blocks->set_type(onnx::AttributeProto_AttributeType_GRAPHS);
+      for (auto block : node->blocks()) {
+        auto graph = blocks->add_graphs();
+        EncodeBlock(graph, block, initializers);
+      }
+    }
+    if (node->kind() == torch::jit::onnx::Loop) {
+      JIT_ASSERT(node->blocks().size() == 1);
+
+      auto body = p_n->add_attribute();
+      body->set_name("body");
+      body->set_type(onnx::AttributeProto_AttributeType_GRAPH);
+      auto g = body->mutable_g();
+      EncodeBlock(g, node->blocks()[0]);
+    }
+    if (node->kind() == torch::jit::onnx::If) {
+      JIT_ASSERT(node->blocks().size() == 2);
+
+      auto true_branch = p_n->add_attribute();
+      true_branch->set_name("then_branch");
+      true_branch->set_type(onnx::AttributeProto_AttributeType_GRAPH);
+      auto true_g = true_branch->mutable_g();
+      EncodeBlock(true_g, node->blocks()[0]);
+
+      auto false_branch = p_n->add_attribute();
+      false_branch->set_name("else_branch");
+      false_branch->set_type(onnx::AttributeProto_AttributeType_GRAPH);
+      auto false_g = false_branch->mutable_g();
+      EncodeBlock(false_g, node->blocks()[1]);
+    }
+  }
+  auto num_initializers = initializers.size();
+  JIT_ASSERT(block->inputs().size() >= num_initializers);
+  size_t inputs_count = block->inputs().size() - num_initializers;
+  for (auto & tensor : initializers) {
+    // TODO: stop using positions to determine which initializers
+    // match to which inputs
+    std::string name = graph_proto->input(inputs_count++).name();
+    auto p = graph_proto->add_initializer();
+    p->set_name(name);
+    EncodeTensor(p, tensor, name);
+  }
+}
+
+void EncoderBase::AddAttribute(onnx::NodeProto *node_proto, const jit::Node *node, const jit::Symbol name) {
+  auto attr = node_proto->add_attribute();
+  JIT_ASSERT(name.is_attr());
+  attr->set_name(name.toUnqualString());
+  switch(node->kindOf(name)) {
+    case AttributeKind::f:
+      attr->set_f(node->f(name));
+      attr->set_type(onnx::AttributeProto_AttributeType_FLOAT);
+      break;
+    case AttributeKind::fs:
+      attr->set_type(onnx::AttributeProto_AttributeType_FLOATS);
+      for(auto & v : node->fs(name))
+        attr->add_floats(v);
+      break;
+    case AttributeKind::i:
+      attr->set_type(onnx::AttributeProto_AttributeType_INT);
+      attr->set_i(node->i(name));
+      break;
+    case AttributeKind::is:
+      attr->set_type(onnx::AttributeProto_AttributeType_INTS);
+      for(auto & v : node->is(name))
+        attr->add_ints(v);
+      break;
+    case AttributeKind::s:
+      attr->set_type(onnx::AttributeProto_AttributeType_STRING);
+      attr->set_s(node->s(name));
+      break;
+    case AttributeKind::ss:
+      attr->set_type(onnx::AttributeProto_AttributeType_STRINGS);
+      for(auto & v : node->ss(name))
+        attr->add_strings(v);
+      break;
+    case AttributeKind::t: {
+      attr->set_type(onnx::AttributeProto_AttributeType_TENSOR);
+      auto t = attr->mutable_t();
+      EncodeTensor(t, node->t(name));
+    } break;
+    case AttributeKind::ts:
+      attr->set_type(onnx::AttributeProto_AttributeType_TENSORS);
+      for(auto & v : node->ts(name)) {
+        auto t = attr->add_tensors();
+        EncodeTensor(t, v);
+      }
+      break;
+    case AttributeKind::g: {
+      attr->set_type(onnx::AttributeProto_AttributeType_GRAPH);
+      auto g = attr->mutable_g();
+      EncodeGraph(g, node->g(name));
+    } break;
+    case AttributeKind::gs:
+      attr->set_type(onnx::AttributeProto_AttributeType_GRAPHS);
+      for(auto & v : node->gs(name)) {
+        auto g = attr->add_graphs();
+        EncodeGraph(g, v);
+      }
+      break;
+    default:
+      throw std::runtime_error("unexpected attribute kind");
+  }
+}
+
+void EncoderBase::EncodeTensor(
+    onnx::TensorProto *tensor_proto,
+    const at::Tensor &tensor,
+    const at::optional<std::string> external_ref) {
+  for(auto d : tensor.sizes()) {
+    tensor_proto->add_dims(d);
+  }
+  tensor_proto->set_data_type(ATenTypeToOnnxType(tensor.type().scalarType()));
+  // CPU's HalfTensor doesn't have contiguous(), so first calling contiguous()
+  auto t = tensor.contiguous().toBackend(at::kCPU);
+  // Add a buffer to the raw_data_export_map for the caller to dump into an
+  // external data store. If external_ref is not specified, we instead dump
+  // the contiguous data into the protobuf itself
+  if (defer_weight_export_ && external_ref) {
+    // For now, we use the name of the tensor as the external lookup name to
+    // avoid ONNX protobuf changes.
+    JIT_ASSERT(external_ref.value() == tensor_proto->name());
+    JIT_ASSERT(raw_data_export_map_.count(external_ref.value()) == 0);
+    raw_data_export_map_[external_ref.value()] = t;
+    tensor_proto->set_raw_data("__EXTERNAL");
+  } else {
+    JIT_ASSERT(t.is_contiguous());
+    tensor_proto->set_raw_data(std::string(static_cast<char*>(t.data_ptr()),  t.type().elementSizeInBytes() * t.numel()));
+  }
+}
+
+class GraphEncoder: public EncoderBase {
+ public:
+  GraphEncoder(onnx::ModelProto *model_proto,
+               const std::shared_ptr<Graph> &graph,
+               int64_t onnx_opset_version,
+               onnx_torch::OperatorExportTypes operator_export_type,
+               const std::vector<at::Tensor> &initializers,
+               bool defer_weight_export);
+
+};
+
+GraphEncoder::GraphEncoder(
+    onnx::ModelProto *model_proto,
+    const std::shared_ptr<Graph> &graph,
+    int64_t onnx_opset_version,
+    onnx_torch::OperatorExportTypes operator_export_type,
+    const std::vector<at::Tensor> &initializers,
+    bool defer_weight_export)
+    : EncoderBase(model_proto, operator_export_type, defer_weight_export) {
+  if (operator_export_type != onnx_torch::OperatorExportTypes::RAW) {
+    validateGraph(graph, operator_export_type);
+  }
+
+  auto* imp = model_proto->add_opset_import();
+  // This is the version of ONNX operator set we are targeting
+  imp->set_version(onnx_opset_version);
+
+  EncodeGraph(model_proto->mutable_graph(), graph, initializers);
+}
+
+class ModuleEncoder: public EncoderBase {
+ public:
+  ModuleEncoder(onnx::ModelProto *model_proto,
+                const std::shared_ptr<script::Module> &module);
+
+ private:
+  void EncodeModule(onnx::GraphProto *graph_proto, const std::shared_ptr<script::Module> &module);
+
+  void EncodeParameters(onnx::GraphProto *graph_proto,
+                        const std::shared_ptr<script::Module> &module,
+                        const std::string prefix);
+
+  void EncodeParameter(onnx::TensorProto *tensor_proto,
+                       const script::NamedParameter &parameter,
+                       const std::string prefix);
+
+  void EncodeMethods(onnx::GraphProto *graph_proto,
+                     const std::shared_ptr<script::Module> &module,
+                     const std::string prefix);
+
+  void EncodeMethod(onnx::NodeProto *node_proto,
+                    const std::unique_ptr<script::Method> &method,
+                    const std::string prefix);
+
+  virtual void EncodeTensor(onnx::TensorProto *tensor_proto,
+                            const at::Tensor &tensor,
+                            const at::optional<std::string> external_ref) override;
+
+  virtual void EncodeIntermediateValueInfo(onnx::GraphProto *graph_proto,
+                                           const Value* n) override;
+
+  virtual void EncodeValueInfo(onnx::GraphProto *graph_proto,
+                               onnx::ValueInfoProto* v,
+                               const Value* n) override;
+
+  void EncodeTypeInfo(onnx::GraphProto *graph_proto,
+                      onnx::ValueInfoProto* v,
+                      const TypePtr& type,
+                      const std::string& name);
+
+  // Used to deduplicate tensor storages
+  std::unordered_map<const void*, std::string> storage_dedup_map_;
+
+  // Used to keep track of Parameter names so Methods can refer to them
+  std::unordered_map<at::Tensor*, std::string> parameter_map_;
+
+  // Used to create sequential tensor storage names
+  size_t storage_counter_ = 0;
+
+  // Used to create sequential dummy names for node types
+  size_t type_counter_ = 0;
+};
+
+ModuleEncoder::ModuleEncoder(
+    onnx::ModelProto *model_proto,
+    const std::shared_ptr<script::Module> &module)
+    : EncoderBase(model_proto,
+                  onnx_torch::OperatorExportTypes::RAW,
+                  /*defer_weight_export*/ true) {
+  model_proto->set_doc_string("THIS PROTO IS NOT STANDARD ONNX");
+  EncodeModule(model_proto->mutable_graph(), module);
+}
+
+void ModuleEncoder::EncodeIntermediateValueInfo(onnx::GraphProto *graph_proto, const Value *n) {
+  auto v = graph_proto->add_value_info();
+  EncodeTypeInfo(graph_proto, v, n->type(), n->uniqueName());
+}
+
+void ModuleEncoder::EncodeTypeInfo(
+    onnx::GraphProto *graph_proto,
+    onnx::ValueInfoProto* v,
+    const TypePtr& type,
+    const std::string& name) {
+  v->set_name(name);
+  onnx::TypeProto* type_proto = v->mutable_type();
+  onnx::TypeProto_Tensor* tensortype_proto = type_proto->mutable_tensor_type();
+  onnx::TensorShapeProto* shape_proto = tensortype_proto->mutable_shape();
+
+  // Use TypeProto fields to encode types.
+  // denotation stores the type as a string
+  auto kind = type->kind();
+  if (kind == TypeKind::DynamicType) {
+    type_proto->set_denotation("DynamicType");
+  } else if (kind == TypeKind::TensorType) {
+    type_proto->set_denotation("TensorType");
+    TensorTypePtr node_type = type->cast<TensorType>();
+    const std::vector<std::int64_t>& sizes = node_type->sizes();
+
+    // store the sizes and strides in the dims field of TensorShapeProto
+    for (size_t i = 0; i < sizes.size(); i++) {
+      shape_proto->add_dim();
+      shape_proto->mutable_dim(i)->set_dim_value(sizes[i]);
+    }
+    const std::vector<std::int64_t>& strides = node_type->strides();
+    for (size_t i = 0; i < strides.size(); i++) {
+      shape_proto->add_dim();
+      shape_proto->mutable_dim(i)->set_dim_value(strides[i]);
+    }
+    tensortype_proto->set_elem_type(ATenTypeToOnnxType(node_type->scalarType()));
+  } else if (kind == TypeKind::TupleType) {
+    type_proto->set_denotation("TupleType");
+    TupleTypePtr node_type = type->cast<TupleType>();
+    auto elements = node_type->elements();
+
+    // Generate a name for and encode each subtype in the value_info field of the GraphProto.
+    for (size_t i = 0; i < elements.size(); i++) {
+      std::string name = "#" + std::to_string(type_counter_++);
+      shape_proto->add_dim();
+      shape_proto->mutable_dim(i)->set_dim_param(name);
+      onnx::ValueInfoProto* subtype_proto = graph_proto->add_value_info();
+      EncodeTypeInfo(graph_proto, subtype_proto, elements[i], name);
+    }
+  } else if (kind == TypeKind::ListType) {
+    type_proto->set_denotation("ListType");
+    ListTypePtr node_type = type->cast<ListType>();
+
+    // Generate a name for and encode the subtype in the value_info field of the GraphProto.
+    std::string name = "#" + std::to_string(type_counter_++);
+    shape_proto->add_dim();
+    shape_proto->mutable_dim(0)->set_dim_param(name);
+    onnx::ValueInfoProto* subtype_proto = graph_proto->add_value_info();
+    EncodeTypeInfo(graph_proto, subtype_proto, node_type->getElementType(), name);
+  } else if (kind == TypeKind::NumberType) {
+    type_proto->set_denotation("NumberType");
+  } else if (kind == TypeKind::FloatType) {
+    type_proto->set_denotation("FloatType");
+  } else if (kind == TypeKind::IntType) {
+    type_proto->set_denotation("IntType");
+  } else if (kind == TypeKind::NoneType) {
+    type_proto->set_denotation("NoneType");
+  }
+  else {
+    throw std::runtime_error("unexpected type kind");
+  }
+}
+
+void ModuleEncoder::EncodeValueInfo(
+    onnx::GraphProto *graph_proto,
+    onnx::ValueInfoProto* v,
+    const Value* n) {
+  EncodeTypeInfo(graph_proto, v, n->type(), n->uniqueName());
+}
+
+void ModuleEncoder::EncodeModule(
+    onnx::GraphProto *graph_proto,
+    const std::shared_ptr<script::Module> &module) {
+  EncodeParameters(graph_proto, module, "");
+  EncodeMethods(graph_proto, module, "");
+}
+
+void ModuleEncoder::EncodeParameters(
+    onnx::GraphProto *graph_proto,
+    const std::shared_ptr<script::Module> &module,
+    const std::string prefix) {
+  // Encode each parameter as a initializer in the proto
+  for (auto &parameter : module->get_parameters()) {
+    auto tensor_proto = graph_proto->add_initializer();
+    EncodeParameter(tensor_proto, parameter.value, prefix);
+  }
+
+  for (auto &submodule : module->get_modules()) {
+    EncodeParameters(graph_proto, submodule.value.module, prefix + submodule.key + ".");
+  }
+}
+
+void ModuleEncoder::EncodeParameter(
+    onnx::TensorProto *tensor_proto,
+    const script::NamedParameter &parameter,
+    const std::string prefix) {
+  auto tensor = parameter.slot();
+  // Name will be prefixed by submodule. e.g. submodule_foo.parameter_bar
+  auto name = prefix + parameter.name;
+
+  tensor_proto->set_name(name);
+  parameter_map_[tensor] = name;
+
+  // Parameters have these fields, but tensors do not
+  tensor_proto->add_int64_data(parameter.is_buffer);
+  tensor_proto->add_int64_data(tensor->requires_grad());
+
+  EncodeTensor(tensor_proto, *tensor, name);
+}
+
+void ModuleEncoder::EncodeMethods(
+    onnx::GraphProto *graph_proto,
+    const std::shared_ptr<script::Module> &module,
+    const std::string prefix) {
+  // Encode each parameter as a initializer in the proto
+  for (auto &method : module->get_methods()) {
+    auto node_proto = graph_proto->add_node();
+    EncodeMethod(node_proto, method.value, prefix);
+  }
+
+  for (auto &submodule : module->get_modules()) {
+    EncodeMethods(graph_proto, submodule.value.module, prefix + submodule.key + ".");
+  }
+}
+
+void ModuleEncoder::EncodeMethod(
+    onnx::NodeProto *node_proto,
+    const std::unique_ptr<script::Method> &method,
+    const std::string prefix) {
+  node_proto->set_name(prefix + method->name());
+
+  // Store member_inputs of Method in input
+  for (auto &member_input : method->params()) {
+    auto it = parameter_map_.find(member_input);
+    JIT_ASSERT(it != parameter_map_.end());
+    node_proto->add_input(it->second);
+  }
+
+  auto attr_proto = node_proto->add_attribute();
+  attr_proto->set_type(onnx::AttributeProto_AttributeType_GRAPH);
+
+  for (auto node : method->graph()->nodes()) {
+    if (node->kind() == prim::PythonOp) {
+      auto py_node = static_cast<torch::jit::PythonOp*>(node);
+      throw std::runtime_error(
+          "Couldn't export Python operator " + py_node->name() +
+          "\n\nDefined at:\n" + getNodeStackTraceString(node));
+    }
+  }
+  EncodeBlock(attr_proto->mutable_g(), method->graph()->block(), {});
+}
+
+void ModuleEncoder::EncodeTensor(
+    onnx::TensorProto *tensor_proto,
+    const at::Tensor &tensor,
+    const at::optional<std::string> external_ref = {}) {
+  for (auto &d : tensor.sizes()) {
+    tensor_proto->add_dims(d);
+  }
+  tensor_proto->set_data_type(ATenTypeToOnnxType(tensor.type().scalarType()));
+
+  tensor_proto->add_int64_data(tensor.storage_offset());
+  for (auto &d : tensor.strides()) {
+    tensor_proto->add_int64_data(d);
+  }
+
+  auto storage_ptr = tensor.storage()->pImpl()->data();
+  auto dedup_it = storage_dedup_map_.find(storage_ptr);
+  if (dedup_it != storage_dedup_map_.end()) {
+    tensor_proto->set_doc_string(dedup_it->second);
+  } else {
+    std::string name;
+    if (external_ref) {
+      name = external_ref.value();
+    } else {
+      name = "$" + std::to_string(storage_counter_++);
+    }
+    tensor_proto->set_doc_string(name);
+    JIT_ASSERT(raw_data_export_map_.count(name) == 0);
+    storage_dedup_map_[storage_ptr] = name;
+
+    // NB: This new tensor is created to support cuda tensors.
+    // Storages can be mutated when converting tensors from cuda to cpu,
+    // and we need a cpu tensor to copy data from.
+    auto t = tensor.type().tensor(
+        *tensor.storage(),
+        /* storageOffset = */ 0,
+        /* size = */ { tensor.numel() },
+        /* strides = */ { 1 })
+    .toBackend(at::kCPU);
+    raw_data_export_map_[name] = t;
+  }
+}
+
 // Pretty printing
 namespace {
 constexpr char indent_char = ' ';
@@ -551,46 +825,8 @@
   dump(model, ss, 0);
   return ss.str();
 }
-
 }
 
-namespace {
-
-RawDataExportMap ToModelProto(
-    const std::shared_ptr<Graph>& graph,
-    const std::vector<at::Tensor> & initializers,
-    int64_t onnx_opset_version,
-    bool defer_weight_export,
-    onnx_torch::OperatorExportTypes operator_export_type,
-    onnx::ModelProto *model_proto) {
-  if (operator_export_type != onnx_torch::OperatorExportTypes::RAW) {
-    validateGraph(graph, operator_export_type);
-  }
-
-  model_proto->set_producer_name("pytorch");
-  model_proto->set_producer_version("0.3");
-  model_proto->set_ir_version(onnx::IR_VERSION);
-  auto* imp = model_proto->add_opset_import();
-  // This is the version of ONNX operator set we are targeting
-  imp->set_version(onnx_opset_version);
-
-  // Map {external_data_ref -> raw data} for external serialization of weights
-  RawDataExportMap raw_data_export_map;
-
-  // Set up nanopb callbacks and compute the amount of space needed to store
-  // the resulting protobuf
-  if (defer_weight_export) {
-    encodeModel(model_proto, graph, initializers, &raw_data_export_map, operator_export_type);
-  } else {
-    encodeModel(model_proto, graph, initializers, nullptr, operator_export_type);
-  }
-
-  return raw_data_export_map;
-}
-
-}  // namespace
-
-
 std::string PrettyPrintExportedGraph(
                         const std::shared_ptr<Graph>& graph,
                         const std::vector<at::Tensor> & initializers,
@@ -598,10 +834,8 @@
                         bool defer_weight_export,
                         ::torch::onnx::OperatorExportTypes operator_export_type) {
   ::ONNX_NAMESPACE::ModelProto model_proto;
-  RawDataExportMap raw_data_export_map;
-  raw_data_export_map = ToModelProto(
-    graph, initializers, onnx_opset_version, defer_weight_export, operator_export_type,
-    &model_proto);
+  auto graph_encoder = GraphEncoder(
+    &model_proto, graph, onnx_opset_version, operator_export_type, initializers, defer_weight_export);
   return prettyPrint(model_proto);
 }
 
@@ -617,11 +851,15 @@
                         bool defer_weight_export,
                         ::torch::onnx::OperatorExportTypes operator_export_type) {
   ::ONNX_NAMESPACE::ModelProto model_proto;
-  RawDataExportMap raw_data_export_map;
-  raw_data_export_map = ToModelProto(
-    graph, initializers, onnx_opset_version, defer_weight_export, operator_export_type,
-    &model_proto);
-  return std::make_tuple(model_proto.SerializeAsString(), raw_data_export_map);
+  auto graph_encoder = GraphEncoder(
+    &model_proto, graph, onnx_opset_version, operator_export_type, initializers, defer_weight_export);
+  return std::make_tuple(model_proto.SerializeAsString(), graph_encoder.get_raw_data_export_map());
+}
+
+std::tuple<std::string, RawDataExportMap> ExportModule(const std::shared_ptr<script::Module>& module) {
+  ::ONNX_NAMESPACE::ModelProto model_proto;
+  auto module_encoder = ModuleEncoder(&model_proto, module);
+  return std::make_tuple(model_proto.SerializeAsString(), module_encoder.get_raw_data_export_map());
 }
 
 }}
diff --git a/torch/csrc/jit/export.h b/torch/csrc/jit/export.h
index d0c6212..9457762 100644
--- a/torch/csrc/jit/export.h
+++ b/torch/csrc/jit/export.h
@@ -1,6 +1,7 @@
 #pragma once
 
 #include "torch/csrc/jit/ir.h"
+#include "torch/csrc/jit/script/module.h"
 #include "torch/csrc/onnx/onnx.h"
 
 namespace torch { namespace jit {
@@ -32,4 +33,7 @@
     ::torch::onnx::OperatorExportTypes operator_export_type
       = ::torch::onnx::OperatorExportTypes::ONNX);
 
+TORCH_API std::tuple<std::string, RawDataExportMap> ExportModule(
+    const std::shared_ptr<script::Module>& module);
+
 }}
diff --git a/torch/csrc/jit/import.cpp b/torch/csrc/jit/import.cpp
index a453925..6d8a4f1 100644
--- a/torch/csrc/jit/import.cpp
+++ b/torch/csrc/jit/import.cpp
@@ -16,44 +16,55 @@
 
 namespace {
 
-// IR graph construction
-
 namespace onnx = ::ONNX_NAMESPACE;
 
-at::Tensor buildTensor(const onnx::TensorProto& tensor_proto) {
+// IR graph construction
 
-  at::Tensor tensor;
+class DecoderBase {
+ protected:
+  virtual std::shared_ptr<Graph> buildGraph(const onnx::GraphProto& graph_proto);
 
-  switch(tensor_proto.data_type()) {
+  void buildBlock(const onnx::GraphProto& graph_proto, Block* block,
+                   std::unordered_map<std::string, Value*>& value_map);
+
+  void buildBlocks(const std::vector<onnx::GraphProto>& graphs_, Node* node,
+                   std::unordered_map<std::string, Value*>& value_map);
+
+  virtual void buildValue(Value* value, const onnx::ValueInfoProto& valueinfo_proto) {};
+
+  virtual void buildIntermediateValue(Value* value, const std::string& name) {};
+
+  at::ScalarType onnxTypeToATenType(onnx::TensorProto_DataType tensor_proto);
+
+  virtual at::Tensor buildTensor(const onnx::TensorProto& tensor_proto);
+};
+
+at::ScalarType DecoderBase::onnxTypeToATenType(onnx::TensorProto_DataType onnx_type) {
+  switch(onnx_type) {
     case onnx::TensorProto_DataType_UINT8:
-      tensor = at::CPU(at::kByte).tensor();
-      break;
+      return at::kByte;
     case onnx::TensorProto_DataType_INT8:
-      tensor = at::CPU(at::kChar).tensor();
-      break;
+      return at::kChar;
     case onnx::TensorProto_DataType_INT16:
-      tensor = at::CPU(at::kShort).tensor();
-      break;
+      return at::kShort;
     case onnx::TensorProto_DataType_INT32:
-      tensor = at::CPU(at::kInt).tensor();
-      break;
+      return at::kInt;
     case onnx::TensorProto_DataType_INT64:
-      tensor = at::CPU(at::kLong).tensor();
-      break;
+      return at::kLong;
     case onnx::TensorProto_DataType_FLOAT16:
-      tensor = at::CPU(at::kHalf).tensor();
-      break;
+      return at::kHalf;
     case onnx::TensorProto_DataType_FLOAT:
-      tensor = at::CPU(at::kFloat).tensor();
-      break;
+      return at::kFloat;
     case onnx::TensorProto_DataType_DOUBLE:
-      tensor = at::CPU(at::kDouble).tensor();
-      break;
+      return at::kDouble;
     default:
       throw std::runtime_error("Unsupported data type");
   }
+}
 
-  std::vector<int64_t> sizes = {tensor_proto.dims().begin(), tensor_proto.dims().end()};
+at::Tensor DecoderBase::buildTensor(const onnx::TensorProto& tensor_proto) {
+  at::Tensor tensor = at::CPU(onnxTypeToATenType(tensor_proto.data_type())).tensor();
+  std::vector<int64_t> sizes = { tensor_proto.dims().begin(), tensor_proto.dims().end() };
   tensor.resize_(sizes);
 
   JIT_ASSERT(
@@ -62,22 +73,19 @@
       tensor_proto.raw_data().size());
 
   std::memcpy(tensor.data_ptr(), tensor_proto.raw_data().data(), tensor_proto.raw_data().size());
-
   return tensor;
 }
 
-void buildBlock(const onnx::GraphProto& graph_proto, Block* block,
-                std::unordered_map<std::string, Value*>& value_map);
-
-void buildBlocks(const std::vector<onnx::GraphProto>& graphs_, Node* node,
-                 std::unordered_map<std::string, Value*>& value_map) {
+void DecoderBase::buildBlocks(
+    const std::vector<onnx::GraphProto>& graphs_, Node* node,
+    std::unordered_map<std::string, Value*>& value_map) {
   for (auto g_ : graphs_) {
     auto block = node->addBlock();
     buildBlock(g_, block, value_map);
   }
 }
 
-std::shared_ptr<Graph> buildGraph(const onnx::GraphProto& graph_proto) {
+std::shared_ptr<Graph> DecoderBase::buildGraph(const onnx::GraphProto& graph_proto) {
   auto graph = std::make_shared<Graph>();
   std::unordered_map<std::string, Value*> value_map;
 
@@ -86,11 +94,13 @@
   return graph;
 }
 
-void buildBlock(const onnx::GraphProto& graph_proto, Block* block,
+void DecoderBase::buildBlock(const onnx::GraphProto& graph_proto, Block* block,
                 std::unordered_map<std::string, Value*>& value_map) {
 
   for (auto & input : graph_proto.input()) {
-    value_map[input.name()] = block->addInput();
+    auto value = block->addInput();
+    value_map[input.name()] = value;
+    buildValue(value, input);
   }
 
   for (auto & node_ : graph_proto.node()) {
@@ -131,14 +141,18 @@
           node->ss_(name, {attr.strings().begin(), attr.strings().end()});
           break;
         case onnx::AttributeProto_AttributeType_TENSORS:
-          node->ts_(name, fmap(attr.tensors(), [](const onnx::TensorProto& t) { return buildTensor(t); }));
+          node->ts_(name, fmap(attr.tensors(), [this](const onnx::TensorProto& t) {
+                                                 return buildTensor(t);
+                                               }));
           break;
         case onnx::AttributeProto_AttributeType_GRAPHS:
           if (attr.name() == "_blocks") {
             buildBlocks({attr.graphs().begin(), attr.graphs().end()}, node, value_map);
           }
           else {
-            node->gs_(name, fmap(attr.graphs(), [](const onnx::GraphProto& g_) { return buildGraph(g_); }));
+            node->gs_(name, fmap(attr.graphs(), [this](const onnx::GraphProto& g_) {
+                                                  return buildGraph(g_);
+                                                }));
           }
           break;
       }
@@ -151,6 +165,7 @@
 
     for (int i=0; i<node_.output().size(); i++) {
       value_map[node_.output(i)] = node->outputs()[i];
+      buildIntermediateValue(node->outputs()[i], node_.output(i));
     }
 
     block->appendNode(node);
@@ -158,23 +173,21 @@
 
   for (auto & output : graph_proto.output()) {
     Value* v = value_map.at(output.name());
+    buildValue(v, output);
     block->registerOutput(v);
   }
 }
 
-std::shared_ptr<Graph> buildGraph(const onnx::GraphProto& graph_proto, std::vector<at::Tensor>& initializers) {
+class GraphDecoder : DecoderBase {
+ public:
+  std::shared_ptr<Graph> decode(const std::string& serialized_graph,
+                                std::vector<at::Tensor>& initializers);
 
-  auto graph = buildGraph(graph_proto);
-
-  for (auto tensor_ : graph_proto.initializer()) {
-    initializers.push_back(buildTensor(tensor_));
-  }
-
-  return graph;
-}
+  void reconstructOutputTypes(Block *b);
+};
 
 // TODO: this should be removed once we'll be able to serialize value types
-void reconstructOutputTypes(Block *b) {
+void GraphDecoder::reconstructOutputTypes(Block *b) {
   for (Node * n : b->nodes()) {
     if (n->kind() == prim::Constant) {
       switch (n->kindOf(attr::value)) {
@@ -211,18 +224,227 @@
   }
 }
 
-} // anonymous namespace
+std::shared_ptr<Graph> GraphDecoder::decode(
+    const std::string& serialized_graph,
+    std::vector<at::Tensor>& initializers) {
+  auto model_proto = onnx::ModelProto();
+  model_proto.ParseFromString(serialized_graph);
+
+  auto graph_proto = model_proto.graph();
+  auto graph = buildGraph(graph_proto);
+  for (auto &tensor_ : graph_proto.initializer()) {
+    initializers.push_back(buildTensor(tensor_));
+  }
+  reconstructOutputTypes(graph->block());
+  return graph;
+}
+
+class ModuleDecoder : DecoderBase {
+ public:
+  std::shared_ptr<script::Module> decode(
+      std::shared_ptr<script::Module> root_module,
+      const std::string& serialized_module,
+      const std::unordered_map<std::string, std::string>& storage_map);
+
+ private:
+  virtual std::shared_ptr<Graph> buildGraph(const onnx::GraphProto& graph_proto) override;
+
+  virtual at::Tensor buildTensor(const onnx::TensorProto& tensor_proto) override;
+
+  TypePtr buildType(const onnx::TypeProto& type_proto);
+
+  virtual void buildValue(Value* value, const onnx::ValueInfoProto& valueinfo_proto) override;
+
+  virtual void buildIntermediateValue(Value* value, const std::string& name) override;
+
+  at::Tensor buildParameter(const onnx::TensorProto& tensor_proto);
+
+  at::Tensor buildTensorCommon(const onnx::TensorProto& tensor_proto,
+                               const int64_t storage_offset,
+                               const std::vector<int64_t>& strides);
+
+  std::pair<std::shared_ptr<script::Module>, std::string> parseFullName(
+      std::shared_ptr<script::Module> root_module,
+      const std::string fullname);
+
+  const std::unordered_map<std::string, std::string> *storage_export_map_;
+  std::unordered_map<std::string, std::shared_ptr<at::Tensor>> storage_map_;
+  std::unordered_map<std::string, const onnx::TypeProto*> value_type_map_;
+};
+
+std::shared_ptr<Graph> ModuleDecoder::buildGraph(const onnx::GraphProto& graph_proto) {
+  for (auto &subtype : graph_proto.value_info()) {
+    value_type_map_[subtype.name()] = &subtype.type();
+  }
+  return DecoderBase::buildGraph(graph_proto);
+}
+
+TypePtr ModuleDecoder::buildType(const onnx::TypeProto& type_proto) {
+  auto tensortype_proto = type_proto.tensor_type();
+  auto shape_proto = tensortype_proto.shape();
+  auto kind = type_proto.denotation();
+  if (kind == "DynamicType") {
+    return DynamicType::get();
+  } else if (kind == "TensorType") {
+    // TODO: Don't use DynamicType here
+    return DynamicType::get();
+  } else if (kind == "TupleType") {
+    std::vector<TypePtr> elems;
+    for (auto &subkind : shape_proto.dim()) {
+      auto it = value_type_map_.find(subkind.dim_param());
+      JIT_ASSERT(it != value_type_map_.end());
+      elems.push_back(buildType(*it->second));
+    }
+    return TupleType::create(elems);
+  } else if (kind == "ListType") {
+    auto subkind = shape_proto.dim(0);
+    auto it = value_type_map_.find(subkind.dim_param());
+    JIT_ASSERT(it != value_type_map_.end());
+    return ListType::create(buildType(*it->second));
+  } else if (kind == "NumberType") {
+    return NumberType::get();
+  } else if (kind == "FloatType") {
+    return FloatType::get();
+  } else if (kind == "IntType") {
+    return IntType::get();
+  } else if (kind == "NoneType") {
+    return NoneType::get();
+  } else {
+    throw std::runtime_error("unexpected string for type kind");
+  }
+}
+
+void ModuleDecoder::buildValue(Value* value, const onnx::ValueInfoProto& valueinfo_proto) {
+  value->setType(buildType(valueinfo_proto.type()));
+}
+
+void ModuleDecoder::buildIntermediateValue(Value* value, const std::string& name) {
+  auto it = value_type_map_.find(name);
+  JIT_ASSERT(it != value_type_map_.end());
+  value->setType(buildType(*it->second));
+}
+
+at::Tensor ModuleDecoder::buildParameter(const onnx::TensorProto& tensor_proto) {
+  std::vector<int64_t> strides;
+  // We've stored three other values (is_buffer, requires_grad, storage_offset) before strides; ignore them
+  std::move(tensor_proto.int64_data().begin() + 3, tensor_proto.int64_data().end(), std::back_inserter(strides));
+  auto tensor = buildTensorCommon(tensor_proto, /* storage_offset = */ tensor_proto.int64_data(2), strides);
+  autograd::Variable var = autograd::make_variable(tensor, /* requires_grad = */ tensor_proto.int64_data(1));
+  return var;
+}
+
+at::Tensor ModuleDecoder::buildTensor(const onnx::TensorProto& tensor_proto) {
+  std::vector<int64_t> strides;
+  // We've stored one other value (storage_offset) before strides; ignore it
+  std::move(tensor_proto.int64_data().begin() + 1, tensor_proto.int64_data().end(), std::back_inserter(strides));
+  return buildTensorCommon(tensor_proto, /* storage_offset = */ tensor_proto.int64_data(0), strides);
+}
+
+at::Tensor ModuleDecoder::buildTensorCommon(
+    const onnx::TensorProto& tensor_proto,
+    const int64_t storage_offset,
+    const std::vector<int64_t>& strides) {
+  // NB: storage_offset and strides are passed in separately because
+  // because they are encoded differently for parameters and tensors
+  auto storage_name = tensor_proto.doc_string();
+  auto type = onnxTypeToATenType(tensor_proto.data_type());
+  std::vector<int64_t> dims;
+  std::move(tensor_proto.dims().begin(), tensor_proto.dims().end(), std::back_inserter(dims));
+
+  // Find or create the storage
+  at::Tensor *storage_tensor;
+  auto storage_it = storage_map_.find(storage_name);
+  if (storage_it == storage_map_.end()) {
+    auto storage = std::make_shared<at::Tensor>(at::CPU(type).tensor());
+    auto string_it = storage_export_map_->find(storage_name);
+    JIT_ASSERT(string_it != storage_export_map_->end());
+    storage->resize_({ static_cast<int64_t>(string_it->second.size()) });
+    std::memcpy(storage->storage()->pImpl()->data(), string_it->second.data(), string_it->second.size());
+    storage_map_.insert(std::make_pair(storage_name, storage));
+    storage_tensor = storage.get();
+  } else {
+    storage_tensor = storage_it->second.get();
+  }
+
+  return at::CPU(onnxTypeToATenType(tensor_proto.data_type())).tensor(
+      *storage_tensor->storage().get(), storage_offset, dims, strides);
+}
+
+// Given a full name of a parameter or method,
+// return the parent submodule and local name
+std::pair<std::shared_ptr<script::Module>, std::string> ModuleDecoder::parseFullName(
+    std::shared_ptr<script::Module> root_module,
+    const std::string fullname) {
+  std::vector<std::string> vec;
+  std::stringstream ss(fullname);
+  std::string name;
+  while (std::getline(ss, name, '.')) {
+    vec.push_back(name);
+  }
+
+  std::shared_ptr<script::Module> curr = root_module;
+  for (size_t i = 0; i < vec.size() - 1; i++) {
+    if (curr->find_module(vec[i]) == nullptr) {
+      curr->register_module(vec[i], std::make_shared<script::Module>());
+    }
+    curr = curr->get_module(vec[i]);
+  }
+  return std::make_pair(curr, vec.back());
+}
+
+std::shared_ptr<script::Module> ModuleDecoder::decode(
+    const std::shared_ptr<script::Module> root_module,
+    const std::string &serialized_module,
+    const std::unordered_map<std::string, std::string> &storage_export_map) {
+  auto model_proto = onnx::ModelProto();
+  model_proto.ParseFromString(serialized_module);
+  auto graph_proto = model_proto.graph();
+
+  std::unordered_map<std::string, at::Tensor*> param_map;
+  storage_export_map_ = &storage_export_map;
+  storage_map_.clear();
+
+  for (auto &tensor_proto : graph_proto.initializer()) {
+    std::shared_ptr<script::Module> parent_module;
+    std::string name;
+    std::tie(parent_module, name) = parseFullName(root_module, tensor_proto.name());
+
+    auto param = buildParameter(tensor_proto);
+    parent_module->register_parameter(name, param, /* is_buffer = */ tensor_proto.int64_data(1));
+    param_map[tensor_proto.name()] = parent_module->parameter_slot(name);
+  }
+
+  for (auto &node_proto : graph_proto.node()) {
+    std::shared_ptr<script::Module> parent_module;
+    std::string name;
+    std::tie(parent_module, name) = parseFullName(root_module, node_proto.name());
+
+    std::vector<at::Tensor*> member_inputs;
+    for (auto &param_name : node_proto.input()) {
+      member_inputs.push_back(param_map[param_name]);
+    }
+
+    auto graph = buildGraph(node_proto.attribute(0).g());
+    parent_module->create_method(name, graph, member_inputs);
+  }
+
+  return root_module;
+}
+
+}  // namespace
 
 std::shared_ptr<Graph> ImportIRGraph(const std::string& serialized_graph,
                                      std::vector<at::Tensor>& initializers) {
-  auto model_proto = ::ONNX_NAMESPACE::ModelProto();
-  model_proto.ParseFromString(serialized_graph);
+  GraphDecoder decoder;
+  return decoder.decode(serialized_graph, initializers);
+}
 
-  auto graph = buildGraph(model_proto.graph(), initializers);
-
-  reconstructOutputTypes(graph->block());
-
-  return graph;
+void ImportIRModule(
+    const std::shared_ptr<script::Module> module,
+    const std::string& serialized_module,
+    const std::unordered_map<std::string, std::string>& storage_map) {
+  ModuleDecoder decoder;
+  decoder.decode(module, serialized_module, storage_map);
 }
 
 }}
diff --git a/torch/csrc/jit/import.h b/torch/csrc/jit/import.h
index d593896f..56606cf 100644
--- a/torch/csrc/jit/import.h
+++ b/torch/csrc/jit/import.h
@@ -1,9 +1,15 @@
 #pragma once
 
 #include "torch/csrc/jit/ir.h"
+#include "torch/csrc/jit/script/module.h"
 
 namespace torch { namespace jit {
 
 TORCH_API std::shared_ptr<Graph> ImportIRGraph(const std::string& serialized_graph, std::vector<at::Tensor> & initializers);
 
+TORCH_API void ImportIRModule(
+    const std::shared_ptr<script::Module> module,
+    const std::string& serialized_module,
+    const std::unordered_map<std::string, std::string>& storage_map);
+
 }}
diff --git a/torch/csrc/jit/python_ir.cpp b/torch/csrc/jit/python_ir.cpp
index b72fdb6..59884ed 100644
--- a/torch/csrc/jit/python_ir.cpp
+++ b/torch/csrc/jit/python_ir.cpp
@@ -496,5 +496,10 @@
     }
     return std::make_tuple(graph, variables);
   });
+  m.def("_jit_import_module", [](const std::shared_ptr<script::Module> module,
+                                 const std::string& serialized_module,
+                                 const std::unordered_map<std::string, std::string>& storages) {
+    ImportIRModule(module, serialized_module, storages);
+  });
 }
 }}
diff --git a/torch/csrc/jit/script/init.cpp b/torch/csrc/jit/script/init.cpp
index b91d348..1813327 100644
--- a/torch/csrc/jit/script/init.cpp
+++ b/torch/csrc/jit/script/init.cpp
@@ -3,6 +3,7 @@
 #include "torch/csrc/Device.h"
 #include "torch/csrc/Dtype.h"
 #include "torch/csrc/Layout.h"
+#include "torch/csrc/jit/export.h"
 #include "torch/csrc/jit/script/compiler.h"
 
 #include "torch/csrc/jit/python_tracer.h"
@@ -431,6 +432,21 @@
   // public.
   py::class_<Module, std::shared_ptr<Module>>(m, "ScriptModule")
       .def(py::init<>())
+      .def("export", [](const std::shared_ptr<Module> m) {
+        std::string module;
+        RawDataExportMap export_map;
+        std::tie(module, export_map) = ExportModule(m);
+        std::unordered_map<std::string, py::bytes> python_serialized_export_map;
+        for (auto& kv : export_map) {
+          auto t = kv.second;
+          size_t copy_bytes = t.type().elementSizeInBytes() * t.numel();
+          // TODO: this is an unecessary copy. In theory we can directly return
+          // the map from identifier to Tensor, but we need some API in Python
+          // to get raw `bytes` containing the raw tensor data.
+          python_serialized_export_map[kv.first] = py::bytes(static_cast<const char*>(t.data_ptr()), copy_bytes);
+        }
+        return std::make_tuple(py::bytes(module), python_serialized_export_map);
+      })
       .def("_set_optimized", &Module::set_optimized)
       .def(
           "_define",