update nn loss tests to use new reduction arg (#9118)

Summary:
The tests were using the old args, which caused them to emit a lot of deprecation warnings.

closes #9103.

Reviewed By: ezyang

Differential Revision: D8720581

Pulled By: li-roy

fbshipit-source-id: 3b79527f6fe862fb48b99a6394e8d7b89fc7a8c8
diff --git a/test/common_nn.py b/test/common_nn.py
index 7f6c2ac..2f5b6a2 100644
--- a/test/common_nn.py
+++ b/test/common_nn.py
@@ -21,11 +21,12 @@
 PRECISION = 1e-5
 
 
-def get_size_average(m):
+def get_reduction(m):
     result = getattr(m, 'reduction', None)
-    if result is not None:
-        return result is 'elementwise_mean'
-    return getattr(m, 'sizeAverage', None)
+    if result is None:
+        result = _Reduction.legacy_get_string(getattr(m, 'sizeAverage', None), True, emit_warning=False)
+    assert result is not None
+    return result
 
 
 def get_weight(m):
@@ -246,19 +247,19 @@
 ]
 
 
-def kldivloss_reference(input, target, size_average=True, reduce=True):
+def kldivloss_reference(input, target, reduction='elementwise_mean'):
     safe_target = target * (target > 0).type_as(target)
     safe_target_log = (safe_target + (target <= 0).type_as(target)).log()
     result = safe_target * (safe_target_log - input)
-    if reduce and size_average:
+    if reduction == 'elementwise_mean':
         return result.mean()
-    elif reduce:
+    elif reduction == 'sum':
         return result.sum()
     return result
 
 
 def nlllossNd_reference(input, target, weight=None, ignore_index=-100,
-                        size_average=True, reduce=True):
+                        reduction='elementwise_mean'):
     assert input.dim() >= 3
     N = input.size(0)
     C = input.size(1)
@@ -276,15 +277,15 @@
         output[tup] = -input[tuple(input_index)] * norm
         total_weight += norm
 
-    if reduce and size_average:
+    if reduction == 'elementwise_mean':
         return output.sum() / total_weight
-    elif reduce:
+    elif reduction == 'sum':
         return output.sum()
     return output
 
 
 def nllloss_reference(input, target, weight=None, ignore_index=-100,
-                      size_average=True, reduce=True):
+                      reduction='elementwise_mean'):
 
     def nll_loss_helper(input, target, weight, ignore_index):
         if target == ignore_index:
@@ -297,22 +298,22 @@
                           for i, t in zip(input, target)]
     losses, weights = zip(*losses_and_weights)
     losses_tensor = input.new_tensor(losses)
-    if reduce and size_average:
+    if reduction == 'elementwise_mean':
         return sum(losses_tensor) / sum(weights)
-    elif reduce:
+    elif reduction == 'sum':
         return sum(losses_tensor)
     else:
         return losses_tensor
 
 
-def smoothl1loss_reference(input, target, size_average=True, reduce=True):
+def smoothl1loss_reference(input, target, reduction='elementwise_mean'):
     abs_diff = (input - target).abs()
     ge_one_mask = (abs_diff >= 1).type_as(abs_diff)
     lt_one_mask = (abs_diff < 1).type_as(abs_diff)
     output = ge_one_mask * (abs_diff - 0.5) + lt_one_mask * 0.5 * (abs_diff ** 2)
-    if reduce and size_average:
+    if reduction == 'elementwise_mean':
         return output.mean()
-    elif reduce:
+    elif reduction == 'sum':
         return output.sum()
     return output
 
@@ -333,7 +334,7 @@
     return sum
 
 
-def multilabelmarginloss_reference(input, target, size_average=True, reduce=True):
+def multilabelmarginloss_reference(input, target, reduction='elementwise_mean'):
     if input.dim() == 1:
         n = 1
         dim = input.size(0)
@@ -346,30 +347,30 @@
         for i in range(0, n):
             output[i] = _multilabelmarginloss_reference(input[i], target[i])
 
-    if reduce and size_average:
+    if reduction == 'elementwise_mean':
         return output.mean() / dim
-    elif reduce:
+    elif reduction == 'sum':
         return output.sum() / dim
     return output / dim
 
 
-def hingeembeddingloss_reference(input, target, margin=1.0, size_average=True, reduce=True):
+def hingeembeddingloss_reference(input, target, margin=1.0, reduction='elementwise_mean'):
     margin_clamp = (margin - input).clamp(min=0).type_as(input)
     output = torch.where(target == 1, input, margin_clamp)
 
-    if reduce and size_average:
+    if reduction == 'elementwise_mean':
         return output.mean()
-    elif reduce:
+    elif reduction == 'sum':
         return output.sum()
     return output
 
 
-def softmarginloss_reference(input, target, size_average=True, reduce=True):
+def softmarginloss_reference(input, target, reduction='elementwise_mean'):
     output = (1 + (-input * target).exp()).log()
 
-    if reduce and size_average:
+    if reduction == 'elementwise_mean':
         return output.mean()
-    elif reduce:
+    elif reduction == 'sum':
         return output.sum()
     return output
 
@@ -385,8 +386,7 @@
     return output
 
 
-def multimarginloss_reference(input, target, p=1, margin=1, weight=None, size_average=True,
-                              reduce=True):
+def multimarginloss_reference(input, target, p=1, margin=1, weight=None, reduction='elementwise_mean'):
     if input.dim() == 1:
         n = 1
         dim = input.size(0)
@@ -400,14 +400,14 @@
         for x in range(0, n):
             output[x] = _multimarginloss_reference(input[x], target[x], p, margin, weight)
 
-        if reduce and size_average:
+        if reduction == 'elementwise_mean':
             return output.mean() / dim
-        elif reduce:
+        elif reduction == 'sum':
             return output.sum() / dim
         return output / dim
 
 
-def cosineembeddingloss_reference(input1, input2, target, margin=0, size_average=True, reduce=True):
+def cosineembeddingloss_reference(input1, input2, target, margin=0, reduction='elementwise_mean'):
     def _cos(a, b):
         cos = a.new(a.size(0))
         for i in range(0, a.size(0)):
@@ -416,15 +416,15 @@
 
     output = torch.where(target == 1, 1 - _cos(input1, input2), (_cos(input1, input2) - margin).clamp(min=0))
 
-    if reduce and size_average:
+    if reduction == 'elementwise_mean':
         return output.mean()
-    elif reduce:
+    elif reduction == 'sum':
         return output.sum()
     return output
 
 
 def tripletmarginloss_reference(anchor, positive, negative, margin=1.0, p=2, eps=1e-6, swap=False,
-                                size_average=True, reduce=True):
+                                reduction='elementwise_mean'):
     d_p = torch.pairwise_distance(anchor, positive, p, eps)
     d_n = torch.pairwise_distance(anchor, negative, p, eps)
     if swap:
@@ -432,18 +432,18 @@
         d_n = torch.min(d_n, d_s)
 
     output = torch.clamp(margin + d_p - d_n, min=0.0)
-    if reduce and size_average:
+    if reduction == 'elementwise_mean':
         return output.mean()
-    elif reduce:
+    elif reduction == 'sum':
         return output.sum()
     return output
 
 
-def marginrankingloss_reference(input1, input2, target, margin=0, size_average=True, reduce=True):
+def marginrankingloss_reference(input1, input2, target, margin=0, reduction='elementwise_mean'):
     output = (-target * (input1 - input2) + margin).clamp(min=0)
-    if reduce and size_average:
+    if reduction == 'elementwise_mean':
         return output.mean()
-    elif reduce:
+    elif reduction == 'sum':
         return output.sum()
     return output
 
@@ -476,12 +476,12 @@
         input_fn=lambda: torch.rand(15, 10).log(),
         target_fn=lambda: torch.Tensor(15).uniform_().mul(10).floor().long(),
         reference_fn=lambda i, t, m:
-            nllloss_reference(i, t, size_average=get_size_average(m)),
-        check_no_size_average=True
+            nllloss_reference(i, t, reduction=get_reduction(m)),
+        check_sum_reduction=True
     ),
     dict(
         module_name='NLLLoss',
-        constructor_args=(None, True, 2),
+        constructor_args=(None, None, 2),
         input_fn=lambda: torch.rand(15, 10).log(),
         target_fn=lambda: torch.Tensor(15).uniform_().mul(10).floor().long(),
         reference_fn=lambda i, t, _: nllloss_reference(i, t, ignore_index=2),
@@ -498,7 +498,7 @@
     ),
     dict(
         module_name='NLLLoss',
-        constructor_args_fn=lambda: (torch.rand(10), True, 2),
+        constructor_args_fn=lambda: (torch.rand(10), None, 2),
         input_fn=lambda: torch.rand(15, 10).add(1e-2).log(),
         target_fn=lambda: torch.Tensor(15).uniform_().mul(10).floor().long(),
         reference_fn=lambda i, t, m:
@@ -507,7 +507,7 @@
     ),
     dict(
         module_name='NLLLoss',
-        constructor_args_fn=lambda: (torch.rand(10), True, -1),
+        constructor_args_fn=lambda: (torch.rand(10), None, -1),
         input_fn=lambda: torch.rand(15, 10).add(1e-2).log(),
         target_fn=lambda: torch.Tensor(15).uniform_().mul(10 + 1).floor().long() - 1,
         reference_fn=lambda i, t, m:
@@ -519,22 +519,23 @@
         input_fn=lambda: torch.rand(10, 10).log(),
         target_fn=lambda: torch.rand(10, 10),
         reference_fn=lambda i, t, m:
-            kldivloss_reference(i, t, get_size_average(m), reduce=True),
-        check_no_size_average=True,
+            kldivloss_reference(i, t, get_reduction(m)),
+        check_sum_reduction=True,
     ),
     dict(
         module_name='MSELoss',
         input_size=(2, 3, 4, 5),
         target_size=(2, 3, 4, 5),
-        reference_fn=lambda i, t, m: (i - t).abs().pow(2).sum() / (i.numel() if get_size_average(m) else 1),
-        check_no_size_average=True,
+        reference_fn=lambda i, t, m: ((i - t).abs().pow(2).sum() / (i.numel()
+                                      if get_reduction(m) == 'elementwise_mean' else 1)),
+        check_sum_reduction=True,
     ),
     dict(
         module_name='BCELoss',
         input_fn=lambda: torch.rand(15, 10).clamp_(1e-2, 1 - 1e-2),
         target_fn=lambda: torch.randn(15, 10).gt(0).double(),
         reference_fn=lambda i, t, m: -(t * i.log() + (1 - t) * (1 - i).log()).sum() /
-            (i.numel() if get_size_average(m) else 1),
+            (i.numel() if get_reduction(m) else 1),
         check_gradgrad=False,
     ),
     dict(
@@ -543,7 +544,7 @@
         input_fn=lambda: torch.rand(15, 10).clamp_(1e-2, 1 - 1e-2),
         target_fn=lambda: torch.randn(15, 10).gt(0).double(),
         reference_fn=lambda i, t, m: -((t * i.log() + (1 - t) * (1 - i).log()) * get_weight(m)).sum() /
-            (i.numel() if get_size_average(m) else 1),
+            (i.numel() if get_reduction(m) else 1),
         desc='weights',
         check_gradgrad=False,
     ),
@@ -564,8 +565,8 @@
         input_size=(10,),
         target_fn=lambda: torch.randn(10).gt(0).double().mul_(2).sub(1),
         reference_fn=lambda i, t, m:
-            hingeembeddingloss_reference(i, t, size_average=get_size_average(m)),
-        check_no_size_average=True,
+            hingeembeddingloss_reference(i, t, reduction=get_reduction(m)),
+        check_sum_reduction=True,
     ),
     dict(
         module_name='HingeEmbeddingLoss',
@@ -573,18 +574,18 @@
         input_size=(10,),
         target_fn=lambda: torch.randn(10).gt(0).double().mul_(2).sub(1),
         reference_fn=lambda i, t, m:
-            hingeembeddingloss_reference(i, t, margin=0.5, size_average=get_size_average(m)),
+            hingeembeddingloss_reference(i, t, margin=0.5, reduction=get_reduction(m)),
         desc='margin',
-        check_no_size_average=True,
+        check_sum_reduction=True,
     ),
     dict(
         module_name='MultiLabelMarginLoss',
         input_size=(10,),
         target_fn=lambda: torch.rand(10).mul(10).floor().long(),
         reference_fn=lambda i, t, m:
-            multilabelmarginloss_reference(i, t, size_average=get_size_average(m)),
+            multilabelmarginloss_reference(i, t, reduction=get_reduction(m)),
         desc="1d",
-        check_no_size_average=True,
+        check_sum_reduction=True,
         check_gradgrad=False,
     ),
     dict(
@@ -592,8 +593,8 @@
         input_size=(5, 10),
         target_fn=lambda: torch.rand(5, 10).mul(10).floor().long(),
         reference_fn=lambda i, t, m:
-            multilabelmarginloss_reference(i, t, size_average=get_size_average(m)),
-        check_no_size_average=True,
+            multilabelmarginloss_reference(i, t, reduction=get_reduction(m)),
+        check_sum_reduction=True,
         check_gradgrad=False,
     ),
     dict(
@@ -608,8 +609,8 @@
         input_size=(5, 10),
         target_fn=lambda: torch.rand(5).mul(8).floor().long(),
         reference_fn=lambda i, t, m:
-            multimarginloss_reference(i, t, size_average=get_size_average(m)),
-        check_no_size_average=True,
+            multimarginloss_reference(i, t, reduction=get_reduction(m)),
+        check_sum_reduction=True,
         check_gradgrad=False,
     ),
     dict(
@@ -617,9 +618,9 @@
         input_size=(10,),
         target_fn=lambda: torch.rand(1).mul(8).floor().long(),
         reference_fn=lambda i, t, m:
-            multimarginloss_reference(i, t, size_average=get_size_average(m)),
+            multimarginloss_reference(i, t, reduction=get_reduction(m)),
         desc='1d',
-        check_no_size_average=True,
+        check_sum_reduction=True,
         check_gradgrad=False,
     ),
     dict(
@@ -628,9 +629,9 @@
         input_fn=lambda: torch.rand(5, 10).clamp_(1e-2, 1 - 1e-2),
         target_fn=lambda: torch.rand(5).mul(8).floor().long(),
         reference_fn=lambda i, t, m:
-            multimarginloss_reference(i, t, p=2, size_average=get_size_average(m)),
+            multimarginloss_reference(i, t, p=2, reduction=get_reduction(m)),
         desc='p',
-        check_no_size_average=True,
+        check_sum_reduction=True,
         check_gradgrad=False,
     ),
     dict(
@@ -640,9 +641,9 @@
         input_size=(5, 10),
         target_fn=lambda: torch.rand(5).mul(8).floor().long(),
         reference_fn=lambda i, t, m:
-            multimarginloss_reference(i, t, margin=0.5, size_average=get_size_average(m)),
+            multimarginloss_reference(i, t, margin=0.5, reduction=get_reduction(m)),
         desc='margin',
-        check_no_size_average=True,
+        check_sum_reduction=True,
         check_gradgrad=False,
     ),
     dict(
@@ -652,34 +653,34 @@
         input_size=(5, 10),
         target_fn=lambda: torch.rand(5).mul(8).floor().long(),
         reference_fn=lambda i, t, m:
-            multimarginloss_reference(i, t, weight=get_weight(m), size_average=get_size_average(m)),
+            multimarginloss_reference(i, t, weight=get_weight(m), reduction=get_reduction(m)),
         desc='weights',
-        check_no_size_average=True,
+        check_sum_reduction=True,
         check_gradgrad=False,
     ),
     dict(
         module_name='SmoothL1Loss',
         input_size=(5, 10),
         target_size=(5, 10),
-        check_no_size_average=True,
+        check_sum_reduction=True,
         reference_fn=lambda i, t, m:
-            smoothl1loss_reference(i, t, size_average=get_size_average(m)),
+            smoothl1loss_reference(i, t, reduction=get_reduction(m)),
     ),
     dict(
         module_name='SoftMarginLoss',
         input_size=(5, 5),
         target_fn=lambda: torch.randn(5, 5).sign(),
         reference_fn=lambda i, t, m:
-            softmarginloss_reference(i, t, size_average=get_size_average(m)),
-        check_no_size_average=True,
+            softmarginloss_reference(i, t, reduction=get_reduction(m)),
+        check_sum_reduction=True,
     ),
     dict(
         module_name='CosineEmbeddingLoss',
         input_fn=lambda: (torch.rand(15, 10), torch.rand(15, 10)),
         target_fn=lambda: torch.randn(15).sign(),
         reference_fn=lambda i, t, m:
-            cosineembeddingloss_reference(i[0], i[1], t, size_average=get_size_average(m)),
-        check_no_size_average=True,
+            cosineembeddingloss_reference(i[0], i[1], t, reduction=get_reduction(m)),
+        check_sum_reduction=True,
     ),
     dict(
         module_name='CosineEmbeddingLoss',
@@ -687,17 +688,17 @@
         input_fn=lambda: (torch.rand(15, 10), torch.rand(15, 10)),
         target_fn=lambda: torch.randn(15).sign(),
         reference_fn=lambda i, t, m:
-            cosineembeddingloss_reference(i[0], i[1], t, margin=0.7, size_average=get_size_average(m)),
+            cosineembeddingloss_reference(i[0], i[1], t, margin=0.7, reduction=get_reduction(m)),
         desc='margin',
-        check_no_size_average=True,
+        check_sum_reduction=True,
     ),
     dict(
         module_name='MarginRankingLoss',
         input_fn=lambda: (torch.randn(50).mul(10), torch.randn(50).mul(10)),
         target_fn=lambda: torch.randn(50).sign(),
         reference_fn=lambda i, t, m:
-            marginrankingloss_reference(i[0], i[1], t, size_average=get_size_average(m)),
-        check_no_size_average=True,
+            marginrankingloss_reference(i[0], i[1], t, reduction=get_reduction(m)),
+        check_sum_reduction=True,
     ),
     dict(
         module_name='MarginRankingLoss',
@@ -705,9 +706,9 @@
         input_fn=lambda: (torch.randn(50).mul(10), torch.randn(50).mul(10)),
         target_fn=lambda: torch.randn(50).sign(),
         reference_fn=lambda i, t, m:
-            marginrankingloss_reference(i[0], i[1], t, margin=0.5, size_average=get_size_average(m)),
+            marginrankingloss_reference(i[0], i[1], t, margin=0.5, reduction=get_reduction(m)),
         desc='margin',
-        check_no_size_average=True,
+        check_sum_reduction=True,
     ),
 ]
 
diff --git a/test/test_nn.py b/test/test_nn.py
index 8c52bd1..c25863e 100644
--- a/test/test_nn.py
+++ b/test/test_nn.py
@@ -34,7 +34,7 @@
 from common_cuda import TEST_CUDA, TEST_MULTIGPU, TEST_CUDNN, \
     TEST_CUDNN_VERSION
 from common_nn import NNTestCase, ModuleTest, CriterionTest, TestBase, \
-    module_tests, criterion_tests, loss_reference_fns, get_size_average, \
+    module_tests, criterion_tests, loss_reference_fns, get_reduction, \
     get_weight, smoothl1loss_reference, kldivloss_reference
 
 
@@ -4270,8 +4270,8 @@
 
         self.assertEqual(nn.BCEWithLogitsLoss()(output, target), nn.BCELoss()(sigmoid(output), target))
 
-        self.assertEqual(nn.BCEWithLogitsLoss(reduce=False)(output, target),
-                         nn.BCELoss(reduce=False)(sigmoid(output), target))
+        self.assertEqual(nn.BCEWithLogitsLoss(reduction='none')(output, target),
+                         nn.BCELoss(reduction='none')(sigmoid(output), target))
 
         weight = torch.rand(1, dtype=torch.float)
         self.assertEqual(nn.BCEWithLogitsLoss(weight)(output, target), nn.BCELoss(weight)(sigmoid(output), target))
@@ -4279,7 +4279,7 @@
     def test_bce_with_logits_has_correct_grad_at_zero(self):
         output = torch.zeros(3, 1, requires_grad=True)
         target = torch.zeros(3, 1)
-        nn.BCEWithLogitsLoss(size_average=False)(output, target).backward()
+        nn.BCEWithLogitsLoss(reduction='sum')(output, target).backward()
         expected_grad = torch.empty(3, 1).fill_(0.5)
         self.assertEqual(output.grad, expected_grad)
 
@@ -4330,10 +4330,9 @@
         output = torch.zeros(3, 1, requires_grad=True)
         target = torch.zeros(3, 1)
         pos_weight = torch.ones(3, 1)
-        nn.BCEWithLogitsLoss(pos_weight=pos_weight, size_average=False)(output, target).backward()
+        nn.BCEWithLogitsLoss(pos_weight=pos_weight, reduction='sum')(output, target).backward()
         expected_grad = torch.empty(3, 1).fill_(0.5)
         grad = output.grad
-        print(grad)
         self.assertEqual(grad, expected_grad)
 
     def test_bce_loss_broadcasts_weights(self):
@@ -4560,36 +4559,37 @@
         input2 = torch.randn(15, 10, requires_grad=True)
         target = torch.randn(15).sign()
         self.assertTrue(gradcheck(lambda x, y, z: F.cosine_embedding_loss(
-            x, y, z, reduce=False), (input1, input2, target)))
-        self.assertEqual(F.cosine_embedding_loss(input1, input2, target, reduce=False),
-                         loss_reference_fns['CosineEmbeddingLoss'](input1, input2, target, reduce=False))
+            x, y, z, reduction='none'), (input1, input2, target)))
+        self.assertEqual(F.cosine_embedding_loss(input1, input2, target, reduction='none'),
+                         loss_reference_fns['CosineEmbeddingLoss'](input1, input2, target, reduction='none'))
 
     def test_cosine_embedding_loss_margin_no_reduce(self):
         input1 = torch.randn(15, 10, requires_grad=True)
         input2 = torch.randn(15, 10, requires_grad=True)
         target = torch.randn(15).sign()
         self.assertTrue(gradcheck(lambda x, y, z: F.cosine_embedding_loss(
-            x, y, z, margin=0.5, reduce=False), (input1, input2, target)))
-        self.assertEqual(F.cosine_embedding_loss(input1, input2, target, margin=0.5, reduce=False),
-                         loss_reference_fns['CosineEmbeddingLoss'](input1, input2, target, margin=0.5, reduce=False))
+            x, y, z, margin=0.5, reduction='none'), (input1, input2, target)))
+        self.assertEqual(F.cosine_embedding_loss(input1, input2, target, margin=0.5, reduction='none'),
+                         loss_reference_fns['CosineEmbeddingLoss'](input1, input2, target,
+                                                                   margin=0.5, reduction='none'))
 
     def test_margin_ranking_loss_no_reduce(self):
         input1 = torch.tensor(torch.randn(15).mul(10), requires_grad=True)
         input2 = torch.tensor(torch.randn(15).mul(10), requires_grad=True)
         target = torch.randn(15).sign()
         self.assertTrue(gradcheck(lambda x, y, z: F.margin_ranking_loss(
-            x, y, z, reduce=False), (input1, input2, target)))
-        self.assertEqual(F.margin_ranking_loss(input1, input2, target, reduce=False),
-                         loss_reference_fns['MarginRankingLoss'](input1, input2, target, reduce=False))
+            x, y, z, reduction='none'), (input1, input2, target)))
+        self.assertEqual(F.margin_ranking_loss(input1, input2, target, reduction='none'),
+                         loss_reference_fns['MarginRankingLoss'](input1, input2, target, reduction='none'))
 
     def test_margin_ranking_loss_margin_no_reduce(self):
         input1 = torch.tensor(torch.randn(15).mul(10), requires_grad=True)
         input2 = torch.tensor(torch.randn(15).mul(10), requires_grad=True)
         target = torch.randn(15).sign()
         self.assertTrue(gradcheck(lambda x, y, z: F.margin_ranking_loss(
-            x, y, z, margin=0.5, reduce=False), (input1, input2, target)))
-        self.assertEqual(F.margin_ranking_loss(input1, input2, target, margin=0.5, reduce=False),
-                         loss_reference_fns['MarginRankingLoss'](input1, input2, target, margin=0.5, reduce=False))
+            x, y, z, margin=0.5, reduction='none'), (input1, input2, target)))
+        self.assertEqual(F.margin_ranking_loss(input1, input2, target, margin=0.5, reduction='none'),
+                         loss_reference_fns['MarginRankingLoss'](input1, input2, target, margin=0.5, reduction='none'))
 
     def test_triplet_margin_loss(self):
         input1 = torch.randn(5, 10, requires_grad=True)
@@ -4614,39 +4614,18 @@
         input2 = torch.randn(5, 10, requires_grad=True)
         input3 = torch.randn(5, 10, requires_grad=True)
         self.assertTrue(gradcheck(lambda x1, x2, x3: F.triplet_margin_loss(
-            x1, x2, x3, reduce=False), (input1, input2, input3)))
-        self.assertEqual(F.triplet_margin_loss(input1, input2, input3, reduce=False),
-                         loss_reference_fns['TripletMarginLoss'](input1, input2, input3, reduce=False))
+            x1, x2, x3, reduction='none'), (input1, input2, input3)))
+        self.assertEqual(F.triplet_margin_loss(input1, input2, input3, reduction='none'),
+                         loss_reference_fns['TripletMarginLoss'](input1, input2, input3, reduction='none'))
 
     def test_triplet_margin_loss_swap_no_reduce(self):
         input1 = torch.randn(5, 10, requires_grad=True)
         input2 = torch.randn(5, 10, requires_grad=True)
         input3 = torch.randn(5, 10, requires_grad=True)
         self.assertTrue(gradcheck(lambda x1, x2, x3: F.triplet_margin_loss(
-            x1, x2, x3, swap=True, reduce=False), (input1, input2, input3)))
-        self.assertEqual(F.triplet_margin_loss(input1, input2, input3, swap=True, reduce=False),
-                         loss_reference_fns['TripletMarginLoss'](input1, input2, input3, swap=True, reduce=False))
-
-    def test_loss_reduction_arg(self):
-        # NB: This is a sanity test to check that the new reduction arg works the same as size_average and reduce
-        # Remove this when size_average and reduce are deprecated and tests are ported to the new arg
-        input1 = torch.randn(5, 10, requires_grad=True)
-        input2 = torch.randn(5, 10, requires_grad=True)
-        input3 = torch.randn(5, 10, requires_grad=True)
-        self.assertTrue(gradcheck(lambda x1, x2, x3: F.triplet_margin_loss(
-            x1, x2, x3, reduction='elementwise_mean'), (input1, input2, input3)))
-        self.assertEqual(F.triplet_margin_loss(input1, input2, input3, reduction='elementwise_mean'),
-                         loss_reference_fns['TripletMarginLoss'](input1, input2, input3))
-
-        self.assertTrue(gradcheck(lambda x1, x2, x3: F.triplet_margin_loss(
-            x1, x2, x3, reduction='sum'), (input1, input2, input3)))
-        self.assertEqual(F.triplet_margin_loss(input1, input2, input3, reduction='sum'),
-                         loss_reference_fns['TripletMarginLoss'](input1, input2, input3, size_average=False))
-
-        self.assertTrue(gradcheck(lambda x1, x2, x3: F.triplet_margin_loss(
-            x1, x2, x3, reduction='none'), (input1, input2, input3)))
-        self.assertEqual(F.triplet_margin_loss(input1, input2, input3, reduction='none'),
-                         loss_reference_fns['TripletMarginLoss'](input1, input2, input3, reduce=False))
+            x1, x2, x3, swap=True, reduction='none'), (input1, input2, input3)))
+        self.assertEqual(F.triplet_margin_loss(input1, input2, input3, swap=True, reduction='none'),
+                         loss_reference_fns['TripletMarginLoss'](input1, input2, input3, swap=True, reduction='none'))
 
     def test_cosine_similarity(self):
         input1 = torch.randn(4, 4, requires_grad=True)
@@ -5774,8 +5753,8 @@
         input_size=(2, 3, 5, 5),
         target_fn=lambda: torch.rand(2, 5, 5).mul(3).floor().long(),
         reference_fn=lambda i, t, m:
-            loss_reference_fns['NLLLossNd'](i, t, size_average=get_size_average(m)),
-        check_no_size_average=True,
+            loss_reference_fns['NLLLossNd'](i, t, reduction=get_reduction(m)),
+        check_sum_reduction=True,
         desc='2d'
     ),
     dict(
@@ -5789,7 +5768,7 @@
     ),
     dict(
         module_name='NLLLoss',
-        constructor_args=(None, True, 1),
+        constructor_args=(None, None, 1),
         input_size=(2, 3, 5, 5),
         target_fn=lambda: torch.rand(2, 5, 5).mul(3).floor().long(),
         reference_fn=lambda i, t, m:
@@ -5801,8 +5780,8 @@
         input_size=(2, 3, 5, 5, 2, 2),
         target_fn=lambda: torch.rand(2, 5, 5, 2, 2).mul(3).floor().long(),
         reference_fn=lambda i, t, m:
-            loss_reference_fns['NLLLossNd'](i, t, size_average=get_size_average(m)),
-        check_no_size_average=True,
+            loss_reference_fns['NLLLossNd'](i, t, reduction=get_reduction(m)),
+        check_sum_reduction=True,
         desc='higher_dim'
     ),
     dict(
@@ -5810,8 +5789,8 @@
         input_size=(2, 3, 5),
         target_fn=lambda: torch.rand(2, 5).mul(3).floor().long(),
         reference_fn=lambda i, t, m:
-            loss_reference_fns['NLLLossNd'](i, t, size_average=get_size_average(m)),
-        check_no_size_average=True,
+            loss_reference_fns['NLLLossNd'](i, t, reduction=get_reduction(m)),
+        check_sum_reduction=True,
         desc='dim_is_3'
     ),
     dict(
@@ -5822,7 +5801,7 @@
     ),
     dict(
         module_name='PoissonNLLLoss',
-        constructor_args=(False, True, True),
+        constructor_args=(False,),
         input_fn=lambda: torch.randn(2, 3, 4, 5).abs_().add_(0.001),
         target_fn=lambda: torch.randn(2, 3, 4, 5).floor_().abs_(),
         desc='full_loss',  # with sterling approx
@@ -5839,16 +5818,17 @@
         input_fn=lambda: torch.rand(()).log(),
         target_fn=lambda: torch.rand(()),
         reference_fn=lambda i, t, m:
-            kldivloss_reference(i, t, get_size_average(m), reduce=True),
-        check_no_size_average=True,
+            kldivloss_reference(i, t, get_reduction(m)),
+        check_sum_reduction=True,
         desc='scalar',
     ),
     dict(
         module_name='MSELoss',
         input_size=(),
         target_size=(),
-        reference_fn=lambda i, t, m: (i - t).abs().pow(2).sum() / (i.numel() if get_size_average(m) else 1),
-        check_no_size_average=True,
+        reference_fn=lambda i, t, m: ((i - t).abs().pow(2).sum() /
+                                      (i.numel() if get_reduction(m) == 'elementwise_mean' else 1)),
+        check_sum_reduction=True,
         desc='scalar'
     ),
     dict(
@@ -5857,7 +5837,7 @@
         input_fn=lambda: torch.rand(()).clamp_(1e-2, 1 - 1e-2),
         target_fn=lambda: torch.rand(()).gt(0).double(),
         reference_fn=lambda i, t, m: -((t * i.log() + (1 - t) * (1 - i).log()) * get_weight(m)).sum() /
-            (i.numel() if get_size_average(m) else 1),
+            (i.numel() if get_reduction(m) == 'elementwise_mean' else 1),
         desc='scalar_weights',
         check_gradgrad=False,
     ),
@@ -5867,15 +5847,15 @@
         input_size=(),
         target_fn=lambda: torch.randn(()).gt(0).double().mul_(2).sub(1),
         desc='scalar_margin',
-        check_no_size_average=True,
+        check_sum_reduction=True,
     ),
     dict(
         module_name='SmoothL1Loss',
         input_size=(),
         target_size=(),
-        check_no_size_average=True,
+        check_sum_reduction=True,
         reference_fn=lambda i, t, m:
-            smoothl1loss_reference(i, t, size_average=get_size_average(m)),
+            smoothl1loss_reference(i, t, reduction=get_reduction(m)),
         desc='scalar',
     ),
     dict(
@@ -5884,9 +5864,9 @@
         input_fn=lambda: torch.randn(5, 10),
         target_fn=lambda: torch.rand(5, 10).mul(2).floor(),
         reference_fn=lambda i, t, m: -((t * i.sigmoid().log() + (1 - t) * (-i).sigmoid().log()) * get_weight(m)).sum() /
-            (i.numel() if get_size_average(m) else 1),
+            (i.numel() if get_reduction(m) == 'elementwise_mean' else 1),
         desc='weights',
-        check_no_size_average=True,
+        check_sum_reduction=True,
         check_gradgrad=False,
     ),
 ]
@@ -5897,7 +5877,7 @@
     return dict(
         fullname='PoissonNLLLLoss_no_reduce',
         constructor=wrap_functional(
-            lambda i: F.poisson_nll_loss(i, t.type_as(i), reduce=False)),
+            lambda i: F.poisson_nll_loss(i, t.type_as(i), reduction='none')),
         input_fn=lambda: torch.rand(10, 10),
         pickle=False)
 
@@ -5907,7 +5887,7 @@
     return dict(
         fullname='BCELoss_no_reduce',
         constructor=wrap_functional(
-            lambda i: F.binary_cross_entropy(i, t.type_as(i), reduce=False)),
+            lambda i: F.binary_cross_entropy(i, t.type_as(i), reduction='none')),
         input_fn=lambda: torch.rand(15, 10).clamp_(2.8e-2, 1 - 2.8e-2),
         reference_fn=lambda i, m: -(t * i.log() + (1 - t) * (1 - i).log()),
         check_gradgrad=False,
@@ -5919,7 +5899,7 @@
     return dict(
         fullname='BCELoss_no_reduce_scalar',
         constructor=wrap_functional(
-            lambda i: F.binary_cross_entropy(i, t.type_as(i), reduce=False)),
+            lambda i: F.binary_cross_entropy(i, t.type_as(i), reduction='none')),
         input_fn=lambda: torch.rand(()).clamp_(2.8e-2, 1 - 2.8e-2),
         reference_fn=lambda i, m: -(t * i.log() + (1 - t) * (1 - i).log()),
         check_gradgrad=False,
@@ -5933,7 +5913,7 @@
         fullname='BCELoss_weights_no_reduce',
         constructor=wrap_functional(
             lambda i: F.binary_cross_entropy(i, t.type_as(i),
-                                             weight=weights.type_as(i), reduce=False)),
+                                             weight=weights.type_as(i), reduction='none')),
         input_fn=lambda: torch.rand(15, 10).clamp_(2.8e-2, 1 - 2.8e-2),
         reference_fn=lambda i, m: -(t * i.log() + (1 - t) * (1 - i).log()) * weights,
         check_gradgrad=False,
@@ -5947,7 +5927,7 @@
         fullname='BCELoss_weights_no_reduce_scalar',
         constructor=wrap_functional(
             lambda i: F.binary_cross_entropy(i, t.type_as(i),
-                                             weight=weights.type_as(i), reduce=False)),
+                                             weight=weights.type_as(i), reduction='none')),
         input_fn=lambda: torch.rand(()).clamp_(2.8e-2, 1 - 2.8e-2),
         reference_fn=lambda i, m: -(t * i.log() + (1 - t) * (1 - i).log()) * weights,
         check_gradgrad=False,
@@ -5960,7 +5940,7 @@
     return dict(
         fullname='BCEWithLogitsLoss_no_reduce',
         constructor=wrap_functional(
-            lambda i: F.binary_cross_entropy_with_logits(i, t.type_as(i), reduce=False)),
+            lambda i: F.binary_cross_entropy_with_logits(i, t.type_as(i), reduction='none')),
         input_fn=lambda: torch.rand(15, 10).clamp_(2.8e-2, 1 - 2.8e-2),
         reference_fn=lambda i, m: -(t * sigmoid(i).log() + (1 - t) * (1 - sigmoid(i)).log()),
         check_gradgrad=False,
@@ -5973,7 +5953,7 @@
     return dict(
         fullname='BCEWithLogitsLoss_no_reduce_scalar',
         constructor=wrap_functional(
-            lambda i: F.binary_cross_entropy_with_logits(i, t.type_as(i), reduce=False)),
+            lambda i: F.binary_cross_entropy_with_logits(i, t.type_as(i), reduction='none')),
         input_fn=lambda: torch.rand(()).clamp_(2.8e-2, 1 - 2.8e-2),
         reference_fn=lambda i, m: -(t * sigmoid(i).log() + (1 - t) * (1 - sigmoid(i)).log()),
         check_gradgrad=False,
@@ -5985,10 +5965,10 @@
     return dict(
         fullname='KLDivLoss_with_target_no_reduce',
         constructor=wrap_functional(
-            lambda t: F.kl_div(i.type_as(t), t, reduce=False)),
+            lambda t: F.kl_div(i.type_as(t), t, reduction='none')),
         input_fn=lambda: torch.rand(10, 10),
         reference_fn=lambda t, _:
-            loss_reference_fns['KLDivLoss'](i.type_as(t), t, reduce=False),
+            loss_reference_fns['KLDivLoss'](i.type_as(t), t, reduction='none'),
         pickle=False)
 
 
@@ -5997,10 +5977,10 @@
     return dict(
         fullname='KLDivLoss_no_reduce',
         constructor=wrap_functional(
-            lambda i: F.kl_div(i, t.type_as(i), reduce=False)),
+            lambda i: F.kl_div(i, t.type_as(i), reduction='none')),
         input_fn=lambda: torch.rand(10, 10).log(),
         reference_fn=lambda i, _:
-            loss_reference_fns['KLDivLoss'](i, t.type_as(i), reduce=False),
+            loss_reference_fns['KLDivLoss'](i, t.type_as(i), reduction='none'),
         pickle=False)
 
 
@@ -6009,10 +5989,10 @@
     return dict(
         fullname='KLDivLoss_no_reduce_scalar',
         constructor=wrap_functional(
-            lambda i: F.kl_div(i, t.type_as(i), reduce=False)),
+            lambda i: F.kl_div(i, t.type_as(i), reduction='none')),
         input_fn=lambda: torch.rand(()).log(),
         reference_fn=lambda i, _:
-            loss_reference_fns['KLDivLoss'](i, t.type_as(i), reduce=False),
+            loss_reference_fns['KLDivLoss'](i, t.type_as(i), reduction='none'),
         pickle=False)
 
 
@@ -6021,7 +6001,7 @@
     return dict(
         fullname='L1Loss_no_reduce',
         constructor=wrap_functional(
-            lambda i: F.l1_loss(i, t.type_as(i), reduce=False)),
+            lambda i: F.l1_loss(i, t.type_as(i), reduction='none')),
         input_fn=lambda: torch.randn(2, 3, 4),
         reference_fn=lambda i, m: (i - t.type_as(i)).abs(),
         pickle=False)
@@ -6032,7 +6012,7 @@
     return dict(
         fullname='L1Loss_no_reduce_scalar',
         constructor=wrap_functional(
-            lambda i: F.l1_loss(i, t.type_as(i), reduce=False)),
+            lambda i: F.l1_loss(i, t.type_as(i), reduction='none')),
         input_fn=lambda: torch.randn(()),
         reference_fn=lambda i, m: (i - t.type_as(i)).abs(),
         pickle=False)
@@ -6044,7 +6024,7 @@
     return dict(
         fullname='MSELoss_no_reduce',
         constructor=wrap_functional(
-            lambda i: F.mse_loss(i, target.type_as(i), reduce=False)),
+            lambda i: F.mse_loss(i, target.type_as(i), reduction='none')),
         input_size=input_size,
         reference_fn=lambda i, m: (i - target).pow(2),
         pickle=False)
@@ -6056,7 +6036,7 @@
     return dict(
         fullname='MSELoss_no_reduce_scalar',
         constructor=wrap_functional(
-            lambda i: F.mse_loss(i, target.type_as(i), reduce=False)),
+            lambda i: F.mse_loss(i, target.type_as(i), reduction='none')),
         input_size=input_size,
         reference_fn=lambda i, m: (i - target).pow(2),
         pickle=False)
@@ -6064,7 +6044,7 @@
 
 def nllloss_no_reduce_test():
     t = Variable(torch.Tensor(15).uniform_().mul(10).floor().long())
-    kwargs = {'reduce': False}
+    kwargs = {'reduction': 'none'}
     return dict(
         fullname='NLLLoss_no_reduce',
         constructor=wrap_functional(
@@ -6077,7 +6057,7 @@
 
 def nllloss_no_reduce_ignore_index_test():
     t = Variable(torch.Tensor(15).uniform_().mul(10).floor().long())
-    kwargs = {'ignore_index': 2, 'reduce': False}
+    kwargs = {'ignore_index': 2, 'reduction': 'none'}
     return dict(
         fullname='NLLLoss_no_reduce_ignore_index',
         constructor=wrap_functional(
@@ -6093,7 +6073,7 @@
     weight = torch.rand(10)
 
     def kwargs(i):
-        return {'weight': weight.type_as(i), 'reduce': False}
+        return {'weight': weight.type_as(i), 'reduction': 'none'}
 
     return dict(
         fullname='NLLLoss_no_reduce_weights',
@@ -6110,7 +6090,7 @@
     weight = torch.rand(10)
 
     def kwargs(i):
-        return {'weight': weight.type_as(i), 'reduce': False,
+        return {'weight': weight.type_as(i), 'reduction': 'none',
                 'ignore_index': 2}
 
     return dict(
@@ -6128,7 +6108,7 @@
     weight = torch.rand(10)
 
     def kwargs(i):
-        return {'weight': weight.type_as(i), 'reduce': False,
+        return {'weight': weight.type_as(i), 'reduction': 'none',
                 'ignore_index': -1}
 
     return dict(
@@ -6143,7 +6123,7 @@
 
 def nllloss2d_no_reduce_test():
     t = Variable(torch.rand(2, 5, 5).mul(3).floor().long())
-    kwargs = {'reduce': False}
+    kwargs = {'reduction': 'none'}
     return dict(
         fullname='NLLLoss2d_no_reduce',
         constructor=wrap_functional(
@@ -6156,7 +6136,7 @@
 
 def nllloss2d_no_reduce_ignore_index_test():
     t = Variable(torch.rand(2, 5, 5).mul(3).floor().long())
-    kwargs = {'ignore_index': 1, 'reduce': False}
+    kwargs = {'ignore_index': 1, 'reduction': 'none'}
     return dict(
         fullname='NLLLoss2d_no_reduce_ignore_index',
         constructor=wrap_functional(
@@ -6172,7 +6152,7 @@
     weight = torch.rand(3)
 
     def kwargs(i):
-        return {'weight': weight.type_as(i), 'reduce': False}
+        return {'weight': weight.type_as(i), 'reduction': 'none'}
 
     return dict(
         fullname='NLLLoss2d_no_reduce_weights',
@@ -6186,7 +6166,7 @@
 
 def nlllossNd_no_reduce_test():
     t = Variable(torch.rand(2, 5, 5, 2, 2).mul(3).floor().long())
-    kwargs = {'reduce': False}
+    kwargs = {'reduction': 'none'}
     return dict(
         fullname='NLLLossNd_no_reduce',
         constructor=wrap_functional(
@@ -6199,7 +6179,7 @@
 
 def nlllossNd_no_reduce_ignore_index_test():
     t = Variable(torch.rand(2, 5, 5, 2, 2).mul(3).floor().long())
-    kwargs = {'ignore_index': 1, 'reduce': False}
+    kwargs = {'ignore_index': 1, 'reduction': 'none'}
     return dict(
         fullname='NLLLossNd_no_reduce_ignore_index',
         constructor=wrap_functional(
@@ -6215,7 +6195,7 @@
     weight = torch.rand(3)
 
     def kwargs(i):
-        return {'weight': weight.type_as(i), 'reduce': False}
+        return {'weight': weight.type_as(i), 'reduction': 'none'}
 
     return dict(
         fullname='NLLLossNd_no_reduce_weights',
@@ -6232,10 +6212,10 @@
     return dict(
         fullname='SmoothL1Loss_no_reduce',
         constructor=wrap_functional(
-            lambda i: F.smooth_l1_loss(i, t.type_as(i), reduce=False)),
+            lambda i: F.smooth_l1_loss(i, t.type_as(i), reduction='none')),
         input_fn=lambda: torch.randn(2, 3, 4),
         reference_fn=lambda i, _:
-            loss_reference_fns['SmoothL1Loss'](i, t.type_as(i), reduce=False),
+            loss_reference_fns['SmoothL1Loss'](i, t.type_as(i), reduction='none'),
         pickle=False)
 
 
@@ -6244,10 +6224,10 @@
     return dict(
         fullname='SmoothL1Loss_no_reduce_scalar',
         constructor=wrap_functional(
-            lambda i: F.smooth_l1_loss(i, t.type_as(i), reduce=False)),
+            lambda i: F.smooth_l1_loss(i, t.type_as(i), reduction='none')),
         input_fn=lambda: torch.randn(()),
         reference_fn=lambda i, _:
-            loss_reference_fns['SmoothL1Loss'](i, t.type_as(i), reduce=False),
+            loss_reference_fns['SmoothL1Loss'](i, t.type_as(i), reduction='none'),
         pickle=False)
 
 
@@ -6256,11 +6236,11 @@
     return dict(
         fullname='MultiLabelMarginLoss_1d_no_reduce',
         constructor=wrap_functional(
-            lambda i: F.multilabel_margin_loss(i, t.type_as(i).long(), reduce=False)),
+            lambda i: F.multilabel_margin_loss(i, t.type_as(i).long(), reduction='none')),
         input_fn=lambda: torch.randn(10),
         reference_fn=lambda i, _:
-            loss_reference_fns['MultiLabelMarginLoss'](i, t.data.type_as(i).long(), reduce=False),
-        check_no_size_average=True,
+            loss_reference_fns['MultiLabelMarginLoss'](i, t.data.type_as(i).long(), reduction='none'),
+        check_sum_reduction=True,
         check_gradgrad=False,
         pickle=False)
 
@@ -6270,11 +6250,11 @@
     return dict(
         fullname='MultiLabelMarginLoss_index_neg',
         constructor=wrap_functional(
-            lambda i: F.multilabel_margin_loss(i, t.type_as(i).long(), reduce=False)),
+            lambda i: F.multilabel_margin_loss(i, t.type_as(i).long(), reduction='none')),
         input_fn=lambda: torch.randn(5, 10),
         reference_fn=lambda i, _:
-            loss_reference_fns['MultiLabelMarginLoss'](i, t.data.type_as(i).long(), reduce=False),
-        check_no_size_average=True,
+            loss_reference_fns['MultiLabelMarginLoss'](i, t.data.type_as(i).long(), reduction='none'),
+        check_sum_reduction=True,
         check_gradgrad=False,
         pickle=False)
 
@@ -6284,11 +6264,11 @@
     return dict(
         fullname='MultiLabelMarginLoss_no_reduce',
         constructor=wrap_functional(
-            lambda i: F.multilabel_margin_loss(i, t.type_as(i).long(), reduce=False)),
+            lambda i: F.multilabel_margin_loss(i, t.type_as(i).long(), reduction='none')),
         input_fn=lambda: torch.randn(5, 10),
         reference_fn=lambda i, _:
-            loss_reference_fns['MultiLabelMarginLoss'](i, t.data.type_as(i).long(), reduce=False),
-        check_no_size_average=True,
+            loss_reference_fns['MultiLabelMarginLoss'](i, t.data.type_as(i).long(), reduction='none'),
+        check_sum_reduction=True,
         check_gradgrad=False,
         pickle=False)
 
@@ -6298,11 +6278,11 @@
     return dict(
         fullname='HingeEmbeddingLoss_no_reduce',
         constructor=wrap_functional(
-            lambda i: F.hinge_embedding_loss(i, t.type_as(i), reduce=False)),
+            lambda i: F.hinge_embedding_loss(i, t.type_as(i), reduction='none')),
         input_fn=lambda: torch.randn(10),
         reference_fn=lambda i, _:
-            loss_reference_fns['HingeEmbeddingLoss'](i, t.type_as(i), reduce=False),
-        check_no_size_average=True,
+            loss_reference_fns['HingeEmbeddingLoss'](i, t.type_as(i), reduction='none'),
+        check_sum_reduction=True,
         pickle=False)
 
 
@@ -6311,11 +6291,11 @@
     return dict(
         fullname='HingeEmbeddingLoss_margin_no_reduce',
         constructor=wrap_functional(
-            lambda i: F.hinge_embedding_loss(i, t.type_as(i), margin=0.5, reduce=False)),
+            lambda i: F.hinge_embedding_loss(i, t.type_as(i), margin=0.5, reduction='none')),
         input_fn=lambda: torch.randn(10),
         reference_fn=lambda i, _:
-            loss_reference_fns['HingeEmbeddingLoss'](i, t.type_as(i), margin=0.5, reduce=False),
-        check_no_size_average=True,
+            loss_reference_fns['HingeEmbeddingLoss'](i, t.type_as(i), margin=0.5, reduction='none'),
+        check_sum_reduction=True,
         pickle=False)
 
 
@@ -6324,11 +6304,10 @@
     return dict(
         fullname='SoftMarginLoss_no_reduce',
         constructor=wrap_functional(
-            lambda i: F.soft_margin_loss(i, t.type_as(i), reduce=False)),
+            lambda i: F.soft_margin_loss(i, t.type_as(i), reduction='none')),
         input_fn=lambda: torch.randn(5, 5),
         reference_fn=lambda i, _:
-            loss_reference_fns['SoftMarginLoss'](i, t.type_as(i), reduce=False),
-        check_no_size_average=True,
+            loss_reference_fns['SoftMarginLoss'](i, t.type_as(i), reduction='none'),
         pickle=False)
 
 
@@ -6337,11 +6316,9 @@
     return dict(
         fullname='MultiLabelSoftMarginLoss_no_reduce',
         constructor=wrap_functional(
-            lambda i: F.multilabel_soft_margin_loss(i, t.type_as(i), reduce=False)),
+            lambda i: F.multilabel_soft_margin_loss(i, t.type_as(i), reduction='none')),
         input_fn=lambda: torch.randn(5, 10),
-        reference_fn=lambda i, m: (-(t * i.sigmoid().log() + (1 - t) * (-i).sigmoid().log()) /
-                                   (i.numel() if get_size_average(m) else 1)),
-        check_no_size_average=True,
+        reference_fn=lambda i, m: -(t * i.sigmoid().log() + (1 - t) * (-i).sigmoid().log()),
         check_gradgrad=False,
         pickle=False)
 
@@ -6353,11 +6330,10 @@
         fullname='MultiLabelSoftMarginLoss_weights_no_reduce',
         constructor=wrap_functional(
             lambda i: F.multilabel_soft_margin_loss(i, t.type_as(i),
-                                                    weight=weights.type_as(i), reduce=False)),
+                                                    weight=weights.type_as(i), reduction='none')),
         input_fn=lambda: torch.randn(5, 10),
-        reference_fn=lambda i, m: (-((t * i.sigmoid().log() + (1 - t) * (-i).sigmoid().log()) * weights) /
-                                   (i.numel() if get_size_average(m) else 1)),
-        check_no_size_average=True,
+        reference_fn=lambda i, m: -((t * i.sigmoid().log() + (1 - t) * (-i).sigmoid().log()) * weights),
+        check_sum_reduction=True,
         check_gradgrad=False,
         pickle=False)
 
@@ -6367,11 +6343,11 @@
     return dict(
         fullname='MultiMarginLoss_no_reduce',
         constructor=wrap_functional(
-            lambda i: F.multi_margin_loss(i, t.type_as(i).long(), reduce=False)),
+            lambda i: F.multi_margin_loss(i, t.type_as(i).long(), reduction='none')),
         input_fn=lambda: torch.randn(5, 10),
         reference_fn=lambda i, _:
-            loss_reference_fns['MultiMarginLoss'](i, t.data.type_as(i).long(), reduce=False),
-        check_no_size_average=True,
+            loss_reference_fns['MultiMarginLoss'](i, t.data.type_as(i).long(), reduction='none'),
+        check_sum_reduction=True,
         check_gradgrad=False,
         pickle=False)
 
@@ -6381,11 +6357,11 @@
     return dict(
         fullname='MultiMarginLoss_1d_no_reduce',
         constructor=wrap_functional(
-            lambda i: F.multi_margin_loss(i, t.type_as(i).long(), reduce=False)),
+            lambda i: F.multi_margin_loss(i, t.type_as(i).long(), reduction='none')),
         input_fn=lambda: torch.randn(10),
         reference_fn=lambda i, _:
-            loss_reference_fns['MultiMarginLoss'](i, t.data.type_as(i).long(), reduce=False),
-        check_no_size_average=True,
+            loss_reference_fns['MultiMarginLoss'](i, t.data.type_as(i).long(), reduction='none'),
+        check_sum_reduction=True,
         check_gradgrad=False,
         pickle=False)
 
@@ -6395,11 +6371,11 @@
     return dict(
         fullname='MultiMarginLoss_p_no_reduce',
         constructor=wrap_functional(
-            lambda i: F.multi_margin_loss(i, t.type_as(i).long(), p=2, reduce=False)),
+            lambda i: F.multi_margin_loss(i, t.type_as(i).long(), p=2, reduction='none')),
         input_fn=lambda: torch.randn(5, 10).clamp_(1e-2, 1 - 1e-2),
         reference_fn=lambda i, _:
-            loss_reference_fns['MultiMarginLoss'](i, t.data.type_as(i).long(), p=2, reduce=False),
-        check_no_size_average=True,
+            loss_reference_fns['MultiMarginLoss'](i, t.data.type_as(i).long(), p=2, reduction='none'),
+        check_sum_reduction=True,
         check_gradgrad=False,
         pickle=False)
 
@@ -6409,12 +6385,12 @@
     return dict(
         fullname='MultiMarginLoss_margin_no_reduce',
         constructor=wrap_functional(
-            lambda i: F.multi_margin_loss(i, t.type_as(i).long(), margin=0.5, reduce=False)),
+            lambda i: F.multi_margin_loss(i, t.type_as(i).long(), margin=0.5, reduction='none')),
         input_fn=lambda: torch.randn(5, 10),
         reference_fn=lambda i, _:
             loss_reference_fns['MultiMarginLoss'](i, t.data.type_as(i).long(),
-                                                  margin=0.5, reduce=False),
-        check_no_size_average=True,
+                                                  margin=0.5, reduction='none'),
+        check_sum_reduction=True,
         check_gradgrad=False,
         pickle=False)
 
@@ -6426,12 +6402,12 @@
         fullname='MultiMarginLoss_weights_no_reduce',
         constructor=wrap_functional(
             lambda i: F.multi_margin_loss(i, t.type_as(i).long(), weight=weights.type_as(i),
-                                          reduce=False)),
+                                          reduction='none')),
         input_fn=lambda: torch.randn(5, 10),
         reference_fn=lambda i, _:
             loss_reference_fns['MultiMarginLoss'](i, t.data.type_as(i).long(),
-                                                  weight=weights, reduce=False),
-        check_no_size_average=True,
+                                                  weight=weights, reduction='none'),
+        check_sum_reduction=True,
         check_gradgrad=False,
         pickle=False)
 
@@ -7734,18 +7710,18 @@
     test = NewCriterionTest(**test_params)
     decorator = test_params.pop('decorator', None)
     add_test(test, decorator)
-    if 'check_no_size_average' in test_params:
+    if 'check_sum_reduction' in test_params:
         desc = test_params.get('desc', None)
-        test_params['desc'] = 'no_size_average' if desc is None else desc + '_no_size_average'
+        test_params['desc'] = 'sum_reduction' if desc is None else desc + '_sum_reduction'
 
-        def gen_no_size_average_constructor(constructor):
-            def no_size_average_constructor(*args, **kwargs):
-                cons = constructor(*args, size_average=False, **kwargs)
+        def gen_sum_reduction_constructor(constructor):
+            def sum_reduction_constructor(*args, **kwargs):
+                cons = constructor(*args, reduction='sum', **kwargs)
                 return cons
-            no_size_average_constructor.__name__ = constructor.__name__
-            return no_size_average_constructor
+            sum_reduction_constructor.__name__ = constructor.__name__
+            return sum_reduction_constructor
 
-        test_params['constructor'] = gen_no_size_average_constructor(test_params['constructor'])
+        test_params['constructor'] = gen_sum_reduction_constructor(test_params['constructor'])
         test = NewCriterionTest(**test_params)
         add_test(test, decorator)
 
diff --git a/torch/legacy/nn/AbsCriterion.py b/torch/legacy/nn/AbsCriterion.py
index 4a2ea68..66f7615 100644
--- a/torch/legacy/nn/AbsCriterion.py
+++ b/torch/legacy/nn/AbsCriterion.py
@@ -18,7 +18,7 @@
             input,
             target,
             self.output_tensor,
-            _Reduction.legacy_get_enum(self.sizeAverage, True),
+            _Reduction.legacy_get_enum(self.sizeAverage, True, emit_warning=False),
         )
         self.output = self.output_tensor[0].item()
         return self.output
@@ -31,6 +31,6 @@
             target,
             implicit_gradOutput,
             self.gradInput,
-            _Reduction.legacy_get_enum(self.sizeAverage, True),
+            _Reduction.legacy_get_enum(self.sizeAverage, True, emit_warning=False),
         )
         return self.gradInput
diff --git a/torch/legacy/nn/ClassNLLCriterion.py b/torch/legacy/nn/ClassNLLCriterion.py
index 50ddcfd..33c28e5 100644
--- a/torch/legacy/nn/ClassNLLCriterion.py
+++ b/torch/legacy/nn/ClassNLLCriterion.py
@@ -25,7 +25,7 @@
             input,
             target,
             self.output_tensor,
-            _Reduction.legacy_get_enum(self.sizeAverage, True),
+            _Reduction.legacy_get_enum(self.sizeAverage, True, emit_warning=False),
             self.weights,
             self.total_weight_tensor,
             self.ignore_index,
@@ -44,7 +44,7 @@
             target,
             implicit_gradOutput,
             self.gradInput,
-            _Reduction.legacy_get_enum(self.sizeAverage, True),
+            _Reduction.legacy_get_enum(self.sizeAverage, True, emit_warning=False),
             self.weights,
             self.total_weight_tensor,
             self.ignore_index,
diff --git a/torch/legacy/nn/ClassSimplexCriterion.py b/torch/legacy/nn/ClassSimplexCriterion.py
index b28ce67..1de5851 100644
--- a/torch/legacy/nn/ClassSimplexCriterion.py
+++ b/torch/legacy/nn/ClassSimplexCriterion.py
@@ -81,7 +81,7 @@
             input,
             self._target,
             self.output_tensor,
-            _Reduction.legacy_get_enum(self.sizeAverage, True),
+            _Reduction.legacy_get_enum(self.sizeAverage, True, emit_warning=False),
         )
         self.output = self.output_tensor[0].item()
         return self.output
@@ -95,7 +95,7 @@
             self._target,
             implicit_gradOutput,
             self.gradInput,
-            _Reduction.legacy_get_enum(self.sizeAverage, True),
+            _Reduction.legacy_get_enum(self.sizeAverage, True, emit_warning=False),
         )
         return self.gradInput
 
diff --git a/torch/legacy/nn/DistKLDivCriterion.py b/torch/legacy/nn/DistKLDivCriterion.py
index 8c18cf1..5aa1756 100644
--- a/torch/legacy/nn/DistKLDivCriterion.py
+++ b/torch/legacy/nn/DistKLDivCriterion.py
@@ -19,7 +19,7 @@
             input,
             target,
             self.output_tensor,
-            _Reduction.legacy_get_enum(self.sizeAverage, True),
+            _Reduction.legacy_get_enum(self.sizeAverage, True, emit_warning=False),
         )
         self.output = self.output_tensor[0].item()
         return self.output
@@ -33,6 +33,6 @@
             target,
             implicit_gradOutput,
             self.gradInput,
-            _Reduction.legacy_get_enum(self.sizeAverage, True),
+            _Reduction.legacy_get_enum(self.sizeAverage, True, emit_warning=False),
         )
         return self.gradInput
diff --git a/torch/legacy/nn/MSECriterion.py b/torch/legacy/nn/MSECriterion.py
index 2422e07..2079d36 100644
--- a/torch/legacy/nn/MSECriterion.py
+++ b/torch/legacy/nn/MSECriterion.py
@@ -18,7 +18,7 @@
             input,
             target,
             self.output_tensor,
-            _Reduction.legacy_get_enum(self.sizeAverage, True),
+            _Reduction.legacy_get_enum(self.sizeAverage, True, emit_warning=False),
         )
         self.output = self.output_tensor[0].item()
         return self.output
@@ -32,6 +32,6 @@
             target,
             implicit_gradOutput,
             self.gradInput,
-            _Reduction.legacy_get_enum(self.sizeAverage, True),
+            _Reduction.legacy_get_enum(self.sizeAverage, True, emit_warning=False),
         )
         return self.gradInput
diff --git a/torch/legacy/nn/MultiLabelMarginCriterion.py b/torch/legacy/nn/MultiLabelMarginCriterion.py
index 1de12bf..9ca2a23 100644
--- a/torch/legacy/nn/MultiLabelMarginCriterion.py
+++ b/torch/legacy/nn/MultiLabelMarginCriterion.py
@@ -21,7 +21,7 @@
             target,
             self.output_tensor,
             self.isTarget,
-            _Reduction.legacy_get_enum(self.sizeAverage, True),
+            _Reduction.legacy_get_enum(self.sizeAverage, True, emit_warning=False),
         )
         self.output = self.output_tensor[0].item()
         return self.output
@@ -36,6 +36,6 @@
             implicit_gradOutput,
             self.gradInput,
             self.isTarget,
-            _Reduction.legacy_get_enum(self.sizeAverage, True),
+            _Reduction.legacy_get_enum(self.sizeAverage, True, emit_warning=False),
         )
         return self.gradInput
diff --git a/torch/legacy/nn/MultiMarginCriterion.py b/torch/legacy/nn/MultiMarginCriterion.py
index 26b9cff..cc9835c 100644
--- a/torch/legacy/nn/MultiMarginCriterion.py
+++ b/torch/legacy/nn/MultiMarginCriterion.py
@@ -26,7 +26,7 @@
             input,
             target,
             self.output_tensor,
-            _Reduction.legacy_get_enum(self.sizeAverage, True),
+            _Reduction.legacy_get_enum(self.sizeAverage, True, emit_warning=False),
             self.p,
             self.weights,
             self.margin,
@@ -43,7 +43,7 @@
             target,
             implicit_gradOutput,
             self.gradInput,
-            _Reduction.legacy_get_enum(self.sizeAverage, True),
+            _Reduction.legacy_get_enum(self.sizeAverage, True, emit_warning=False),
             self.p,
             self.weights,
             self.margin,
diff --git a/torch/legacy/nn/SmoothL1Criterion.py b/torch/legacy/nn/SmoothL1Criterion.py
index c02e7a2..714d0b6 100644
--- a/torch/legacy/nn/SmoothL1Criterion.py
+++ b/torch/legacy/nn/SmoothL1Criterion.py
@@ -18,7 +18,7 @@
             input,
             target,
             self.output_tensor,
-            _Reduction.legacy_get_enum(self.sizeAverage, True),
+            _Reduction.legacy_get_enum(self.sizeAverage, True, emit_warning=False),
         )
         self.output = self.output_tensor[0].item()
         return self.output
@@ -31,6 +31,6 @@
             target,
             implicit_gradOutput,
             self.gradInput,
-            _Reduction.legacy_get_enum(self.sizeAverage, True),
+            _Reduction.legacy_get_enum(self.sizeAverage, True, emit_warning=False),
         )
         return self.gradInput
diff --git a/torch/legacy/nn/SoftMarginCriterion.py b/torch/legacy/nn/SoftMarginCriterion.py
index e56d871..4bfa371 100644
--- a/torch/legacy/nn/SoftMarginCriterion.py
+++ b/torch/legacy/nn/SoftMarginCriterion.py
@@ -18,7 +18,7 @@
             input,
             target,
             self.output_tensor,
-            _Reduction.legacy_get_enum(self.sizeAverage, True),
+            _Reduction.legacy_get_enum(self.sizeAverage, True, emit_warning=False),
         )
         self.output = self.output_tensor[0].item()
         return self.output
@@ -31,6 +31,6 @@
             target,
             implicit_gradOutput,
             self.gradInput,
-            _Reduction.legacy_get_enum(self.sizeAverage, True),
+            _Reduction.legacy_get_enum(self.sizeAverage, True, emit_warning=False),
         )
         return self.gradInput
diff --git a/torch/legacy/nn/SpatialClassNLLCriterion.py b/torch/legacy/nn/SpatialClassNLLCriterion.py
index 382cfea..8a7e15c 100644
--- a/torch/legacy/nn/SpatialClassNLLCriterion.py
+++ b/torch/legacy/nn/SpatialClassNLLCriterion.py
@@ -23,7 +23,7 @@
             input,
             target,
             self.output_tensor,
-            _Reduction.legacy_get_enum(self.sizeAverage, True),
+            _Reduction.legacy_get_enum(self.sizeAverage, True, emit_warning=False),
             self.weights,
             self.total_weight_tensor,
             self.ignore_index,
@@ -40,7 +40,7 @@
             target,
             implicit_gradOutput,
             self.gradInput,
-            _Reduction.legacy_get_enum(self.sizeAverage, True),
+            _Reduction.legacy_get_enum(self.sizeAverage, True, emit_warning=False),
             self.weights,
             self.total_weight_tensor,
             self.ignore_index,
diff --git a/torch/legacy/nn/WeightedMSECriterion.py b/torch/legacy/nn/WeightedMSECriterion.py
index 4e03439..2f0da29 100644
--- a/torch/legacy/nn/WeightedMSECriterion.py
+++ b/torch/legacy/nn/WeightedMSECriterion.py
@@ -29,7 +29,7 @@
             input,
             self.buffer,
             self.output_tensor,
-            _Reduction.legacy_get_enum(self.sizeAverage, True),
+            _Reduction.legacy_get_enum(self.sizeAverage, True, emit_warning=False),
         )
         self.output = self.output_tensor[0].item()
         return self.output
@@ -50,6 +50,6 @@
             self.buffer,
             implicit_gradOutput,
             self.gradInput,
-            _Reduction.legacy_get_enum(self.sizeAverage, True),
+            _Reduction.legacy_get_enum(self.sizeAverage, True, emit_warning=False),
         )
         return self.gradInput
diff --git a/torch/nn/functional.py b/torch/nn/functional.py
index 379cc11..a6c00b1 100644
--- a/torch/nn/functional.py
+++ b/torch/nn/functional.py
@@ -31,8 +31,10 @@
 
     # In order to support previous versions, accept boolean size_average and reduce
     # and convert them into the new constants for now
+
+    # We use these functions in torch/legacy as well, in which case we'll silence the warning
     @staticmethod
-    def legacy_get_string(size_average, reduce):
+    def legacy_get_string(size_average, reduce, emit_warning=True):
         warning = "size_average and reduce args will be deprecated, please use reduction='{}' instead."
 
         if size_average is None:
@@ -41,18 +43,18 @@
             reduce = True
 
         if size_average and reduce:
-            warnings.warn(warning.format('elementwise_mean'))
-            return 'elementwise_mean'
+            ret = 'elementwise_mean'
         elif reduce:
-            warnings.warn(warning.format('sum'))
-            return 'sum'
+            ret = 'sum'
         else:
-            warnings.warn(warning.format('none'))
-            return 'none'
+            ret = 'none'
+        if emit_warning:
+            warnings.warn(warning.format(ret))
+        return ret
 
     @staticmethod
-    def legacy_get_enum(size_average, reduce):
-        return _Reduction.get_enum(_Reduction.legacy_get_string(size_average, reduce))
+    def legacy_get_enum(size_average, reduce, emit_warning=True):
+        return _Reduction.get_enum(_Reduction.legacy_get_string(size_average, reduce, emit_warning))
 
 
 conv1d = _add_docstr(torch.conv1d, r"""