add reduce=True arg to HingeEmbeddingLoss (#5130)
* add reduce=True arg to HingeEmbeddingLoss
* pass arg to super constructor in HingeEmbeddingLoss
* make HingeEmbeddingLoss reference fn work on legacy
diff --git a/aten/src/ATen/native/Loss.cpp b/aten/src/ATen/native/Loss.cpp
index a4204e6..f1027b1 100644
--- a/aten/src/ATen/native/Loss.cpp
+++ b/aten/src/ATen/native/Loss.cpp
@@ -4,15 +4,17 @@
namespace at { namespace native {
-Tensor hinge_embedding_loss(const Tensor& self, const Tensor& target, double margin, bool size_average) {
+Tensor hinge_embedding_loss(const Tensor& self, const Tensor& target, double margin, bool size_average, bool reduce) {
auto zeros = at::zeros_like(self);
auto margin_clamp = (margin - self).clamp_min_(0);
auto output_margin = at::where(target != 1, margin_clamp, zeros);
auto output_self = at::where(target != -1, self, zeros);
- auto output = (output_margin + output_self).sum();
+ auto output = output_margin + output_self;
- if (size_average) {
- output = output / self.numel();
+ if (reduce && size_average) {
+ return output.sum() / self.numel();
+ } else if (reduce) {
+ return output.sum();
}
return output;
}
diff --git a/aten/src/ATen/native/native_functions.yaml b/aten/src/ATen/native/native_functions.yaml
index 861ab25..9c4efe6 100644
--- a/aten/src/ATen/native/native_functions.yaml
+++ b/aten/src/ATen/native/native_functions.yaml
@@ -183,7 +183,7 @@
- func: expand_as(Tensor self, Tensor other) -> Tensor
variants: method # This is method-only to match the previous tensor API. In the future we could make this a function too.
-- func: hinge_embedding_loss(Tensor self, Tensor target, double margin, bool size_average) -> Tensor
+- func: hinge_embedding_loss(Tensor self, Tensor target, double margin, bool size_average, bool reduce) -> Tensor
variants: function
- func: ger(Tensor self, Tensor vec2) -> Tensor
diff --git a/test/common_nn.py b/test/common_nn.py
index 015835e..a1542de 100644
--- a/test/common_nn.py
+++ b/test/common_nn.py
@@ -482,12 +482,29 @@
return output / dim
+def hingeembeddingloss_reference(input, target, margin=1.0, size_average=True, reduce=True):
+ # needed for legacy tests
+ if not isinstance(input, Variable):
+ input = Variable(input)
+ target = Variable(target)
+
+ margin_clamp = (margin - input).clamp(min=0).type_as(input)
+ output = torch.where(target == 1, input, margin_clamp)
+
+ if reduce and size_average:
+ return output.mean()
+ elif reduce:
+ return output.sum()
+ return output
+
+
loss_reference_fns = {
'KLDivLoss': kldivloss_reference,
'NLLLoss': nllloss_reference,
'NLLLossNd': nlllossNd_reference,
'SmoothL1Loss': smoothl1loss_reference,
'MultiLabelMarginLoss': multilabelmarginloss_reference,
+ 'HingeEmbeddingLoss': hingeembeddingloss_reference,
}
@@ -645,12 +662,17 @@
module_name='HingeEmbeddingLoss',
input_size=(10,),
target_fn=lambda: torch.randn(10).gt(0).double().mul_(2).sub(1),
+ reference_fn=lambda i, t, m:
+ hingeembeddingloss_reference(i, t, size_average=get_size_average(m)),
+ check_no_size_average=True,
),
dict(
module_name='HingeEmbeddingLoss',
constructor_args=(0.5,),
input_size=(10,),
target_fn=lambda: torch.randn(10).gt(0).double().mul_(2).sub(1),
+ reference_fn=lambda i, t, m:
+ hingeembeddingloss_reference(i, t, margin=0.5, size_average=get_size_average(m)),
desc='margin',
check_no_size_average=True,
),
diff --git a/test/test_nn.py b/test/test_nn.py
index 7ce609e..7b5f43f 100644
--- a/test/test_nn.py
+++ b/test/test_nn.py
@@ -4811,6 +4811,32 @@
pickle=False)
+def hingeembeddingloss_no_reduce_test():
+ t = Variable(torch.randn(10).gt(0).double().mul_(2).sub(1))
+ return dict(
+ fullname='HingeEmbeddingLoss_no_reduce',
+ constructor=wrap_functional(
+ lambda i: F.hinge_embedding_loss(i, t.type_as(i), reduce=False)),
+ input_fn=lambda: torch.randn(10),
+ reference_fn=lambda i, _:
+ loss_reference_fns['HingeEmbeddingLoss'](i, t.type_as(i), reduce=False),
+ check_no_size_average=True,
+ pickle=False)
+
+
+def hingeembeddingloss_margin_no_reduce_test():
+ t = Variable(torch.randn(10).gt(0).double().mul_(2).sub(1))
+ return dict(
+ fullname='HingeEmbeddingLoss_margin_no_reduce',
+ constructor=wrap_functional(
+ lambda i: F.hinge_embedding_loss(i, t.type_as(i), margin=0.5, reduce=False)),
+ input_fn=lambda: torch.randn(10),
+ reference_fn=lambda i, _:
+ loss_reference_fns['HingeEmbeddingLoss'](i, t.type_as(i), margin=0.5, reduce=False),
+ check_no_size_average=True,
+ pickle=False)
+
+
new_module_tests = [
poissonnllloss_no_reduce_test(),
bceloss_no_reduce_test(),
@@ -4841,6 +4867,8 @@
multilabelmarginloss_1d_no_reduce_test(),
multilabelmarginloss_index_neg_test(),
multilabelmarginloss_no_reduce_test(),
+ hingeembeddingloss_no_reduce_test(),
+ hingeembeddingloss_margin_no_reduce_test(),
dict(
module_name='BatchNorm1d',
constructor_args=(10,),
diff --git a/torch/nn/functional.py b/torch/nn/functional.py
index 212c614..4a1d09a 100644
--- a/torch/nn/functional.py
+++ b/torch/nn/functional.py
@@ -1512,12 +1512,12 @@
return _functions.loss.MarginRankingLoss.apply(input1, input2, target, margin, size_average)
-def hinge_embedding_loss(input, target, margin=1.0, size_average=True):
- """hinge_embedding_loss(input, target, margin=1.0, size_average=True) -> Variable
+def hinge_embedding_loss(input, target, margin=1.0, size_average=True, reduce=True):
+ """hinge_embedding_loss(input, target, margin=1.0, size_average=True, reduce=True) -> Variable
See :class:`~torch.nn.HingeEmbeddingLoss` for details.
"""
- return torch._C._VariableFunctions.hinge_embedding_loss(input, target, margin, size_average)
+ return torch._C._VariableFunctions.hinge_embedding_loss(input, target, margin, size_average, reduce)
multilabel_margin_loss = _add_docstr(torch._C._nn.multilabel_margin_loss, r"""
diff --git a/torch/nn/modules/loss.py b/torch/nn/modules/loss.py
index f82599c..1d91e92 100644
--- a/torch/nn/modules/loss.py
+++ b/torch/nn/modules/loss.py
@@ -530,22 +530,31 @@
where :math:`L = \{l_1,\dots,l_N\}^\top`.
- `x` and `y` can be of arbitrary shapes with a total of `n` elements each.
- The sum operation operates over all the elements.
+ Args:
+ margin (float, optional): Has a default value of `1`.
+ size_average (bool, optional): By default, the losses are averaged over
+ observations for each minibatch. However, if the field :attr:`size_average`
+ is set to ``False``, the losses are instead summed for each minibatch.
+ Default: ``True``
+ reduce (bool, optional): By default, the losses are averaged or summed over
+ observations for each minibatch depending on :attr:`size_average`. When
+ :attr:`reduce` is ``False``, returns a loss per batch element instead and
+ ignores :attr:`size_average`. Default: ``True``
- The division by `n` can be avoided if one sets the internal
- variable `size_average=False`.
-
- The `margin` has a default value of `1`, or can be set in the constructor.
+ Shape:
+ - Input: Tensor of arbitrary shape. The sum operation operates over all the elements.
+ - Target: Same shape as input.
+ - Output: scalar. If reduce is ``False``, then same shape as the input
"""
- def __init__(self, margin=1.0, size_average=True):
- super(HingeEmbeddingLoss, self).__init__()
+ def __init__(self, margin=1.0, size_average=True, reduce=True):
+ super(HingeEmbeddingLoss, self).__init__(size_average)
self.margin = margin
- self.size_average = size_average
+ self.reduce = reduce
def forward(self, input, target):
- return F.hinge_embedding_loss(input, target, self.margin, self.size_average)
+ return F.hinge_embedding_loss(input, target, self.margin, self.size_average,
+ self.reduce)
class MultiLabelMarginLoss(_Loss):