|  |  | 
|  |  | 
|  | import errno | 
|  | import os | 
|  | from subprocess import PIPE, Popen | 
|  |  | 
|  | import caffe2.python._import_c_extension as C | 
|  | from caffe2.proto import caffe2_pb2 | 
|  | from caffe2.python import core | 
|  |  | 
|  |  | 
|  | class NNModule(object): | 
|  | def __init__(self, net=None, device_map=None): | 
|  | if net is not None: | 
|  | serialized_proto = None | 
|  | if isinstance(net, core.Net): | 
|  | serialized_proto = net.Proto().SerializeToString() | 
|  | elif isinstance(net, caffe2_pb2.NetDef): | 
|  | serialized_proto = net.SerializeToString() | 
|  |  | 
|  | # Distributed | 
|  | if device_map is not None: | 
|  | serialized_device_map = {} | 
|  | for k in device_map: | 
|  | serialized_device_map[k] = device_map[k].SerializeToString() | 
|  | self._NNModule = C.NNModuleFromProtobufDistributed( | 
|  | serialized_proto, serialized_device_map | 
|  | ) | 
|  | # Default | 
|  | elif serialized_proto: | 
|  | self._NNModule, self._OpList = C.NNModuleFromProtobuf(serialized_proto) | 
|  | else: | 
|  | raise Exception( | 
|  | "NNModule can be constructed with core.Net or caffe2_pb2.NetDef types" | 
|  | ) | 
|  | else: | 
|  | self._NNModule = C.NNModule() | 
|  |  | 
|  | @property | 
|  | def dataFlow(self): | 
|  | return self._NNModule.dataFlow() | 
|  |  | 
|  | @property | 
|  | def controlFlow(self): | 
|  | return self._NNModule.getExecutionOrder() | 
|  |  | 
|  | @property | 
|  | def nodes(self): | 
|  | return self._NNModule.dataFlow().nodes | 
|  |  | 
|  | @property | 
|  | def operators(self): | 
|  | return self._NNModule.dataFlow().operators | 
|  |  | 
|  | @property | 
|  | def tensors(self): | 
|  | return self._NNModule.dataFlow().tensors | 
|  |  | 
|  | def createNode(self, val): | 
|  | return self._NNModule.dataFlow().createNode(val) | 
|  |  | 
|  | def deleteNode(self, node): | 
|  | return self._NNModule.dataFlow().deleteNode(node) | 
|  |  | 
|  | def createEdge(self, a, b): | 
|  | return self._NNModule.dataFlow().createEdge(a, b) | 
|  |  | 
|  | def deleteEdge(self, a, b=None): | 
|  | if b: | 
|  | self._NNModule.dataFlow().deleteEdge(a, b) | 
|  | else: | 
|  | self._NNModule.dataFlow().deleteEdge(a) | 
|  |  | 
|  | def replaceNode(self, old_node, new_node): | 
|  | return self._NNModule.dataFlow().replaceNode(old_node, new_node) | 
|  |  | 
|  | def replaceProducer(self, tensor, new_producer): | 
|  | C.replaceProducer(tensor, new_producer) | 
|  |  | 
|  | def replaceAllUsesWith(self, old_tensor, new_tensor): | 
|  | C.replaceAllUsesWith(old_tensor, new_tensor) | 
|  |  | 
|  | def replaceAsConsumer(self, old_consumer, new_consumer): | 
|  | C.replaceAsConsumer(old_consumer, new_consumer) | 
|  |  | 
|  | def replaceSubgraph(self, subgraph, new_node, inputs, outputs): | 
|  | self._NNModule.replaceSubgraph(subgraph, new_node, inputs, outputs) | 
|  |  | 
|  | def deleteSubgraph(self, subgraph): | 
|  | self._NNModule.deleteSubgraph(subgraph) | 
|  |  | 
|  | def createUniqueDataNode(self, prefix="_unique"): | 
|  | return self._NNModule.createUniqueDataNode(prefix) | 
|  |  | 
|  | def convertToCaffe2Proto(self, old_proto=None): | 
|  | if not old_proto: | 
|  | old_proto = caffe2_pb2.NetDef() | 
|  | output = self._NNModule.convertToCaffe2Proto(old_proto) | 
|  | new_proto = caffe2_pb2.NetDef() | 
|  | new_proto.ParseFromString(output) | 
|  | return new_proto | 
|  |  | 
|  | def match(self, pattern): | 
|  | for n in self.dataFlow.getMutableNodes(): | 
|  | m = C.matchSubgraph(n, pattern) | 
|  | if m: | 
|  | yield m | 
|  |  | 
|  |  | 
|  | def render(s): | 
|  | s = str(s) | 
|  | cmd_exists = lambda x: any( | 
|  | os.access(os.path.join(path, x), os.X_OK) | 
|  | for path in os.getenv("PATH", "").split(os.pathsep) | 
|  | ) | 
|  | if cmd_exists("graph-easy"): | 
|  | p = Popen("graph-easy", stdin=PIPE) | 
|  | try: | 
|  | p.stdin.write(s.encode("utf-8")) | 
|  | except IOError as e: | 
|  | if e.errno == errno.EPIPE or e.errno == errno.EINVAL: | 
|  | pass | 
|  | else: | 
|  | # Raise any other error. | 
|  | raise | 
|  |  | 
|  | p.stdin.close() | 
|  | p.wait() | 
|  | else: | 
|  | print(s) | 
|  |  | 
|  |  | 
|  | NeuralNetOperator = C.NeuralNetOperator | 
|  | Operator = C.NeuralNetOperator | 
|  | NeuralNetData = C.NeuralNetData | 
|  | Data = C.NeuralNetData | 
|  | NNSubgraph = C.NNSubgraph | 
|  | NNMatchGraph = C.NNMatchGraph | 
|  | Graph = C.Graph | 
|  | Annotation = C.Annotation |