[JIT] End-to-end example-based robustness testing for hybrid frontend (#8451)

* End-to-end example-based robustness testing for hybrid frontend

* delet this
diff --git a/test/test_jit.py b/test/test_jit.py
index 9c3b684..5cb2da2 100644
--- a/test/test_jit.py
+++ b/test/test_jit.py
@@ -133,6 +133,90 @@
             trace.set_graph(graph)
         return graph
 
+    def checkTrace(self, func, reference_tensors, input_tensors=None,
+                   optimize=True, drop=None, allow_unused=False,
+                   verbose=False, inputs_require_grads=True):
+        # TODO: check gradients for parameters, not just inputs
+        def allSum(vs):
+            # drop allows us to remove some values from ever being used
+            # to test unused outputs
+            if drop is not None:
+                vs = vs[:-drop]
+            # we don't want all the grad for all the outputs to be the same
+            # so we multiply each by a constant
+            return sum([(i + 1) * v.sum() for i, v in enumerate(vs) if v is not None])
+        if input_tensors is None:
+            input_tensors = reference_tensors
+
+        nograd_inputs = reference_tensors
+        if inputs_require_grads:
+            recording_inputs = [t.clone().requires_grad_() for t in reference_tensors]
+        else:
+            recording_inputs = reference_tensors
+
+        if isinstance(func, torch._C.Graph):
+            ge = torch._C.GraphExecutor(func, optimize)
+        else:
+            ge = torch.jit.trace(*input_tensors, optimize=optimize)(func)
+
+        if verbose:
+            print(ge.__getattr__('forward').graph)
+
+        # test no gradients case
+
+        outputs = func(*nograd_inputs)
+        outputs_ge = ge(*nograd_inputs)
+        self.assertEqual(outputs, outputs_ge)
+
+        # test single grad case
+        outputs = func(*recording_inputs)
+        if inputs_require_grads:
+            grads = torch.autograd.grad(allSum(outputs), recording_inputs,
+                                        allow_unused=allow_unused)
+
+        outputs_ge = ge(*recording_inputs)
+        if inputs_require_grads:
+            grads_ge = torch.autograd.grad(allSum(outputs_ge), recording_inputs,
+                                           allow_unused=allow_unused)
+        self.assertEqual(outputs, outputs_ge)
+        if inputs_require_grads:
+            self.assertEqual(grads, grads_ge)
+
+        # test the grad grad case
+
+        outputs = func(*recording_inputs)
+        l1 = allSum(outputs)
+        if inputs_require_grads:
+            grads = torch.autograd.grad(l1, recording_inputs, create_graph=True,
+                                        allow_unused=allow_unused)
+        if inputs_require_grads:
+            l2 = (allSum(grads) * l1)
+            grads2 = torch.autograd.grad(l2, recording_inputs, allow_unused=allow_unused)
+
+        if inputs_require_grads:
+            recording_inputs = [Variable(t, requires_grad=True)
+                                for t in reference_tensors]
+
+        outputs_ge = ge(*recording_inputs)
+        l1_ge = allSum(outputs_ge)
+        if inputs_require_grads:
+            grads_ge = torch.autograd.grad(
+                l1_ge, recording_inputs, create_graph=True, allow_unused=allow_unused)
+
+        if inputs_require_grads:
+            l2_ge = (allSum(grads_ge) * l1_ge)
+            grads2_ge = torch.autograd.grad(l2_ge, recording_inputs, allow_unused=allow_unused)
+
+        self.assertEqual(outputs, outputs_ge)
+        if inputs_require_grads:
+            self.assertEqual(grads, grads_ge)
+            for g2, g2_ge in zip(grads2, grads2_ge):
+                if g2 is None and g2_ge is None:
+                    continue
+                self.assertTrue(torch.allclose(g2, g2_ge, atol=5e-4, rtol=1e-4))
+
+        return ge
+
 
 class TestJit(JitTestCase):
     def assertExportImport(self, trace, inputs):
@@ -725,73 +809,6 @@
         self.assertExpectedGraph(trace)
         self.assertExportImport(trace, (x,))
 
-    def checkTrace(self, func, reference_tensors, input_tensors=None,
-                   optimize=True, drop=None, allow_unused=False):
-        def allSum(vs):
-            # drop allows us to remove some values from ever being used
-            # to test unused outputs
-            if drop is not None:
-                vs = vs[:-drop]
-            # we don't want all the grad for all the outputs to be the same
-            # so we multiply each by a constant
-            return sum([(i + 1) * v.sum() for i, v in enumerate(vs) if v is not None])
-        if input_tensors is None:
-            input_tensors = reference_tensors
-
-        nograd_inputs = reference_tensors
-        recording_inputs = [t.clone().requires_grad_() for t in reference_tensors]
-
-        if isinstance(func, torch._C.Graph):
-            ge = torch._C.GraphExecutor(func, optimize)
-        else:
-            ge = torch.jit.trace(*input_tensors, optimize=optimize)(func)
-
-        # test no gradients case
-
-        outputs = func(*nograd_inputs)
-        outputs_ge = ge(*nograd_inputs)
-        self.assertEqual(outputs, outputs_ge)
-
-        # test single grad case
-
-        outputs = func(*recording_inputs)
-        grads = torch.autograd.grad(allSum(outputs), recording_inputs,
-                                    allow_unused=allow_unused)
-
-        outputs_ge = ge(*recording_inputs)
-        grads_ge = torch.autograd.grad(allSum(outputs_ge), recording_inputs,
-                                       allow_unused=allow_unused)
-        self.assertEqual(outputs, outputs_ge)
-        self.assertEqual(grads, grads_ge)
-
-        # test the grad grad case
-
-        outputs = func(*recording_inputs)
-        l1 = allSum(outputs)
-        grads = torch.autograd.grad(l1, recording_inputs, create_graph=True,
-                                    allow_unused=allow_unused)
-        l2 = (allSum(grads) * l1)
-        grads2 = torch.autograd.grad(l2, recording_inputs, allow_unused=allow_unused)
-
-        recording_inputs = [Variable(t, requires_grad=True)
-                            for t in reference_tensors]
-
-        outputs_ge = ge(*recording_inputs)
-        l1_ge = allSum(outputs_ge)
-        grads_ge = torch.autograd.grad(
-            l1_ge, recording_inputs, create_graph=True, allow_unused=allow_unused)
-        l2_ge = (allSum(grads_ge) * l1_ge)
-        grads2_ge = torch.autograd.grad(l2_ge, recording_inputs, allow_unused=allow_unused)
-
-        self.assertEqual(outputs, outputs_ge)
-        self.assertEqual(grads, grads_ge)
-        for g2, g2_ge in zip(grads2, grads2_ge):
-            if g2 is None and g2_ge is None:
-                continue
-            self.assertTrue(torch.allclose(g2, g2_ge, atol=5e-4, rtol=1e-4))
-
-        return ge
-
     def run_ge_tests(self, optimize, use_cuda):
         def rand(*args):
             t = torch.rand(*args).float()
@@ -3062,6 +3079,392 @@
         self.checkScript(fn, (torch.tensor(2),))
 
 
+class TestEndToEndHybridFrontendModels(JitTestCase):
+
+    def test_dcgan_models(self):
+        class DCGANGenerator(nn.Module):
+            def __init__(self, nz, ngf, nc):
+                super(DCGANGenerator, self).__init__()
+                self.main = nn.Sequential(
+                    # input is Z, going into a convolution
+                    nn.ConvTranspose2d(nz, ngf * 8, 4, 1, 0, bias=False),
+                    nn.BatchNorm2d(ngf * 8),
+                    nn.ReLU(True),
+                    # state size. (ngf*8) x 4 x 4
+                    nn.ConvTranspose2d(ngf * 8, ngf * 4, 4, 2, 1, bias=False),
+                    nn.BatchNorm2d(ngf * 4),
+                    nn.ReLU(True),
+                    # state size. (ngf*4) x 8 x 8
+                    nn.ConvTranspose2d(ngf * 4, ngf * 2, 4, 2, 1, bias=False),
+                    nn.BatchNorm2d(ngf * 2),
+                    nn.ReLU(True),
+                    # state size. (ngf*2) x 16 x 16
+                    nn.ConvTranspose2d(ngf * 2, ngf, 4, 2, 1, bias=False),
+                    nn.BatchNorm2d(ngf),
+                    nn.ReLU(True),
+                    # state size. (ngf) x 32 x 32
+                    nn.ConvTranspose2d(ngf, nc, 4, 2, 1, bias=False),
+                    nn.Tanh()
+                    # state size. (nc) x 64 x 64
+                )
+
+            def forward(self, input):
+                return self.main(input)
+
+        class DCGANDiscriminator(nn.Module):
+            def __init__(self, nc, ndf):
+                super(DCGANDiscriminator, self).__init__()
+                self.main = nn.Sequential(
+                    # input is (nc) x 64 x 64
+                    nn.Conv2d(nc, ndf, 4, 2, 1, bias=False),
+                    nn.LeakyReLU(0.2, inplace=True),
+                    # state size. (ndf) x 32 x 32
+                    nn.Conv2d(ndf, ndf * 2, 4, 2, 1, bias=False),
+                    nn.BatchNorm2d(ndf * 2),
+                    nn.LeakyReLU(0.2, inplace=True),
+                    # state size. (ndf*2) x 16 x 16
+                    nn.Conv2d(ndf * 2, ndf * 4, 4, 2, 1, bias=False),
+                    nn.BatchNorm2d(ndf * 4),
+                    nn.LeakyReLU(0.2, inplace=True),
+                    # state size. (ndf*4) x 8 x 8
+                    nn.Conv2d(ndf * 4, ndf * 8, 4, 2, 1, bias=False),
+                    nn.BatchNorm2d(ndf * 8),
+                    nn.LeakyReLU(0.2, inplace=True),
+                    # state size. (ndf*8) x 4 x 4
+                    nn.Conv2d(ndf * 8, 1, 4, 1, 0, bias=False),
+                    nn.Sigmoid()
+                )
+
+            def forward(self, input):
+                return self.main(input).view(-1, 1).squeeze(1)
+
+        bs, nz, ngf, nc, ndf = 5, 6, 9, 3, 10
+        self.checkTrace(DCGANGenerator(nz, ngf, nc), (torch.rand(bs, nz, 1, 1),))
+        example_input = DCGANGenerator(nz, ngf, nc)(torch.rand(bs, nz, 1, 1))
+        self.checkTrace(DCGANDiscriminator(nc, ndf), (example_input,))
+
+    @unittest.skip('https://github.com/pytorch/pytorch/issues/8439 InstanceNormalization bug')
+    def test_neural_style(self):
+        class TransformerNet(torch.nn.Module):
+            def __init__(self):
+                super(TransformerNet, self).__init__()
+                # Initial convolution layers
+                self.conv1 = ConvLayer(3, 32, kernel_size=9, stride=1)
+                self.in1 = torch.nn.InstanceNorm2d(32, affine=True)
+                self.conv2 = ConvLayer(32, 64, kernel_size=3, stride=2)
+                self.in2 = torch.nn.InstanceNorm2d(64, affine=True)
+                self.conv3 = ConvLayer(64, 128, kernel_size=3, stride=2)
+                self.in3 = torch.nn.InstanceNorm2d(128, affine=True)
+                # Residual layers
+                self.res1 = ResidualBlock(128)
+                self.res2 = ResidualBlock(128)
+                self.res3 = ResidualBlock(128)
+                self.res4 = ResidualBlock(128)
+                self.res5 = ResidualBlock(128)
+                # Upsampling Layers
+                self.deconv1 = UpsampleConvLayer(128, 64, kernel_size=3, stride=1, upsample=2)
+                self.in4 = torch.nn.InstanceNorm2d(64, affine=True)
+                self.deconv2 = UpsampleConvLayer(64, 32, kernel_size=3, stride=1, upsample=2)
+                self.in5 = torch.nn.InstanceNorm2d(32, affine=True)
+                self.deconv3 = ConvLayer(32, 3, kernel_size=9, stride=1)
+                # Non-linearities
+                self.relu = torch.nn.ReLU()
+
+            def forward(self, X):
+                y = self.relu(self.in1(self.conv1(X)))
+                y = self.relu(self.in2(self.conv2(y)))
+                y = self.relu(self.in3(self.conv3(y)))
+                y = self.res1(y)
+                y = self.res2(y)
+                y = self.res3(y)
+                y = self.res4(y)
+                y = self.res5(y)
+                y = self.relu(self.in4(self.deconv1(y)))
+                y = self.relu(self.in5(self.deconv2(y)))
+                y = self.deconv3(y)
+                return y
+
+        class ConvLayer(torch.nn.Module):
+            def __init__(self, in_channels, out_channels, kernel_size, stride):
+                super(ConvLayer, self).__init__()
+                reflection_padding = kernel_size // 2
+                self.reflection_pad = torch.nn.ReflectionPad2d(reflection_padding)
+                self.conv2d = torch.nn.Conv2d(in_channels, out_channels, kernel_size, stride)
+
+            def forward(self, x):
+                out = self.reflection_pad(x)
+                out = self.conv2d(out)
+                return out
+
+        class ResidualBlock(torch.nn.Module):
+            """ResidualBlock
+            introduced in: https://arxiv.org/abs/1512.03385
+            recommended architecture: http://torch.ch/blog/2016/02/04/resnets.html
+            """
+
+            def __init__(self, channels):
+                super(ResidualBlock, self).__init__()
+                self.conv1 = ConvLayer(channels, channels, kernel_size=3, stride=1)
+                self.in1 = torch.nn.InstanceNorm2d(channels, affine=True)
+                self.conv2 = ConvLayer(channels, channels, kernel_size=3, stride=1)
+                self.in2 = torch.nn.InstanceNorm2d(channels, affine=True)
+                self.relu = torch.nn.ReLU()
+
+            def forward(self, x):
+                residual = x
+                out = self.relu(self.in1(self.conv1(x)))
+                out = self.in2(self.conv2(out))
+                out = out + residual
+                return out
+
+        class UpsampleConvLayer(torch.nn.Module):
+            """UpsampleConvLayer
+            Upsamples the input and then does a convolution. This method gives better results
+            compared to ConvTranspose2d.
+            ref: http://distill.pub/2016/deconv-checkerboard/
+            """
+
+            def __init__(self, in_channels, out_channels, kernel_size, stride, upsample=None):
+                super(UpsampleConvLayer, self).__init__()
+                self.upsample = upsample
+                if upsample:
+                    self.upsample_layer = torch.nn.Upsample(mode='nearest', scale_factor=upsample)
+                reflection_padding = kernel_size // 2
+                self.reflection_pad = torch.nn.ReflectionPad2d(reflection_padding)
+                self.conv2d = torch.nn.Conv2d(in_channels, out_channels, kernel_size, stride)
+
+            def forward(self, x):
+                x_in = x
+                if self.upsample:
+                    x_in = self.upsample_layer(x_in)
+                out = self.reflection_pad(x_in)
+                out = self.conv2d(out)
+                return out
+
+        self.checkTrace(TransformerNet(), (torch.rand(5, 3, 224, 224),))
+
+    def test_mnist(self):
+        class Net(nn.Module):
+            def __init__(self):
+                super(Net, self).__init__()
+                self.conv1 = nn.Conv2d(1, 10, kernel_size=5)
+                self.conv2 = nn.Conv2d(10, 20, kernel_size=5)
+                self.conv2_drop = nn.Dropout2d()
+                self.fc1 = nn.Linear(320, 50)
+                self.fc2 = nn.Linear(50, 10)
+
+            def forward(self, x):
+                x = F.relu(F.max_pool2d(self.conv1(x), 2))
+                x = F.relu(F.max_pool2d(self.conv2_drop(self.conv2(x)), 2))
+                x = x.view(-1, 320)
+                x = F.relu(self.fc1(x))
+                x = F.dropout(x, training=self.training)
+                x = self.fc2(x)
+                return F.log_softmax(x, dim=1)
+
+        # FIXME: eval() is present because it works around the issue described
+        # in https://github.com/pytorch/pytorch/issues/8448
+        self.checkTrace(Net().eval(), (torch.rand(5, 1, 28, 28),))
+
+    def test_reinforcement_learning(self):
+        class Policy(nn.Module):
+            def __init__(self):
+                super(Policy, self).__init__()
+                self.affine1 = nn.Linear(4, 128)
+                self.affine2 = nn.Linear(128, 2)
+
+            def forward(self, x):
+                x = F.relu(self.affine1(x))
+                action_scores = self.affine2(x)
+                return F.softmax(action_scores, dim=1)
+
+        self.checkTrace(Policy(), (torch.rand(1, 4),))
+
+    def test_snli(self):
+        # TODO:
+        #   1) nn.LSTM is called as a Python function https://github.com/pytorch/pytorch/issues/8449
+        #   2) Dropout is called as a Python function https://github.com/pytorch/pytorch/issues/8450
+        class Bottle(nn.Module):
+
+            def forward(self, input):
+                if len(input.size()) <= 2:
+                    return super(Bottle, self).forward(input)
+                size = input.size()[:2]
+                out = super(Bottle, self).forward(input.view(size[0] * size[1], -1))
+                return out.view(size[0], size[1], -1)
+
+        class Linear(Bottle, nn.Linear):
+            pass
+
+        class Encoder(nn.Module):
+
+            def __init__(self, config):
+                super(Encoder, self).__init__()
+                self.config = config
+                input_size = config.d_proj if config.projection else config.d_embed
+                dropout = 0 if config.n_layers == 1 else config.dp_ratio
+                self.rnn = nn.LSTM(input_size=input_size, hidden_size=config.d_hidden,
+                                   num_layers=config.n_layers, dropout=dropout,
+                                   bidirectional=config.birnn)
+
+            def forward(self, inputs):
+                batch_size = inputs.size()[1]
+                state_shape = self.config.n_cells, batch_size, self.config.d_hidden
+                h0 = c0 = inputs.new_zeros(state_shape)
+                outputs, (ht, ct) = self.rnn(inputs, (h0, c0))
+                return ht[-1] if not self.config.birnn else ht[-2:].transpose(0, 1).contiguous().view(batch_size, -1)
+
+        class SNLIClassifier(nn.Module):
+
+            def __init__(self, config):
+                super(SNLIClassifier, self).__init__()
+                self.config = config
+                self.embed = nn.Embedding(config.n_embed, config.d_embed)
+                self.projection = Linear(config.d_embed, config.d_proj)
+                self.encoder = Encoder(config)
+                self.dropout = nn.Dropout(p=config.dp_ratio)
+                self.relu = nn.ReLU()
+                seq_in_size = 2 * config.d_hidden
+                if self.config.birnn:
+                    seq_in_size *= 2
+                lin_config = [seq_in_size] * 2
+                self.out = nn.Sequential(
+                    Linear(*lin_config),
+                    self.relu,
+                    self.dropout,
+                    Linear(*lin_config),
+                    self.relu,
+                    self.dropout,
+                    Linear(*lin_config),
+                    self.relu,
+                    self.dropout,
+                    Linear(seq_in_size, config.d_out))
+
+            def forward(self, premise, hypothesis):
+                prem_embed = self.embed(premise)
+                hypo_embed = self.embed(hypothesis)
+                if self.config.fix_emb:
+                    prem_embed = prem_embed.detach()
+                    hypo_embed = hypo_embed.detach()
+                if self.config.projection:
+                    prem_embed = self.relu(self.projection(prem_embed))
+                    hypo_embed = self.relu(self.projection(hypo_embed))
+                premise = self.encoder(prem_embed)
+                hypothesis = self.encoder(hypo_embed)
+                scores = self.out(torch.cat([premise, hypothesis], 1))
+                return scores
+
+        class Config:
+            n_embed = 100
+            d_embed = 100
+            d_proj = 300
+            dp_ratio = 0.0  # For deterministic testing TODO: change by fixing seed in checkTrace?
+            d_hidden = 300
+            birnn = True
+            d_out = 300
+            fix_emb = True
+            projection = True
+            n_layers = 2
+            n_cells = 4  # 2 * n_layers because birnn = True
+
+        premise = torch.LongTensor(48, 128).random_(0, 100)
+        hypothesis = torch.LongTensor(24, 128).random_(0, 100)
+
+        self.checkTrace(SNLIClassifier(Config()), (premise, hypothesis), inputs_require_grads=False)
+
+    def test_super_resolution(self):
+        import torch.nn.init as init
+
+        class Net(nn.Module):
+
+            def __init__(self, upscale_factor):
+                super(Net, self).__init__()
+
+                self.relu = nn.ReLU()
+                self.conv1 = nn.Conv2d(1, 64, (5, 5), (1, 1), (2, 2))
+                self.conv2 = nn.Conv2d(64, 64, (3, 3), (1, 1), (1, 1))
+                self.conv3 = nn.Conv2d(64, 32, (3, 3), (1, 1), (1, 1))
+                self.conv4 = nn.Conv2d(32, upscale_factor ** 2, (3, 3), (1, 1), (1, 1))
+                self.pixel_shuffle = nn.PixelShuffle(upscale_factor)
+
+            def forward(self, x):
+                x = self.relu(self.conv1(x))
+                x = self.relu(self.conv2(x))
+                x = self.relu(self.conv3(x))
+                x = self.pixel_shuffle(self.conv4(x))
+                return x
+
+        net = Net(upscale_factor=4)
+        self.checkTrace(net, (torch.rand(5, 1, 64, 64),))
+
+    @unittest.skip('This needs to be re-written into a script module')
+    def test_time_sequence_prediction(self):
+        class Sequence(nn.Module):
+            def __init__(self):
+                super(Sequence, self).__init__()
+                self.lstm1 = nn.LSTMCell(1, 51)
+                self.lstm2 = nn.LSTMCell(51, 51)
+                self.linear = nn.Linear(51, 1)
+
+            def forward(self, input, future=0):
+                outputs = []
+                h_t = torch.zeros(input.size(0), 51, dtype=torch.double)
+                c_t = torch.zeros(input.size(0), 51, dtype=torch.double)
+                h_t2 = torch.zeros(input.size(0), 51, dtype=torch.double)
+                c_t2 = torch.zeros(input.size(0), 51, dtype=torch.double)
+
+                for _, input_t in enumerate(input.chunk(input.size(1), dim=1)):
+                    h_t, c_t = self.lstm1(input_t, (h_t, c_t))
+                    h_t2, c_t2 = self.lstm2(h_t, (h_t2, c_t2))
+                    output = self.linear(h_t2)
+                    outputs += [output]
+                for _ in range(future):  # if we should predict the future
+                    h_t, c_t = self.lstm1(output, (h_t, c_t))
+                    h_t2, c_t2 = self.lstm2(h_t, (h_t2, c_t2))
+                    output = self.linear(h_t2)
+                    outputs += [output]
+                outputs = torch.stack(outputs, 1).squeeze(2)
+                return outputs
+
+        self.checkTrace(Sequence(), (torch.rand(97, 999),), verbose=True)
+
+    def test_vae(self):
+        class VAE(nn.Module):
+            def __init__(self):
+                super(VAE, self).__init__()
+
+                self.fc1 = nn.Linear(784, 400)
+                self.fc21 = nn.Linear(400, 20)
+                self.fc22 = nn.Linear(400, 20)
+                self.fc3 = nn.Linear(20, 400)
+                self.fc4 = nn.Linear(400, 784)
+
+            def encode(self, x):
+                h1 = F.relu(self.fc1(x))
+                return self.fc21(h1), self.fc22(h1)
+
+            def reparameterize(self, mu, logvar):
+                if self.training:
+                    std = torch.exp(0.5 * logvar)
+                    eps = torch.randn_like(std)
+                    return eps.mul(std).add_(mu)
+                else:
+                    return mu
+
+            def decode(self, z):
+                h3 = F.relu(self.fc3(z))
+                return F.sigmoid(self.fc4(h3))
+
+            def forward(self, x):
+                mu, logvar = self.encode(x.view(-1, 784))
+                z = self.reparameterize(mu, logvar)
+                return self.decode(z), mu, logvar
+
+        # FIXME: this fails under training because of the call to `randn_like`
+        # https://github.com/pytorch/pytorch/issues/8443
+        self.checkTrace(VAE().eval(), (torch.rand(128, 1, 28, 28),))
+
+
 # Smoke tests for export methods
 class TestPytorchExportModes(JitTestCase):
     class MyModel(nn.Module):