[JIT] End-to-end example-based robustness testing for hybrid frontend (#8451)
* End-to-end example-based robustness testing for hybrid frontend
* delet this
diff --git a/test/test_jit.py b/test/test_jit.py
index 9c3b684..5cb2da2 100644
--- a/test/test_jit.py
+++ b/test/test_jit.py
@@ -133,6 +133,90 @@
trace.set_graph(graph)
return graph
+ def checkTrace(self, func, reference_tensors, input_tensors=None,
+ optimize=True, drop=None, allow_unused=False,
+ verbose=False, inputs_require_grads=True):
+ # TODO: check gradients for parameters, not just inputs
+ def allSum(vs):
+ # drop allows us to remove some values from ever being used
+ # to test unused outputs
+ if drop is not None:
+ vs = vs[:-drop]
+ # we don't want all the grad for all the outputs to be the same
+ # so we multiply each by a constant
+ return sum([(i + 1) * v.sum() for i, v in enumerate(vs) if v is not None])
+ if input_tensors is None:
+ input_tensors = reference_tensors
+
+ nograd_inputs = reference_tensors
+ if inputs_require_grads:
+ recording_inputs = [t.clone().requires_grad_() for t in reference_tensors]
+ else:
+ recording_inputs = reference_tensors
+
+ if isinstance(func, torch._C.Graph):
+ ge = torch._C.GraphExecutor(func, optimize)
+ else:
+ ge = torch.jit.trace(*input_tensors, optimize=optimize)(func)
+
+ if verbose:
+ print(ge.__getattr__('forward').graph)
+
+ # test no gradients case
+
+ outputs = func(*nograd_inputs)
+ outputs_ge = ge(*nograd_inputs)
+ self.assertEqual(outputs, outputs_ge)
+
+ # test single grad case
+ outputs = func(*recording_inputs)
+ if inputs_require_grads:
+ grads = torch.autograd.grad(allSum(outputs), recording_inputs,
+ allow_unused=allow_unused)
+
+ outputs_ge = ge(*recording_inputs)
+ if inputs_require_grads:
+ grads_ge = torch.autograd.grad(allSum(outputs_ge), recording_inputs,
+ allow_unused=allow_unused)
+ self.assertEqual(outputs, outputs_ge)
+ if inputs_require_grads:
+ self.assertEqual(grads, grads_ge)
+
+ # test the grad grad case
+
+ outputs = func(*recording_inputs)
+ l1 = allSum(outputs)
+ if inputs_require_grads:
+ grads = torch.autograd.grad(l1, recording_inputs, create_graph=True,
+ allow_unused=allow_unused)
+ if inputs_require_grads:
+ l2 = (allSum(grads) * l1)
+ grads2 = torch.autograd.grad(l2, recording_inputs, allow_unused=allow_unused)
+
+ if inputs_require_grads:
+ recording_inputs = [Variable(t, requires_grad=True)
+ for t in reference_tensors]
+
+ outputs_ge = ge(*recording_inputs)
+ l1_ge = allSum(outputs_ge)
+ if inputs_require_grads:
+ grads_ge = torch.autograd.grad(
+ l1_ge, recording_inputs, create_graph=True, allow_unused=allow_unused)
+
+ if inputs_require_grads:
+ l2_ge = (allSum(grads_ge) * l1_ge)
+ grads2_ge = torch.autograd.grad(l2_ge, recording_inputs, allow_unused=allow_unused)
+
+ self.assertEqual(outputs, outputs_ge)
+ if inputs_require_grads:
+ self.assertEqual(grads, grads_ge)
+ for g2, g2_ge in zip(grads2, grads2_ge):
+ if g2 is None and g2_ge is None:
+ continue
+ self.assertTrue(torch.allclose(g2, g2_ge, atol=5e-4, rtol=1e-4))
+
+ return ge
+
class TestJit(JitTestCase):
def assertExportImport(self, trace, inputs):
@@ -725,73 +809,6 @@
self.assertExpectedGraph(trace)
self.assertExportImport(trace, (x,))
- def checkTrace(self, func, reference_tensors, input_tensors=None,
- optimize=True, drop=None, allow_unused=False):
- def allSum(vs):
- # drop allows us to remove some values from ever being used
- # to test unused outputs
- if drop is not None:
- vs = vs[:-drop]
- # we don't want all the grad for all the outputs to be the same
- # so we multiply each by a constant
- return sum([(i + 1) * v.sum() for i, v in enumerate(vs) if v is not None])
- if input_tensors is None:
- input_tensors = reference_tensors
-
- nograd_inputs = reference_tensors
- recording_inputs = [t.clone().requires_grad_() for t in reference_tensors]
-
- if isinstance(func, torch._C.Graph):
- ge = torch._C.GraphExecutor(func, optimize)
- else:
- ge = torch.jit.trace(*input_tensors, optimize=optimize)(func)
-
- # test no gradients case
-
- outputs = func(*nograd_inputs)
- outputs_ge = ge(*nograd_inputs)
- self.assertEqual(outputs, outputs_ge)
-
- # test single grad case
-
- outputs = func(*recording_inputs)
- grads = torch.autograd.grad(allSum(outputs), recording_inputs,
- allow_unused=allow_unused)
-
- outputs_ge = ge(*recording_inputs)
- grads_ge = torch.autograd.grad(allSum(outputs_ge), recording_inputs,
- allow_unused=allow_unused)
- self.assertEqual(outputs, outputs_ge)
- self.assertEqual(grads, grads_ge)
-
- # test the grad grad case
-
- outputs = func(*recording_inputs)
- l1 = allSum(outputs)
- grads = torch.autograd.grad(l1, recording_inputs, create_graph=True,
- allow_unused=allow_unused)
- l2 = (allSum(grads) * l1)
- grads2 = torch.autograd.grad(l2, recording_inputs, allow_unused=allow_unused)
-
- recording_inputs = [Variable(t, requires_grad=True)
- for t in reference_tensors]
-
- outputs_ge = ge(*recording_inputs)
- l1_ge = allSum(outputs_ge)
- grads_ge = torch.autograd.grad(
- l1_ge, recording_inputs, create_graph=True, allow_unused=allow_unused)
- l2_ge = (allSum(grads_ge) * l1_ge)
- grads2_ge = torch.autograd.grad(l2_ge, recording_inputs, allow_unused=allow_unused)
-
- self.assertEqual(outputs, outputs_ge)
- self.assertEqual(grads, grads_ge)
- for g2, g2_ge in zip(grads2, grads2_ge):
- if g2 is None and g2_ge is None:
- continue
- self.assertTrue(torch.allclose(g2, g2_ge, atol=5e-4, rtol=1e-4))
-
- return ge
-
def run_ge_tests(self, optimize, use_cuda):
def rand(*args):
t = torch.rand(*args).float()
@@ -3062,6 +3079,392 @@
self.checkScript(fn, (torch.tensor(2),))
+class TestEndToEndHybridFrontendModels(JitTestCase):
+
+ def test_dcgan_models(self):
+ class DCGANGenerator(nn.Module):
+ def __init__(self, nz, ngf, nc):
+ super(DCGANGenerator, self).__init__()
+ self.main = nn.Sequential(
+ # input is Z, going into a convolution
+ nn.ConvTranspose2d(nz, ngf * 8, 4, 1, 0, bias=False),
+ nn.BatchNorm2d(ngf * 8),
+ nn.ReLU(True),
+ # state size. (ngf*8) x 4 x 4
+ nn.ConvTranspose2d(ngf * 8, ngf * 4, 4, 2, 1, bias=False),
+ nn.BatchNorm2d(ngf * 4),
+ nn.ReLU(True),
+ # state size. (ngf*4) x 8 x 8
+ nn.ConvTranspose2d(ngf * 4, ngf * 2, 4, 2, 1, bias=False),
+ nn.BatchNorm2d(ngf * 2),
+ nn.ReLU(True),
+ # state size. (ngf*2) x 16 x 16
+ nn.ConvTranspose2d(ngf * 2, ngf, 4, 2, 1, bias=False),
+ nn.BatchNorm2d(ngf),
+ nn.ReLU(True),
+ # state size. (ngf) x 32 x 32
+ nn.ConvTranspose2d(ngf, nc, 4, 2, 1, bias=False),
+ nn.Tanh()
+ # state size. (nc) x 64 x 64
+ )
+
+ def forward(self, input):
+ return self.main(input)
+
+ class DCGANDiscriminator(nn.Module):
+ def __init__(self, nc, ndf):
+ super(DCGANDiscriminator, self).__init__()
+ self.main = nn.Sequential(
+ # input is (nc) x 64 x 64
+ nn.Conv2d(nc, ndf, 4, 2, 1, bias=False),
+ nn.LeakyReLU(0.2, inplace=True),
+ # state size. (ndf) x 32 x 32
+ nn.Conv2d(ndf, ndf * 2, 4, 2, 1, bias=False),
+ nn.BatchNorm2d(ndf * 2),
+ nn.LeakyReLU(0.2, inplace=True),
+ # state size. (ndf*2) x 16 x 16
+ nn.Conv2d(ndf * 2, ndf * 4, 4, 2, 1, bias=False),
+ nn.BatchNorm2d(ndf * 4),
+ nn.LeakyReLU(0.2, inplace=True),
+ # state size. (ndf*4) x 8 x 8
+ nn.Conv2d(ndf * 4, ndf * 8, 4, 2, 1, bias=False),
+ nn.BatchNorm2d(ndf * 8),
+ nn.LeakyReLU(0.2, inplace=True),
+ # state size. (ndf*8) x 4 x 4
+ nn.Conv2d(ndf * 8, 1, 4, 1, 0, bias=False),
+ nn.Sigmoid()
+ )
+
+ def forward(self, input):
+ return self.main(input).view(-1, 1).squeeze(1)
+
+ bs, nz, ngf, nc, ndf = 5, 6, 9, 3, 10
+ self.checkTrace(DCGANGenerator(nz, ngf, nc), (torch.rand(bs, nz, 1, 1),))
+ example_input = DCGANGenerator(nz, ngf, nc)(torch.rand(bs, nz, 1, 1))
+ self.checkTrace(DCGANDiscriminator(nc, ndf), (example_input,))
+
+ @unittest.skip('https://github.com/pytorch/pytorch/issues/8439 InstanceNormalization bug')
+ def test_neural_style(self):
+ class TransformerNet(torch.nn.Module):
+ def __init__(self):
+ super(TransformerNet, self).__init__()
+ # Initial convolution layers
+ self.conv1 = ConvLayer(3, 32, kernel_size=9, stride=1)
+ self.in1 = torch.nn.InstanceNorm2d(32, affine=True)
+ self.conv2 = ConvLayer(32, 64, kernel_size=3, stride=2)
+ self.in2 = torch.nn.InstanceNorm2d(64, affine=True)
+ self.conv3 = ConvLayer(64, 128, kernel_size=3, stride=2)
+ self.in3 = torch.nn.InstanceNorm2d(128, affine=True)
+ # Residual layers
+ self.res1 = ResidualBlock(128)
+ self.res2 = ResidualBlock(128)
+ self.res3 = ResidualBlock(128)
+ self.res4 = ResidualBlock(128)
+ self.res5 = ResidualBlock(128)
+ # Upsampling Layers
+ self.deconv1 = UpsampleConvLayer(128, 64, kernel_size=3, stride=1, upsample=2)
+ self.in4 = torch.nn.InstanceNorm2d(64, affine=True)
+ self.deconv2 = UpsampleConvLayer(64, 32, kernel_size=3, stride=1, upsample=2)
+ self.in5 = torch.nn.InstanceNorm2d(32, affine=True)
+ self.deconv3 = ConvLayer(32, 3, kernel_size=9, stride=1)
+ # Non-linearities
+ self.relu = torch.nn.ReLU()
+
+ def forward(self, X):
+ y = self.relu(self.in1(self.conv1(X)))
+ y = self.relu(self.in2(self.conv2(y)))
+ y = self.relu(self.in3(self.conv3(y)))
+ y = self.res1(y)
+ y = self.res2(y)
+ y = self.res3(y)
+ y = self.res4(y)
+ y = self.res5(y)
+ y = self.relu(self.in4(self.deconv1(y)))
+ y = self.relu(self.in5(self.deconv2(y)))
+ y = self.deconv3(y)
+ return y
+
+ class ConvLayer(torch.nn.Module):
+ def __init__(self, in_channels, out_channels, kernel_size, stride):
+ super(ConvLayer, self).__init__()
+ reflection_padding = kernel_size // 2
+ self.reflection_pad = torch.nn.ReflectionPad2d(reflection_padding)
+ self.conv2d = torch.nn.Conv2d(in_channels, out_channels, kernel_size, stride)
+
+ def forward(self, x):
+ out = self.reflection_pad(x)
+ out = self.conv2d(out)
+ return out
+
+ class ResidualBlock(torch.nn.Module):
+ """ResidualBlock
+ introduced in: https://arxiv.org/abs/1512.03385
+ recommended architecture: http://torch.ch/blog/2016/02/04/resnets.html
+ """
+
+ def __init__(self, channels):
+ super(ResidualBlock, self).__init__()
+ self.conv1 = ConvLayer(channels, channels, kernel_size=3, stride=1)
+ self.in1 = torch.nn.InstanceNorm2d(channels, affine=True)
+ self.conv2 = ConvLayer(channels, channels, kernel_size=3, stride=1)
+ self.in2 = torch.nn.InstanceNorm2d(channels, affine=True)
+ self.relu = torch.nn.ReLU()
+
+ def forward(self, x):
+ residual = x
+ out = self.relu(self.in1(self.conv1(x)))
+ out = self.in2(self.conv2(out))
+ out = out + residual
+ return out
+
+ class UpsampleConvLayer(torch.nn.Module):
+ """UpsampleConvLayer
+ Upsamples the input and then does a convolution. This method gives better results
+ compared to ConvTranspose2d.
+ ref: http://distill.pub/2016/deconv-checkerboard/
+ """
+
+ def __init__(self, in_channels, out_channels, kernel_size, stride, upsample=None):
+ super(UpsampleConvLayer, self).__init__()
+ self.upsample = upsample
+ if upsample:
+ self.upsample_layer = torch.nn.Upsample(mode='nearest', scale_factor=upsample)
+ reflection_padding = kernel_size // 2
+ self.reflection_pad = torch.nn.ReflectionPad2d(reflection_padding)
+ self.conv2d = torch.nn.Conv2d(in_channels, out_channels, kernel_size, stride)
+
+ def forward(self, x):
+ x_in = x
+ if self.upsample:
+ x_in = self.upsample_layer(x_in)
+ out = self.reflection_pad(x_in)
+ out = self.conv2d(out)
+ return out
+
+ self.checkTrace(TransformerNet(), (torch.rand(5, 3, 224, 224),))
+
+ def test_mnist(self):
+ class Net(nn.Module):
+ def __init__(self):
+ super(Net, self).__init__()
+ self.conv1 = nn.Conv2d(1, 10, kernel_size=5)
+ self.conv2 = nn.Conv2d(10, 20, kernel_size=5)
+ self.conv2_drop = nn.Dropout2d()
+ self.fc1 = nn.Linear(320, 50)
+ self.fc2 = nn.Linear(50, 10)
+
+ def forward(self, x):
+ x = F.relu(F.max_pool2d(self.conv1(x), 2))
+ x = F.relu(F.max_pool2d(self.conv2_drop(self.conv2(x)), 2))
+ x = x.view(-1, 320)
+ x = F.relu(self.fc1(x))
+ x = F.dropout(x, training=self.training)
+ x = self.fc2(x)
+ return F.log_softmax(x, dim=1)
+
+ # FIXME: eval() is present because it works around the issue described
+ # in https://github.com/pytorch/pytorch/issues/8448
+ self.checkTrace(Net().eval(), (torch.rand(5, 1, 28, 28),))
+
+ def test_reinforcement_learning(self):
+ class Policy(nn.Module):
+ def __init__(self):
+ super(Policy, self).__init__()
+ self.affine1 = nn.Linear(4, 128)
+ self.affine2 = nn.Linear(128, 2)
+
+ def forward(self, x):
+ x = F.relu(self.affine1(x))
+ action_scores = self.affine2(x)
+ return F.softmax(action_scores, dim=1)
+
+ self.checkTrace(Policy(), (torch.rand(1, 4),))
+
+ def test_snli(self):
+ # TODO:
+ # 1) nn.LSTM is called as a Python function https://github.com/pytorch/pytorch/issues/8449
+ # 2) Dropout is called as a Python function https://github.com/pytorch/pytorch/issues/8450
+ class Bottle(nn.Module):
+
+ def forward(self, input):
+ if len(input.size()) <= 2:
+ return super(Bottle, self).forward(input)
+ size = input.size()[:2]
+ out = super(Bottle, self).forward(input.view(size[0] * size[1], -1))
+ return out.view(size[0], size[1], -1)
+
+ class Linear(Bottle, nn.Linear):
+ pass
+
+ class Encoder(nn.Module):
+
+ def __init__(self, config):
+ super(Encoder, self).__init__()
+ self.config = config
+ input_size = config.d_proj if config.projection else config.d_embed
+ dropout = 0 if config.n_layers == 1 else config.dp_ratio
+ self.rnn = nn.LSTM(input_size=input_size, hidden_size=config.d_hidden,
+ num_layers=config.n_layers, dropout=dropout,
+ bidirectional=config.birnn)
+
+ def forward(self, inputs):
+ batch_size = inputs.size()[1]
+ state_shape = self.config.n_cells, batch_size, self.config.d_hidden
+ h0 = c0 = inputs.new_zeros(state_shape)
+ outputs, (ht, ct) = self.rnn(inputs, (h0, c0))
+ return ht[-1] if not self.config.birnn else ht[-2:].transpose(0, 1).contiguous().view(batch_size, -1)
+
+ class SNLIClassifier(nn.Module):
+
+ def __init__(self, config):
+ super(SNLIClassifier, self).__init__()
+ self.config = config
+ self.embed = nn.Embedding(config.n_embed, config.d_embed)
+ self.projection = Linear(config.d_embed, config.d_proj)
+ self.encoder = Encoder(config)
+ self.dropout = nn.Dropout(p=config.dp_ratio)
+ self.relu = nn.ReLU()
+ seq_in_size = 2 * config.d_hidden
+ if self.config.birnn:
+ seq_in_size *= 2
+ lin_config = [seq_in_size] * 2
+ self.out = nn.Sequential(
+ Linear(*lin_config),
+ self.relu,
+ self.dropout,
+ Linear(*lin_config),
+ self.relu,
+ self.dropout,
+ Linear(*lin_config),
+ self.relu,
+ self.dropout,
+ Linear(seq_in_size, config.d_out))
+
+ def forward(self, premise, hypothesis):
+ prem_embed = self.embed(premise)
+ hypo_embed = self.embed(hypothesis)
+ if self.config.fix_emb:
+ prem_embed = prem_embed.detach()
+ hypo_embed = hypo_embed.detach()
+ if self.config.projection:
+ prem_embed = self.relu(self.projection(prem_embed))
+ hypo_embed = self.relu(self.projection(hypo_embed))
+ premise = self.encoder(prem_embed)
+ hypothesis = self.encoder(hypo_embed)
+ scores = self.out(torch.cat([premise, hypothesis], 1))
+ return scores
+
+ class Config:
+ n_embed = 100
+ d_embed = 100
+ d_proj = 300
+ dp_ratio = 0.0 # For deterministic testing TODO: change by fixing seed in checkTrace?
+ d_hidden = 300
+ birnn = True
+ d_out = 300
+ fix_emb = True
+ projection = True
+ n_layers = 2
+ n_cells = 4 # 2 * n_layers because birnn = True
+
+ premise = torch.LongTensor(48, 128).random_(0, 100)
+ hypothesis = torch.LongTensor(24, 128).random_(0, 100)
+
+ self.checkTrace(SNLIClassifier(Config()), (premise, hypothesis), inputs_require_grads=False)
+
+ def test_super_resolution(self):
+ import torch.nn.init as init
+
+ class Net(nn.Module):
+
+ def __init__(self, upscale_factor):
+ super(Net, self).__init__()
+
+ self.relu = nn.ReLU()
+ self.conv1 = nn.Conv2d(1, 64, (5, 5), (1, 1), (2, 2))
+ self.conv2 = nn.Conv2d(64, 64, (3, 3), (1, 1), (1, 1))
+ self.conv3 = nn.Conv2d(64, 32, (3, 3), (1, 1), (1, 1))
+ self.conv4 = nn.Conv2d(32, upscale_factor ** 2, (3, 3), (1, 1), (1, 1))
+ self.pixel_shuffle = nn.PixelShuffle(upscale_factor)
+
+ def forward(self, x):
+ x = self.relu(self.conv1(x))
+ x = self.relu(self.conv2(x))
+ x = self.relu(self.conv3(x))
+ x = self.pixel_shuffle(self.conv4(x))
+ return x
+
+ net = Net(upscale_factor=4)
+ self.checkTrace(net, (torch.rand(5, 1, 64, 64),))
+
+ @unittest.skip('This needs to be re-written into a script module')
+ def test_time_sequence_prediction(self):
+ class Sequence(nn.Module):
+ def __init__(self):
+ super(Sequence, self).__init__()
+ self.lstm1 = nn.LSTMCell(1, 51)
+ self.lstm2 = nn.LSTMCell(51, 51)
+ self.linear = nn.Linear(51, 1)
+
+ def forward(self, input, future=0):
+ outputs = []
+ h_t = torch.zeros(input.size(0), 51, dtype=torch.double)
+ c_t = torch.zeros(input.size(0), 51, dtype=torch.double)
+ h_t2 = torch.zeros(input.size(0), 51, dtype=torch.double)
+ c_t2 = torch.zeros(input.size(0), 51, dtype=torch.double)
+
+ for _, input_t in enumerate(input.chunk(input.size(1), dim=1)):
+ h_t, c_t = self.lstm1(input_t, (h_t, c_t))
+ h_t2, c_t2 = self.lstm2(h_t, (h_t2, c_t2))
+ output = self.linear(h_t2)
+ outputs += [output]
+ for _ in range(future): # if we should predict the future
+ h_t, c_t = self.lstm1(output, (h_t, c_t))
+ h_t2, c_t2 = self.lstm2(h_t, (h_t2, c_t2))
+ output = self.linear(h_t2)
+ outputs += [output]
+ outputs = torch.stack(outputs, 1).squeeze(2)
+ return outputs
+
+ self.checkTrace(Sequence(), (torch.rand(97, 999),), verbose=True)
+
+ def test_vae(self):
+ class VAE(nn.Module):
+ def __init__(self):
+ super(VAE, self).__init__()
+
+ self.fc1 = nn.Linear(784, 400)
+ self.fc21 = nn.Linear(400, 20)
+ self.fc22 = nn.Linear(400, 20)
+ self.fc3 = nn.Linear(20, 400)
+ self.fc4 = nn.Linear(400, 784)
+
+ def encode(self, x):
+ h1 = F.relu(self.fc1(x))
+ return self.fc21(h1), self.fc22(h1)
+
+ def reparameterize(self, mu, logvar):
+ if self.training:
+ std = torch.exp(0.5 * logvar)
+ eps = torch.randn_like(std)
+ return eps.mul(std).add_(mu)
+ else:
+ return mu
+
+ def decode(self, z):
+ h3 = F.relu(self.fc3(z))
+ return F.sigmoid(self.fc4(h3))
+
+ def forward(self, x):
+ mu, logvar = self.encode(x.view(-1, 784))
+ z = self.reparameterize(mu, logvar)
+ return self.decode(z), mu, logvar
+
+ # FIXME: this fails under training because of the call to `randn_like`
+ # https://github.com/pytorch/pytorch/issues/8443
+ self.checkTrace(VAE().eval(), (torch.rand(128, 1, 28, 28),))
+
+
# Smoke tests for export methods
class TestPytorchExportModes(JitTestCase):
class MyModel(nn.Module):