Add reduce arg to BCELoss (#4231)

* Add reduce arg to BCELoss

* Fix test precision

* reduce keyword for BCELoss in derivatives.yaml
diff --git a/aten/src/ATen/nn.yaml b/aten/src/ATen/nn.yaml
index 3df33ee..08d87b1 100644
--- a/aten/src/ATen/nn.yaml
+++ b/aten/src/ATen/nn.yaml
@@ -1,6 +1,6 @@
 # Loss functions
 
-- name: binary_cross_entropy(Tensor self, Tensor target, Tensor weight={}, bool size_average=true)
+- name: binary_cross_entropy(Tensor self, Tensor target, Tensor weight={}, bool size_average=true, bool reduce=true)
   cname: BCECriterion
 
 - name: kl_div(Tensor self, Tensor target, bool size_average=true, bool reduce=true)
diff --git a/aten/src/THCUNN/BCECriterion.cu b/aten/src/THCUNN/BCECriterion.cu
index 04218dc..ccb4000 100644
--- a/aten/src/THCUNN/BCECriterion.cu
+++ b/aten/src/THCUNN/BCECriterion.cu
@@ -33,6 +33,22 @@
 };
 
 template <typename Dtype, typename Acctype>
+struct bce_updateOutput_no_reduce_functor
+{
+  __forceinline__ __host__ __device__
+  void operator()(
+      const Dtype *input,
+      const Dtype *target,
+      Dtype *output)
+  {
+    assert(*input >= 0. && *input <= 1.);
+    *output = ScalarConvert<Acctype, Dtype>::to(
+        -(*target * THCNumerics<Acctype>::log(*input + eps<Acctype>()) +
+          (Acctype(1) - *target) * THCNumerics<Acctype>::log(Acctype(1) - *input + eps<Acctype>())));
+  }
+};
+
+template <typename Dtype, typename Acctype>
 struct bce_functor_weights
 {
   template <class Tuple>
@@ -43,7 +59,22 @@
     Dtype t = thrust::get<1>(x);
     Dtype w = thrust::get<2>(x);
     assert(input >= 0. && input <= 1.);
-    return - w * (t * THCNumerics<Acctype>::log(input + eps<Acctype>()) + (Acctype(1) - t) * THCNumerics<Acctype>::log(Acctype(1) - input + eps<Acctype>()));
+    return - w * (t * THCNumerics<Acctype>::log(input + eps<Acctype>()) +
+        (Acctype(1) - t) * THCNumerics<Acctype>::log(Acctype(1) - input + eps<Acctype>()));
+  }
+};
+
+template <typename Dtype, typename Acctype>
+struct bce_updateGradInput_no_reduce_functor
+{
+  __forceinline__ __host__ __device__
+  void operator()(
+      const Dtype *x,
+      const Dtype *t,
+      Dtype *gradInput)
+  {
+      *gradInput = ScalarConvert<Acctype,Dtype>::to(
+          - (*t - *x) / ((Acctype(1) - *x + eps<Acctype>()) * (*x + eps<Acctype>())));
   }
 };
 
diff --git a/aten/src/THCUNN/generic/BCECriterion.cu b/aten/src/THCUNN/generic/BCECriterion.cu
index 4d9988b..e98f1b0 100644
--- a/aten/src/THCUNN/generic/BCECriterion.cu
+++ b/aten/src/THCUNN/generic/BCECriterion.cu
@@ -2,19 +2,32 @@
 #define THC_GENERIC_FILE "generic/BCECriterion.cu"
 #else
 
+#include "THCApply.cuh"
+
 void THNN_(BCECriterion_updateOutput)(
            THCState *state,
            THCTensor *input,
            THCTensor *target,
            THCTensor *output,
            bool sizeAverage,
-           THCTensor *weights)
+           THCTensor *weights,
+           bool reduce)
 {
   THCUNN_check_nElement(state, input, target);
   THCUNN_check_nElement(state, input, weights);
-  THCTensor_(resize1d)(state, output, 1);
   THCUNN_assertSameGPU(state, 3, input, target, weights);
 
+  if (!reduce) {
+    THCTensor_(resizeAs)(state, output, input);
+    THC_pointwiseApply3(state, input, target, output,
+        bce_updateOutput_no_reduce_functor<real, accreal>());
+    if (weights) {
+      THCTensor_(cmul)(state, output, output, weights);
+    }
+    return;
+  }
+
+  THCTensor_(resize1d)(state, output, 1);
   ptrdiff_t size = THCTensor_(nElement)(state, input);
 
   input = THCTensor_(newContiguous)(state, input);
@@ -58,22 +71,37 @@
            THCState *state,
            THCTensor *input,
            THCTensor *target,
+           THCTensor *gradOutput,
            THCTensor *gradInput,
            bool sizeAverage,
-           THCTensor *weights)
+           THCTensor *weights,
+           bool reduce)
 {
   THCUNN_check_nElement(state, input, target);
   THCUNN_check_nElement(state, input, weights);
   THCUNN_assertSameGPU(state, 4, input, target, gradInput, weights);
 
+  THCTensor_(resizeAs)(state, gradInput, input);
+
+  if (!reduce) {
+    THCUNN_check_nElement(state, gradOutput, input);
+    THC_pointwiseApply3(state, input, target, gradInput,
+        bce_updateGradInput_no_reduce_functor<real, accreal>());
+    THCTensor_(cmul)(state, gradInput, gradInput, gradOutput);
+    if (weights) {
+      THCTensor_(cmul)(state, gradInput, gradInput, weights);
+    }
+    return;
+  }
+
+  THCUNN_check_dim_size(state, gradOutput, 1, 0, 1);
+
   ptrdiff_t size = THCTensor_(nElement)(state, input);
-  real norm = ScalarConvert<accreal, real>::to(sizeAverage ? accreal(1)/size : accreal(1));
+  real norm = ScalarConvert<accreal, real>::to((sizeAverage ? accreal(1)/size : accreal(1)) * THCTensor_(get1d)(state, gradOutput, 0));
 
   input = THCTensor_(newContiguous)(state, input);
   target = THCTensor_(newContiguous)(state, target);
 
-  THCTensor_(resizeAs)(state, gradInput, input);
-
   thrust::device_ptr<real> input_data(THCTensor_(data)(state, input));
   thrust::device_ptr<real> target_data(THCTensor_(data)(state, target));
   thrust::device_ptr<real> gradInput_data(THCTensor_(data)(state, gradInput));
diff --git a/aten/src/THCUNN/generic/THCUNN.h b/aten/src/THCUNN/generic/THCUNN.h
index e86583b..7f5002f 100644
--- a/aten/src/THCUNN/generic/THCUNN.h
+++ b/aten/src/THCUNN/generic/THCUNN.h
@@ -66,15 +66,18 @@
                   THCTensor *target,
                   THCTensor *output,
                   bool sizeAverage,
-                  THCTensor *weights);        // [OPTIONAL]
+                  THCTensor *weights,         // [OPTIONAL]
+                  bool reduce);
 
 TH_API void THNN_(BCECriterion_updateGradInput)(
                   THCState *state,
                   THCTensor *input,
                   THCTensor *target,
+                  THCTensor *gradOutput,
                   THCTensor *gradInput,
                   bool sizeAverage,
-                  THCTensor *weights);        // [OPTIONAL]
+                  THCTensor *weights,         // [OPTIONAL]
+                  bool reduce);
 
 TH_API void THNN_(ClassNLLCriterion_updateOutput)(
                   THCState *state,
diff --git a/aten/src/THNN/generic/BCECriterion.c b/aten/src/THNN/generic/BCECriterion.c
index b668370..1f69c31 100644
--- a/aten/src/THNN/generic/BCECriterion.c
+++ b/aten/src/THNN/generic/BCECriterion.c
@@ -4,16 +4,38 @@
 
 #define EPS 1e-12
 
-void THNN_(BCECriterion_updateOutput)(THNNState *state, THTensor *input,
-				      THTensor *target, THTensor *output,
-				      bool sizeAverage, THTensor *weights)
+void THNN_(BCECriterion_updateOutput)(
+    THNNState *state,
+    THTensor *input,
+    THTensor *target,
+    THTensor *output,
+    bool sizeAverage,
+    THTensor *weights,
+    bool reduce)
 {
   THNN_CHECK_NELEMENT(input, target);
   THNN_CHECK_NELEMENT(input, weights);
+
+  if (!reduce) {
+    THTensor_(resizeAs)(output, input);
+    TH_TENSOR_APPLY3(real, input, real, target, real, output,
+        real x = *input_data;
+        real y = *target_data;
+        THAssertMsg(x >= 0. && x <= 1.,
+          "input value should be between 0~1, but got %f",
+		      (double) x);
+		    *output_data = -(log(x + EPS) * y + log(1. - x + EPS) * (1. - y));
+    );
+		if (weights) {
+      THTensor_(cmul)(output, output, weights);
+    }
+    return;
+  }
+
 	THTensor_(resize1d)(output, 1);
   real sum = 0;
 
-  if(weights)
+  if (weights) {
     TH_TENSOR_APPLY3(real, input, real, target, real, weights,
       real x = *input_data;
       real y = *target_data;
@@ -22,8 +44,8 @@
         "input value should be between 0~1, but got %f",
 		  (double) x);
       sum -= (log(x + EPS) * y + log(1. - x + EPS) * (1. - y)) * w;
-    )
-  else
+    );
+  } else {
     TH_TENSOR_APPLY2(real, input, real, target,
       real x = *input_data;
       real y = *target_data;
@@ -32,6 +54,7 @@
 		  (double) x);
       sum -= log(x + EPS) * y + log(1. - x + EPS) * (1. - y);
     );
+  }
 
 
   if (sizeAverage)
@@ -40,21 +63,45 @@
   THTensor_(set1d)(output, 0, sum);
 }
 
-void THNN_(BCECriterion_updateGradInput)(THNNState *state, THTensor *input,
-					 THTensor *target, THTensor *gradInput,
-					 bool sizeAverage, THTensor *weights)
+void THNN_(BCECriterion_updateGradInput)(
+    THNNState *state,
+    THTensor *input,
+    THTensor *target,
+    THTensor *gradOutput,
+    THTensor *gradInput,
+    bool sizeAverage,
+    THTensor *weights,
+    bool reduce)
 {
   THNN_CHECK_NELEMENT(input, target);
   THNN_CHECK_NELEMENT(input, weights);
-
-  real norm = (sizeAverage ? 1./((real)THTensor_(nElement)(input)) : 1.);
-
   THTensor_(resizeAs)(gradInput, input);
 
+  if (!reduce) {
+    THNN_CHECK_NELEMENT(gradOutput, input);
+    TH_TENSOR_APPLY3(real, gradInput, real, input, real, target,
+      real x = *input_data;
+      real y = *target_data;
+      *gradInput_data = -(y - x) / ((1. - x + EPS) * (x + EPS));
+    );
+
+    if (weights) {
+      TH_TENSOR_APPLY3(real, gradInput, real, weights, real, gradOutput,
+        *gradInput_data = *gradInput_data * *weights_data * *gradOutput_data;
+      );
+    } else {
+      THTensor_(cmul)(gradInput, gradInput, gradOutput);
+    }
+    return;
+  }
+
+  THNN_CHECK_DIM_SIZE(gradOutput, 1, 0, 1);
+  real norm = (sizeAverage ? 1./((real)THTensor_(nElement)(input)) : 1.);
+
   TH_TENSOR_APPLY3(real, gradInput, real, input, real, target,
     real x = *input_data;
     real y = *target_data;
-    *gradInput_data = - norm * (y - x) / ((1. - x + EPS) * (x + EPS));
+    *gradInput_data = - norm * (y - x) / ((1. - x + EPS) * (x + EPS)) * THTensor_fastGet1d(gradOutput, 0);
   );
 
   if(weights)
diff --git a/aten/src/THNN/generic/THNN.h b/aten/src/THNN/generic/THNN.h
index 32fe59a..93436df 100644
--- a/aten/src/THNN/generic/THNN.h
+++ b/aten/src/THNN/generic/THNN.h
@@ -34,14 +34,17 @@
           THTensor *target,
           THTensor *output,
           bool sizeAverage,
-          THTensor *weights);          // [OPTIONAL]
+          THTensor *weights,           // [OPTIONAL]
+          bool reduce);
 TH_API void THNN_(BCECriterion_updateGradInput)(
           THNNState *state,
           THTensor *input,
           THTensor *target,
+          THTensor *gradOutput,
           THTensor *gradInput,
           bool sizeAverage,
-          THTensor *weights);          // [OPTIONAL]
+          THTensor *weights,           // [OPTIONAL]
+          bool reduce);
 
 TH_API void THNN_(ClassNLLCriterion_updateOutput)(
           THNNState *state,            // library's state
diff --git a/test/common_nn.py b/test/common_nn.py
index c0a2e1b..3b926ad 100644
--- a/test/common_nn.py
+++ b/test/common_nn.py
@@ -400,6 +400,8 @@
         module_name='BCELoss',
         input_fn=lambda: torch.rand(15, 10).clamp_(1e-2, 1 - 1e-2),
         target_fn=lambda: torch.randn(15, 10).gt(0).double(),
+        reference_fn=lambda i, t, m: -(t * i.log() + (1 - t) * (1 - i).log()).sum() /
+            (i.numel() if get_size_average(m) else 1),
         check_gradgrad=False,
     ),
     dict(
@@ -407,6 +409,8 @@
         constructor_args_fn=lambda: (torch.rand(10),),
         input_fn=lambda: torch.rand(15, 10).clamp_(1e-2, 1 - 1e-2),
         target_fn=lambda: torch.randn(15, 10).gt(0).double(),
+        reference_fn=lambda i, t, m: -((t * i.log() + (1 - t) * (1 - i).log()) * get_weight(m)).sum() /
+            (i.numel() if get_size_average(m) else 1),
         desc='weights',
         check_gradgrad=False,
     ),
diff --git a/test/test_nn.py b/test/test_nn.py
index ef65b18..f12ba8c 100644
--- a/test/test_nn.py
+++ b/test/test_nn.py
@@ -3962,6 +3962,32 @@
         pickle=False)
 
 
+def bceloss_no_reduce_test():
+    t = torch.randn(15, 10).gt(0).double()
+    return dict(
+        fullname='BCELoss_no_reduce',
+        constructor=wrap_functional(
+            lambda i: F.binary_cross_entropy(i, Variable(t.type_as(i.data)), reduce=False)),
+        input_fn=lambda: torch.rand(15, 10).clamp_(2e-2, 1 - 2e-2),
+        reference_fn=lambda i, m: -(t * i.log() + (1 - t) * (1 - i).log()),
+        check_gradgrad=False,
+        pickle=False)
+
+
+def bceloss_weights_no_reduce_test():
+    t = torch.randn(15, 10).gt(0).double()
+    weights = torch.rand(10)
+    return dict(
+        fullname='BCELoss_weights_no_reduce',
+        constructor=wrap_functional(
+            lambda i: F.binary_cross_entropy(i, Variable(t.type_as(i.data)),
+                                             weight=weights.type_as(i.data), reduce=False)),
+        input_fn=lambda: torch.rand(15, 10).clamp_(2e-2, 1 - 2e-2),
+        reference_fn=lambda i, m: -(t * i.log() + (1 - t) * (1 - i).log()) * weights,
+        check_gradgrad=False,
+        pickle=False)
+
+
 def kldivloss_no_reduce_test():
     t = Variable(torch.randn(10, 10))
     return dict(
@@ -4176,6 +4202,8 @@
 
 new_module_tests = [
     poissonnllloss_no_reduce_test(),
+    bceloss_no_reduce_test(),
+    bceloss_weights_no_reduce_test(),
     kldivloss_no_reduce_test(),
     l1loss_no_reduce_test(),
     mseloss_no_reduce_test(),
diff --git a/tools/autograd/derivatives.yaml b/tools/autograd/derivatives.yaml
index 2b9ec23..6bc7d4c 100644
--- a/tools/autograd/derivatives.yaml
+++ b/tools/autograd/derivatives.yaml
@@ -655,8 +655,8 @@
 
 # NN
 
-- name: binary_cross_entropy(Tensor self, Tensor target, Tensor weight, bool size_average)
-  self: binary_cross_entropy_backward(self, target, weight, size_average).mul_(grad)
+- name: binary_cross_entropy(Tensor self, Tensor target, Tensor weight, bool size_average, bool reduce)
+  self: binary_cross_entropy_backward(grad, self, target, weight, size_average, reduce)
 
 - name: kl_div(Tensor self, Tensor target, bool size_average, bool reduce)
   self: kl_div_backward(grad, self, target, size_average, reduce)
diff --git a/torch/nn/functional.py b/torch/nn/functional.py
index a9b5a24..8ba9292 100644
--- a/torch/nn/functional.py
+++ b/torch/nn/functional.py
@@ -1212,7 +1212,7 @@
     return nll_loss(log_softmax(input, 1), target, weight, size_average, ignore_index, reduce)
 
 
-def binary_cross_entropy(input, target, weight=None, size_average=True):
+def binary_cross_entropy(input, target, weight=None, size_average=True, reduce=True):
     r"""Function that measures the Binary Cross Entropy
     between the target and the output.
 
@@ -1227,6 +1227,10 @@
                 over observations for each minibatch. However, if the field
                 sizeAverage is set to False, the losses are instead summed
                 for each minibatch. Default: ``True``
+        reduce (bool, optional): By default, the losses are averaged or summed over
+                observations for each minibatch depending on size_average. When reduce
+                is False, returns a loss per batch element instead and ignores
+                size_average. Default: True
 
     Examples::
 
@@ -1248,7 +1252,7 @@
         if torch.is_tensor(weight):
             weight = Variable(weight)
 
-    return torch._C._nn.binary_cross_entropy(input, target, weight, size_average)
+    return torch._C._nn.binary_cross_entropy(input, target, weight, size_average, reduce)
 
 
 def binary_cross_entropy_with_logits(input, target, weight=None, size_average=True):
diff --git a/torch/nn/modules/loss.py b/torch/nn/modules/loss.py
index b70a5ce..4b3620b 100644
--- a/torch/nn/modules/loss.py
+++ b/torch/nn/modules/loss.py
@@ -366,11 +366,17 @@
             over observations for each minibatch. However, if the field
             size_average is set to ``False``, the losses are instead summed for
             each minibatch. Default: ``True``
+        reduce (bool, optional): By default, the losses are averaged or summed over
+            observations for each minibatch depending on size_average. When reduce
+            is False, returns a loss per batch element instead and ignores
+            size_average. Default: True
 
     Shape:
         - Input: :math:`(N, *)` where `*` means, any number of additional
           dimensions
         - Target: :math:`(N, *)`, same shape as the input
+        - Output: scalar. If `reduce` is False, then `(N, *)`, same shape as
+          input.
 
     Examples::
 
@@ -381,10 +387,15 @@
         >>> output = loss(m(input), target)
         >>> output.backward()
     """
+    def __init__(self, weight=None, size_average=True, reduce=True):
+        super(BCELoss, self).__init__(weight, size_average)
+        self.reduce = reduce
+
     def forward(self, input, target):
         _assert_no_grad(target)
         return F.binary_cross_entropy(input, target, weight=self.weight,
-                                      size_average=self.size_average)
+                                      size_average=self.size_average,
+                                      reduce=self.reduce)
 
 
 class BCEWithLogitsLoss(Module):