| import torch |
| from torch.autograd._functions.utils import check_onnx_broadcast # TODO: move me |
| from torch.nn.modules.utils import _single, _pair, _triple |
| from torch.nn.utils.rnn import PackedSequence |
| import warnings |
| |
| import torch.onnx |
| |
| # EDITING THIS FILE? READ THIS FIRST! |
| # |
| # - Parameter ordering does NOT necessarily match what is in VariableType.cpp; |
| # tensors are always first, then non-tensor arguments. |
| # - Parameter names must *exactly* match the names in VariableType.cpp, because |
| # dispatch is done with keyword arguments. |
| # - Looking for inplace ops? They're detected by the trailing underscore, and |
| # transparently dispatched to their non inplace versions in |
| # 'run_symbolic_function'. See Note [Export inplace] |
| |
| # --------------------------------------------------------------------- |
| # Helper functions |
| # --------------------------------------------------------------------- |
| |
| |
| def _scalar(x): |
| """Convert a scalar tensor into a Python value.""" |
| assert x.numel() == 1 |
| return x[0] |
| |
| |
| def _if_scalar_type_as(self, tensor): |
| """ |
| Convert self into the same type of tensor, as necessary. |
| |
| We only support implicit casting for scalars, so we never |
| actually need to insert an ONNX cast operator here; just |
| fix up the scalar. |
| """ |
| if isinstance(self, torch._C.Value): |
| return self |
| else: |
| ty = tensor.type().scalarType().lower() |
| return getattr(self, ty)() |
| |
| |
| def _broadcast_if_scalar(x): |
| """Return kwargs enabling broadcasting if 'x' is a scalar.""" |
| if isinstance(x, torch._C.Value): |
| return {} |
| else: |
| return {"broadcast_i": 1} |
| |
| |
| def _unimplemented(op, msg): |
| warnings.warn("ONNX export failed on " + op + " because " + msg + " not supported") |
| |
| |
| # --------------------------------------------------------------------- |
| # ONNX operator version |
| # --------------------------------------------------------------------- |
| |
| # READ ME BEFORE EDITING _onnx_opset_version: |
| # |
| # The variable below controls which ONNX operator set version we are |
| # targeting. THIS VARIABLE HAS SEMANTIC EFFECT! Say a breaking |
| # change occurred in version 8. As long as this variable < 8, you can |
| # export models targeting the old behavior. However, if you bump |
| # this variable to 8 or later, the breaking change will take into effect: |
| # you MUST adjust any symbolic affected by breaking changes. The ONNX |
| # spec publishes a *comprehensive* list of BC-breaking changes for every |
| # operator revision at: |
| # |
| # https://github.com/onnx/onnx/blob/master/docs/Changelog.md |
| # |
| # Please be sure to go through and check all of our implementations here before |
| # increasing this number. This includes symbolic definitions NOT in this |
| # file, so grep for "OpName" (with quotes) |
| |
| _onnx_opset_version = 2 |
| |
| |
| # --------------------------------------------------------------------- |
| # Symbolic definitions |
| # --------------------------------------------------------------------- |
| |
| |
| # Note [Pointwise by scalar] |
| # ~~~~~~~~~~~~~~~~~~~~~~~~~~ |
| # What happens if you add a tensor with a constant (e.g., x + 2)? There are |
| # some moving parts to implementing the ONNX translation in this case: |
| # |
| # - By the time we get the scalar in a symbolic function here, it is no longer |
| # a Python long/float, but a PyTorch tensor with numel == 1 (eventually, we |
| # want it to be a zero dim tensor but this change has not happened yet.) |
| # However, the type of this scalar is *exactly* what the user wrote in |
| # Python, which may not match the tensor it is being added to. PyTorch |
| # will do implicit conversions on scalars; however, ONNX will not, so |
| # we must do the conversion ourselves. This is what _if_scalar_type_as |
| # does. |
| # |
| # - Most of the time, the arguments to self/other are pre-expanded according |
| # to broadcasting. However, a scalar will NOT be broadcasted, so we have |
| # to enable broadcasting ONNX side. |
| # |
| # - Dispatch to these functions takes advantage an outrageous coincidence |
| # between the tensor and scalar name. When we add two tensors together, |
| # you get the dispatch: |
| # |
| # add(*[self, other], **{"alpha": alpha}) |
| # |
| # When you add a tensor and a scalar, you get the dispatch: |
| # |
| # add(*[self], **{"other": other, "alpha": alpha}) |
| # |
| # By having the argument name line up with the name of the scalar attribute |
| # if it exists, we can write a single function for both overloads. |
| # |
| |
| # used to represent "missing" optional inputs |
| def unused(g): |
| return g.op("Undefined") |
| |
| |
| def add(g, self, other, alpha): |
| if _scalar(alpha) != 1: |
| return _unimplemented("add", "alpha != 1") |
| # See Note [Pointwise by scalar] |
| return g.op("Add", self, _if_scalar_type_as(other, self), **_broadcast_if_scalar(other)) |
| |
| |
| def sub(g, self, other, alpha): |
| if _scalar(alpha) != 1: |
| return _unimplemented("sub", "alpha != 1") |
| # See Note [Pointwise by scalar] |
| return g.op("Sub", self, _if_scalar_type_as(other, self), **_broadcast_if_scalar(other)) |
| |
| |
| def mul(g, self, other): |
| # See Note [Pointwise by scalar] |
| return g.op("Mul", self, _if_scalar_type_as(other, self), **_broadcast_if_scalar(other)) |
| |
| |
| def div(g, self, other): |
| # See Note [Pointwise by scalar] |
| return g.op("Div", self, _if_scalar_type_as(other, self), **_broadcast_if_scalar(other)) |
| |
| |
| # This syntax is Python 2 portable |
| def cat(g, *tensors, **kwargs): |
| dim = kwargs.pop("dim") |
| assert not kwargs |
| return g.op("Concat", *tensors, axis_i=dim) |
| |
| |
| def mm(g, self, other): |
| # Create a dummy C tensor. Only needed for API purposes, the value is |
| # since beta = 0 |
| ty = self.type().scalarType().lower() |
| C = g.constant(0, [1], ty) |
| return g.op("Gemm", self, other, C, beta_f=0.0, alpha_f=1.0, broadcast_i=True) |
| |
| |
| def bmm(g, self, other): |
| return g.op("MatMul", self, other) |
| |
| |
| def matmul(g, self, other): |
| return g.op("MatMul", self, other) |
| |
| |
| def addmm(g, self, mat1, mat2, beta, alpha): |
| return g.op("Gemm", mat1, mat2, self, beta_f=_scalar(beta), alpha_f=_scalar(alpha)) |
| |
| |
| def neg(g, self): |
| return g.op("Neg", self) |
| |
| |
| def sqrt(g, self): |
| return g.op("Sqrt", self) |
| |
| |
| def tanh(g, self): |
| return g.op("Tanh", self) |
| |
| |
| def sigmoid(g, self): |
| return g.op("Sigmoid", self) |
| |
| |
| def mean(g, self, dim=None, keepdim=None): |
| if dim is None and keepdim is None: |
| return g.op("Mean", self) |
| # NB: ONNX's default is different from PyTorch's |
| if keepdim is None: |
| keepdim = 0 |
| return g.op("ReduceMean", self, axes_i=[dim], keepdims_i=keepdim) |
| |
| |
| def sum(g, self, dim=None, keepdim=None): |
| if dim is None and keepdim is None: |
| return g.op("Sum", self) |
| if keepdim is None: |
| keepdim = 0 |
| return g.op("ReduceSum", self, axes_i=[dim], keepdims_i=keepdim) |
| |
| |
| def prod(g, self, dim=None, keepdim=None): |
| if dim is None: |
| dims = None |
| else: |
| dims = [dim] |
| if keepdim is None: |
| keepdim = 0 |
| return g.op("ReduceProd", self, axes_i=dims, keepdims_i=keepdim) |
| |
| |
| def t(g, self): |
| return g.op("Transpose", self, perm_i=(1, 0)) |
| |
| |
| def expand(g, self, size): |
| # TODO: This is not a real ONNX operator at the moment |
| return g.op("Expand", self, shape_i=size) |
| |
| |
| def embedding(g, weight, indices, padding_idx, scale_grad_by_freq, sparse): |
| return g.op("Gather", weight, indices) |
| |
| |
| def transpose(g, self, dim0, dim1): |
| if dim0 == dim1: # micro-optimization |
| return self |
| |
| # NB: Transpose in ONNX is actually a Permute |
| axes = list(range(len(self.type().sizes()))) |
| axes[dim0], axes[dim1] = axes[dim1], axes[dim0] |
| return g.op("Transpose", self, perm_i=axes) |
| |
| |
| def permute(g, self, dims): |
| if dims == list(range(0, len(dims))): |
| return self |
| return g.op("Transpose", self, perm_i=dims) |
| |
| |
| def view(g, self, size): |
| if self.type().sizes()[0] == size[0] and len(size) == 2: |
| return g.op("Flatten", self, axis_i=1) |
| return g.op("Reshape", self, shape_i=size) |
| |
| |
| def split(g, self, split_size, dim): |
| size = self.type().sizes()[dim] |
| splits = [split_size] * (size // split_size) |
| leftover = size % split_size |
| if leftover: |
| splits.append(leftover) |
| return g.op("Split", self, split_i=splits, axis_i=dim, outputs=len(splits)) |
| |
| |
| # TODO: It would be better to export this as a chunk directly, as this is |
| # less sensitive to changes in input size. |
| # TODO: Once we have proper scoping, stop reimplementing chunk, delete this |
| # method, and use the desugared version |
| def chunk(g, self, chunks, dim): |
| split_size = (self.type().sizes()[dim] + chunks - 1) // chunks |
| return split(g, self, split_size, dim) |
| |
| |
| def select(g, self, dim, index): |
| slice_node = g.op("Slice", self, axes_i=[dim], starts_i=[index], ends_i=[index + 1]) |
| return g.op("Squeeze", slice_node, axes_i=[dim]) |
| |
| |
| def squeeze(g, self, dim=None): |
| if dim is None: |
| dims = [] |
| for i, size in enumerate(self.type().sizes()): |
| if size == 1: |
| dims.append(i) |
| else: |
| dims = [dim] |
| return g.op("Squeeze", self, axes_i=dims) |
| |
| |
| def prelu(g, self, weight): |
| return g.op("PRelu", self, weight) |
| |
| |
| def threshold(g, self, threshold, value): |
| # See Note [Export inplace] |
| if _scalar(threshold) != 0: |
| return _unimplemented("threshold", "non-zero threshold") |
| if _scalar(value) != 0: |
| return _unimplemented("threshold", "non-zero value") |
| return g.op("Relu", self) |
| |
| |
| def leaky_relu(g, input, negative_slope, inplace=False): |
| # See Note [Export inplace] |
| # TODO: Talk to ONNX about unconditional cast of scalar to float |
| return g.op("LeakyRelu", input, alpha_f=_scalar(negative_slope)) |
| |
| |
| def glu(g, input, dim): |
| assert input.type().sizes()[dim] % 2 == 0 |
| |
| first, second = g.op('Split', input, axis_i=dim, outputs=2) |
| return g.op('Mul', first, g.op('Sigmoid', second)) |
| |
| |
| def softmax(g, input, dim=None): |
| # Softmax does normalization at vector level. |
| # PyTorch and ONNX use different strategies to split the input tensor into vectors. |
| # Thus dim and axis have different meanings. |
| # PyTorch slices the input tensor into vectors along the `dim`-th dimension. |
| # ONNX reshapes the input into a 2-D tensor, and `axis` indicates where the input is coerced. |
| # If input is a 2 x 3 tensor: |
| # input = [[1.0, 1.0, 1.0], |
| # [1.0, 1,0, 1,0]] |
| # with dim = 0, the result is: |
| # result = [[0.5, 0.5, 0.5], |
| # [0.5, 0.5, 0.5]] |
| # with axis = 0, the result is: |
| # result = [[0.167, 0.167, 0.167], |
| # [0.167, 0.167, 0.167]] |
| # So only when dim and axis both equal to ndim - 1 (the last dimension), |
| # their semantics are equivalent. |
| if len(input.type().sizes()) != dim + 1: |
| return _unimplemented("dim", "ONNX and PyTorch use different strategies to split the input.") |
| return g.op('Softmax', input, axis_i=dim) |
| |
| |
| def softplus(g, self, beta, threshold): |
| if beta != 1: |
| return _unimplemented("beta", "has to be 1") |
| return g.op('Softplus', self) |
| |
| |
| def max_pool1d(g, input, kernel_size, stride, padding, dilation, ceil_mode): |
| if ceil_mode: |
| return _unimplemented("max_pool1d", "ceil_mode") |
| if set(_single(dilation)) != {1}: |
| return _unimplemented("max_pool1d", "dilation") |
| if stride is None: |
| stride = kernel_size |
| r = g.op("MaxPool", input, |
| kernel_shape_i=_single(kernel_size), |
| pads_i=_single(padding) * 2, |
| strides_i=_single(stride)) |
| return r, None |
| |
| |
| def max_pool2d(g, input, kernel_size, stride, padding, dilation, ceil_mode): |
| if ceil_mode: |
| return _unimplemented("max_pool2d", "ceil_mode") |
| if set(_pair(dilation)) != {1}: |
| return _unimplemented("max_pool2d", "dilation") |
| if not stride: |
| stride = kernel_size |
| r = g.op("MaxPool", input, |
| kernel_shape_i=_pair(kernel_size), |
| pads_i=_pair(padding) * 2, |
| strides_i=_pair(stride)) |
| return r, None |
| |
| |
| def avg_pool2d(g, input, kernel_size, stride, padding, ceil_mode, count_include_pad): |
| if ceil_mode: |
| return _unimplemented("avg_pool2d", "ceil_mode") |
| if not stride: |
| stride = kernel_size |
| # TODO: What about count_include_pad?! |
| return g.op("AveragePool", input, |
| kernel_shape_i=_pair(kernel_size), |
| strides_i=_pair(stride), |
| pads_i=_pair(padding) * 2) |
| |
| |
| def avg_pool3d(g, input, kernel_size, stride, padding, ceil_mode, count_include_pad): |
| if ceil_mode: |
| return _unimplemented("avg_pool3d", "ceil_mode") |
| if not stride: |
| stride = kernel_size |
| # TODO: What about count_include_pad?! |
| return g.op("AveragePool", input, |
| kernel_shape_i=_triple(kernel_size), |
| strides_i=_triple(stride), |
| pads_i=_triple(padding)) |
| |
| |
| def reflection_pad(g, input, padding): |
| from torch.autograd._functions.utils import prepare_onnx_paddings |
| mode = "reflect" |
| paddings = prepare_onnx_paddings(len(input.type().sizes()), padding) |
| return g.op("Pad", input, pads_i=paddings, mode_s=mode) |
| |
| |
| def replication_pad(g, input, padding): |
| from torch.autograd._functions.utils import prepare_onnx_paddings |
| mode = "edge" |
| paddings = prepare_onnx_paddings(len(input.type().sizes()), padding) |
| return g.op("Pad", input, pads_i=paddings, mode_s=mode) |
| |
| |
| reflection_pad1d = reflection_pad |
| reflection_pad2d = reflection_pad |
| reflection_pad3d = reflection_pad |
| replication_pad1d = replication_pad |
| replication_pad2d = replication_pad |
| replication_pad3d = replication_pad |
| |
| |
| def upsample_nearest2d(g, input, scale_factor): |
| return g.op("Upsample", input, width_scale_f=scale_factor, |
| height_scale_f=scale_factor, mode_s="nearest") |
| |
| |
| def log_softmax(g, input, dim=None): |
| return g.op("LogSoftmax", input, axis_i=dim) |
| |
| |
| def _convolution(g, input, weight, bias, stride, padding, dilation, |
| transposed, output_padding, groups, benchmark, deterministic, cudnn_enabled): |
| weight_size = weight.type().sizes() |
| |
| args = [input, weight] |
| # ONNX only supports 1D bias |
| if bias.node().kind() != "Undefined" and len(bias.type().sizes()) == 1: |
| args.append(bias) |
| |
| kwargs = {"kernel_shape_i": weight_size[2:], |
| "strides_i": stride, |
| # NB: ONNX supports asymmetric padding, whereas PyTorch supports only |
| # symmetric padding |
| "pads_i": padding + padding, |
| "dilations_i": dilation, |
| "group_i": groups} |
| |
| if any(o != 0 for o in output_padding): |
| # ONNX supports both output_shape and output_padding. they are equivalent expressive. |
| # output_padding is more straightforward, so we use it here. |
| # output_shape = stride * (input_shape - 1) + output_padding + kernel_shape - padding * 2 |
| assert transposed |
| assert len(stride) == len(output_padding) |
| kwargs["output_padding_i"] = output_padding |
| |
| n = g.op("ConvTranspose" if transposed else "Conv", *args, **kwargs) |
| |
| if bias.node().kind() != "Undefined" and len(bias.type().sizes()) != 1: |
| return g.op("Add", n, bias, broadcast_i=1, axis_i=1) |
| else: |
| return n |
| |
| |
| def batch_norm(g, input, weight, bias, running_mean, running_var, training, momentum, eps, cudnn_enabled): |
| out = g.op("BatchNormalization", input, weight, bias, running_mean, running_var, |
| is_test_i=not training, |
| epsilon_f=eps, |
| momentum_f=1 - momentum, |
| consumed_inputs_i=(0, 0, 0, 1, 1), |
| outputs=1 if not training else 5) |
| if not training: |
| return out |
| else: |
| res, new_running_mean, new_running_var, saved_mean, saved_var = out |
| new_running_mean.setType(running_mean.type()) |
| new_running_var.setType(running_var.type()) |
| saved_mean.setUniqueName("batch_norm_dead_output-" + saved_mean.uniqueName()) |
| saved_var.setUniqueName("batch_norm_dead_output-" + saved_var.uniqueName()) |
| return res |
| |
| |
| def unfold(g, input, dimension, size, step): |
| return g.op("ATen", input, operator_s="unfold", dimension_i=dimension, size_i=size, step_i=step) |
| |
| |
| def elu(g, input, alpha, inplace=False): |
| # See Note [Export inplace] |
| return g.op("Elu", input, alpha_f=_scalar(alpha)) |
| |
| |
| def selu(g, input): |
| return g.op("Selu", input) |
| |
| |
| def index_select(g, self, index, dim): |
| return g.op("Gather", self, index, axis_i=dim) |
| |
| |
| def type_as(g, self, other): |
| if self.type().scalarType() == other.type().scalarType(): |
| # no-op |
| return self |
| else: |
| # TODO: This should be pretty easy, just implement it with Cast |
| return _unimplemented("type_as", "non no-op application") |
| |
| |
| # ignore clone operators that are inserted by PyTorch autograd |
| def clone(g, input): |
| return input |
| |
| |
| def abs(g, self): |
| return g.op("Abs", self) |
| |
| |
| def pow(g, self, exponent): |
| return g.op("Pow", self, exponent) |
| |
| |
| def clamp(g, self, min, max): |
| return g.op("Clip", self, min_f=min, max_f=max) |
| |
| |
| def max(g, self, other): |
| return g.op("Max", self, other) |
| |
| |
| def min(g, self, other): |
| return g.op("Min", self, other) |
| |
| |
| def eq(g, self, other): |
| return g.op("Equal", self, other) |
| |
| |
| def exp(g, self): |
| return g.op("Exp", self) |
| |
| |
| def conv_tbc(g, input, weight, bias, pad): |
| return g.op("ATen", input, weight, bias, operator_s="conv_tbc", pad_i=pad) |
| |
| |
| def slice(g, self, dim, start, end, step): |
| if step != 1: |
| _unimplemented("slice", "step!=1 is currently not supported") |
| return g.op("Slice", self, axes_i=[dim], starts_i=[start], ends_i=[end]) |
| |
| |
| def instance_norm(g, input, **kwargs): |
| input_type = input.type().scalarType() |
| weight = kwargs.get("weight", None) |
| bias = kwargs.get("bias", None) |
| eps = kwargs.get("eps", 1e-5) |
| if not weight: |
| weight = g.constant(1.0, [input.type().sizes()[1]], input_type) |
| else: |
| weight = g.op('Constant', value_t=weight) |
| if not bias: |
| bias = g.constant(0.0, [input.type().sizes()[1]], input_type) |
| else: |
| bias = g.op('Constant', value_t=bias) |
| return g.op("InstanceNormalization", input, weight, bias, epsilon_f=eps) |
| |
| |
| def RNN_symbolic_builder(cell_type, *args, **kwargs): |
| if cell_type == 'LSTM': |
| return LSTM_symbolic_builder(*args, **kwargs) |
| elif cell_type == 'GRU': |
| return GRU_symbolic_builder(*args, **kwargs) |
| elif cell_type.startswith('RNN_'): |
| return Elman_RNN_symbolic_builder(cell_type[4:], *args, **kwargs) |
| else: |
| return lambda *args, **kwargs: _unimplemented("RNN", "cell type " + cell_type) |
| |
| |
| def reform_weights(g, w, n, intervals): |
| slices = [g.op('Slice', w, axes_i=[0], starts_i=[x * n], ends_i=[y * n]) for x, y in intervals] |
| return g.op('Concat', *slices, axis_i=0) |
| |
| |
| def Elman_RNN_symbolic_builder( |
| nonlinearity, input_size, hidden_size, num_layers, batch_first, dropout, bidirectional, **kwargs): |
| def symbolic(g, input, all_weights, h0, batch_sizes): |
| if batch_first: |
| return _unimplemented("RNN", "batch_first") |
| if dropout and kwargs['train']: |
| return _unimplemented("RNN", "dropout in training mode") |
| |
| unidirectional = not bidirectional |
| |
| prev_output = input |
| h_outs = [] |
| |
| sequence_lens = unused(g) if batch_sizes is None else batch_sizes |
| |
| for i in range(num_layers): |
| if unidirectional: |
| weight_ih, weight_hh, bias_ih, bias_hh = all_weights[i] |
| bias_concat = g.op('Concat', bias_ih, bias_hh, axis_i=0) |
| |
| h_in = h0 if num_layers == 1 else g.op('Slice', h0, axes_i=[0], starts_i=[i], ends_i=[i + 1]) |
| else: |
| weight_ih = g.op('Concat', all_weights[2 * i][0], all_weights[2 * i + 1][0], axis_i=0) |
| weight_hh = g.op('Concat', all_weights[2 * i][1], all_weights[2 * i + 1][1], axis_i=0) |
| bias_concat = g.op('Concat', |
| all_weights[2 * i][2], |
| all_weights[2 * i][3], |
| all_weights[2 * i + 1][2], |
| all_weights[2 * i + 1][3], |
| axis_i=0) |
| |
| h_in = h0 if num_layers == 1 else g.op('Slice', h0, axes_i=[0], starts_i=[2 * i], ends_i=[2 * i + 2]) |
| |
| inputs = [prev_output, weight_ih, weight_hh, bias_concat, sequence_lens, h_in] |
| extra_kwargs = {} if unidirectional else {'direction_s': 'bidirectional'} |
| prev_output, h_out = g.op('RNN', *inputs, outputs=2, |
| hidden_size_i=hidden_size, |
| activations_s=[nonlinearity.lower()], |
| **extra_kwargs) |
| h_outs.append(h_out) |
| h_outs = h_out if num_layers == 1 else g.op('Concat', *h_outs, axis_i=0) |
| return prev_output, h_outs |
| |
| return symbolic |
| |
| |
| def LSTM_symbolic_builder(input_size, hidden_size, num_layers, batch_first, dropout, bidirectional, **kwargs): |
| def symbolic(g, input, all_weights, h0_and_c0, batch_sizes): |
| if batch_first: |
| return _unimplemented("LSTM", "batch_first") |
| if dropout and kwargs['train']: |
| return _unimplemented("RNN", "dropout in training mode") |
| |
| unidirectional = not bidirectional |
| |
| h0, c0 = h0_and_c0 |
| |
| prev_output = input |
| h_outs = [] |
| |
| sequence_lens = unused(g) if batch_sizes is None else batch_sizes |
| |
| for i in range(num_layers): |
| if unidirectional: |
| # pytorch is input, forget, cell, output. |
| # onnx is input, output, forget, cell. |
| weight_ih, weight_hh, bias_ih, bias_hh = \ |
| [reform_weights(g, w, hidden_size, [(0, 1), (3, 4), (1, 3)]) for w in all_weights[i]] |
| |
| bias_concat = g.op('Concat', bias_ih, bias_hh, axis_i=0) |
| |
| h_in = h0 if num_layers == 1 else g.op('Slice', h0, axes_i=[0], starts_i=[i], ends_i=[i + 1]) |
| c_in = c0 if num_layers == 1 else g.op('Slice', c0, axes_i=[0], starts_i=[i], ends_i=[i + 1]) |
| else: |
| # pytorch is input, forget, cell, output. |
| # onnx is input, output, forget, cell. |
| weight_ih_f, weight_hh_f, bias_ih_f, bias_hh_f = \ |
| [reform_weights(g, w, hidden_size, [(0, 1), (3, 4), (1, 3)]) for w in all_weights[2 * i]] |
| weight_ih_b, weight_hh_b, bias_ih_b, bias_hh_b = \ |
| [reform_weights(g, w, hidden_size, [(0, 1), (3, 4), (1, 3)]) for w in all_weights[2 * i + 1]] |
| |
| weight_ih = g.op('Concat', weight_ih_f, weight_ih_b, axis_i=0) |
| weight_hh = g.op('Concat', weight_hh_f, weight_hh_b, axis_i=0) |
| bias_concat = g.op('Concat', bias_ih_f, bias_hh_f, bias_ih_b, bias_hh_b, axis_i=0) |
| |
| h_in = h0 if num_layers == 1 else g.op('Slice', h0, axes_i=[0], starts_i=[2 * i], ends_i=[2 * i + 2]) |
| c_in = c0 if num_layers == 1 else g.op('Slice', c0, axes_i=[0], starts_i=[2 * i], ends_i=[2 * i + 2]) |
| |
| inputs = [prev_output, weight_ih, weight_hh, bias_concat, sequence_lens, h_in, c_in] |
| extra_kwargs = {} if unidirectional else {'direction_s': 'bidirectional'} |
| prev_output, h_out = g.op('LSTM', *inputs, outputs=2, |
| hidden_size_i=hidden_size, |
| **extra_kwargs) |
| h_outs.append(h_out) |
| h_outs = h_out if num_layers == 1 else g.op('Concat', *h_outs, axis_i=0) |
| return prev_output, h_outs, None |
| |
| return symbolic |
| |
| |
| def GRU_symbolic_builder(input_size, hidden_size, num_layers, batch_first, dropout, bidirectional, **kwargs): |
| def symbolic(g, input, all_weights, h0, batch_sizes): |
| if batch_first: |
| return _unimplemented("GRU", "batch_first") |
| if dropout and kwargs['train']: |
| return _unimplemented("RNN", "dropout in training mode") |
| |
| unidirectional = not bidirectional |
| |
| prev_output = input |
| h_outs = [] |
| |
| sequence_lens = unused(g) if batch_sizes is None else batch_sizes |
| |
| for i in range(num_layers): |
| if unidirectional: |
| # pytorch is reset, input, hidden |
| # onnx is input, reset, hidden |
| weight_ih, weight_hh, bias_ih, bias_hh = \ |
| [reform_weights(g, w, hidden_size, [(1, 2), (0, 1), (2, 3)]) for w in all_weights[i]] |
| |
| bias_concat = g.op('Concat', bias_ih, bias_hh, axis_i=0) |
| |
| h_in = h0 if num_layers == 1 else g.op('Slice', h0, axes_i=[0], starts_i=[i], ends_i=[i + 1]) |
| else: |
| # pytorch is reset, input, hidden |
| # onnx is input, reset, hidden |
| weight_ih_f, weight_hh_f, bias_ih_f, bias_hh_f = \ |
| [reform_weights(g, w, hidden_size, [(1, 2), (0, 1), (2, 3)]) for w in all_weights[2 * i]] |
| weight_ih_b, weight_hh_b, bias_ih_b, bias_hh_b = \ |
| [reform_weights(g, w, hidden_size, [(1, 2), (0, 1), (2, 3)]) for w in all_weights[2 * i + 1]] |
| |
| weight_ih = g.op('Concat', weight_ih_f, weight_ih_b, axis_i=0) |
| weight_hh = g.op('Concat', weight_hh_f, weight_hh_b, axis_i=0) |
| bias_concat = g.op('Concat', bias_ih_f, bias_hh_f, bias_ih_b, bias_hh_b, axis_i=0) |
| |
| h_in = h0 if num_layers == 1 else g.op('Slice', h0, axes_i=[0], starts_i=[2 * i], ends_i=[2 * i + 2]) |
| |
| inputs = [prev_output, weight_ih, weight_hh, bias_concat, sequence_lens, h_in] |
| extra_kwargs = {} if unidirectional else {'direction_s': 'bidirectional'} |
| prev_output, h_out = g.op('GRU', *inputs, outputs=2, |
| hidden_size_i=hidden_size, linear_before_reset_i=1, |
| **extra_kwargs) |
| h_outs.append(h_out) |
| h_outs = h_out if num_layers == 1 else g.op('Concat', *h_outs, axis_i=0) |
| return prev_output, h_outs |
| |
| return symbolic |