update nn loss tests to use new reduction arg (#9118)
Summary:
The tests were using the old args, which caused them to emit a lot of deprecation warnings.
closes #9103.
Reviewed By: ezyang
Differential Revision: D8720581
Pulled By: li-roy
fbshipit-source-id: 3b79527f6fe862fb48b99a6394e8d7b89fc7a8c8
diff --git a/test/common_nn.py b/test/common_nn.py
index 7f6c2ac..2f5b6a2 100644
--- a/test/common_nn.py
+++ b/test/common_nn.py
@@ -21,11 +21,12 @@
PRECISION = 1e-5
-def get_size_average(m):
+def get_reduction(m):
result = getattr(m, 'reduction', None)
- if result is not None:
- return result is 'elementwise_mean'
- return getattr(m, 'sizeAverage', None)
+ if result is None:
+ result = _Reduction.legacy_get_string(getattr(m, 'sizeAverage', None), True, emit_warning=False)
+ assert result is not None
+ return result
def get_weight(m):
@@ -246,19 +247,19 @@
]
-def kldivloss_reference(input, target, size_average=True, reduce=True):
+def kldivloss_reference(input, target, reduction='elementwise_mean'):
safe_target = target * (target > 0).type_as(target)
safe_target_log = (safe_target + (target <= 0).type_as(target)).log()
result = safe_target * (safe_target_log - input)
- if reduce and size_average:
+ if reduction == 'elementwise_mean':
return result.mean()
- elif reduce:
+ elif reduction == 'sum':
return result.sum()
return result
def nlllossNd_reference(input, target, weight=None, ignore_index=-100,
- size_average=True, reduce=True):
+ reduction='elementwise_mean'):
assert input.dim() >= 3
N = input.size(0)
C = input.size(1)
@@ -276,15 +277,15 @@
output[tup] = -input[tuple(input_index)] * norm
total_weight += norm
- if reduce and size_average:
+ if reduction == 'elementwise_mean':
return output.sum() / total_weight
- elif reduce:
+ elif reduction == 'sum':
return output.sum()
return output
def nllloss_reference(input, target, weight=None, ignore_index=-100,
- size_average=True, reduce=True):
+ reduction='elementwise_mean'):
def nll_loss_helper(input, target, weight, ignore_index):
if target == ignore_index:
@@ -297,22 +298,22 @@
for i, t in zip(input, target)]
losses, weights = zip(*losses_and_weights)
losses_tensor = input.new_tensor(losses)
- if reduce and size_average:
+ if reduction == 'elementwise_mean':
return sum(losses_tensor) / sum(weights)
- elif reduce:
+ elif reduction == 'sum':
return sum(losses_tensor)
else:
return losses_tensor
-def smoothl1loss_reference(input, target, size_average=True, reduce=True):
+def smoothl1loss_reference(input, target, reduction='elementwise_mean'):
abs_diff = (input - target).abs()
ge_one_mask = (abs_diff >= 1).type_as(abs_diff)
lt_one_mask = (abs_diff < 1).type_as(abs_diff)
output = ge_one_mask * (abs_diff - 0.5) + lt_one_mask * 0.5 * (abs_diff ** 2)
- if reduce and size_average:
+ if reduction == 'elementwise_mean':
return output.mean()
- elif reduce:
+ elif reduction == 'sum':
return output.sum()
return output
@@ -333,7 +334,7 @@
return sum
-def multilabelmarginloss_reference(input, target, size_average=True, reduce=True):
+def multilabelmarginloss_reference(input, target, reduction='elementwise_mean'):
if input.dim() == 1:
n = 1
dim = input.size(0)
@@ -346,30 +347,30 @@
for i in range(0, n):
output[i] = _multilabelmarginloss_reference(input[i], target[i])
- if reduce and size_average:
+ if reduction == 'elementwise_mean':
return output.mean() / dim
- elif reduce:
+ elif reduction == 'sum':
return output.sum() / dim
return output / dim
-def hingeembeddingloss_reference(input, target, margin=1.0, size_average=True, reduce=True):
+def hingeembeddingloss_reference(input, target, margin=1.0, reduction='elementwise_mean'):
margin_clamp = (margin - input).clamp(min=0).type_as(input)
output = torch.where(target == 1, input, margin_clamp)
- if reduce and size_average:
+ if reduction == 'elementwise_mean':
return output.mean()
- elif reduce:
+ elif reduction == 'sum':
return output.sum()
return output
-def softmarginloss_reference(input, target, size_average=True, reduce=True):
+def softmarginloss_reference(input, target, reduction='elementwise_mean'):
output = (1 + (-input * target).exp()).log()
- if reduce and size_average:
+ if reduction == 'elementwise_mean':
return output.mean()
- elif reduce:
+ elif reduction == 'sum':
return output.sum()
return output
@@ -385,8 +386,7 @@
return output
-def multimarginloss_reference(input, target, p=1, margin=1, weight=None, size_average=True,
- reduce=True):
+def multimarginloss_reference(input, target, p=1, margin=1, weight=None, reduction='elementwise_mean'):
if input.dim() == 1:
n = 1
dim = input.size(0)
@@ -400,14 +400,14 @@
for x in range(0, n):
output[x] = _multimarginloss_reference(input[x], target[x], p, margin, weight)
- if reduce and size_average:
+ if reduction == 'elementwise_mean':
return output.mean() / dim
- elif reduce:
+ elif reduction == 'sum':
return output.sum() / dim
return output / dim
-def cosineembeddingloss_reference(input1, input2, target, margin=0, size_average=True, reduce=True):
+def cosineembeddingloss_reference(input1, input2, target, margin=0, reduction='elementwise_mean'):
def _cos(a, b):
cos = a.new(a.size(0))
for i in range(0, a.size(0)):
@@ -416,15 +416,15 @@
output = torch.where(target == 1, 1 - _cos(input1, input2), (_cos(input1, input2) - margin).clamp(min=0))
- if reduce and size_average:
+ if reduction == 'elementwise_mean':
return output.mean()
- elif reduce:
+ elif reduction == 'sum':
return output.sum()
return output
def tripletmarginloss_reference(anchor, positive, negative, margin=1.0, p=2, eps=1e-6, swap=False,
- size_average=True, reduce=True):
+ reduction='elementwise_mean'):
d_p = torch.pairwise_distance(anchor, positive, p, eps)
d_n = torch.pairwise_distance(anchor, negative, p, eps)
if swap:
@@ -432,18 +432,18 @@
d_n = torch.min(d_n, d_s)
output = torch.clamp(margin + d_p - d_n, min=0.0)
- if reduce and size_average:
+ if reduction == 'elementwise_mean':
return output.mean()
- elif reduce:
+ elif reduction == 'sum':
return output.sum()
return output
-def marginrankingloss_reference(input1, input2, target, margin=0, size_average=True, reduce=True):
+def marginrankingloss_reference(input1, input2, target, margin=0, reduction='elementwise_mean'):
output = (-target * (input1 - input2) + margin).clamp(min=0)
- if reduce and size_average:
+ if reduction == 'elementwise_mean':
return output.mean()
- elif reduce:
+ elif reduction == 'sum':
return output.sum()
return output
@@ -476,12 +476,12 @@
input_fn=lambda: torch.rand(15, 10).log(),
target_fn=lambda: torch.Tensor(15).uniform_().mul(10).floor().long(),
reference_fn=lambda i, t, m:
- nllloss_reference(i, t, size_average=get_size_average(m)),
- check_no_size_average=True
+ nllloss_reference(i, t, reduction=get_reduction(m)),
+ check_sum_reduction=True
),
dict(
module_name='NLLLoss',
- constructor_args=(None, True, 2),
+ constructor_args=(None, None, 2),
input_fn=lambda: torch.rand(15, 10).log(),
target_fn=lambda: torch.Tensor(15).uniform_().mul(10).floor().long(),
reference_fn=lambda i, t, _: nllloss_reference(i, t, ignore_index=2),
@@ -498,7 +498,7 @@
),
dict(
module_name='NLLLoss',
- constructor_args_fn=lambda: (torch.rand(10), True, 2),
+ constructor_args_fn=lambda: (torch.rand(10), None, 2),
input_fn=lambda: torch.rand(15, 10).add(1e-2).log(),
target_fn=lambda: torch.Tensor(15).uniform_().mul(10).floor().long(),
reference_fn=lambda i, t, m:
@@ -507,7 +507,7 @@
),
dict(
module_name='NLLLoss',
- constructor_args_fn=lambda: (torch.rand(10), True, -1),
+ constructor_args_fn=lambda: (torch.rand(10), None, -1),
input_fn=lambda: torch.rand(15, 10).add(1e-2).log(),
target_fn=lambda: torch.Tensor(15).uniform_().mul(10 + 1).floor().long() - 1,
reference_fn=lambda i, t, m:
@@ -519,22 +519,23 @@
input_fn=lambda: torch.rand(10, 10).log(),
target_fn=lambda: torch.rand(10, 10),
reference_fn=lambda i, t, m:
- kldivloss_reference(i, t, get_size_average(m), reduce=True),
- check_no_size_average=True,
+ kldivloss_reference(i, t, get_reduction(m)),
+ check_sum_reduction=True,
),
dict(
module_name='MSELoss',
input_size=(2, 3, 4, 5),
target_size=(2, 3, 4, 5),
- reference_fn=lambda i, t, m: (i - t).abs().pow(2).sum() / (i.numel() if get_size_average(m) else 1),
- check_no_size_average=True,
+ reference_fn=lambda i, t, m: ((i - t).abs().pow(2).sum() / (i.numel()
+ if get_reduction(m) == 'elementwise_mean' else 1)),
+ check_sum_reduction=True,
),
dict(
module_name='BCELoss',
input_fn=lambda: torch.rand(15, 10).clamp_(1e-2, 1 - 1e-2),
target_fn=lambda: torch.randn(15, 10).gt(0).double(),
reference_fn=lambda i, t, m: -(t * i.log() + (1 - t) * (1 - i).log()).sum() /
- (i.numel() if get_size_average(m) else 1),
+ (i.numel() if get_reduction(m) else 1),
check_gradgrad=False,
),
dict(
@@ -543,7 +544,7 @@
input_fn=lambda: torch.rand(15, 10).clamp_(1e-2, 1 - 1e-2),
target_fn=lambda: torch.randn(15, 10).gt(0).double(),
reference_fn=lambda i, t, m: -((t * i.log() + (1 - t) * (1 - i).log()) * get_weight(m)).sum() /
- (i.numel() if get_size_average(m) else 1),
+ (i.numel() if get_reduction(m) else 1),
desc='weights',
check_gradgrad=False,
),
@@ -564,8 +565,8 @@
input_size=(10,),
target_fn=lambda: torch.randn(10).gt(0).double().mul_(2).sub(1),
reference_fn=lambda i, t, m:
- hingeembeddingloss_reference(i, t, size_average=get_size_average(m)),
- check_no_size_average=True,
+ hingeembeddingloss_reference(i, t, reduction=get_reduction(m)),
+ check_sum_reduction=True,
),
dict(
module_name='HingeEmbeddingLoss',
@@ -573,18 +574,18 @@
input_size=(10,),
target_fn=lambda: torch.randn(10).gt(0).double().mul_(2).sub(1),
reference_fn=lambda i, t, m:
- hingeembeddingloss_reference(i, t, margin=0.5, size_average=get_size_average(m)),
+ hingeembeddingloss_reference(i, t, margin=0.5, reduction=get_reduction(m)),
desc='margin',
- check_no_size_average=True,
+ check_sum_reduction=True,
),
dict(
module_name='MultiLabelMarginLoss',
input_size=(10,),
target_fn=lambda: torch.rand(10).mul(10).floor().long(),
reference_fn=lambda i, t, m:
- multilabelmarginloss_reference(i, t, size_average=get_size_average(m)),
+ multilabelmarginloss_reference(i, t, reduction=get_reduction(m)),
desc="1d",
- check_no_size_average=True,
+ check_sum_reduction=True,
check_gradgrad=False,
),
dict(
@@ -592,8 +593,8 @@
input_size=(5, 10),
target_fn=lambda: torch.rand(5, 10).mul(10).floor().long(),
reference_fn=lambda i, t, m:
- multilabelmarginloss_reference(i, t, size_average=get_size_average(m)),
- check_no_size_average=True,
+ multilabelmarginloss_reference(i, t, reduction=get_reduction(m)),
+ check_sum_reduction=True,
check_gradgrad=False,
),
dict(
@@ -608,8 +609,8 @@
input_size=(5, 10),
target_fn=lambda: torch.rand(5).mul(8).floor().long(),
reference_fn=lambda i, t, m:
- multimarginloss_reference(i, t, size_average=get_size_average(m)),
- check_no_size_average=True,
+ multimarginloss_reference(i, t, reduction=get_reduction(m)),
+ check_sum_reduction=True,
check_gradgrad=False,
),
dict(
@@ -617,9 +618,9 @@
input_size=(10,),
target_fn=lambda: torch.rand(1).mul(8).floor().long(),
reference_fn=lambda i, t, m:
- multimarginloss_reference(i, t, size_average=get_size_average(m)),
+ multimarginloss_reference(i, t, reduction=get_reduction(m)),
desc='1d',
- check_no_size_average=True,
+ check_sum_reduction=True,
check_gradgrad=False,
),
dict(
@@ -628,9 +629,9 @@
input_fn=lambda: torch.rand(5, 10).clamp_(1e-2, 1 - 1e-2),
target_fn=lambda: torch.rand(5).mul(8).floor().long(),
reference_fn=lambda i, t, m:
- multimarginloss_reference(i, t, p=2, size_average=get_size_average(m)),
+ multimarginloss_reference(i, t, p=2, reduction=get_reduction(m)),
desc='p',
- check_no_size_average=True,
+ check_sum_reduction=True,
check_gradgrad=False,
),
dict(
@@ -640,9 +641,9 @@
input_size=(5, 10),
target_fn=lambda: torch.rand(5).mul(8).floor().long(),
reference_fn=lambda i, t, m:
- multimarginloss_reference(i, t, margin=0.5, size_average=get_size_average(m)),
+ multimarginloss_reference(i, t, margin=0.5, reduction=get_reduction(m)),
desc='margin',
- check_no_size_average=True,
+ check_sum_reduction=True,
check_gradgrad=False,
),
dict(
@@ -652,34 +653,34 @@
input_size=(5, 10),
target_fn=lambda: torch.rand(5).mul(8).floor().long(),
reference_fn=lambda i, t, m:
- multimarginloss_reference(i, t, weight=get_weight(m), size_average=get_size_average(m)),
+ multimarginloss_reference(i, t, weight=get_weight(m), reduction=get_reduction(m)),
desc='weights',
- check_no_size_average=True,
+ check_sum_reduction=True,
check_gradgrad=False,
),
dict(
module_name='SmoothL1Loss',
input_size=(5, 10),
target_size=(5, 10),
- check_no_size_average=True,
+ check_sum_reduction=True,
reference_fn=lambda i, t, m:
- smoothl1loss_reference(i, t, size_average=get_size_average(m)),
+ smoothl1loss_reference(i, t, reduction=get_reduction(m)),
),
dict(
module_name='SoftMarginLoss',
input_size=(5, 5),
target_fn=lambda: torch.randn(5, 5).sign(),
reference_fn=lambda i, t, m:
- softmarginloss_reference(i, t, size_average=get_size_average(m)),
- check_no_size_average=True,
+ softmarginloss_reference(i, t, reduction=get_reduction(m)),
+ check_sum_reduction=True,
),
dict(
module_name='CosineEmbeddingLoss',
input_fn=lambda: (torch.rand(15, 10), torch.rand(15, 10)),
target_fn=lambda: torch.randn(15).sign(),
reference_fn=lambda i, t, m:
- cosineembeddingloss_reference(i[0], i[1], t, size_average=get_size_average(m)),
- check_no_size_average=True,
+ cosineembeddingloss_reference(i[0], i[1], t, reduction=get_reduction(m)),
+ check_sum_reduction=True,
),
dict(
module_name='CosineEmbeddingLoss',
@@ -687,17 +688,17 @@
input_fn=lambda: (torch.rand(15, 10), torch.rand(15, 10)),
target_fn=lambda: torch.randn(15).sign(),
reference_fn=lambda i, t, m:
- cosineembeddingloss_reference(i[0], i[1], t, margin=0.7, size_average=get_size_average(m)),
+ cosineembeddingloss_reference(i[0], i[1], t, margin=0.7, reduction=get_reduction(m)),
desc='margin',
- check_no_size_average=True,
+ check_sum_reduction=True,
),
dict(
module_name='MarginRankingLoss',
input_fn=lambda: (torch.randn(50).mul(10), torch.randn(50).mul(10)),
target_fn=lambda: torch.randn(50).sign(),
reference_fn=lambda i, t, m:
- marginrankingloss_reference(i[0], i[1], t, size_average=get_size_average(m)),
- check_no_size_average=True,
+ marginrankingloss_reference(i[0], i[1], t, reduction=get_reduction(m)),
+ check_sum_reduction=True,
),
dict(
module_name='MarginRankingLoss',
@@ -705,9 +706,9 @@
input_fn=lambda: (torch.randn(50).mul(10), torch.randn(50).mul(10)),
target_fn=lambda: torch.randn(50).sign(),
reference_fn=lambda i, t, m:
- marginrankingloss_reference(i[0], i[1], t, margin=0.5, size_average=get_size_average(m)),
+ marginrankingloss_reference(i[0], i[1], t, margin=0.5, reduction=get_reduction(m)),
desc='margin',
- check_no_size_average=True,
+ check_sum_reduction=True,
),
]
diff --git a/test/test_nn.py b/test/test_nn.py
index 8c52bd1..c25863e 100644
--- a/test/test_nn.py
+++ b/test/test_nn.py
@@ -34,7 +34,7 @@
from common_cuda import TEST_CUDA, TEST_MULTIGPU, TEST_CUDNN, \
TEST_CUDNN_VERSION
from common_nn import NNTestCase, ModuleTest, CriterionTest, TestBase, \
- module_tests, criterion_tests, loss_reference_fns, get_size_average, \
+ module_tests, criterion_tests, loss_reference_fns, get_reduction, \
get_weight, smoothl1loss_reference, kldivloss_reference
@@ -4270,8 +4270,8 @@
self.assertEqual(nn.BCEWithLogitsLoss()(output, target), nn.BCELoss()(sigmoid(output), target))
- self.assertEqual(nn.BCEWithLogitsLoss(reduce=False)(output, target),
- nn.BCELoss(reduce=False)(sigmoid(output), target))
+ self.assertEqual(nn.BCEWithLogitsLoss(reduction='none')(output, target),
+ nn.BCELoss(reduction='none')(sigmoid(output), target))
weight = torch.rand(1, dtype=torch.float)
self.assertEqual(nn.BCEWithLogitsLoss(weight)(output, target), nn.BCELoss(weight)(sigmoid(output), target))
@@ -4279,7 +4279,7 @@
def test_bce_with_logits_has_correct_grad_at_zero(self):
output = torch.zeros(3, 1, requires_grad=True)
target = torch.zeros(3, 1)
- nn.BCEWithLogitsLoss(size_average=False)(output, target).backward()
+ nn.BCEWithLogitsLoss(reduction='sum')(output, target).backward()
expected_grad = torch.empty(3, 1).fill_(0.5)
self.assertEqual(output.grad, expected_grad)
@@ -4330,10 +4330,9 @@
output = torch.zeros(3, 1, requires_grad=True)
target = torch.zeros(3, 1)
pos_weight = torch.ones(3, 1)
- nn.BCEWithLogitsLoss(pos_weight=pos_weight, size_average=False)(output, target).backward()
+ nn.BCEWithLogitsLoss(pos_weight=pos_weight, reduction='sum')(output, target).backward()
expected_grad = torch.empty(3, 1).fill_(0.5)
grad = output.grad
- print(grad)
self.assertEqual(grad, expected_grad)
def test_bce_loss_broadcasts_weights(self):
@@ -4560,36 +4559,37 @@
input2 = torch.randn(15, 10, requires_grad=True)
target = torch.randn(15).sign()
self.assertTrue(gradcheck(lambda x, y, z: F.cosine_embedding_loss(
- x, y, z, reduce=False), (input1, input2, target)))
- self.assertEqual(F.cosine_embedding_loss(input1, input2, target, reduce=False),
- loss_reference_fns['CosineEmbeddingLoss'](input1, input2, target, reduce=False))
+ x, y, z, reduction='none'), (input1, input2, target)))
+ self.assertEqual(F.cosine_embedding_loss(input1, input2, target, reduction='none'),
+ loss_reference_fns['CosineEmbeddingLoss'](input1, input2, target, reduction='none'))
def test_cosine_embedding_loss_margin_no_reduce(self):
input1 = torch.randn(15, 10, requires_grad=True)
input2 = torch.randn(15, 10, requires_grad=True)
target = torch.randn(15).sign()
self.assertTrue(gradcheck(lambda x, y, z: F.cosine_embedding_loss(
- x, y, z, margin=0.5, reduce=False), (input1, input2, target)))
- self.assertEqual(F.cosine_embedding_loss(input1, input2, target, margin=0.5, reduce=False),
- loss_reference_fns['CosineEmbeddingLoss'](input1, input2, target, margin=0.5, reduce=False))
+ x, y, z, margin=0.5, reduction='none'), (input1, input2, target)))
+ self.assertEqual(F.cosine_embedding_loss(input1, input2, target, margin=0.5, reduction='none'),
+ loss_reference_fns['CosineEmbeddingLoss'](input1, input2, target,
+ margin=0.5, reduction='none'))
def test_margin_ranking_loss_no_reduce(self):
input1 = torch.tensor(torch.randn(15).mul(10), requires_grad=True)
input2 = torch.tensor(torch.randn(15).mul(10), requires_grad=True)
target = torch.randn(15).sign()
self.assertTrue(gradcheck(lambda x, y, z: F.margin_ranking_loss(
- x, y, z, reduce=False), (input1, input2, target)))
- self.assertEqual(F.margin_ranking_loss(input1, input2, target, reduce=False),
- loss_reference_fns['MarginRankingLoss'](input1, input2, target, reduce=False))
+ x, y, z, reduction='none'), (input1, input2, target)))
+ self.assertEqual(F.margin_ranking_loss(input1, input2, target, reduction='none'),
+ loss_reference_fns['MarginRankingLoss'](input1, input2, target, reduction='none'))
def test_margin_ranking_loss_margin_no_reduce(self):
input1 = torch.tensor(torch.randn(15).mul(10), requires_grad=True)
input2 = torch.tensor(torch.randn(15).mul(10), requires_grad=True)
target = torch.randn(15).sign()
self.assertTrue(gradcheck(lambda x, y, z: F.margin_ranking_loss(
- x, y, z, margin=0.5, reduce=False), (input1, input2, target)))
- self.assertEqual(F.margin_ranking_loss(input1, input2, target, margin=0.5, reduce=False),
- loss_reference_fns['MarginRankingLoss'](input1, input2, target, margin=0.5, reduce=False))
+ x, y, z, margin=0.5, reduction='none'), (input1, input2, target)))
+ self.assertEqual(F.margin_ranking_loss(input1, input2, target, margin=0.5, reduction='none'),
+ loss_reference_fns['MarginRankingLoss'](input1, input2, target, margin=0.5, reduction='none'))
def test_triplet_margin_loss(self):
input1 = torch.randn(5, 10, requires_grad=True)
@@ -4614,39 +4614,18 @@
input2 = torch.randn(5, 10, requires_grad=True)
input3 = torch.randn(5, 10, requires_grad=True)
self.assertTrue(gradcheck(lambda x1, x2, x3: F.triplet_margin_loss(
- x1, x2, x3, reduce=False), (input1, input2, input3)))
- self.assertEqual(F.triplet_margin_loss(input1, input2, input3, reduce=False),
- loss_reference_fns['TripletMarginLoss'](input1, input2, input3, reduce=False))
+ x1, x2, x3, reduction='none'), (input1, input2, input3)))
+ self.assertEqual(F.triplet_margin_loss(input1, input2, input3, reduction='none'),
+ loss_reference_fns['TripletMarginLoss'](input1, input2, input3, reduction='none'))
def test_triplet_margin_loss_swap_no_reduce(self):
input1 = torch.randn(5, 10, requires_grad=True)
input2 = torch.randn(5, 10, requires_grad=True)
input3 = torch.randn(5, 10, requires_grad=True)
self.assertTrue(gradcheck(lambda x1, x2, x3: F.triplet_margin_loss(
- x1, x2, x3, swap=True, reduce=False), (input1, input2, input3)))
- self.assertEqual(F.triplet_margin_loss(input1, input2, input3, swap=True, reduce=False),
- loss_reference_fns['TripletMarginLoss'](input1, input2, input3, swap=True, reduce=False))
-
- def test_loss_reduction_arg(self):
- # NB: This is a sanity test to check that the new reduction arg works the same as size_average and reduce
- # Remove this when size_average and reduce are deprecated and tests are ported to the new arg
- input1 = torch.randn(5, 10, requires_grad=True)
- input2 = torch.randn(5, 10, requires_grad=True)
- input3 = torch.randn(5, 10, requires_grad=True)
- self.assertTrue(gradcheck(lambda x1, x2, x3: F.triplet_margin_loss(
- x1, x2, x3, reduction='elementwise_mean'), (input1, input2, input3)))
- self.assertEqual(F.triplet_margin_loss(input1, input2, input3, reduction='elementwise_mean'),
- loss_reference_fns['TripletMarginLoss'](input1, input2, input3))
-
- self.assertTrue(gradcheck(lambda x1, x2, x3: F.triplet_margin_loss(
- x1, x2, x3, reduction='sum'), (input1, input2, input3)))
- self.assertEqual(F.triplet_margin_loss(input1, input2, input3, reduction='sum'),
- loss_reference_fns['TripletMarginLoss'](input1, input2, input3, size_average=False))
-
- self.assertTrue(gradcheck(lambda x1, x2, x3: F.triplet_margin_loss(
- x1, x2, x3, reduction='none'), (input1, input2, input3)))
- self.assertEqual(F.triplet_margin_loss(input1, input2, input3, reduction='none'),
- loss_reference_fns['TripletMarginLoss'](input1, input2, input3, reduce=False))
+ x1, x2, x3, swap=True, reduction='none'), (input1, input2, input3)))
+ self.assertEqual(F.triplet_margin_loss(input1, input2, input3, swap=True, reduction='none'),
+ loss_reference_fns['TripletMarginLoss'](input1, input2, input3, swap=True, reduction='none'))
def test_cosine_similarity(self):
input1 = torch.randn(4, 4, requires_grad=True)
@@ -5774,8 +5753,8 @@
input_size=(2, 3, 5, 5),
target_fn=lambda: torch.rand(2, 5, 5).mul(3).floor().long(),
reference_fn=lambda i, t, m:
- loss_reference_fns['NLLLossNd'](i, t, size_average=get_size_average(m)),
- check_no_size_average=True,
+ loss_reference_fns['NLLLossNd'](i, t, reduction=get_reduction(m)),
+ check_sum_reduction=True,
desc='2d'
),
dict(
@@ -5789,7 +5768,7 @@
),
dict(
module_name='NLLLoss',
- constructor_args=(None, True, 1),
+ constructor_args=(None, None, 1),
input_size=(2, 3, 5, 5),
target_fn=lambda: torch.rand(2, 5, 5).mul(3).floor().long(),
reference_fn=lambda i, t, m:
@@ -5801,8 +5780,8 @@
input_size=(2, 3, 5, 5, 2, 2),
target_fn=lambda: torch.rand(2, 5, 5, 2, 2).mul(3).floor().long(),
reference_fn=lambda i, t, m:
- loss_reference_fns['NLLLossNd'](i, t, size_average=get_size_average(m)),
- check_no_size_average=True,
+ loss_reference_fns['NLLLossNd'](i, t, reduction=get_reduction(m)),
+ check_sum_reduction=True,
desc='higher_dim'
),
dict(
@@ -5810,8 +5789,8 @@
input_size=(2, 3, 5),
target_fn=lambda: torch.rand(2, 5).mul(3).floor().long(),
reference_fn=lambda i, t, m:
- loss_reference_fns['NLLLossNd'](i, t, size_average=get_size_average(m)),
- check_no_size_average=True,
+ loss_reference_fns['NLLLossNd'](i, t, reduction=get_reduction(m)),
+ check_sum_reduction=True,
desc='dim_is_3'
),
dict(
@@ -5822,7 +5801,7 @@
),
dict(
module_name='PoissonNLLLoss',
- constructor_args=(False, True, True),
+ constructor_args=(False,),
input_fn=lambda: torch.randn(2, 3, 4, 5).abs_().add_(0.001),
target_fn=lambda: torch.randn(2, 3, 4, 5).floor_().abs_(),
desc='full_loss', # with sterling approx
@@ -5839,16 +5818,17 @@
input_fn=lambda: torch.rand(()).log(),
target_fn=lambda: torch.rand(()),
reference_fn=lambda i, t, m:
- kldivloss_reference(i, t, get_size_average(m), reduce=True),
- check_no_size_average=True,
+ kldivloss_reference(i, t, get_reduction(m)),
+ check_sum_reduction=True,
desc='scalar',
),
dict(
module_name='MSELoss',
input_size=(),
target_size=(),
- reference_fn=lambda i, t, m: (i - t).abs().pow(2).sum() / (i.numel() if get_size_average(m) else 1),
- check_no_size_average=True,
+ reference_fn=lambda i, t, m: ((i - t).abs().pow(2).sum() /
+ (i.numel() if get_reduction(m) == 'elementwise_mean' else 1)),
+ check_sum_reduction=True,
desc='scalar'
),
dict(
@@ -5857,7 +5837,7 @@
input_fn=lambda: torch.rand(()).clamp_(1e-2, 1 - 1e-2),
target_fn=lambda: torch.rand(()).gt(0).double(),
reference_fn=lambda i, t, m: -((t * i.log() + (1 - t) * (1 - i).log()) * get_weight(m)).sum() /
- (i.numel() if get_size_average(m) else 1),
+ (i.numel() if get_reduction(m) == 'elementwise_mean' else 1),
desc='scalar_weights',
check_gradgrad=False,
),
@@ -5867,15 +5847,15 @@
input_size=(),
target_fn=lambda: torch.randn(()).gt(0).double().mul_(2).sub(1),
desc='scalar_margin',
- check_no_size_average=True,
+ check_sum_reduction=True,
),
dict(
module_name='SmoothL1Loss',
input_size=(),
target_size=(),
- check_no_size_average=True,
+ check_sum_reduction=True,
reference_fn=lambda i, t, m:
- smoothl1loss_reference(i, t, size_average=get_size_average(m)),
+ smoothl1loss_reference(i, t, reduction=get_reduction(m)),
desc='scalar',
),
dict(
@@ -5884,9 +5864,9 @@
input_fn=lambda: torch.randn(5, 10),
target_fn=lambda: torch.rand(5, 10).mul(2).floor(),
reference_fn=lambda i, t, m: -((t * i.sigmoid().log() + (1 - t) * (-i).sigmoid().log()) * get_weight(m)).sum() /
- (i.numel() if get_size_average(m) else 1),
+ (i.numel() if get_reduction(m) == 'elementwise_mean' else 1),
desc='weights',
- check_no_size_average=True,
+ check_sum_reduction=True,
check_gradgrad=False,
),
]
@@ -5897,7 +5877,7 @@
return dict(
fullname='PoissonNLLLLoss_no_reduce',
constructor=wrap_functional(
- lambda i: F.poisson_nll_loss(i, t.type_as(i), reduce=False)),
+ lambda i: F.poisson_nll_loss(i, t.type_as(i), reduction='none')),
input_fn=lambda: torch.rand(10, 10),
pickle=False)
@@ -5907,7 +5887,7 @@
return dict(
fullname='BCELoss_no_reduce',
constructor=wrap_functional(
- lambda i: F.binary_cross_entropy(i, t.type_as(i), reduce=False)),
+ lambda i: F.binary_cross_entropy(i, t.type_as(i), reduction='none')),
input_fn=lambda: torch.rand(15, 10).clamp_(2.8e-2, 1 - 2.8e-2),
reference_fn=lambda i, m: -(t * i.log() + (1 - t) * (1 - i).log()),
check_gradgrad=False,
@@ -5919,7 +5899,7 @@
return dict(
fullname='BCELoss_no_reduce_scalar',
constructor=wrap_functional(
- lambda i: F.binary_cross_entropy(i, t.type_as(i), reduce=False)),
+ lambda i: F.binary_cross_entropy(i, t.type_as(i), reduction='none')),
input_fn=lambda: torch.rand(()).clamp_(2.8e-2, 1 - 2.8e-2),
reference_fn=lambda i, m: -(t * i.log() + (1 - t) * (1 - i).log()),
check_gradgrad=False,
@@ -5933,7 +5913,7 @@
fullname='BCELoss_weights_no_reduce',
constructor=wrap_functional(
lambda i: F.binary_cross_entropy(i, t.type_as(i),
- weight=weights.type_as(i), reduce=False)),
+ weight=weights.type_as(i), reduction='none')),
input_fn=lambda: torch.rand(15, 10).clamp_(2.8e-2, 1 - 2.8e-2),
reference_fn=lambda i, m: -(t * i.log() + (1 - t) * (1 - i).log()) * weights,
check_gradgrad=False,
@@ -5947,7 +5927,7 @@
fullname='BCELoss_weights_no_reduce_scalar',
constructor=wrap_functional(
lambda i: F.binary_cross_entropy(i, t.type_as(i),
- weight=weights.type_as(i), reduce=False)),
+ weight=weights.type_as(i), reduction='none')),
input_fn=lambda: torch.rand(()).clamp_(2.8e-2, 1 - 2.8e-2),
reference_fn=lambda i, m: -(t * i.log() + (1 - t) * (1 - i).log()) * weights,
check_gradgrad=False,
@@ -5960,7 +5940,7 @@
return dict(
fullname='BCEWithLogitsLoss_no_reduce',
constructor=wrap_functional(
- lambda i: F.binary_cross_entropy_with_logits(i, t.type_as(i), reduce=False)),
+ lambda i: F.binary_cross_entropy_with_logits(i, t.type_as(i), reduction='none')),
input_fn=lambda: torch.rand(15, 10).clamp_(2.8e-2, 1 - 2.8e-2),
reference_fn=lambda i, m: -(t * sigmoid(i).log() + (1 - t) * (1 - sigmoid(i)).log()),
check_gradgrad=False,
@@ -5973,7 +5953,7 @@
return dict(
fullname='BCEWithLogitsLoss_no_reduce_scalar',
constructor=wrap_functional(
- lambda i: F.binary_cross_entropy_with_logits(i, t.type_as(i), reduce=False)),
+ lambda i: F.binary_cross_entropy_with_logits(i, t.type_as(i), reduction='none')),
input_fn=lambda: torch.rand(()).clamp_(2.8e-2, 1 - 2.8e-2),
reference_fn=lambda i, m: -(t * sigmoid(i).log() + (1 - t) * (1 - sigmoid(i)).log()),
check_gradgrad=False,
@@ -5985,10 +5965,10 @@
return dict(
fullname='KLDivLoss_with_target_no_reduce',
constructor=wrap_functional(
- lambda t: F.kl_div(i.type_as(t), t, reduce=False)),
+ lambda t: F.kl_div(i.type_as(t), t, reduction='none')),
input_fn=lambda: torch.rand(10, 10),
reference_fn=lambda t, _:
- loss_reference_fns['KLDivLoss'](i.type_as(t), t, reduce=False),
+ loss_reference_fns['KLDivLoss'](i.type_as(t), t, reduction='none'),
pickle=False)
@@ -5997,10 +5977,10 @@
return dict(
fullname='KLDivLoss_no_reduce',
constructor=wrap_functional(
- lambda i: F.kl_div(i, t.type_as(i), reduce=False)),
+ lambda i: F.kl_div(i, t.type_as(i), reduction='none')),
input_fn=lambda: torch.rand(10, 10).log(),
reference_fn=lambda i, _:
- loss_reference_fns['KLDivLoss'](i, t.type_as(i), reduce=False),
+ loss_reference_fns['KLDivLoss'](i, t.type_as(i), reduction='none'),
pickle=False)
@@ -6009,10 +5989,10 @@
return dict(
fullname='KLDivLoss_no_reduce_scalar',
constructor=wrap_functional(
- lambda i: F.kl_div(i, t.type_as(i), reduce=False)),
+ lambda i: F.kl_div(i, t.type_as(i), reduction='none')),
input_fn=lambda: torch.rand(()).log(),
reference_fn=lambda i, _:
- loss_reference_fns['KLDivLoss'](i, t.type_as(i), reduce=False),
+ loss_reference_fns['KLDivLoss'](i, t.type_as(i), reduction='none'),
pickle=False)
@@ -6021,7 +6001,7 @@
return dict(
fullname='L1Loss_no_reduce',
constructor=wrap_functional(
- lambda i: F.l1_loss(i, t.type_as(i), reduce=False)),
+ lambda i: F.l1_loss(i, t.type_as(i), reduction='none')),
input_fn=lambda: torch.randn(2, 3, 4),
reference_fn=lambda i, m: (i - t.type_as(i)).abs(),
pickle=False)
@@ -6032,7 +6012,7 @@
return dict(
fullname='L1Loss_no_reduce_scalar',
constructor=wrap_functional(
- lambda i: F.l1_loss(i, t.type_as(i), reduce=False)),
+ lambda i: F.l1_loss(i, t.type_as(i), reduction='none')),
input_fn=lambda: torch.randn(()),
reference_fn=lambda i, m: (i - t.type_as(i)).abs(),
pickle=False)
@@ -6044,7 +6024,7 @@
return dict(
fullname='MSELoss_no_reduce',
constructor=wrap_functional(
- lambda i: F.mse_loss(i, target.type_as(i), reduce=False)),
+ lambda i: F.mse_loss(i, target.type_as(i), reduction='none')),
input_size=input_size,
reference_fn=lambda i, m: (i - target).pow(2),
pickle=False)
@@ -6056,7 +6036,7 @@
return dict(
fullname='MSELoss_no_reduce_scalar',
constructor=wrap_functional(
- lambda i: F.mse_loss(i, target.type_as(i), reduce=False)),
+ lambda i: F.mse_loss(i, target.type_as(i), reduction='none')),
input_size=input_size,
reference_fn=lambda i, m: (i - target).pow(2),
pickle=False)
@@ -6064,7 +6044,7 @@
def nllloss_no_reduce_test():
t = Variable(torch.Tensor(15).uniform_().mul(10).floor().long())
- kwargs = {'reduce': False}
+ kwargs = {'reduction': 'none'}
return dict(
fullname='NLLLoss_no_reduce',
constructor=wrap_functional(
@@ -6077,7 +6057,7 @@
def nllloss_no_reduce_ignore_index_test():
t = Variable(torch.Tensor(15).uniform_().mul(10).floor().long())
- kwargs = {'ignore_index': 2, 'reduce': False}
+ kwargs = {'ignore_index': 2, 'reduction': 'none'}
return dict(
fullname='NLLLoss_no_reduce_ignore_index',
constructor=wrap_functional(
@@ -6093,7 +6073,7 @@
weight = torch.rand(10)
def kwargs(i):
- return {'weight': weight.type_as(i), 'reduce': False}
+ return {'weight': weight.type_as(i), 'reduction': 'none'}
return dict(
fullname='NLLLoss_no_reduce_weights',
@@ -6110,7 +6090,7 @@
weight = torch.rand(10)
def kwargs(i):
- return {'weight': weight.type_as(i), 'reduce': False,
+ return {'weight': weight.type_as(i), 'reduction': 'none',
'ignore_index': 2}
return dict(
@@ -6128,7 +6108,7 @@
weight = torch.rand(10)
def kwargs(i):
- return {'weight': weight.type_as(i), 'reduce': False,
+ return {'weight': weight.type_as(i), 'reduction': 'none',
'ignore_index': -1}
return dict(
@@ -6143,7 +6123,7 @@
def nllloss2d_no_reduce_test():
t = Variable(torch.rand(2, 5, 5).mul(3).floor().long())
- kwargs = {'reduce': False}
+ kwargs = {'reduction': 'none'}
return dict(
fullname='NLLLoss2d_no_reduce',
constructor=wrap_functional(
@@ -6156,7 +6136,7 @@
def nllloss2d_no_reduce_ignore_index_test():
t = Variable(torch.rand(2, 5, 5).mul(3).floor().long())
- kwargs = {'ignore_index': 1, 'reduce': False}
+ kwargs = {'ignore_index': 1, 'reduction': 'none'}
return dict(
fullname='NLLLoss2d_no_reduce_ignore_index',
constructor=wrap_functional(
@@ -6172,7 +6152,7 @@
weight = torch.rand(3)
def kwargs(i):
- return {'weight': weight.type_as(i), 'reduce': False}
+ return {'weight': weight.type_as(i), 'reduction': 'none'}
return dict(
fullname='NLLLoss2d_no_reduce_weights',
@@ -6186,7 +6166,7 @@
def nlllossNd_no_reduce_test():
t = Variable(torch.rand(2, 5, 5, 2, 2).mul(3).floor().long())
- kwargs = {'reduce': False}
+ kwargs = {'reduction': 'none'}
return dict(
fullname='NLLLossNd_no_reduce',
constructor=wrap_functional(
@@ -6199,7 +6179,7 @@
def nlllossNd_no_reduce_ignore_index_test():
t = Variable(torch.rand(2, 5, 5, 2, 2).mul(3).floor().long())
- kwargs = {'ignore_index': 1, 'reduce': False}
+ kwargs = {'ignore_index': 1, 'reduction': 'none'}
return dict(
fullname='NLLLossNd_no_reduce_ignore_index',
constructor=wrap_functional(
@@ -6215,7 +6195,7 @@
weight = torch.rand(3)
def kwargs(i):
- return {'weight': weight.type_as(i), 'reduce': False}
+ return {'weight': weight.type_as(i), 'reduction': 'none'}
return dict(
fullname='NLLLossNd_no_reduce_weights',
@@ -6232,10 +6212,10 @@
return dict(
fullname='SmoothL1Loss_no_reduce',
constructor=wrap_functional(
- lambda i: F.smooth_l1_loss(i, t.type_as(i), reduce=False)),
+ lambda i: F.smooth_l1_loss(i, t.type_as(i), reduction='none')),
input_fn=lambda: torch.randn(2, 3, 4),
reference_fn=lambda i, _:
- loss_reference_fns['SmoothL1Loss'](i, t.type_as(i), reduce=False),
+ loss_reference_fns['SmoothL1Loss'](i, t.type_as(i), reduction='none'),
pickle=False)
@@ -6244,10 +6224,10 @@
return dict(
fullname='SmoothL1Loss_no_reduce_scalar',
constructor=wrap_functional(
- lambda i: F.smooth_l1_loss(i, t.type_as(i), reduce=False)),
+ lambda i: F.smooth_l1_loss(i, t.type_as(i), reduction='none')),
input_fn=lambda: torch.randn(()),
reference_fn=lambda i, _:
- loss_reference_fns['SmoothL1Loss'](i, t.type_as(i), reduce=False),
+ loss_reference_fns['SmoothL1Loss'](i, t.type_as(i), reduction='none'),
pickle=False)
@@ -6256,11 +6236,11 @@
return dict(
fullname='MultiLabelMarginLoss_1d_no_reduce',
constructor=wrap_functional(
- lambda i: F.multilabel_margin_loss(i, t.type_as(i).long(), reduce=False)),
+ lambda i: F.multilabel_margin_loss(i, t.type_as(i).long(), reduction='none')),
input_fn=lambda: torch.randn(10),
reference_fn=lambda i, _:
- loss_reference_fns['MultiLabelMarginLoss'](i, t.data.type_as(i).long(), reduce=False),
- check_no_size_average=True,
+ loss_reference_fns['MultiLabelMarginLoss'](i, t.data.type_as(i).long(), reduction='none'),
+ check_sum_reduction=True,
check_gradgrad=False,
pickle=False)
@@ -6270,11 +6250,11 @@
return dict(
fullname='MultiLabelMarginLoss_index_neg',
constructor=wrap_functional(
- lambda i: F.multilabel_margin_loss(i, t.type_as(i).long(), reduce=False)),
+ lambda i: F.multilabel_margin_loss(i, t.type_as(i).long(), reduction='none')),
input_fn=lambda: torch.randn(5, 10),
reference_fn=lambda i, _:
- loss_reference_fns['MultiLabelMarginLoss'](i, t.data.type_as(i).long(), reduce=False),
- check_no_size_average=True,
+ loss_reference_fns['MultiLabelMarginLoss'](i, t.data.type_as(i).long(), reduction='none'),
+ check_sum_reduction=True,
check_gradgrad=False,
pickle=False)
@@ -6284,11 +6264,11 @@
return dict(
fullname='MultiLabelMarginLoss_no_reduce',
constructor=wrap_functional(
- lambda i: F.multilabel_margin_loss(i, t.type_as(i).long(), reduce=False)),
+ lambda i: F.multilabel_margin_loss(i, t.type_as(i).long(), reduction='none')),
input_fn=lambda: torch.randn(5, 10),
reference_fn=lambda i, _:
- loss_reference_fns['MultiLabelMarginLoss'](i, t.data.type_as(i).long(), reduce=False),
- check_no_size_average=True,
+ loss_reference_fns['MultiLabelMarginLoss'](i, t.data.type_as(i).long(), reduction='none'),
+ check_sum_reduction=True,
check_gradgrad=False,
pickle=False)
@@ -6298,11 +6278,11 @@
return dict(
fullname='HingeEmbeddingLoss_no_reduce',
constructor=wrap_functional(
- lambda i: F.hinge_embedding_loss(i, t.type_as(i), reduce=False)),
+ lambda i: F.hinge_embedding_loss(i, t.type_as(i), reduction='none')),
input_fn=lambda: torch.randn(10),
reference_fn=lambda i, _:
- loss_reference_fns['HingeEmbeddingLoss'](i, t.type_as(i), reduce=False),
- check_no_size_average=True,
+ loss_reference_fns['HingeEmbeddingLoss'](i, t.type_as(i), reduction='none'),
+ check_sum_reduction=True,
pickle=False)
@@ -6311,11 +6291,11 @@
return dict(
fullname='HingeEmbeddingLoss_margin_no_reduce',
constructor=wrap_functional(
- lambda i: F.hinge_embedding_loss(i, t.type_as(i), margin=0.5, reduce=False)),
+ lambda i: F.hinge_embedding_loss(i, t.type_as(i), margin=0.5, reduction='none')),
input_fn=lambda: torch.randn(10),
reference_fn=lambda i, _:
- loss_reference_fns['HingeEmbeddingLoss'](i, t.type_as(i), margin=0.5, reduce=False),
- check_no_size_average=True,
+ loss_reference_fns['HingeEmbeddingLoss'](i, t.type_as(i), margin=0.5, reduction='none'),
+ check_sum_reduction=True,
pickle=False)
@@ -6324,11 +6304,10 @@
return dict(
fullname='SoftMarginLoss_no_reduce',
constructor=wrap_functional(
- lambda i: F.soft_margin_loss(i, t.type_as(i), reduce=False)),
+ lambda i: F.soft_margin_loss(i, t.type_as(i), reduction='none')),
input_fn=lambda: torch.randn(5, 5),
reference_fn=lambda i, _:
- loss_reference_fns['SoftMarginLoss'](i, t.type_as(i), reduce=False),
- check_no_size_average=True,
+ loss_reference_fns['SoftMarginLoss'](i, t.type_as(i), reduction='none'),
pickle=False)
@@ -6337,11 +6316,9 @@
return dict(
fullname='MultiLabelSoftMarginLoss_no_reduce',
constructor=wrap_functional(
- lambda i: F.multilabel_soft_margin_loss(i, t.type_as(i), reduce=False)),
+ lambda i: F.multilabel_soft_margin_loss(i, t.type_as(i), reduction='none')),
input_fn=lambda: torch.randn(5, 10),
- reference_fn=lambda i, m: (-(t * i.sigmoid().log() + (1 - t) * (-i).sigmoid().log()) /
- (i.numel() if get_size_average(m) else 1)),
- check_no_size_average=True,
+ reference_fn=lambda i, m: -(t * i.sigmoid().log() + (1 - t) * (-i).sigmoid().log()),
check_gradgrad=False,
pickle=False)
@@ -6353,11 +6330,10 @@
fullname='MultiLabelSoftMarginLoss_weights_no_reduce',
constructor=wrap_functional(
lambda i: F.multilabel_soft_margin_loss(i, t.type_as(i),
- weight=weights.type_as(i), reduce=False)),
+ weight=weights.type_as(i), reduction='none')),
input_fn=lambda: torch.randn(5, 10),
- reference_fn=lambda i, m: (-((t * i.sigmoid().log() + (1 - t) * (-i).sigmoid().log()) * weights) /
- (i.numel() if get_size_average(m) else 1)),
- check_no_size_average=True,
+ reference_fn=lambda i, m: -((t * i.sigmoid().log() + (1 - t) * (-i).sigmoid().log()) * weights),
+ check_sum_reduction=True,
check_gradgrad=False,
pickle=False)
@@ -6367,11 +6343,11 @@
return dict(
fullname='MultiMarginLoss_no_reduce',
constructor=wrap_functional(
- lambda i: F.multi_margin_loss(i, t.type_as(i).long(), reduce=False)),
+ lambda i: F.multi_margin_loss(i, t.type_as(i).long(), reduction='none')),
input_fn=lambda: torch.randn(5, 10),
reference_fn=lambda i, _:
- loss_reference_fns['MultiMarginLoss'](i, t.data.type_as(i).long(), reduce=False),
- check_no_size_average=True,
+ loss_reference_fns['MultiMarginLoss'](i, t.data.type_as(i).long(), reduction='none'),
+ check_sum_reduction=True,
check_gradgrad=False,
pickle=False)
@@ -6381,11 +6357,11 @@
return dict(
fullname='MultiMarginLoss_1d_no_reduce',
constructor=wrap_functional(
- lambda i: F.multi_margin_loss(i, t.type_as(i).long(), reduce=False)),
+ lambda i: F.multi_margin_loss(i, t.type_as(i).long(), reduction='none')),
input_fn=lambda: torch.randn(10),
reference_fn=lambda i, _:
- loss_reference_fns['MultiMarginLoss'](i, t.data.type_as(i).long(), reduce=False),
- check_no_size_average=True,
+ loss_reference_fns['MultiMarginLoss'](i, t.data.type_as(i).long(), reduction='none'),
+ check_sum_reduction=True,
check_gradgrad=False,
pickle=False)
@@ -6395,11 +6371,11 @@
return dict(
fullname='MultiMarginLoss_p_no_reduce',
constructor=wrap_functional(
- lambda i: F.multi_margin_loss(i, t.type_as(i).long(), p=2, reduce=False)),
+ lambda i: F.multi_margin_loss(i, t.type_as(i).long(), p=2, reduction='none')),
input_fn=lambda: torch.randn(5, 10).clamp_(1e-2, 1 - 1e-2),
reference_fn=lambda i, _:
- loss_reference_fns['MultiMarginLoss'](i, t.data.type_as(i).long(), p=2, reduce=False),
- check_no_size_average=True,
+ loss_reference_fns['MultiMarginLoss'](i, t.data.type_as(i).long(), p=2, reduction='none'),
+ check_sum_reduction=True,
check_gradgrad=False,
pickle=False)
@@ -6409,12 +6385,12 @@
return dict(
fullname='MultiMarginLoss_margin_no_reduce',
constructor=wrap_functional(
- lambda i: F.multi_margin_loss(i, t.type_as(i).long(), margin=0.5, reduce=False)),
+ lambda i: F.multi_margin_loss(i, t.type_as(i).long(), margin=0.5, reduction='none')),
input_fn=lambda: torch.randn(5, 10),
reference_fn=lambda i, _:
loss_reference_fns['MultiMarginLoss'](i, t.data.type_as(i).long(),
- margin=0.5, reduce=False),
- check_no_size_average=True,
+ margin=0.5, reduction='none'),
+ check_sum_reduction=True,
check_gradgrad=False,
pickle=False)
@@ -6426,12 +6402,12 @@
fullname='MultiMarginLoss_weights_no_reduce',
constructor=wrap_functional(
lambda i: F.multi_margin_loss(i, t.type_as(i).long(), weight=weights.type_as(i),
- reduce=False)),
+ reduction='none')),
input_fn=lambda: torch.randn(5, 10),
reference_fn=lambda i, _:
loss_reference_fns['MultiMarginLoss'](i, t.data.type_as(i).long(),
- weight=weights, reduce=False),
- check_no_size_average=True,
+ weight=weights, reduction='none'),
+ check_sum_reduction=True,
check_gradgrad=False,
pickle=False)
@@ -7734,18 +7710,18 @@
test = NewCriterionTest(**test_params)
decorator = test_params.pop('decorator', None)
add_test(test, decorator)
- if 'check_no_size_average' in test_params:
+ if 'check_sum_reduction' in test_params:
desc = test_params.get('desc', None)
- test_params['desc'] = 'no_size_average' if desc is None else desc + '_no_size_average'
+ test_params['desc'] = 'sum_reduction' if desc is None else desc + '_sum_reduction'
- def gen_no_size_average_constructor(constructor):
- def no_size_average_constructor(*args, **kwargs):
- cons = constructor(*args, size_average=False, **kwargs)
+ def gen_sum_reduction_constructor(constructor):
+ def sum_reduction_constructor(*args, **kwargs):
+ cons = constructor(*args, reduction='sum', **kwargs)
return cons
- no_size_average_constructor.__name__ = constructor.__name__
- return no_size_average_constructor
+ sum_reduction_constructor.__name__ = constructor.__name__
+ return sum_reduction_constructor
- test_params['constructor'] = gen_no_size_average_constructor(test_params['constructor'])
+ test_params['constructor'] = gen_sum_reduction_constructor(test_params['constructor'])
test = NewCriterionTest(**test_params)
add_test(test, decorator)
diff --git a/torch/legacy/nn/AbsCriterion.py b/torch/legacy/nn/AbsCriterion.py
index 4a2ea68..66f7615 100644
--- a/torch/legacy/nn/AbsCriterion.py
+++ b/torch/legacy/nn/AbsCriterion.py
@@ -18,7 +18,7 @@
input,
target,
self.output_tensor,
- _Reduction.legacy_get_enum(self.sizeAverage, True),
+ _Reduction.legacy_get_enum(self.sizeAverage, True, emit_warning=False),
)
self.output = self.output_tensor[0].item()
return self.output
@@ -31,6 +31,6 @@
target,
implicit_gradOutput,
self.gradInput,
- _Reduction.legacy_get_enum(self.sizeAverage, True),
+ _Reduction.legacy_get_enum(self.sizeAverage, True, emit_warning=False),
)
return self.gradInput
diff --git a/torch/legacy/nn/ClassNLLCriterion.py b/torch/legacy/nn/ClassNLLCriterion.py
index 50ddcfd..33c28e5 100644
--- a/torch/legacy/nn/ClassNLLCriterion.py
+++ b/torch/legacy/nn/ClassNLLCriterion.py
@@ -25,7 +25,7 @@
input,
target,
self.output_tensor,
- _Reduction.legacy_get_enum(self.sizeAverage, True),
+ _Reduction.legacy_get_enum(self.sizeAverage, True, emit_warning=False),
self.weights,
self.total_weight_tensor,
self.ignore_index,
@@ -44,7 +44,7 @@
target,
implicit_gradOutput,
self.gradInput,
- _Reduction.legacy_get_enum(self.sizeAverage, True),
+ _Reduction.legacy_get_enum(self.sizeAverage, True, emit_warning=False),
self.weights,
self.total_weight_tensor,
self.ignore_index,
diff --git a/torch/legacy/nn/ClassSimplexCriterion.py b/torch/legacy/nn/ClassSimplexCriterion.py
index b28ce67..1de5851 100644
--- a/torch/legacy/nn/ClassSimplexCriterion.py
+++ b/torch/legacy/nn/ClassSimplexCriterion.py
@@ -81,7 +81,7 @@
input,
self._target,
self.output_tensor,
- _Reduction.legacy_get_enum(self.sizeAverage, True),
+ _Reduction.legacy_get_enum(self.sizeAverage, True, emit_warning=False),
)
self.output = self.output_tensor[0].item()
return self.output
@@ -95,7 +95,7 @@
self._target,
implicit_gradOutput,
self.gradInput,
- _Reduction.legacy_get_enum(self.sizeAverage, True),
+ _Reduction.legacy_get_enum(self.sizeAverage, True, emit_warning=False),
)
return self.gradInput
diff --git a/torch/legacy/nn/DistKLDivCriterion.py b/torch/legacy/nn/DistKLDivCriterion.py
index 8c18cf1..5aa1756 100644
--- a/torch/legacy/nn/DistKLDivCriterion.py
+++ b/torch/legacy/nn/DistKLDivCriterion.py
@@ -19,7 +19,7 @@
input,
target,
self.output_tensor,
- _Reduction.legacy_get_enum(self.sizeAverage, True),
+ _Reduction.legacy_get_enum(self.sizeAverage, True, emit_warning=False),
)
self.output = self.output_tensor[0].item()
return self.output
@@ -33,6 +33,6 @@
target,
implicit_gradOutput,
self.gradInput,
- _Reduction.legacy_get_enum(self.sizeAverage, True),
+ _Reduction.legacy_get_enum(self.sizeAverage, True, emit_warning=False),
)
return self.gradInput
diff --git a/torch/legacy/nn/MSECriterion.py b/torch/legacy/nn/MSECriterion.py
index 2422e07..2079d36 100644
--- a/torch/legacy/nn/MSECriterion.py
+++ b/torch/legacy/nn/MSECriterion.py
@@ -18,7 +18,7 @@
input,
target,
self.output_tensor,
- _Reduction.legacy_get_enum(self.sizeAverage, True),
+ _Reduction.legacy_get_enum(self.sizeAverage, True, emit_warning=False),
)
self.output = self.output_tensor[0].item()
return self.output
@@ -32,6 +32,6 @@
target,
implicit_gradOutput,
self.gradInput,
- _Reduction.legacy_get_enum(self.sizeAverage, True),
+ _Reduction.legacy_get_enum(self.sizeAverage, True, emit_warning=False),
)
return self.gradInput
diff --git a/torch/legacy/nn/MultiLabelMarginCriterion.py b/torch/legacy/nn/MultiLabelMarginCriterion.py
index 1de12bf..9ca2a23 100644
--- a/torch/legacy/nn/MultiLabelMarginCriterion.py
+++ b/torch/legacy/nn/MultiLabelMarginCriterion.py
@@ -21,7 +21,7 @@
target,
self.output_tensor,
self.isTarget,
- _Reduction.legacy_get_enum(self.sizeAverage, True),
+ _Reduction.legacy_get_enum(self.sizeAverage, True, emit_warning=False),
)
self.output = self.output_tensor[0].item()
return self.output
@@ -36,6 +36,6 @@
implicit_gradOutput,
self.gradInput,
self.isTarget,
- _Reduction.legacy_get_enum(self.sizeAverage, True),
+ _Reduction.legacy_get_enum(self.sizeAverage, True, emit_warning=False),
)
return self.gradInput
diff --git a/torch/legacy/nn/MultiMarginCriterion.py b/torch/legacy/nn/MultiMarginCriterion.py
index 26b9cff..cc9835c 100644
--- a/torch/legacy/nn/MultiMarginCriterion.py
+++ b/torch/legacy/nn/MultiMarginCriterion.py
@@ -26,7 +26,7 @@
input,
target,
self.output_tensor,
- _Reduction.legacy_get_enum(self.sizeAverage, True),
+ _Reduction.legacy_get_enum(self.sizeAverage, True, emit_warning=False),
self.p,
self.weights,
self.margin,
@@ -43,7 +43,7 @@
target,
implicit_gradOutput,
self.gradInput,
- _Reduction.legacy_get_enum(self.sizeAverage, True),
+ _Reduction.legacy_get_enum(self.sizeAverage, True, emit_warning=False),
self.p,
self.weights,
self.margin,
diff --git a/torch/legacy/nn/SmoothL1Criterion.py b/torch/legacy/nn/SmoothL1Criterion.py
index c02e7a2..714d0b6 100644
--- a/torch/legacy/nn/SmoothL1Criterion.py
+++ b/torch/legacy/nn/SmoothL1Criterion.py
@@ -18,7 +18,7 @@
input,
target,
self.output_tensor,
- _Reduction.legacy_get_enum(self.sizeAverage, True),
+ _Reduction.legacy_get_enum(self.sizeAverage, True, emit_warning=False),
)
self.output = self.output_tensor[0].item()
return self.output
@@ -31,6 +31,6 @@
target,
implicit_gradOutput,
self.gradInput,
- _Reduction.legacy_get_enum(self.sizeAverage, True),
+ _Reduction.legacy_get_enum(self.sizeAverage, True, emit_warning=False),
)
return self.gradInput
diff --git a/torch/legacy/nn/SoftMarginCriterion.py b/torch/legacy/nn/SoftMarginCriterion.py
index e56d871..4bfa371 100644
--- a/torch/legacy/nn/SoftMarginCriterion.py
+++ b/torch/legacy/nn/SoftMarginCriterion.py
@@ -18,7 +18,7 @@
input,
target,
self.output_tensor,
- _Reduction.legacy_get_enum(self.sizeAverage, True),
+ _Reduction.legacy_get_enum(self.sizeAverage, True, emit_warning=False),
)
self.output = self.output_tensor[0].item()
return self.output
@@ -31,6 +31,6 @@
target,
implicit_gradOutput,
self.gradInput,
- _Reduction.legacy_get_enum(self.sizeAverage, True),
+ _Reduction.legacy_get_enum(self.sizeAverage, True, emit_warning=False),
)
return self.gradInput
diff --git a/torch/legacy/nn/SpatialClassNLLCriterion.py b/torch/legacy/nn/SpatialClassNLLCriterion.py
index 382cfea..8a7e15c 100644
--- a/torch/legacy/nn/SpatialClassNLLCriterion.py
+++ b/torch/legacy/nn/SpatialClassNLLCriterion.py
@@ -23,7 +23,7 @@
input,
target,
self.output_tensor,
- _Reduction.legacy_get_enum(self.sizeAverage, True),
+ _Reduction.legacy_get_enum(self.sizeAverage, True, emit_warning=False),
self.weights,
self.total_weight_tensor,
self.ignore_index,
@@ -40,7 +40,7 @@
target,
implicit_gradOutput,
self.gradInput,
- _Reduction.legacy_get_enum(self.sizeAverage, True),
+ _Reduction.legacy_get_enum(self.sizeAverage, True, emit_warning=False),
self.weights,
self.total_weight_tensor,
self.ignore_index,
diff --git a/torch/legacy/nn/WeightedMSECriterion.py b/torch/legacy/nn/WeightedMSECriterion.py
index 4e03439..2f0da29 100644
--- a/torch/legacy/nn/WeightedMSECriterion.py
+++ b/torch/legacy/nn/WeightedMSECriterion.py
@@ -29,7 +29,7 @@
input,
self.buffer,
self.output_tensor,
- _Reduction.legacy_get_enum(self.sizeAverage, True),
+ _Reduction.legacy_get_enum(self.sizeAverage, True, emit_warning=False),
)
self.output = self.output_tensor[0].item()
return self.output
@@ -50,6 +50,6 @@
self.buffer,
implicit_gradOutput,
self.gradInput,
- _Reduction.legacy_get_enum(self.sizeAverage, True),
+ _Reduction.legacy_get_enum(self.sizeAverage, True, emit_warning=False),
)
return self.gradInput
diff --git a/torch/nn/functional.py b/torch/nn/functional.py
index 379cc11..a6c00b1 100644
--- a/torch/nn/functional.py
+++ b/torch/nn/functional.py
@@ -31,8 +31,10 @@
# In order to support previous versions, accept boolean size_average and reduce
# and convert them into the new constants for now
+
+ # We use these functions in torch/legacy as well, in which case we'll silence the warning
@staticmethod
- def legacy_get_string(size_average, reduce):
+ def legacy_get_string(size_average, reduce, emit_warning=True):
warning = "size_average and reduce args will be deprecated, please use reduction='{}' instead."
if size_average is None:
@@ -41,18 +43,18 @@
reduce = True
if size_average and reduce:
- warnings.warn(warning.format('elementwise_mean'))
- return 'elementwise_mean'
+ ret = 'elementwise_mean'
elif reduce:
- warnings.warn(warning.format('sum'))
- return 'sum'
+ ret = 'sum'
else:
- warnings.warn(warning.format('none'))
- return 'none'
+ ret = 'none'
+ if emit_warning:
+ warnings.warn(warning.format(ret))
+ return ret
@staticmethod
- def legacy_get_enum(size_average, reduce):
- return _Reduction.get_enum(_Reduction.legacy_get_string(size_average, reduce))
+ def legacy_get_enum(size_average, reduce, emit_warning=True):
+ return _Reduction.get_enum(_Reduction.legacy_get_string(size_average, reduce, emit_warning))
conv1d = _add_docstr(torch.conv1d, r"""