blob: c0a2e1b13f265ff710d0f3fa0ca75e613331adea [file] [log] [blame]
import sys
import tempfile
import unittest
from copy import deepcopy
from itertools import product
import torch
import torch.cuda
from torch.autograd import Variable
from common import TestCase, to_gpu, freeze_rng_state
from torch.autograd.gradcheck import get_numerical_jacobian, iter_tensors, contiguous
import torch.backends.cudnn
# tarfile module tries to obtain a file object name in python 3.3
if sys.version_info[:2] == (3, 3):
TemporaryFile = tempfile.NamedTemporaryFile
else:
TemporaryFile = tempfile.TemporaryFile
TEST_CUDA = torch.cuda.is_available()
TEST_MULTIGPU = TEST_CUDA and torch.cuda.device_count() >= 2
TEST_CUDNN = TEST_CUDA and torch.backends.cudnn.is_acceptable(torch.cuda.FloatTensor(1))
TEST_CUDNN_VERSION = TEST_CUDNN and torch.backends.cudnn.version()
PRECISION = 1e-5
def get_size_average(m):
return getattr(m, 'size_average', False) or getattr(m, 'sizeAverage', False)
def get_weight(m):
result = getattr(m, 'weight', None)
if result is not None:
return result
return getattr(m, 'weights', None)
module_tests = [
dict(
module_name='Linear',
constructor_args=(10, 8),
input_size=(4, 10),
reference_fn=lambda i, p: torch.mm(i, p[0].t()) + p[1].view(1, -1).expand(4, 8)
),
dict(
module_name='Linear',
constructor_args=(10, 8, False),
input_size=(4, 10),
desc='no_bias',
reference_fn=lambda i, p: torch.mm(i, p[0].t())
),
dict(
module_name='Threshold',
constructor_args=(2, 1),
input_size=(2, 3, 4, 5),
check_inplace=True,
desc='threshold_value'
),
dict(
module_name='Threshold',
constructor_args=(2, 10),
input_size=(2, 3, 4, 5),
desc='large_value'
),
dict(
module_name='ReLU',
input_size=(2, 3, 4, 5),
check_inplace=True,
),
dict(
module_name='ReLU6',
input_size=(2, 3, 4, 5),
check_inplace=True,
),
dict(
module_name='RReLU',
input_size=(1, 2, 2),
test_cuda=False,
),
dict(
module_name='RReLU',
constructor_args=(0.1, 0.9),
input_size=(4, 4, 5),
desc='with_up_down',
test_cuda=False,
),
dict(
module_name='Hardtanh',
input_size=(3, 2, 5),
reference_fn=lambda i, _: i.clamp(-1, 1),
),
dict(
module_name='Sigmoid',
input_size=(2, 3, 4, 5)
),
dict(
module_name='Tanh',
input_size=(2, 3, 4, 5)
),
dict(
module_name='Softmax',
constructor_args=(1,),
input_size=(10, 20),
reference_fn=lambda i, _: torch.exp(i).div(torch.exp(i).sum(1, True).expand(10, 20)),
),
dict(
module_name='Softmax2d',
input_size=(1, 3, 10, 20),
reference_fn=lambda i, _: torch.exp(i).div(torch.exp(i).sum(1, False)),
),
dict(
module_name='LogSoftmax',
constructor_args=(1,),
input_size=(10, 20),
reference_fn=lambda i, _: torch.exp(i).div_(torch.exp(i).sum(1, True).expand(10, 20)).log_(),
),
dict(
module_name='LogSoftmax',
constructor_args=(1,),
input_size=(1, 3, 10, 20),
reference_fn=lambda i, _: torch.exp(i).div_(torch.exp(i).sum(1, False)).log_(),
desc='multiparam',
),
dict(
module_name='ELU',
constructor_args=(2.,),
input_size=(3, 2, 5),
),
# TODO: reference function
dict(
module_name='Hardshrink',
constructor_args=(2.,),
input_size=(4, 3, 2, 4),
),
dict(
module_name='LeakyReLU',
input_size=(3, 2, 5),
check_inplace=True
),
dict(
module_name='LeakyReLU',
constructor_args=(0.5,),
input_size=(3, 2, 5),
check_inplace=True,
desc='with_negval'
),
dict(
module_name='LogSigmoid',
input_size=(2, 3, 4),
reference_fn=lambda i, _: i.sigmoid().log(),
),
dict(
module_name='Softplus',
input_size=(10, 20),
reference_fn=lambda i, _: torch.log(1 + torch.exp(i)),
),
dict(
module_name='Softplus',
constructor_args=(2,),
input_size=(10, 20),
reference_fn=lambda i, _: 1. / 2. * torch.log(1 + torch.exp(2 * i)),
desc='beta',
),
dict(
module_name='Softplus',
constructor_args=(2, -100),
input_size=(10, 20),
reference_fn=(lambda i, _: ((i * 2) > -100).type_as(i) * i +
((i * 2) <= -100).type_as(i) * 1. / 2. * torch.log(1 + torch.exp(2 * i))),
desc='beta_threshold',
),
dict(
module_name='Softshrink',
input_size=(3, 2, 5),
),
dict(
module_name='Softshrink',
constructor_args=(1,),
input_size=(3, 2, 5),
desc='lambda',
),
dict(
module_name='CrossMapLRN2d',
constructor_args=(5, 5e-3, 1e-3, 2),
input_size=(2, 3, 6, 6),
check_gradgrad=False,
),
dict(
module_name='PReLU',
input_size=(2, 3, 4),
reference_fn=lambda i, p: torch.clamp(i, min=0) + torch.clamp(i, max=0) * p[0][0],
desc='1d',
),
dict(
module_name='PReLU',
constructor_args=(3,),
input_size=(2, 3, 4),
desc='1d_multiparam',
reference_fn=lambda i, p: torch.clamp(i, min=0) + torch.clamp(i, max=0) * p[0][0],
),
dict(
module_name='PReLU',
input_size=(2, 3, 4, 5),
desc='2d',
reference_fn=lambda i, p: torch.clamp(i, min=0) + torch.clamp(i, max=0) * p[0][0],
),
dict(
module_name='PReLU',
constructor_args=(3,),
input_size=(2, 3, 4, 5),
desc='2d_multiparam',
reference_fn=lambda i, p: torch.clamp(i, min=0) + torch.clamp(i, max=0) * p[0][0],
),
dict(
module_name='PReLU',
input_size=(2, 3, 4, 5, 6),
reference_fn=lambda i, p: torch.clamp(i, min=0) + torch.clamp(i, max=0) * p[0][0],
desc='3d',
),
dict(
module_name='PReLU',
constructor_args=(3,),
input_size=(2, 3, 4, 5, 6),
desc='3d_multiparam',
reference_fn=lambda i, p: torch.clamp(i, min=0) + torch.clamp(i, max=0) * p[0][0],
),
dict(
module_name='Softsign',
input_size=(3, 2, 5),
reference_fn=lambda i, _: i.div(1 + torch.abs(i)),
),
dict(
module_name='Softmin',
constructor_args=(1,),
input_size=(10, 20),
),
dict(
module_name='Softmin',
constructor_args=(1,),
input_size=(2, 3, 5, 10),
desc='multidim',
),
dict(
module_name='Tanhshrink',
input_size=(2, 3, 4, 5)
),
]
def kldivloss_reference(input, target, size_average=True, reduce=True):
safe_target = target * (target > 0).type_as(target)
safe_target_log = (safe_target + (target <= 0).type_as(target)).log()
result = safe_target * (safe_target_log - input)
if reduce and size_average:
return result.mean()
elif reduce:
return result.sum()
return result
def nlllossNd_reference(input, target, weight=None, ignore_index=-100,
size_average=True, reduce=True):
assert input.dim() >= 4
N = input.size(0)
C = input.size(1)
out_size = (N,) + input.size()[2:]
output = torch.zeros(out_size).type_as(input)
if isinstance(target, Variable):
target = target.data
if weight is None:
weight = torch.ones(C).type_as(input)
total_weight_data = 0
for tup in product(*[range(size) for size in out_size]):
t_nx = target[tup]
norm = 0. if ignore_index == t_nx else weight[t_nx]
input_index = list(tup)
input_index.insert(1, t_nx)
output[tup] = -input[tuple(input_index)] * norm
total_weight_data += norm
if reduce and size_average:
return output.sum() / total_weight_data
elif reduce:
return output.sum()
return output
def nllloss_reference(input, target, weight=None, ignore_index=-100,
size_average=True, reduce=True):
if isinstance(target, Variable):
target = target.data
def nll_loss_helper(input, target, weight, ignore_index):
if target is ignore_index:
return (0, 0)
norm = 1 if weight is None else weight[target]
result = -input[target] * norm
return (result, norm)
losses_and_weights = [nll_loss_helper(i, t, weight, ignore_index)
for i, t in zip(input, target)]
losses, weights = zip(*losses_and_weights)
losses_tensor = torch.Tensor(losses).type_as(input)
if reduce and size_average:
return sum(losses_tensor) / sum(weights)
elif reduce:
return sum(losses_tensor)
else:
return losses_tensor
def smoothl1loss_reference(input, target, size_average=True, reduce=True):
abs_diff = (input - target).abs()
ge_one_mask = (abs_diff >= 1).type_as(abs_diff)
lt_one_mask = (abs_diff < 1).type_as(abs_diff)
output = ge_one_mask * (abs_diff - 0.5) + lt_one_mask * 0.5 * (abs_diff ** 2)
if reduce and size_average:
return output.mean()
elif reduce:
return output.sum()
return output
loss_reference_fns = {
'KLDivLoss': kldivloss_reference,
'NLLLoss': nllloss_reference,
'NLLLossNd': nlllossNd_reference,
'SmoothL1Loss': smoothl1loss_reference,
}
criterion_tests = [
dict(
module_name='L1Loss',
input_size=(2, 3, 4),
target_size=(2, 3, 4),
reference_fn=lambda i, t, _: 1. / i.numel() *
sum((a - b).abs().sum() for a, b in zip(i, t)),
),
dict(
module_name='NLLLoss',
input_fn=lambda: torch.rand(15, 10).log(),
target_fn=lambda: torch.Tensor(15).uniform_().mul(10).floor().long(),
reference_fn=lambda i, t, m:
nllloss_reference(i, t, size_average=get_size_average(m)),
check_no_size_average=True
),
dict(
module_name='NLLLoss',
constructor_args=(None, True, 2),
input_fn=lambda: torch.rand(15, 10).log(),
target_fn=lambda: torch.Tensor(15).uniform_().mul(10).floor().long(),
reference_fn=lambda i, t, _: nllloss_reference(i, t, ignore_index=2),
desc='ignore_index'
),
dict(
module_name='NLLLoss',
constructor_args_fn=lambda: (torch.rand(10),),
input_fn=lambda: torch.rand(15, 10).add(1e-2).log(),
target_fn=lambda: torch.Tensor(15).uniform_().mul(10).floor().long(),
reference_fn=lambda i, t, m:
nllloss_reference(i, t, weight=get_weight(m)),
desc='weights',
),
dict(
module_name='NLLLoss',
constructor_args_fn=lambda: (torch.rand(10), True, 2),
input_fn=lambda: torch.rand(15, 10).add(1e-2).log(),
target_fn=lambda: torch.Tensor(15).uniform_().mul(10).floor().long(),
reference_fn=lambda i, t, m:
nllloss_reference(i, t, weight=get_weight(m), ignore_index=2),
desc='weights_ignore_index'
),
dict(
module_name='NLLLoss',
constructor_args_fn=lambda: (torch.rand(10), True, -1),
input_fn=lambda: torch.rand(15, 10).add(1e-2).log(),
target_fn=lambda: torch.Tensor(15).uniform_().mul(10 + 1).floor().long() - 1,
reference_fn=lambda i, t, m:
nllloss_reference(i, t, weight=get_weight(m), ignore_index=-1),
desc='weights_ignore_index_neg'
),
dict(
module_name='KLDivLoss',
input_fn=lambda: torch.rand(10, 10).log(),
target_fn=lambda: torch.rand(10, 10),
reference_fn=lambda i, t, m:
kldivloss_reference(i, t, get_size_average(m), reduce=True),
check_no_size_average=True,
),
dict(
module_name='MSELoss',
input_size=(2, 3, 4, 5),
target_size=(2, 3, 4, 5),
reference_fn=lambda i, t, m: (i - t).abs().pow(2).sum() / (i.numel() if get_size_average(m) else 1),
check_no_size_average=True,
),
dict(
module_name='BCELoss',
input_fn=lambda: torch.rand(15, 10).clamp_(1e-2, 1 - 1e-2),
target_fn=lambda: torch.randn(15, 10).gt(0).double(),
check_gradgrad=False,
),
dict(
module_name='BCELoss',
constructor_args_fn=lambda: (torch.rand(10),),
input_fn=lambda: torch.rand(15, 10).clamp_(1e-2, 1 - 1e-2),
target_fn=lambda: torch.randn(15, 10).gt(0).double(),
desc='weights',
check_gradgrad=False,
),
dict(
module_name='CrossEntropyLoss',
input_size=(15, 10),
target_fn=lambda: torch.Tensor(15).uniform_().mul(10).floor().long(),
),
dict(
module_name='CrossEntropyLoss',
constructor_args_fn=lambda: (torch.rand(10),),
input_size=(15, 10),
target_fn=lambda: torch.Tensor(15).uniform_().mul(10).floor().long(),
desc='weights',
),
dict(
module_name='NLLLoss2d',
input_size=(2, 3, 5, 5),
target_fn=lambda: torch.rand(2, 5, 5).mul(3).floor().long(),
reference_fn=lambda i, t, m:
nlllossNd_reference(i, t, size_average=get_size_average(m)),
check_no_size_average=True,
),
dict(
module_name='NLLLoss2d',
constructor_args_fn=lambda: (torch.rand(3),),
input_size=(2, 3, 5, 5),
target=torch.rand(2, 5, 5).mul(3).floor().long(),
reference_fn=lambda i, t, m:
nlllossNd_reference(i, t, weight=get_weight(m)),
desc='weights',
),
dict(
module_name='NLLLoss2d',
constructor_args=(None, True, 1),
input_size=(2, 3, 5, 5),
target_fn=lambda: torch.rand(2, 5, 5).mul(3).floor().long(),
reference_fn=lambda i, t, m:
nlllossNd_reference(i, t, ignore_index=1),
desc='ignore_index',
),
dict(
module_name='HingeEmbeddingLoss',
input_size=(10,),
target_fn=lambda: torch.randn(10).gt(0).double().mul_(2).sub(1),
),
dict(
module_name='HingeEmbeddingLoss',
constructor_args=(0.5,),
input_size=(10,),
target_fn=lambda: torch.randn(10).gt(0).double().mul_(2).sub(1),
desc='margin',
check_no_size_average=True,
),
dict(
module_name='MultiLabelMarginLoss',
input_size=(5, 10),
target_fn=lambda: torch.rand(5, 10).mul(10).floor().long(),
check_no_size_average=True,
check_gradgrad=False,
),
dict(
module_name='MultiLabelSoftMarginLoss',
input_size=(5, 10),
target_fn=lambda: torch.rand(5, 10).mul(2).floor(),
check_gradgrad=False,
),
dict(
module_name='MultiLabelSoftMarginLoss',
constructor_args_fn=lambda: (torch.rand(10),),
input_size=(5, 10),
target_fn=lambda: torch.rand(5, 10).mul(2).floor(),
desc='weights',
check_gradgrad=False,
),
dict(
module_name='MultiMarginLoss',
input_size=(5, 10),
target_fn=lambda: torch.rand(5).mul(8).floor().long(),
check_gradgrad=False,
),
dict(
module_name='SmoothL1Loss',
input_size=(5, 10),
target_size=(5, 10),
check_no_size_average=True,
reference_fn=lambda i, t, m:
smoothl1loss_reference(i, t, size_average=get_size_average(m)),
),
dict(
module_name='SoftMarginLoss',
input_size=(5, 5),
target_fn=lambda: torch.randn(5, 5).sign(),
check_no_size_average=True,
),
dict(
module_name='CosineEmbeddingLoss',
input_fn=lambda: (torch.rand(15, 10), torch.rand(15, 10)),
target_fn=lambda: torch.randn(15).sign(),
check_gradgrad=False,
),
dict(
module_name='CosineEmbeddingLoss',
constructor_args=(0.7,),
input_fn=lambda: (torch.rand(15, 10), torch.rand(15, 10)),
target_fn=lambda: torch.randn(15).sign(),
desc='margin',
check_gradgrad=False,
),
dict(
module_name='MarginRankingLoss',
input_fn=lambda: (torch.randn(50).mul(10), torch.randn(50).mul(10)),
target_fn=lambda: torch.randn(50).sign(),
check_no_size_average=True,
),
dict(
module_name='MarginRankingLoss',
constructor_args=(2,),
input_fn=lambda: (torch.randn(50).mul(10), torch.randn(50).mul(10)),
target_fn=lambda: torch.randn(50).sign(),
desc='margin',
check_no_size_average=True,
),
]
class NNTestCase(TestCase):
def _jacobian(self, input, num_out):
if isinstance(input, tuple):
return tuple(self._jacobian(elem, num_out) for elem in input)
elif isinstance(input, list):
return [self._jacobian(elem, num_out) for elem in input]
else:
return torch.zeros(input.nelement(), num_out)
def _flatten_tensors(self, x):
if torch.is_tensor(x):
if x.is_sparse:
return x.to_dense().view(-1)
else:
return x.view(-1)
elif isinstance(x, Variable):
return self._flatten_tensors(x.data)
else:
return tuple(self._flatten_tensors(a) for a in x)
def _zero_grad_input(self, input):
if isinstance(input, Variable):
if input.requires_grad and input.grad is not None:
input.grad.data.zero_()
input.grad.detach_()
elif torch.is_tensor(input):
return
else:
for i in input:
self._zero_grad_input(i)
def _analytical_jacobian(self, module, input, jacobian_input=True, jacobian_parameters=True):
output = self._forward(module, input)
output_size = output.nelement()
output_t = output.data if isinstance(output, Variable) else output
if jacobian_input:
jacobian_inp = self._jacobian(input, output_size)
flat_jacobian_input = list(iter_tensors(jacobian_inp))
if jacobian_parameters:
num_param = sum(p.numel() for p in self._get_parameters(module)[0])
jacobian_param = torch.zeros(num_param, output_size)
for i in range(output_size):
_, d_param = self._get_parameters(module)
d_out = torch.zeros_like(output_t)
flat_d_out = d_out.view(-1)
flat_d_out[i] = 1
if jacobian_parameters:
self._zero_grad_parameters(module)
# Variables will accumulate gradient from multiple steps
if jacobian_input:
self._zero_grad_input(input)
d_input = self._backward(module, input, output, d_out)
if jacobian_input:
for jacobian_x, d_x in zip(flat_jacobian_input, iter_tensors(d_input)):
jacobian_x[:, i] = d_x
if jacobian_parameters:
jacobian_param[:, i] = torch.cat(self._flatten_tensors(d_param), 0)
res = tuple()
if jacobian_input:
res += jacobian_inp,
if jacobian_parameters:
res += jacobian_param,
return res
def _numerical_jacobian(self, module, input, jacobian_input=True, jacobian_parameters=True):
def fw(input):
out = self._forward(module, input)
if isinstance(out, Variable):
return out.data
return out
res = tuple()
input = contiguous(input)
if jacobian_input:
res += get_numerical_jacobian(fw, input, input, eps=1e-6),
if jacobian_parameters:
param, _ = self._get_parameters(module)
res += torch.cat(list(get_numerical_jacobian(fw, input, p, eps=1e-6) for p in param), 0),
return res
def check_jacobian(self, module, input, jacobian_input=True):
jacobian_parameters = bool(self._get_parameters(module)[0])
analytical = self._analytical_jacobian(module, input, jacobian_input, jacobian_parameters)
numerical = self._numerical_jacobian(module, input, jacobian_input, jacobian_parameters)
analytical_t = iter_tensors(analytical)
numerical_t = iter_tensors(numerical)
# TODO: compare structure
self.assertLessEqual(
max(a.add(-1, n).abs().max() for a, n in zip(analytical_t, numerical_t)),
PRECISION
)
def check_criterion_jacobian(self, criterion, input, target):
eps = 1e-6
self._forward_criterion(criterion, input, target)
analytical_d_x = self._backward_criterion(criterion, input, target)
numerical_d_x = deepcopy(analytical_d_x)
input_t = iter_tensors(input)
numerical_t = iter_tensors(numerical_d_x)
for x, d_x in zip(input_t, numerical_t):
x = x.view(-1)
d_x = d_x.view(-1)
for i in range(x.nelement()):
original = x[i]
x[i] = original + eps
fx1 = self._forward_criterion(criterion, input, target)
x[i] = original - eps
fx2 = self._forward_criterion(criterion, input, target)
deriv = (fx1 - fx2) / (2. * eps)
d_x[i] = deriv
x[i] = original
# TODO: check structure
analytical_t = iter_tensors(analytical_d_x)
numerical_t = iter_tensors(numerical_d_x)
self.assertLessEqual(
max(a.add(-1, n).abs().max() for a, n in zip(analytical_t, numerical_t)),
PRECISION
)
class TestBase(object):
_required_arg_names = {'constructor_args', 'input'}
def __init__(self, constructor, desc='', reference_fn=None, fullname=None, **kwargs):
self.desc = desc
self.fullname = fullname
self.constructor = constructor
self.reference_fn = reference_fn
for name in self._required_arg_names:
if name not in kwargs and name + '_fn' not in kwargs and name + '_size' not in kwargs:
if name == 'constructor_args':
kwargs['constructor_args'] = tuple()
else:
raise ValueError("{}: Specify {} by a value, a function to generate it, or it's size!"
.format(self.get_name(), name))
self._extra_kwargs = kwargs
self._arg_cache = {}
def get_name(self):
if self.fullname is not None:
return 'test_' + self.fullname
test_name = 'test_' + self.constructor.__name__
if self.desc:
test_name += '_' + self.desc
return test_name
def _unpack(self, value):
if isinstance(value, Variable):
return value.data
elif torch.is_tensor(value):
return value
else:
return type(value)(self._unpack(v) for v in value)
@property
def constructor_args(self):
return self._get_arg('constructor_args')
def _get_arg(self, name):
assert name in self._required_arg_names
if name not in self._arg_cache:
fn_name = name + '_fn'
size_name = name + '_size'
if name in self._extra_kwargs:
self._arg_cache[name] = self._extra_kwargs[name]
elif fn_name in self._extra_kwargs:
self._arg_cache[name] = self._extra_kwargs[fn_name]()
else:
assert size_name in self._extra_kwargs
def map_tensor_sizes(sizes):
if isinstance(sizes, list):
return [map_tensor_sizes(s) for s in sizes]
elif torch.is_tensor(sizes):
return sizes.double()
else:
return torch.randn(*sizes)
self._arg_cache[name] = map_tensor_sizes(self._extra_kwargs[size_name])
return self._arg_cache[name]
def _get_input(self):
return self._get_arg('input')
def __call__(self, test_case):
raise NotImplementedError
class ModuleTest(TestBase):
def __init__(self, *args, **kwargs):
super(ModuleTest, self).__init__(*args, **kwargs)
self.jacobian_input = kwargs.get('jacobian_input', True)
self.should_test_cuda = kwargs.get('test_cuda', True)
self.should_test_pickle = kwargs.get('pickle', True)
def __call__(self, test_case):
module = self.constructor(*self.constructor_args)
input = self._get_input()
if self.reference_fn is not None:
out = test_case._forward(module, input)
if isinstance(out, Variable):
out = out.data
ref_input = self._unpack(deepcopy(input))
expected_out = self.reference_fn(ref_input, test_case._get_parameters(module)[0])
test_case.assertEqual(out, expected_out)
self.test_noncontig(test_case, module, input)
if self.should_test_pickle:
# TODO: do this with in-memory files as soon as torch.save will support it
with TemporaryFile() as f:
test_case._forward(module, input)
torch.save(module, f)
f.seek(0)
module_copy = torch.load(f)
test_case.assertEqual(test_case._forward(module, input), test_case._forward(module_copy, input))
self._do_test(test_case, module, input)
def noncontiguize(self, obj):
if isinstance(obj, list):
return [self.noncontiguize(o) for o in obj]
tensor = obj.data if isinstance(obj, Variable) else obj
ndim = tensor.dim()
noncontig = torch.stack([tensor.clone().zero_(), tensor], ndim).select(ndim, 1)
assert noncontig.numel() == 1 or not noncontig.is_contiguous()
if isinstance(obj, Variable):
return Variable(noncontig, requires_grad=obj.requires_grad)
return noncontig
def test_noncontig(self, test_case, module, input):
test_case._zero_grad_parameters(module)
test_case._zero_grad_input(input)
with freeze_rng_state():
output = test_case._forward(module, input)
grad_output = output
if isinstance(grad_output, Variable):
grad_output = grad_output.data.clone()
else:
grad_output = grad_output.clone()
output = output.clone()
grad_output.normal_()
d_input = deepcopy(test_case._backward(module, input, output, grad_output))
d_param = deepcopy(test_case._get_parameters(module)[1])
nc_input = self.noncontiguize(input)
nc_grad_output = self.noncontiguize(grad_output)
for contig_i, contig_g in product((True, False), repeat=2):
i = input if contig_i else nc_input
go = grad_output if contig_g else nc_grad_output
test_case._zero_grad_parameters(module)
test_case._zero_grad_input(i)
with freeze_rng_state():
try:
out = test_case._forward(module, i)
except Exception:
# Some modules will fail because of non contiguous inputs and we're ok with that
continue
grad = test_case._backward(module, i, out, go)
test_case.assertEqual(out, output)
test_case.assertEqual(grad, d_input, 1e-4)
test_case.assertEqual(test_case._get_parameters(module)[1], d_param)
def test_cuda(self, test_case):
if not TEST_CUDA or not self.should_test_cuda:
raise unittest.SkipTest('Excluded from CUDA tests')
try:
cpu_input = self._get_input()
type_map = {torch.DoubleTensor: torch.cuda.FloatTensor}
gpu_input = to_gpu(cpu_input, type_map=type_map)
cpu_module = self.constructor(*self.constructor_args)
gpu_module = self.constructor(*self.constructor_args).float().cuda()
cpu_param = test_case._get_parameters(cpu_module)
gpu_param = test_case._get_parameters(gpu_module)
for cpu_p, gpu_p in zip(cpu_param[0], gpu_param[0]):
if isinstance(cpu_p, Variable):
cpu_p = cpu_p.data
if isinstance(gpu_p, Variable):
gpu_p = gpu_p.data
gpu_p.copy_(cpu_p)
test_case._zero_grad_input(cpu_input)
test_case._zero_grad_input(gpu_input)
test_case._zero_grad_parameters(cpu_module)
test_case._zero_grad_parameters(gpu_module)
cpu_output = test_case._forward(cpu_module, cpu_input)
gpu_output = test_case._forward(gpu_module, gpu_input)
test_case.assertEqual(cpu_output, gpu_output, 2e-4)
for i in range(5):
cpu_output_t = cpu_output.data if isinstance(cpu_output, Variable) else cpu_output
cpu_gradOutput = cpu_output_t.clone().bernoulli_()
gpu_gradOutput = cpu_gradOutput.type('torch.cuda.FloatTensor')
cpu_gradInput = test_case._backward(cpu_module, cpu_input, cpu_output, cpu_gradOutput)
gpu_gradInput = test_case._backward(gpu_module, gpu_input, gpu_output, gpu_gradOutput)
test_case.assertEqual(cpu_gradInput, gpu_gradInput, 2e-4)
for cpu_d_p, gpu_d_p in zip(cpu_param[1], gpu_param[1]):
test_case.assertEqual(cpu_d_p, gpu_d_p, 2e-4)
self.test_noncontig(test_case, gpu_module, gpu_input)
except NotImplementedError:
pass
# TODO: remove this after CUDA scatter_ is implemented
except AttributeError as e:
if len(e.args) == 1 and "'FloatTensor' object has no attribute 'scatter_'" in e.args[0]:
pass
else:
raise
class CriterionTest(TestBase):
_required_arg_names = TestBase._required_arg_names.union({'target'})
def __init__(self, *args, **kwargs):
super(CriterionTest, self).__init__(*args, **kwargs)
self.should_test_cuda = kwargs.get('test_cuda', True)
def _get_target(self):
return self._get_arg('target')
def __call__(self, test_case):
module = self.constructor(*self.constructor_args)
input = self._get_input()
# Check that these methods don't raise errors
module.__repr__()
str(module)
target = self._get_target()
if self.reference_fn is not None:
out = test_case._forward_criterion(module, input, target)
expected_out = self.reference_fn(deepcopy(self._unpack(input)),
deepcopy(self._unpack(target)), module)
test_case.assertEqual(out, expected_out)
test_case.check_criterion_jacobian(module, input, target)
self._do_extra_tests(test_case, module, input, target)
def test_cuda(self, test_case):
if not TEST_CUDA or not self.should_test_cuda:
raise unittest.SkipTest('Excluded from CUDA tests')
try:
cpu_input = self._get_input()
type_map = {
torch.DoubleTensor: torch.cuda.FloatTensor,
}
gpu_input = to_gpu(cpu_input, type_map=type_map)
cpu_target = self._get_target()
gpu_target = to_gpu(cpu_target, type_map=type_map)
cpu_module = self.constructor(*self.constructor_args)
gpu_module = self.constructor(*self.constructor_args).float().cuda()
cpu_output = test_case._forward_criterion(cpu_module, cpu_input, cpu_target)
gpu_output = test_case._forward_criterion(gpu_module, gpu_input, gpu_target)
test_case.assertEqual(cpu_output, gpu_output, 4e-4)
cpu_gradInput = test_case._backward_criterion(cpu_module, cpu_input, cpu_target)
gpu_gradInput = test_case._backward_criterion(gpu_module, gpu_input, gpu_target)
test_case.assertEqual(cpu_gradInput, gpu_gradInput, 4e-4)
except NotImplementedError:
pass
def _do_extra_tests(self, test_case, module, input, target):
pass