Add convertToCaffe2Proto to python API
Summary: Closing the gap a bit on API, allowing users to go NetDef -> nomnigraph -> NetDef in python now
Reviewed By: duc0
Differential Revision: D9670495
fbshipit-source-id: 6497518ffc05a186deb0d657e06317980d39ddd5
diff --git a/caffe2/python/nomnigraph.py b/caffe2/python/nomnigraph.py
index 8414235..708eae6 100644
--- a/caffe2/python/nomnigraph.py
+++ b/caffe2/python/nomnigraph.py
@@ -29,6 +29,14 @@
def dataFlow(self):
return self._NNModule.dataFlow()
+ def convertToCaffe2Proto(self, old_proto=None):
+ if not old_proto:
+ old_proto = caffe2_pb2.NetDef()
+ output = self._NNModule.convertToCaffe2Proto(old_proto)
+ new_proto = caffe2_pb2.NetDef()
+ new_proto.ParseFromString(output)
+ return new_proto
+
def match(self, pattern):
for n in self.dataFlow.getMutableNodes():
m = C.matchSubgraph(n, pattern)
diff --git a/caffe2/python/nomnigraph_test.py b/caffe2/python/nomnigraph_test.py
index 0c9fd34..7739ac0 100644
--- a/caffe2/python/nomnigraph_test.py
+++ b/caffe2/python/nomnigraph_test.py
@@ -151,3 +151,26 @@
n2 = g.createNode("hello2")
e = g.createEdge(n1, n2)
ng.render(g)
+
+ def test_convertToProto(self):
+ net = core.Net("name")
+ net.FC(["X", "W"], ["Y"])
+ nn = ng.NNModule(net)
+ new_netdef = nn.convertToCaffe2Proto()
+ print(new_netdef)
+ print(net.Proto())
+ assert len(new_netdef.op) == len(net.Proto().op)
+ for i in range(len(new_netdef.op)):
+ op = net.Proto().op[i]
+ new_op = new_netdef.op[i]
+ assert op.type == new_op.type
+ assert len(op.input) == len(new_op.input)
+ assert len(op.output) == len(new_op.output)
+ for a, b in zip(op.input, new_op.input):
+ assert a == b
+ for a, b in zip(op.output, new_op.output):
+ assert a == b
+ for a, b in zip(new_netdef.external_input, net.Proto().external_input):
+ assert a == b
+ for a, b in zip(new_netdef.external_output, net.Proto().external_output):
+ assert a == b
diff --git a/caffe2/python/pybind_state_nomni.cc b/caffe2/python/pybind_state_nomni.cc
index e7e4dfc..bca9b69 100644
--- a/caffe2/python/pybind_state_nomni.cc
+++ b/caffe2/python/pybind_state_nomni.cc
@@ -104,7 +104,19 @@
.def(
"dataFlow",
[](NNModule* nn) -> NNGraph* { return &nn->dataFlow; },
- py::return_value_policy::reference_internal);
+ py::return_value_policy::reference_internal)
+ .def("convertToCaffe2Proto", [](NNModule& nn, py::object def) {
+ auto attr = def.attr("SerializeToString");
+ CAFFE_ENFORCE(
+ attr, "convertToCaffe2Proto takes either no args", "a NetDef");
+ auto str = attr();
+ caffe2::NetDef proto;
+ proto.ParseFromString(py::bytes(str));
+ auto new_proto = caffe2::convertToCaffe2Proto(nn, proto);
+ std::string out;
+ new_proto.SerializeToString(&out);
+ return py::bytes(out);
+ });
// NNGraph methods
py::class_<NNGraph> nngraph(m, "NNGraph");