| """ |
| The torch.onnx module contains functions to export models into the ONNX |
| IR format. These models can be loaded with the ONNX library and then |
| converted to models which run on other deep learning frameworks. |
| """ |
| |
| import torch |
| import torch.jit |
| import torch.autograd |
| import torch.serialization |
| import re |
| import collections |
| import string |
| import json |
| import math |
| import contextlib |
| from ._utils import _range |
| from torch._six import string_classes |
| |
| |
| @contextlib.contextmanager |
| def set_training(model, mode): |
| """ |
| A context manager to temporarily set the training mode of 'model' |
| to 'mode', resetting it when we exit the with-block. A no-op if |
| mode is None. |
| """ |
| if mode is None: |
| yield |
| return |
| old_mode = model.training |
| if old_mode != mode: |
| model.train(mode) |
| try: |
| yield |
| finally: |
| if old_mode != mode: |
| model.train(old_mode) |
| |
| |
| def export(model, args, f, export_params=True, verbose=False, training=False): |
| """ |
| Export a model into ONNX format. This exporter runs your model |
| once in order to get a trace of its execution to be exported; at the |
| moment, it does not support dynamic models (e.g., RNNs.) |
| |
| See also: :ref:`onnx-export` |
| |
| Arguments: |
| model (torch.nn.Module): the model to be exported. |
| args (tuple of arguments): the inputs to |
| the model, e.g., such that ``model(*args)`` is a valid |
| invocation of the model. Any non-Variable arguments will |
| be hard-coded into the exported model; any Variable arguments |
| will become inputs of the exported model, in the order they |
| occur in args. If args is a Variable, this is equivalent |
| to having called it with a 1-ary tuple of that Variable. |
| (Note: passing keyword arguments to the model is not currently |
| supported. Give us a shout if you need it.) |
| f: a file-like object (has to implement fileno that returns a file descriptor) |
| or a string containing a file name. A binary Protobuf will be written |
| to this file. |
| export_params (bool, default True): if specified, all parameters will |
| be exported. Set this to False if you want to export an untrained model. |
| In this case, the exported model will first take all of its parameters |
| as arguments, the ordering as specified by ``model.state_dict().values()`` |
| verbose (bool, default False): if specified, we will print out a debug |
| description of the trace being exported. |
| training (bool, default False): export the model in training mode. At |
| the moment, ONNX is oriented towards exporting models for inference |
| only, so you will generally not need to set this to True. |
| """ |
| _export(model, args, f, export_params, verbose, training) |
| |
| |
| def _export(model, args, f, export_params=True, verbose=False, training=False): |
| # Special case for common case of passing a single Variable |
| if isinstance(args, torch.autograd.Variable): |
| args = (args, ) |
| |
| # A basic sanity check: make sure the state_dict keys are the same |
| # before and after running the model. Fail fast! |
| orig_state_dict_keys = model.state_dict().keys() |
| |
| # By default, training=False, which is good because running a model in |
| # training mode could result in internal buffers getting updated, dropout |
| # getting applied, etc. If you really know what you're doing, you |
| # can turn training=True (or None, to preserve whatever the original |
| # training mode was.) |
| with set_training(model, training): |
| trace, torch_out = torch.jit.trace(model, args) |
| |
| if orig_state_dict_keys != model.state_dict().keys(): |
| raise RuntimeError("state_dict changed after running the tracer; " |
| "something weird is happening in your model!") |
| |
| torch._C._jit_pass_onnx(trace) |
| |
| if verbose: |
| print(trace) |
| |
| # TODO: Don't allocate a in-memory string for the protobuf |
| if export_params: |
| # NB: OrderedDict values is not actually a list, but trace.export is |
| # not duck-typed and expects an actual list. |
| proto = trace.export(list(model.state_dict().values())) |
| else: |
| proto = trace.export() |
| |
| torch.serialization._with_file_like(f, "wb", lambda f: f.write(proto)) |
| return torch_out |
| |
| |
| attr_pattern = re.compile("^(.+)_([ifstg])$") |
| |
| |
| def run_symbolic(op_name, symbolic_fn, args): |
| """ |
| This trampoline function gets invoked for every symbolic call. |
| """ |
| try: |
| return symbolic_fn(*args) |
| except TypeError as e: |
| # Handle the specific case where we didn't successfully dispatch |
| # to symbolic_fn. Otherwise, the backtrace will have the clues |
| # you need. |
| e.args = ("{} (occurred when translating {})".format(e.args[0], op_name), ) |
| raise |
| |
| |
| def _add_attribute(node, key, value): |
| """ initializes the right attribute based on type of value """ |
| m = attr_pattern.match(key) |
| if m is None: |
| raise IndexError(( |
| "Invalid attribute specifier '{}' names " + |
| " must be suffixed with type, e.g. 'dim_i' or 'dims_i'").format(key)) |
| name, kind = m.group(1), m.group(2) |
| if not isinstance(value, string_classes) and isinstance(value, collections.Iterable): |
| kind += "s" |
| return getattr(node, kind + '_')(name, value) |
| |
| |
| def _newNode(self, opname, *args, **kwargs): |
| n = self.create(opname, args) |
| for k, v in sorted(kwargs.items()): |
| _add_attribute(n, k, v) |
| return n |
| |
| |
| def _op(self, opname, *args, **kwargs): |
| """ |
| Create an ONNX operator 'opname', taking 'args' as inputs |
| and attributes 'kwargs' and add it to the current graph, |
| returning the node representing the single output of this |
| operator (see the `outputs` keyword argument for multi-return |
| nodes). |
| |
| The set of operators and the inputs/attributes they take |
| is documented at https://github.com/onnx/onnx/blob/master/docs/Operators.md |
| |
| This op is monkey-patched to be available on the 'Graph' object |
| passed in as the first argument. |
| |
| Arguments: |
| opname (string): The ONNX operator name, e.g., `Abs` or `Add`. |
| args (Node...): The inputs to the operator; usually provided |
| as arguments to the `symbolic` definition. |
| kwargs: The attributes of the ONNX operator, with keys named |
| according to the following convention: `alpha_f` indicates |
| the `alpha` attribute with type `f`. The valid type specifiers are |
| `f` (float), `i` (int), `s` (string) or `t` (Tensor). An attribute |
| specified with type float accepts either a single float, or a |
| list of floats (e.g., you would say `dims_i` for a `dims` attribute |
| that takes a list of integers). |
| outputs (int, optional): The number of outputs this operator returns; |
| by default an operator is assumed to return a single output. |
| If `outputs` is greater than one, this functions returns a tuple |
| of output `Node`, representing each output of the ONNX operator |
| in positional. |
| """ |
| outputs = kwargs.pop('outputs', 1) |
| n = self.appendNode(_newNode(self, opname, *args, **kwargs)) |
| if outputs == 1: |
| return n |
| return tuple(self.appendNode(self.createSelect(n, i)) for i in _range(outputs)) |
| |
| |
| def _at(self, opname, *args, **kwargs): |
| return self.op("ATen", *args, operator_s=opname, **kwargs) |
| |
| |
| def _constant(self, value, dims, type=None, *args, **kwargs): |
| assert(isinstance(value, (int, long, float))) |
| # Infer the type based on value. |
| if type is None: |
| if isinstance(value, int): |
| type = "int" |
| elif isinstance(value, long): |
| type = "long" |
| elif isinstance(value, float): |
| type = "float" |
| |
| if type == "char": |
| tensor = torch.CharTensor(*dims) |
| elif type == "short": |
| tensor = torch.ShortTensor(*dims) |
| elif type == "int": |
| tensor = torch.IntTensor(*dims) |
| elif type == "long": |
| tensor = torch.LongTensor(*dims) |
| elif type == "half": |
| tensor = torch.HalfTensor(*dims) |
| elif type == "float": |
| tensor = torch.FloatTensor(*dims) |
| elif type == "double": |
| tensor = torch.DoubleTensor(*dims) |
| else: |
| raise ValueError("Unknown type, type should be one of the following strings:" |
| "char, short, int, long, half, float, double") |
| tensor.fill_(value) |
| return self.op("Constant", *args, value_t=tensor, **kwargs) |
| |
| torch._C.Graph.op = _op |
| torch._C.Graph.at = _at |
| torch._C.Graph.constant = _constant |