add reduce=True arg to HingeEmbeddingLoss (#5130)

* add reduce=True arg to HingeEmbeddingLoss

* pass arg to super constructor in HingeEmbeddingLoss

* make HingeEmbeddingLoss reference fn work on legacy
diff --git a/aten/src/ATen/native/Loss.cpp b/aten/src/ATen/native/Loss.cpp
index a4204e6..f1027b1 100644
--- a/aten/src/ATen/native/Loss.cpp
+++ b/aten/src/ATen/native/Loss.cpp
@@ -4,15 +4,17 @@
 
 namespace at { namespace native {
 
-Tensor hinge_embedding_loss(const Tensor& self, const Tensor& target, double margin, bool size_average) {
+Tensor hinge_embedding_loss(const Tensor& self, const Tensor& target, double margin, bool size_average, bool reduce) {
   auto zeros = at::zeros_like(self);
   auto margin_clamp = (margin - self).clamp_min_(0);
   auto output_margin = at::where(target != 1, margin_clamp, zeros);
   auto output_self = at::where(target != -1, self, zeros);
-  auto output = (output_margin + output_self).sum();
+  auto output = output_margin + output_self;
 
-  if (size_average) {
-    output = output / self.numel();
+  if (reduce && size_average) {
+    return output.sum() / self.numel();
+  } else if (reduce) {
+    return output.sum();
   }
   return output;
 }
diff --git a/aten/src/ATen/native/native_functions.yaml b/aten/src/ATen/native/native_functions.yaml
index 861ab25..9c4efe6 100644
--- a/aten/src/ATen/native/native_functions.yaml
+++ b/aten/src/ATen/native/native_functions.yaml
@@ -183,7 +183,7 @@
 - func: expand_as(Tensor self, Tensor other) -> Tensor
   variants: method  # This is method-only to match the previous tensor API. In the future we could make this a function too.
 
-- func: hinge_embedding_loss(Tensor self, Tensor target, double margin, bool size_average) -> Tensor
+- func: hinge_embedding_loss(Tensor self, Tensor target, double margin, bool size_average, bool reduce) -> Tensor
   variants: function
 
 - func: ger(Tensor self, Tensor vec2) -> Tensor
diff --git a/test/common_nn.py b/test/common_nn.py
index 015835e..a1542de 100644
--- a/test/common_nn.py
+++ b/test/common_nn.py
@@ -482,12 +482,29 @@
     return output / dim
 
 
+def hingeembeddingloss_reference(input, target, margin=1.0, size_average=True, reduce=True):
+    # needed for legacy tests
+    if not isinstance(input, Variable):
+        input = Variable(input)
+        target = Variable(target)
+
+    margin_clamp = (margin - input).clamp(min=0).type_as(input)
+    output = torch.where(target == 1, input, margin_clamp)
+
+    if reduce and size_average:
+        return output.mean()
+    elif reduce:
+        return output.sum()
+    return output
+
+
 loss_reference_fns = {
     'KLDivLoss': kldivloss_reference,
     'NLLLoss': nllloss_reference,
     'NLLLossNd': nlllossNd_reference,
     'SmoothL1Loss': smoothl1loss_reference,
     'MultiLabelMarginLoss': multilabelmarginloss_reference,
+    'HingeEmbeddingLoss': hingeembeddingloss_reference,
 }
 
 
@@ -645,12 +662,17 @@
         module_name='HingeEmbeddingLoss',
         input_size=(10,),
         target_fn=lambda: torch.randn(10).gt(0).double().mul_(2).sub(1),
+        reference_fn=lambda i, t, m:
+            hingeembeddingloss_reference(i, t, size_average=get_size_average(m)),
+        check_no_size_average=True,
     ),
     dict(
         module_name='HingeEmbeddingLoss',
         constructor_args=(0.5,),
         input_size=(10,),
         target_fn=lambda: torch.randn(10).gt(0).double().mul_(2).sub(1),
+        reference_fn=lambda i, t, m:
+            hingeembeddingloss_reference(i, t, margin=0.5, size_average=get_size_average(m)),
         desc='margin',
         check_no_size_average=True,
     ),
diff --git a/test/test_nn.py b/test/test_nn.py
index 7ce609e..7b5f43f 100644
--- a/test/test_nn.py
+++ b/test/test_nn.py
@@ -4811,6 +4811,32 @@
         pickle=False)
 
 
+def hingeembeddingloss_no_reduce_test():
+    t = Variable(torch.randn(10).gt(0).double().mul_(2).sub(1))
+    return dict(
+        fullname='HingeEmbeddingLoss_no_reduce',
+        constructor=wrap_functional(
+            lambda i: F.hinge_embedding_loss(i, t.type_as(i), reduce=False)),
+        input_fn=lambda: torch.randn(10),
+        reference_fn=lambda i, _:
+            loss_reference_fns['HingeEmbeddingLoss'](i, t.type_as(i), reduce=False),
+        check_no_size_average=True,
+        pickle=False)
+
+
+def hingeembeddingloss_margin_no_reduce_test():
+    t = Variable(torch.randn(10).gt(0).double().mul_(2).sub(1))
+    return dict(
+        fullname='HingeEmbeddingLoss_margin_no_reduce',
+        constructor=wrap_functional(
+            lambda i: F.hinge_embedding_loss(i, t.type_as(i), margin=0.5, reduce=False)),
+        input_fn=lambda: torch.randn(10),
+        reference_fn=lambda i, _:
+            loss_reference_fns['HingeEmbeddingLoss'](i, t.type_as(i), margin=0.5, reduce=False),
+        check_no_size_average=True,
+        pickle=False)
+
+
 new_module_tests = [
     poissonnllloss_no_reduce_test(),
     bceloss_no_reduce_test(),
@@ -4841,6 +4867,8 @@
     multilabelmarginloss_1d_no_reduce_test(),
     multilabelmarginloss_index_neg_test(),
     multilabelmarginloss_no_reduce_test(),
+    hingeembeddingloss_no_reduce_test(),
+    hingeembeddingloss_margin_no_reduce_test(),
     dict(
         module_name='BatchNorm1d',
         constructor_args=(10,),
diff --git a/torch/nn/functional.py b/torch/nn/functional.py
index 212c614..4a1d09a 100644
--- a/torch/nn/functional.py
+++ b/torch/nn/functional.py
@@ -1512,12 +1512,12 @@
     return _functions.loss.MarginRankingLoss.apply(input1, input2, target, margin, size_average)
 
 
-def hinge_embedding_loss(input, target, margin=1.0, size_average=True):
-    """hinge_embedding_loss(input, target, margin=1.0, size_average=True) -> Variable
+def hinge_embedding_loss(input, target, margin=1.0, size_average=True, reduce=True):
+    """hinge_embedding_loss(input, target, margin=1.0, size_average=True, reduce=True) -> Variable
 
     See :class:`~torch.nn.HingeEmbeddingLoss` for details.
     """
-    return torch._C._VariableFunctions.hinge_embedding_loss(input, target, margin, size_average)
+    return torch._C._VariableFunctions.hinge_embedding_loss(input, target, margin, size_average, reduce)
 
 
 multilabel_margin_loss = _add_docstr(torch._C._nn.multilabel_margin_loss, r"""
diff --git a/torch/nn/modules/loss.py b/torch/nn/modules/loss.py
index f82599c..1d91e92 100644
--- a/torch/nn/modules/loss.py
+++ b/torch/nn/modules/loss.py
@@ -530,22 +530,31 @@
 
     where :math:`L = \{l_1,\dots,l_N\}^\top`.
 
-    `x` and `y` can be of arbitrary shapes with a total of `n` elements each.
-    The sum operation operates over all the elements.
+    Args:
+        margin (float, optional): Has a default value of `1`.
+        size_average (bool, optional): By default, the losses are averaged over
+            observations for each minibatch. However, if the field :attr:`size_average`
+            is set to ``False``, the losses are instead summed for each minibatch.
+            Default: ``True``
+        reduce (bool, optional): By default, the losses are averaged or summed over
+            observations for each minibatch depending on :attr:`size_average`. When
+            :attr:`reduce` is ``False``, returns a loss per batch element instead and
+            ignores :attr:`size_average`. Default: ``True``
 
-    The division by `n` can be avoided if one sets the internal
-    variable `size_average=False`.
-
-    The `margin` has a default value of `1`, or can be set in the constructor.
+    Shape:
+        - Input: Tensor of arbitrary shape. The sum operation operates over all the elements.
+        - Target: Same shape as input.
+        - Output: scalar. If reduce is ``False``, then same shape as the input
     """
 
-    def __init__(self, margin=1.0, size_average=True):
-        super(HingeEmbeddingLoss, self).__init__()
+    def __init__(self, margin=1.0, size_average=True, reduce=True):
+        super(HingeEmbeddingLoss, self).__init__(size_average)
         self.margin = margin
-        self.size_average = size_average
+        self.reduce = reduce
 
     def forward(self, input, target):
-        return F.hinge_embedding_loss(input, target, self.margin, self.size_average)
+        return F.hinge_embedding_loss(input, target, self.margin, self.size_average,
+                                      self.reduce)
 
 
 class MultiLabelMarginLoss(_Loss):