| from test_pytorch_common import TestCase, run_tests, skipIfNoLapack, flatten |
| import test_onnx_common |
| |
| import torch |
| import torch.onnx |
| from torch.autograd import Variable, Function |
| from torch.nn import Module |
| import torch.nn as nn |
| |
| import onnx |
| import onnx.checker |
| import onnx.helper |
| |
| import google.protobuf.text_format |
| |
| import itertools |
| import io |
| import unittest |
| import inspect |
| import argparse |
| import glob |
| import os |
| import shutil |
| import sys |
| import common |
| from onnx import numpy_helper |
| |
| _onnx_test = False |
| |
| |
| def export_to_string(model, inputs, *args, **kwargs): |
| f = io.BytesIO() |
| with torch.no_grad(): |
| torch.onnx.export(model, inputs, f, *args, **kwargs) |
| return f.getvalue() |
| |
| |
| class FuncModule(Module): |
| def __init__(self, f, params=tuple()): |
| super(FuncModule, self).__init__() |
| self.f = f |
| self.params = nn.ParameterList(list(params)) |
| |
| def forward(self, *args): |
| return self.f(*itertools.chain(args, self.params)) |
| |
| |
| class TestOperators(TestCase): |
| |
| def assertONNXExpected(self, binary_pb, subname=None): |
| model_def = onnx.ModelProto.FromString(binary_pb) |
| onnx.checker.check_model(model_def) |
| # doc_string contains stack trace in it, strip it |
| onnx.helper.strip_doc_string(model_def) |
| self.assertExpected(google.protobuf.text_format.MessageToString(model_def, float_format='.15g'), subname) |
| return model_def |
| |
| def assertONNX(self, f, args, params=tuple(), **kwargs): |
| if isinstance(f, nn.Module): |
| m = f |
| else: |
| m = FuncModule(f, params) |
| onnx_model_pb = export_to_string(m, args, **kwargs) |
| model_def = self.assertONNXExpected(onnx_model_pb) |
| if _onnx_test: |
| test_function = inspect.stack()[1][0].f_code.co_name |
| test_name = test_function[0:4] + "_operator" + test_function[4:] |
| output_dir = os.path.join(test_onnx_common.pytorch_operator_dir, test_name) |
| # Assume: |
| # 1) the old test should be delete before the test. |
| # 2) only one assertONNX in each test, otherwise will override the data. |
| assert not os.path.exists(output_dir), "{} should not exist!".format(output_dir) |
| os.makedirs(output_dir) |
| with open(os.path.join(output_dir, "model.onnx"), 'wb') as file: |
| file.write(model_def.SerializeToString()) |
| data_dir = os.path.join(output_dir, "test_data_set_0") |
| os.makedirs(data_dir) |
| if isinstance(args, Variable): |
| args = (args,) |
| for index, var in enumerate(flatten(args)): |
| tensor = numpy_helper.from_array(var.data.numpy()) |
| with open(os.path.join(data_dir, "input_{}.pb".format(index)), 'wb') as file: |
| file.write(tensor.SerializeToString()) |
| outputs = m(*args) |
| if isinstance(outputs, Variable): |
| outputs = (outputs,) |
| for index, var in enumerate(flatten(outputs)): |
| tensor = numpy_helper.from_array(var.data.numpy()) |
| with open(os.path.join(data_dir, "output_{}.pb".format(index)), 'wb') as file: |
| file.write(tensor.SerializeToString()) |
| |
| def assertONNXRaises(self, err, f, args, params=tuple(), **kwargs): |
| if isinstance(f, nn.Module): |
| m = f |
| else: |
| m = FuncModule(f, params) |
| self.assertExpectedRaises(err, lambda: export_to_string(m, args, **kwargs)) |
| |
| def assertONNXRaisesRegex(self, err, reg, f, args, params=tuple(), **kwargs): |
| if isinstance(f, nn.Module): |
| m = f |
| else: |
| m = FuncModule(f, params) |
| with self.assertRaisesRegex(err, reg): |
| export_to_string(m, args, **kwargs) |
| |
| def test_basic(self): |
| x = Variable(torch.Tensor([0.4]), requires_grad=True) |
| y = Variable(torch.Tensor([0.7]), requires_grad=True) |
| self.assertONNX(lambda x, y: -torch.sigmoid(torch.tanh(x * (x + y))), (x, y)) |
| |
| def test_view(self): |
| x = Variable(torch.Tensor([0]), requires_grad=True) |
| self.assertONNX(lambda x: x.view(1, 1), x) |
| |
| def test_index(self): |
| x = Variable(torch.Tensor([[0]]), requires_grad=True) |
| self.assertONNX(lambda x: x[0], x) |
| |
| def test_type_as(self): |
| x = Variable(torch.Tensor([0]), requires_grad=True) |
| self.assertONNX(lambda x: x.type_as(x), x) |
| |
| def test_addconstant(self): |
| x = Variable(torch.DoubleTensor(2, 3), requires_grad=True) |
| self.assertONNX(lambda x: x + 1, x) |
| |
| def test_add_broadcast(self): |
| x = Variable(torch.DoubleTensor(2, 3), requires_grad=True) |
| y = Variable(torch.DoubleTensor(3), requires_grad=True) |
| self.assertONNX(lambda x, y: x + y, (x, y)) |
| |
| def test_add_left_broadcast(self): |
| x = Variable(torch.DoubleTensor(3), requires_grad=True) |
| y = Variable(torch.DoubleTensor(2, 3), requires_grad=True) |
| self.assertONNXRaisesRegex(RuntimeError, |
| r"ONNX export failed: Could not export a broadcasted operation.*", |
| lambda x, y: x + y, (x, y), verbose=True) |
| |
| def test_add_size1_broadcast(self): |
| x = Variable(torch.DoubleTensor(2, 3), requires_grad=True) |
| y = Variable(torch.DoubleTensor(2, 1), requires_grad=True) |
| self.assertONNX(lambda x, y: x + y, (x, y)) |
| |
| def test_add_size1_right_broadcast(self): |
| x = Variable(torch.DoubleTensor(2, 3), requires_grad=True) |
| y = Variable(torch.DoubleTensor(3), requires_grad=True) |
| self.assertONNX(lambda x, y: x + y, (x, y)) |
| |
| def test_add_size1_singleton_broadcast(self): |
| x = Variable(torch.DoubleTensor(2, 3), requires_grad=True) |
| y = Variable(torch.DoubleTensor(1, 3), requires_grad=True) |
| self.assertONNX(lambda x, y: x + y, (x, y)) |
| |
| def test_transpose(self): |
| x = Variable(torch.Tensor([[0, 1], [2, 3]]), requires_grad=True) |
| self.assertONNX(lambda x: x.transpose(0, 1).transpose(1, 0), x) |
| |
| def test_chunk(self): |
| x = Variable(torch.Tensor([0, 1, 2]), requires_grad=True) |
| self.assertONNX(lambda x: x.chunk(2), x) |
| |
| def test_concat2(self): |
| x = Variable(torch.randn(2, 3)) |
| y = Variable(torch.randn(2, 3)) |
| self.assertONNX(lambda inputs: torch.cat(inputs, 1), ((x, y),)) |
| |
| def test_mm(self): |
| m1 = Variable(torch.randn(2, 3), requires_grad=True) |
| m2 = Variable(torch.randn(3, 4), requires_grad=True) |
| self.assertONNX(torch.mm, (m1, m2)) |
| |
| def test_addmm(self): |
| m1 = Variable(torch.randn(2, 3), requires_grad=True) |
| m2 = Variable(torch.randn(3, 4), requires_grad=True) |
| m3 = Variable(torch.randn(4), requires_grad=True) |
| self.assertONNX(lambda x, y, z: torch.addmm(torch.addmm(z, x, y), x, y), (m1, m2, m3)) |
| |
| def test_permute2(self): |
| x = Variable(torch.Tensor([[[[[[0]]]]]]), requires_grad=True) |
| self.assertONNX(lambda x: x.permute(0, 1, 4, 2, 5, 3), x) |
| |
| def test_pad(self): |
| x = Variable(torch.Tensor([[[[0, 1, 1, 1], [2, 3, 7, 7]]]]), requires_grad=True) |
| self.assertONNX(nn.ReflectionPad2d((2, 3, 0, 1)), x) |
| |
| def test_params(self): |
| x = Variable(torch.Tensor([[1, 2], [3, 4]]), requires_grad=True) |
| y = nn.Parameter(torch.Tensor([[1, 2], [3, 4]]), requires_grad=True) |
| self.assertONNX(lambda x, y: -torch.sigmoid(torch.tanh(x * (x + y))), x, params=(y, )) |
| |
| def test_symbolic_mismatch(self): |
| class MyFun(Function): |
| @staticmethod |
| def symbolic(g, x): |
| # The inside of this function should never be invoked, because |
| # we will fail due to an argument mismatch first. |
| assert False |
| |
| @staticmethod |
| def forward(ctx, x, y): |
| return x + y |
| |
| x = Variable(torch.randn(2, 2).fill_(1.0)) |
| y = Variable(torch.randn(2, 2).fill_(1.0)) |
| # NB: Don't use expect test here, the type error wobbles depending |
| # on Python version |
| with self.assertRaisesRegex(TypeError, "occurred when translating MyFun"): |
| export_to_string(FuncModule(MyFun().apply), (x, y)) |
| |
| # TODO: Do an nn style test for these |
| def test_batchnorm(self): |
| x = Variable(torch.randn(2, 2, 2, 2).fill_(1.0), requires_grad=True) |
| self.assertONNX(nn.BatchNorm2d(2), x) |
| |
| def test_batchnorm_1d(self): |
| x = Variable(torch.randn(2, 2).fill_(1.0), requires_grad=True) |
| self.assertONNX(nn.BatchNorm1d(2), x) |
| |
| def test_batchnorm_training(self): |
| x = Variable(torch.randn(2, 2, 2, 2).fill_(1.0), requires_grad=True) |
| self.assertONNX(nn.BatchNorm2d(2), x, training=True) |
| |
| def test_conv(self): |
| x = Variable(torch.randn(20, 16, 50, 40).fill_(1.0), requires_grad=True) |
| self.assertONNX(nn.Conv2d(16, 13, 3, bias=False), x) |
| |
| def test_convtranspose(self): |
| x = Variable(torch.randn(2, 3, 4, 5).fill_(1.0), requires_grad=True) |
| self.assertONNX(nn.ConvTranspose2d(3, 3, 3, stride=3, bias=False, |
| padding=1, output_padding=2), x) |
| |
| def test_maxpool(self): |
| x = Variable(torch.randn(20, 16, 50)) |
| self.assertONNX(nn.MaxPool1d(3, stride=2), x) |
| |
| def test_at_op(self): |
| x = Variable(torch.randn(3, 4)) |
| |
| class MyFun(Function): |
| |
| @staticmethod |
| def symbolic(g, x): |
| return g.at("add", x, x) |
| |
| @staticmethod |
| def forward(ctx, x): |
| return x + x |
| |
| class MyModule(Module): |
| def forward(self, x): |
| return MyFun.apply(x) |
| |
| self.assertONNX(MyModule(), x) |
| |
| def test_clip(self): |
| x = Variable(torch.randn(3, 4), requires_grad=True) |
| self.assertONNX(lambda x: torch.clamp(x, min=-0.5, max=0.5), x) |
| |
| def test_clip_min(self): |
| x = Variable(torch.randn(1, 2, 3, 4), requires_grad=True) |
| self.assertONNX(lambda x: x.clamp(min=-0.1), x) |
| |
| def test_clip_max(self): |
| x = Variable(torch.randn(1, 2, 3, 4), requires_grad=True) |
| self.assertONNX(lambda x: x.clamp(max=0.1), x) |
| |
| def test_hardtanh(self): |
| x = Variable(torch.randn(3, 4), requires_grad=True) |
| self.assertONNX(lambda x: torch.nn.Hardtanh(-0.5, 0.5)(x), x) |
| |
| def test_max(self): |
| x = Variable(torch.randn(3, 4), requires_grad=True) |
| y = Variable(torch.randn(3, 4), requires_grad=True) |
| self.assertONNX(lambda x, y: torch.max(x, y), (x, y)) |
| |
| def test_min(self): |
| x = Variable(torch.randn(3, 4), requires_grad=True) |
| y = Variable(torch.randn(3, 4), requires_grad=True) |
| self.assertONNX(lambda x, y: torch.min(x, y), (x, y)) |
| |
| def test_mean(self): |
| x = Variable(torch.randn(1, 2, 3, 4), requires_grad=True) |
| self.assertONNX(lambda x: torch.mean(x), x) |
| |
| def test_reduced_mean(self): |
| x = Variable(torch.randn(1, 2, 3, 4), requires_grad=True) |
| self.assertONNX(lambda x: torch.mean(x, dim=2), x) |
| |
| def test_reduced_mean_keepdim(self): |
| x = Variable(torch.randn(1, 2, 3, 4), requires_grad=True) |
| self.assertONNX(lambda x: torch.mean(x, dim=2, keepdim=True), x) |
| |
| def test_sum(self): |
| x = Variable(torch.randn(1, 2, 3, 4), requires_grad=True) |
| self.assertONNX(lambda x: torch.sum(x), x) |
| |
| def test_reduced_sum(self): |
| x = Variable(torch.randn(1, 2, 3, 4), requires_grad=True) |
| self.assertONNX(lambda x: torch.sum(x, dim=2), x) |
| |
| def test_reduced_sum_keepdim(self): |
| x = Variable(torch.randn(1, 2, 3, 4), requires_grad=True) |
| self.assertONNX(lambda x: torch.sum(x, dim=2, keepdim=True), x) |
| |
| def test_prod(self): |
| x = Variable(torch.randn(1, 2, 3, 4), requires_grad=True) |
| self.assertONNX(lambda x: torch.prod(x), x) |
| |
| def test_reduced_prod(self): |
| x = Variable(torch.randn(1, 2, 3, 4), requires_grad=True) |
| self.assertONNX(lambda x: torch.prod(x, dim=2), x) |
| |
| def test_reduced_prod_keepdim(self): |
| x = Variable(torch.randn(1, 2, 3, 4), requires_grad=True) |
| self.assertONNX(lambda x: torch.prod(x, dim=2, keepdim=True), x) |
| |
| def test_sqrt(self): |
| x = Variable(torch.randn(3, 4), requires_grad=True) |
| self.assertONNX(lambda x: torch.sqrt(x), x) |
| |
| def test_equal(self): |
| x = Variable(torch.randn(3, 4).int(), requires_grad=False) |
| y = Variable(torch.randn(3, 4).int(), requires_grad=False) |
| self.assertONNX(lambda x, y: x == y, (x, y)) |
| |
| def test_lt(self): |
| x = Variable(torch.randn(3, 4).int(), requires_grad=False) |
| y = Variable(torch.randn(3, 4).int(), requires_grad=False) |
| self.assertONNX(lambda x, y: x < y, (x, y)) |
| |
| def test_gt(self): |
| x = Variable(torch.randn(3, 4).int(), requires_grad=False) |
| y = Variable(torch.randn(3, 4).int(), requires_grad=False) |
| self.assertONNX(lambda x, y: x > y, (x, y)) |
| |
| def test_le(self): |
| x = Variable(torch.randn(3, 4).int(), requires_grad=False) |
| y = Variable(torch.randn(3, 4).int(), requires_grad=False) |
| self.assertONNX(lambda x, y: x <= y, (x, y)) |
| |
| def test_ge(self): |
| x = Variable(torch.randn(3, 4).int(), requires_grad=False) |
| y = Variable(torch.randn(3, 4).int(), requires_grad=False) |
| self.assertONNX(lambda x, y: x >= y, (x, y)) |
| |
| def test_exp(self): |
| x = Variable(torch.randn(3, 4), requires_grad=True) |
| self.assertONNX(lambda x: x.exp(), x) |
| |
| def test_flatten(self): |
| # Flatten is a special case of Reshape when the output is a 2-D tensor. |
| x = Variable(torch.randn(1, 2, 3, 4), requires_grad=True) |
| self.assertONNX(lambda x: x.view(x.size()[0], x.numel() // x.size()[0]), x) |
| |
| def test_logsoftmax(self): |
| x = Variable(torch.randn(1, 2, 3, 4), requires_grad=True) |
| self.assertONNX(nn.LogSoftmax(dim=2), x) |
| |
| def test_pow(self): |
| x = Variable(torch.randn(1, 2, 3, 4), requires_grad=True) |
| y = Variable(torch.randn(1, 2, 3, 4), requires_grad=True) |
| self.assertONNX(lambda x, y: x.pow(y), (x, y)) |
| |
| def test_selu(self): |
| x = Variable(torch.randn(1, 2, 3, 4), requires_grad=True) |
| self.assertONNX(nn.SELU(), x) |
| |
| def test_repeat(self): |
| x = Variable(torch.randn(1, 2, 3, 4), requires_grad=True) |
| self.assertONNX(lambda x: x.repeat(1, 2, 3, 4), x) |
| |
| def test_repeat_dim_overflow(self): |
| x = Variable(torch.randn(1, 2), requires_grad=True) |
| self.assertONNX(lambda x: x.repeat(1, 2, 3, 4), x) |
| |
| def test_norm(self): |
| x = Variable(torch.randn(1, 2, 3, 4), requires_grad=True) |
| self.assertONNX(lambda x: x.norm(dim=2), (x)) |
| |
| def test_symbolic_override(self): |
| """Lifted from fast-neural-style: custom implementation of instance norm |
| to be mapped to ONNX operator""" |
| |
| class CustomInstanceNorm(torch.nn.Module): |
| def __init__(self, dim, eps=1e-9): |
| super(CustomInstanceNorm, self).__init__() |
| self.scale = nn.Parameter(torch.FloatTensor(dim).uniform_()) |
| self.shift = nn.Parameter(torch.FloatTensor(dim).zero_()) |
| self.eps = eps |
| |
| def forward(self, x): |
| return self._run_forward(x, self.scale, self.shift, eps=self.eps) |
| |
| @staticmethod |
| @torch.onnx.symbolic_override( |
| lambda g, x, scale, shift, eps: g.op( |
| 'InstanceNormalization', x, scale, shift, epsilon_f=eps) |
| ) |
| def _run_forward(x, scale, shift, eps): |
| # since we hand-roll instance norm it doesn't perform well all in fp16 |
| n = x.size(2) * x.size(3) |
| t = x.view(x.size(0), x.size(1), n) |
| mean = torch.mean(t, 2).unsqueeze(2).unsqueeze(3).expand_as(x) |
| # Calculate the biased var. torch.var returns unbiased var |
| var = torch.var(t, 2).unsqueeze(2).unsqueeze(3).expand_as(x) * ((float(n) - 1) / float(n)) |
| scale_broadcast = scale.unsqueeze(1).unsqueeze(1).unsqueeze(0) |
| scale_broadcast = scale_broadcast.expand_as(x) |
| shift_broadcast = shift.unsqueeze(1).unsqueeze(1).unsqueeze(0) |
| shift_broadcast = shift_broadcast.expand_as(x) |
| out = (x - mean) / torch.sqrt(var + eps) |
| out = out * scale_broadcast + shift_broadcast |
| return out |
| |
| instnorm = CustomInstanceNorm(10) |
| x = Variable(torch.randn(2, 10, 32, 32)) |
| self.assertONNX(instnorm, x) |
| |
| """ |
| def test_rnn(self): |
| rnn = nn.RNN(30, 20, 2) |
| input = Variable(torch.randn(10, 32, 30)) |
| output, hidden = rnn(input) |
| self.assertONNX(rnn, input) |
| """ |
| |
| def test_symbolic_override_nested(self): |
| def symb(g, x, y): |
| assert isinstance(x, torch._C.Value) |
| assert isinstance(y[0], torch._C.Value) |
| assert isinstance(y[1], torch._C.Value) |
| return g.op('Sum', x, y[0], y[1]), ( |
| g.op('Neg', x), g.op('Neg', y[0])) |
| |
| @torch.onnx.symbolic_override(symb) |
| def foo(x, y): |
| return x + y[0] + y[1], (-x, -y[0]) |
| |
| class BigModule(torch.nn.Module): |
| def forward(self, x, y): |
| return foo(x, y) |
| |
| inp = (Variable(torch.FloatTensor([1])), |
| (Variable(torch.FloatTensor([2])), |
| Variable(torch.FloatTensor([3])))) |
| BigModule()(*inp) |
| self.assertONNX(BigModule(), inp) |
| |
| |
| if __name__ == '__main__': |
| onnx_test_flag = '--onnx-test' |
| _onnx_test = onnx_test_flag in common.UNITTEST_ARGS |
| if onnx_test_flag in common.UNITTEST_ARGS: |
| common.UNITTEST_ARGS.remove(onnx_test_flag) |
| if _onnx_test: |
| for d in glob.glob(os.path.join(test_onnx_common.pytorch_operator_dir, "test_operator_*")): |
| shutil.rmtree(d) |
| run_tests() |