blob: 47c3bc2e39ba4707365d10e046913c6fac0e81cc [file] [log] [blame]
import torch
import torch.nn as nn
class DummyNet(nn.Module):
def __init__(self, num_classes=1000):
super(DummyNet, self).__init__()
self.features = nn.Sequential(
nn.LeakyReLU(0.02),
nn.BatchNorm2d(3),
nn.AvgPool2d(kernel_size=3, stride=2, padding=1, ceil_mode=False)
)
def forward(self, x):
output = self.features(x)
return output.view(-1, 1).squeeze(1)
class ConcatNet(nn.Module):
def __init__(self):
super(ConcatNet, self).__init__()
def forward(self, inputs):
return torch.cat(inputs, 1)
class PermuteNet(nn.Module):
def __init__(self):
super(PermuteNet, self).__init__()
def forward(self, input):
return input.permute(2, 3, 0, 1)
class PReluNet(nn.Module):
def __init__(self):
super(PReluNet, self).__init__()
self.features = nn.Sequential(
nn.PReLU(3),
)
def forward(self, x):
output = self.features(x)
return output