Improve Python API with the addition of pythonic setters/getters
Summary:
Simple additions that make it vastly easier to use nomnigraph in
python
Reviewed By: duc0
Differential Revision: D10383027
fbshipit-source-id: 441a883b84d4c53cca4f9c6fcc70e58692b8f782
diff --git a/caffe2/python/nomnigraph.py b/caffe2/python/nomnigraph.py
index 407e92b..100a373 100644
--- a/caffe2/python/nomnigraph.py
+++ b/caffe2/python/nomnigraph.py
@@ -41,6 +41,22 @@
def dataFlow(self):
return self._NNModule.dataFlow()
+ @property
+ def controlFlow(self):
+ return self._NNModule.getExecutionOrder()
+
+ @property
+ def nodes(self):
+ return self._NNModule.dataFlow().nodes
+
+ @property
+ def operators(self):
+ return self._NNModule.dataFlow().operators
+
+ @property
+ def tensors(self):
+ return self._NNModule.dataFlow().tensors
+
def convertToCaffe2Proto(self, old_proto=None):
if not old_proto:
old_proto = caffe2_pb2.NetDef()
diff --git a/caffe2/python/nomnigraph_test.py b/caffe2/python/nomnigraph_test.py
index a42c31a..c977116 100644
--- a/caffe2/python/nomnigraph_test.py
+++ b/caffe2/python/nomnigraph_test.py
@@ -26,9 +26,38 @@
nn = ng.NNModule(net)
for node in nn.dataFlow.getMutableNodes():
if node.isOperator():
- assert node.getOperator().getName() == "FC"
+ assert node.getName() == "FC"
elif node.isTensor():
- assert node.getTensor().getName() in ["X", "W", "Y"]
+ assert node.getName() in ["X", "W", "Y"]
+
+ def test_core_net_controlflow(self):
+ net = core.Net("name")
+ net.FC(["X", "W"], ["Y"])
+ net.Relu(["Y"], ["Z"])
+ nn = ng.NNModule(net)
+ assert len(nn.controlFlow) == 2
+ for instr in nn.controlFlow:
+ assert instr.getType() == "Operator"
+ assert nn.controlFlow[0].getName() == "FC"
+ assert nn.controlFlow[1].getName() == "Relu"
+
+ def test_core_net_nn_accessors(self):
+ net = core.Net("name")
+ net.FC(["X", "W"], ["Y"])
+ net.Relu(["Y"], ["Z"])
+ nn = ng.NNModule(net)
+ tensors = set()
+ for t in nn.tensors:
+ tensors.add(t.name)
+ assert tensors == set(["X", "W", "Y", "Z"])
+ ops = set()
+ for op in nn.operators:
+ ops.add(op.name)
+ assert ops == set(["FC", "Relu"])
+ nodes = set()
+ for node in nn.nodes:
+ nodes.add(node.name)
+ assert nodes == (ops | tensors)
def test_netdef_simple(self):
net = core.Net("name")
@@ -85,6 +114,20 @@
if bool(random.getrandbits(1)):
dfg.createEdge(data[i], ops[j])
+ def test_traversal(self):
+ net = core.Net("test")
+ net.FC(["X", "W"], ["Y"])
+ net.Relu(["Y"], ["Z"])
+ nn = ng.NNModule(net)
+ fc = nn.controlFlow[0]
+ relu = nn.controlFlow[1]
+ assert fc.inputs[0].name == "X"
+ assert fc.inputs[1].name == "W"
+ assert relu.outputs[0].name == "Z"
+ assert relu.inputs[0].name == "Y"
+ assert relu.inputs[0].producer.name == "FC"
+ assert fc.outputs[0].consumers[0].name == "Relu"
+
def test_debug(self):
nn = ng.NNModule()
dfg = nn.dataFlow
diff --git a/caffe2/python/pybind_state_nomni.cc b/caffe2/python/pybind_state_nomni.cc
index c093227..6efdfb3 100644
--- a/caffe2/python/pybind_state_nomni.cc
+++ b/caffe2/python/pybind_state_nomni.cc
@@ -124,19 +124,44 @@
"dataFlow",
[](NNModule* nn) -> NNGraph* { return &nn->dataFlow; },
py::return_value_policy::reference_internal)
- .def("convertToCaffe2Proto", [](NNModule& nn, py::object def) {
- CAFFE_ENFORCE(
- pybind11::hasattr(def, "SerializeToString"),
- "convertToCaffe2Proto takes either no args", "a NetDef");
- auto str = def.attr("SerializeToString")();
- caffe2::NetDef proto;
- proto.ParseFromString(py::bytes(str));
- auto new_proto = caffe2::convertToCaffe2Proto(nn, proto);
- std::string out;
- new_proto.SerializeToString(&out);
- return py::bytes(out);
- });
+ .def(
+ "convertToCaffe2Proto",
+ [](NNModule& nn, py::object def) {
+ CAFFE_ENFORCE(
+ pybind11::hasattr(def, "SerializeToString"),
+ "convertToCaffe2Proto takes either no args",
+ "a NetDef");
+ auto str = def.attr("SerializeToString")();
+ caffe2::NetDef proto;
+ proto.ParseFromString(py::bytes(str));
+ auto new_proto = caffe2::convertToCaffe2Proto(nn, proto);
+ std::string out;
+ new_proto.SerializeToString(&out);
+ return py::bytes(out);
+ })
+ .def(
+ "getExecutionOrder",
+ [](NNModule& nn) {
+ nn::coalesceInsertedDataDependencies(&nn);
+ std::vector<NNGraph::NodeRef> out;
+ auto sccs = nom::algorithm::tarjans(&nn.controlFlow);
+ for (const auto& scc : sccs) {
+ for (const auto& bb : scc.getNodes()) {
+ for (const auto& instr : bb->data().getInstructions()) {
+ out.emplace_back(instr);
+ }
+ }
+ }
+ return out;
+ },
+ py::return_value_policy::reference_internal);
+ auto getTensors = [](NNGraph* g) {
+ return nn::nodeIterator<nom::repr::Tensor>(*g);
+ };
+ auto getOperators = [](NNGraph* g) {
+ return nn::nodeIterator<NeuralNetOperator>(*g);
+ };
// NNGraph methods
py::class_<NNGraph> nngraph(m, "NNGraph");
nngraph
@@ -190,12 +215,62 @@
py::return_value_policy::reference_internal)
.def(
"getMutableNodes",
- [](NNGraph* g) { return g->getMutableNodes(); },
- py::return_value_policy::reference_internal);
+ &NNGraph::getMutableNodes,
+ py::return_value_policy::reference_internal)
+ .def_property_readonly(
+ "nodes",
+ &NNGraph::getMutableNodes,
+ py::return_value_policy::reference_internal)
+ .def_property_readonly(
+ "operators",
+ getOperators,
+ py::return_value_policy::reference_internal)
+ .def_property_readonly(
+ "tensors", getTensors, py::return_value_policy::reference_internal);
// Node level methods
using NodeType = nom::Node<std::unique_ptr<nom::repr::Value>>;
py::class_<NodeType> noderef(m, "NodeRef");
+ auto getName = [](NNGraph::NodeRef n) {
+ if (nn::is<nom::repr::Tensor>(n)) {
+ return nn::get<nom::repr::Tensor>(n)->getName();
+ } else if (nn::is<NeuralNetOperator>(n)) {
+ return nn::get<NeuralNetOperator>(n)->getName();
+ }
+ return std::string("Unknown");
+ };
+ auto getType = [](NNGraph::NodeRef n) {
+ if (nn::is<nom::repr::Tensor>(n)) {
+ return "Tensor";
+ } else if (nn::is<NeuralNetOperator>(n)) {
+ return "Operator";
+ }
+ return "Unknown";
+ };
+ auto getOperator = [](NNGraph::NodeRef n) {
+ CAFFE_ENFORCE(nn::is<NeuralNetOperator>(n));
+ return nn::get<NeuralNetOperator>(n);
+ };
+ auto getTensor = [](NNGraph::NodeRef n) {
+ CAFFE_ENFORCE(nn::is<nom::repr::Tensor>(n));
+ return nn::get<nom::repr::Tensor>(n);
+ };
+ auto getInputs = [](NNGraph::NodeRef n) {
+ CAFFE_ENFORCE(nn::is<NeuralNetOperator>(n));
+ return nn::getInputs(n);
+ };
+ auto getOutputs = [](NNGraph::NodeRef n) {
+ CAFFE_ENFORCE(nn::is<NeuralNetOperator>(n));
+ return nn::getOutputs(n);
+ };
+ auto getProducer = [](NNGraph::NodeRef n) {
+ CAFFE_ENFORCE(nn::is<NeuralNetData>(n));
+ return nn::getProducer(n);
+ };
+ auto getConsumers = [](NNGraph::NodeRef n) {
+ CAFFE_ENFORCE(nn::is<NeuralNetData>(n));
+ return nn::getConsumers(n);
+ };
noderef
.def(
@@ -204,20 +279,31 @@
.def(
"isTensor",
[](NNGraph::NodeRef n) { return nn::is<nom::repr::Tensor>(n); })
+ .def("getType", getType)
+ .def_property_readonly("type", getType)
+ .def("getName", getName)
+ .def_property_readonly("name", getName)
.def(
"getOperator",
- [](NNGraph::NodeRef n) {
- CAFFE_ENFORCE(nn::is<NeuralNetOperator>(n));
- return nn::get<NeuralNetOperator>(n);
- },
+ getOperator,
py::return_value_policy::reference_internal)
- .def(
- "getTensor",
- [](NNGraph::NodeRef n) {
- CAFFE_ENFORCE(nn::is<nom::repr::Tensor>(n));
- return nn::get<nom::repr::Tensor>(n);
- },
- py::return_value_policy::reference_internal)
+ .def("getTensor", getTensor, py::return_value_policy::reference_internal)
+ .def_property_readonly(
+ "operator", getOperator, py::return_value_policy::reference)
+ .def_property_readonly(
+ "tensor", getTensor, py::return_value_policy::reference)
+ .def("getInputs", getInputs, py::return_value_policy::reference)
+ .def("getOutputs", getOutputs, py::return_value_policy::reference)
+ .def("getProducer", getProducer, py::return_value_policy::reference)
+ .def("getConsumers", getConsumers, py::return_value_policy::reference)
+ .def_property_readonly(
+ "inputs", getInputs, py::return_value_policy::reference)
+ .def_property_readonly(
+ "outputs", getOutputs, py::return_value_policy::reference)
+ .def_property_readonly(
+ "producer", getProducer, py::return_value_policy::reference)
+ .def_property_readonly(
+ "consumers", getConsumers, py::return_value_policy::reference)
.def_property(
"annotation",
[](NNGraph::NodeRef n) { return getOrAddCaffe2Annotation(n); },
@@ -275,7 +361,20 @@
// Subgraph matching API
py::class_<NNSubgraph> nnsubgraph(m, "NNSubgraph");
- nnsubgraph.def("__len__", [](NNSubgraph& s) { return s.getNodes().size(); });
+ nnsubgraph.def("__len__", [](NNSubgraph& s) { return s.getNodes().size(); })
+ .def_property_readonly(
+ "nodes",
+ [](NNSubgraph& s) {
+ std::vector<NNGraph::NodeRef> out;
+ for (auto n : s.getNodes()) {
+ out.emplace_back(n);
+ }
+ return out;
+ },
+ py::return_value_policy::reference)
+ .def("hasNode", [](NNSubgraph& s, NNGraph::NodeRef n) {
+ return s.hasNode(n);
+ });
py::class_<nn::NNMatchGraph> nnMatchGraph(m, "NNMatchGraph");
nnMatchGraph.def(py::init<>());