| import io | 
 | import torch.onnx | 
 | import onnx | 
 | from caffe2.python.onnx.backend import Caffe2Backend | 
 | from caffe2.python.core import BlobReference, Net | 
 |  | 
 |  | 
 | _next_idx = 0 | 
 | # Clone net takes a dict instead of a lambda | 
 | # It should probably take a lambda, it is more flexible | 
 | # We fake dict here | 
 |  | 
 |  | 
 | class _FakeDict(object): | 
 |     def __init__(self, fn): | 
 |         self.fn = fn | 
 |  | 
 |     def get(self, name, _): | 
 |         return self.fn(name) | 
 |  | 
 |  | 
 | def PyTorchModule(helper, model, sample_arguments, caffe2_inputs, prefix_name=None): | 
 |     """ | 
 |     Embed an ONNX-exportable PyTorch Model into a Caffe2 model being built. | 
 |  | 
 |     Arguments: | 
 |         helper (caffe2.python.core.ModelHelder): the model helper where | 
 |             this imported network should be inserted | 
 |         model (torch.nn.Module): the model to be exported | 
 |         sample_arguments (tuple of arguments): the inputs to | 
 |             the model, e.g., such that ``model(*args)`` is a valid | 
 |             invocation of the model.  Any non-Variable arguments will | 
 |             be hard-coded into the exported model; any Variable arguments | 
 |             will become inputs of the exported model, in the order they | 
 |             occur in args.  If args is a Variable, this is equivalent | 
 |             to having called it with a 1-ary tuple of that Variable. | 
 |             (Note: passing keyword arguments to the model is not currently | 
 |             supported.  Give us a shout if you need it.) | 
 |         caffe2_inputs (list of str or caffe2.python.core.BlobReference): the | 
 |            caffe2 Blobs that should be inputs to this network. Must be | 
 |            the same length as sample_arguments | 
 |         prefix_name: prefix name to add to each member of the blob, if None then | 
 |            a fresh prefix pytorch_input_N/ is used | 
 |     Returns: | 
 |         A tuple of caffe2.python.core.BlobReference objects referring to the | 
 |         models outputs, or a single BlobReference when the model returns a single | 
 |         value. | 
 |     """ | 
 |     if prefix_name is None: | 
 |         global _next_idx | 
 |         prefix_name = 'pytorch_import_' + str(_next_idx) + '/' | 
 |         _next_idx += 1 | 
 |  | 
 |     # TODO: handle the case where model cannot be exported | 
 |     # and embed as a Python op in Caffe2 | 
 |     f = io.BytesIO() | 
 |     torch.onnx.export( | 
 |         model, sample_arguments, f, export_params=True) | 
 |     onnx_model = onnx.load(io.BytesIO(f.getvalue())) | 
 |     init_net, predict_net = Caffe2Backend.onnx_graph_to_caffe2_net( | 
 |         onnx_model) | 
 |  | 
 |     initialized = set([x.name for x in onnx_model.graph.initializer]) | 
 |     uninitialized_inputs = {x.name: i for i, x in enumerate( | 
 |         onnx_model.graph.input) if x.name not in initialized} | 
 |  | 
 |     if(len(uninitialized_inputs) != len(caffe2_inputs)): | 
 |         raise ValueError('Expected {} inputs but found {}'.format( | 
 |             len(uninitialized_inputs), len(caffe2_inputs))) | 
 |  | 
 |     def remap_blob_name(name): | 
 |         if name in uninitialized_inputs: | 
 |             idx = uninitialized_inputs[name] | 
 |             return str(caffe2_inputs[idx]) | 
 |         return prefix_name + name | 
 |  | 
 |     predict_net = Net(predict_net).Clone('anon', _FakeDict(remap_blob_name)) | 
 |     helper.net.AppendNet(predict_net) | 
 |  | 
 |     init_net = Net(init_net).Clone('anon', _FakeDict(remap_blob_name)) | 
 |     helper.param_init_net.AppendNet(init_net) | 
 |  | 
 |     results = tuple([BlobReference(remap_blob_name(x.name), helper.net) | 
 |                      for x in onnx_model.graph.output]) | 
 |     return results |