Add convertToCaffe2Proto to python API

Summary: Closing the gap a bit on API, allowing users to go NetDef -> nomnigraph -> NetDef in python now

Reviewed By: duc0

Differential Revision: D9670495

fbshipit-source-id: 6497518ffc05a186deb0d657e06317980d39ddd5
diff --git a/caffe2/python/nomnigraph.py b/caffe2/python/nomnigraph.py
index 8414235..708eae6 100644
--- a/caffe2/python/nomnigraph.py
+++ b/caffe2/python/nomnigraph.py
@@ -29,6 +29,14 @@
     def dataFlow(self):
         return self._NNModule.dataFlow()
 
+    def convertToCaffe2Proto(self, old_proto=None):
+        if not old_proto:
+            old_proto = caffe2_pb2.NetDef()
+        output = self._NNModule.convertToCaffe2Proto(old_proto)
+        new_proto = caffe2_pb2.NetDef()
+        new_proto.ParseFromString(output)
+        return new_proto
+
     def match(self, pattern):
         for n in self.dataFlow.getMutableNodes():
             m = C.matchSubgraph(n, pattern)
diff --git a/caffe2/python/nomnigraph_test.py b/caffe2/python/nomnigraph_test.py
index 0c9fd34..7739ac0 100644
--- a/caffe2/python/nomnigraph_test.py
+++ b/caffe2/python/nomnigraph_test.py
@@ -151,3 +151,26 @@
         n2 = g.createNode("hello2")
         e = g.createEdge(n1, n2)
         ng.render(g)
+
+    def test_convertToProto(self):
+        net = core.Net("name")
+        net.FC(["X", "W"], ["Y"])
+        nn = ng.NNModule(net)
+        new_netdef = nn.convertToCaffe2Proto()
+        print(new_netdef)
+        print(net.Proto())
+        assert len(new_netdef.op) == len(net.Proto().op)
+        for i in range(len(new_netdef.op)):
+            op = net.Proto().op[i]
+            new_op = new_netdef.op[i]
+            assert op.type == new_op.type
+            assert len(op.input) == len(new_op.input)
+            assert len(op.output) == len(new_op.output)
+            for a, b in zip(op.input, new_op.input):
+                assert a == b
+            for a, b in zip(op.output, new_op.output):
+                assert a == b
+        for a, b in zip(new_netdef.external_input, net.Proto().external_input):
+            assert a == b
+        for a, b in zip(new_netdef.external_output, net.Proto().external_output):
+            assert a == b
diff --git a/caffe2/python/pybind_state_nomni.cc b/caffe2/python/pybind_state_nomni.cc
index e7e4dfc..bca9b69 100644
--- a/caffe2/python/pybind_state_nomni.cc
+++ b/caffe2/python/pybind_state_nomni.cc
@@ -104,7 +104,19 @@
       .def(
           "dataFlow",
           [](NNModule* nn) -> NNGraph* { return &nn->dataFlow; },
-          py::return_value_policy::reference_internal);
+          py::return_value_policy::reference_internal)
+      .def("convertToCaffe2Proto", [](NNModule& nn, py::object def) {
+        auto attr = def.attr("SerializeToString");
+        CAFFE_ENFORCE(
+            attr, "convertToCaffe2Proto takes either no args", "a NetDef");
+        auto str = attr();
+        caffe2::NetDef proto;
+        proto.ParseFromString(py::bytes(str));
+        auto new_proto = caffe2::convertToCaffe2Proto(nn, proto);
+        std::string out;
+        new_proto.SerializeToString(&out);
+        return py::bytes(out);
+      });
 
   // NNGraph methods
   py::class_<NNGraph> nngraph(m, "NNGraph");