accelerate `binary_cross_entropy_with_logits` by using `log_sigmoid` operator (#115539)
When I was reimplementing BCEwithLogits, I found that `log_sigmoid` operator could accelerate the function.
Simple benchmark on AMD 3600 CPU Ubuntu 22.04:
|avg time (ms)|with `pos_weight`|no `pos_weight`|
|-|-|-|
|original|1986|1658|
|this PR|1295|995|
faster 35-40%. This is probably benefited by the `log_sigmoid` vectorization code.
CUDA benchmark was not obtained, but I believe CUDA can be also benefited by reduecing kernel launches as https://github.com/pytorch/pytorch/pull/11054#issuecomment-442233714 and https://github.com/pytorch/pytorch/pull/78267#issue-1248398454 mentioned.
The simple benchmark cpp file:
[demo.txt](https://github.com/pytorch/pytorch/files/13635355/demo.txt)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/115539
Approved by: https://github.com/malfet
diff --git a/aten/src/ATen/native/Loss.cpp b/aten/src/ATen/native/Loss.cpp
index 0eafdf2..1eedceb 100644
--- a/aten/src/ATen/native/Loss.cpp
+++ b/aten/src/ATen/native/Loss.cpp
@@ -30,6 +30,7 @@
#include <ATen/ops/kl_div_native.h>
#include <ATen/ops/l1_loss_native.h>
#include <ATen/ops/log.h>
+#include <ATen/ops/log_sigmoid.h>
#include <ATen/ops/margin_ranking_loss_native.h>
#include <ATen/ops/mean.h>
#include <ATen/ops/min.h>
@@ -358,21 +359,20 @@
c10::MaybeOwned<Tensor> pos_weight_maybe_owned = at::borrow_from_optional_tensor(pos_weight_opt);
const Tensor& pos_weight = *pos_weight_maybe_owned;
- Tensor loss;
- auto max_val = (-input).clamp_min_(0);
- if (pos_weight.defined()) {
- // pos_weight need to be broadcasted, thus mul(target) is not inplace.
- auto log_weight = (pos_weight - 1).mul(target).add_(1);
- loss = (1 - target).mul_(input).add_(log_weight.mul_(((-max_val).exp_().add_((-input - max_val).exp_())).log_().add_(max_val)));
- } else {
- loss = (1 - target).mul_(input).add_(max_val).add_((-max_val).exp_().add_((-input -max_val).exp_()).log_());
- }
+ Tensor loss;
+ if (pos_weight.defined()) {
+ // pos_weight need to be broadcasted, thus mul(target) is not inplace.
+ auto log_weight = (pos_weight - 1).mul(target).add_(1);
+ loss = (1 - target).mul_(input).sub_(log_weight.mul_(at::log_sigmoid(input)));
+ } else {
+ loss = (1 - target).mul_(input).sub_(at::log_sigmoid(input));
+ }
- if (weight.defined()) {
- loss.mul_(weight);
- }
+ if (weight.defined()) {
+ loss.mul_(weight);
+ }
- return apply_loss_reduction(loss, reduction);
+ return apply_loss_reduction(loss, reduction);
}
Tensor poisson_nll_loss(const Tensor& input, const Tensor& target, const bool log_input, const bool full, const double eps, const int64_t reduction)
diff --git a/test/profiler/test_memory_profiler.py b/test/profiler/test_memory_profiler.py
index f9348f8..a492730 100644
--- a/test/profiler/test_memory_profiler.py
+++ b/test/profiler/test_memory_profiler.py
@@ -1147,26 +1147,26 @@
aten::mul.Tensor 1 (INPUT), 3 (INPUT) -> 4 (INPUT)
aten::mul.Tensor 1 (INPUT), 5 (INPUT) -> 6 (INPUT)
aten::cat 4 (INPUT), 6 (INPUT) -> 7 (INPUT)
- aten::binary_cross_entropy_with_logits 7 (INPUT), 2 (INPUT) -> 13 (INPUT)
+ aten::binary_cross_entropy_with_logits 7 (INPUT), 2 (INPUT) -> 11 (INPUT)
-- Backward ---------------------------------------------------------------------------------------------
- aten::ones_like 13 (INPUT) -> 16 (INPUT)
- aten::sigmoid 7 (INPUT) -> 17 (TEMPORARY)
- aten::sub.Tensor 17 (TEMPORARY), 2 (INPUT) -> 18 (TEMPORARY)
- aten::mul.Tensor 18 (TEMPORARY), 16 (INPUT) -> 19 (AUTOGRAD_DETAIL)
- aten::div_.Scalar 19 (AUTOGRAD_DETAIL) -> 19 (AUTOGRAD_DETAIL)
- aten::slice.Tensor 19 (AUTOGRAD_DETAIL) -> 19 (AUTOGRAD_DETAIL)
- aten::slice.Tensor 19 (AUTOGRAD_DETAIL) -> 19 (AUTOGRAD_DETAIL)
- aten::mul.Tensor 19 (AUTOGRAD_DETAIL), 1 (INPUT) -> 22 (AUTOGRAD_DETAIL)
+ aten::ones_like 11 (INPUT) -> 14 (INPUT)
+ aten::sigmoid 7 (INPUT) -> 15 (TEMPORARY)
+ aten::sub.Tensor 15 (TEMPORARY), 2 (INPUT) -> 16 (TEMPORARY)
+ aten::mul.Tensor 16 (TEMPORARY), 14 (INPUT) -> 17 (AUTOGRAD_DETAIL)
+ aten::div_.Scalar 17 (AUTOGRAD_DETAIL) -> 17 (AUTOGRAD_DETAIL)
+ aten::slice.Tensor 17 (AUTOGRAD_DETAIL) -> 17 (AUTOGRAD_DETAIL)
+ aten::slice.Tensor 17 (AUTOGRAD_DETAIL) -> 17 (AUTOGRAD_DETAIL)
+ aten::mul.Tensor 17 (AUTOGRAD_DETAIL), 1 (INPUT) -> 20 (AUTOGRAD_DETAIL)
+ aten::sum.dim_IntList 20 (AUTOGRAD_DETAIL) -> 21 (GRADIENT)
+ aten::view 21 (GRADIENT) -> 21 (GRADIENT)
+ aten::detach 21 (GRADIENT) -> 21 (GRADIENT)
+ aten::detach 21 (GRADIENT) -> ???
+ aten::mul.Tensor 17 (AUTOGRAD_DETAIL), 1 (INPUT) -> 22 (AUTOGRAD_DETAIL)
aten::sum.dim_IntList 22 (AUTOGRAD_DETAIL) -> 23 (GRADIENT)
aten::view 23 (GRADIENT) -> 23 (GRADIENT)
aten::detach 23 (GRADIENT) -> 23 (GRADIENT)
- aten::detach 23 (GRADIENT) -> ???
- aten::mul.Tensor 19 (AUTOGRAD_DETAIL), 1 (INPUT) -> 24 (AUTOGRAD_DETAIL)
- aten::sum.dim_IntList 24 (AUTOGRAD_DETAIL) -> 25 (GRADIENT)
- aten::view 25 (GRADIENT) -> 25 (GRADIENT)
- aten::detach 25 (GRADIENT) -> 25 (GRADIENT)
- aten::detach 25 (GRADIENT) -> ???""",
+ aten::detach 23 (GRADIENT) -> ???""",
)
def test_categories_e2e_simple_fwd_bwd_step(self) -> None:
@@ -1199,30 +1199,30 @@
aten::mul.Tensor 1 (INPUT), 3 (PARAMETER) -> 4 (ACTIVATION)
aten::mul.Tensor 1 (INPUT), 5 (PARAMETER) -> 6 (ACTIVATION)
aten::cat 4 (ACTIVATION), 6 (ACTIVATION) -> 7 (ACTIVATION)
- aten::binary_cross_entropy_with_logits 7 (ACTIVATION), 2 (INPUT) -> 13 (ACTIVATION)
+ aten::binary_cross_entropy_with_logits 7 (ACTIVATION), 2 (INPUT) -> 11 (ACTIVATION)
-- Backward ---------------------------------------------------------------------------------------------
- aten::ones_like 13 (ACTIVATION) -> 16 (ACTIVATION)
- aten::sigmoid 7 (ACTIVATION) -> 17 (TEMPORARY)
- aten::sub.Tensor 17 (TEMPORARY), 2 (INPUT) -> 18 (TEMPORARY)
- aten::mul.Tensor 18 (TEMPORARY), 16 (ACTIVATION) -> 19 (AUTOGRAD_DETAIL)
- aten::div_.Scalar 19 (AUTOGRAD_DETAIL) -> 19 (AUTOGRAD_DETAIL)
- aten::slice.Tensor 19 (AUTOGRAD_DETAIL) -> 19 (AUTOGRAD_DETAIL)
- aten::slice.Tensor 19 (AUTOGRAD_DETAIL) -> 19 (AUTOGRAD_DETAIL)
- aten::mul.Tensor 19 (AUTOGRAD_DETAIL), 1 (INPUT) -> 22 (AUTOGRAD_DETAIL)
+ aten::ones_like 11 (ACTIVATION) -> 14 (ACTIVATION)
+ aten::sigmoid 7 (ACTIVATION) -> 15 (TEMPORARY)
+ aten::sub.Tensor 15 (TEMPORARY), 2 (INPUT) -> 16 (TEMPORARY)
+ aten::mul.Tensor 16 (TEMPORARY), 14 (ACTIVATION) -> 17 (AUTOGRAD_DETAIL)
+ aten::div_.Scalar 17 (AUTOGRAD_DETAIL) -> 17 (AUTOGRAD_DETAIL)
+ aten::slice.Tensor 17 (AUTOGRAD_DETAIL) -> 17 (AUTOGRAD_DETAIL)
+ aten::slice.Tensor 17 (AUTOGRAD_DETAIL) -> 17 (AUTOGRAD_DETAIL)
+ aten::mul.Tensor 17 (AUTOGRAD_DETAIL), 1 (INPUT) -> 20 (AUTOGRAD_DETAIL)
+ aten::sum.dim_IntList 20 (AUTOGRAD_DETAIL) -> 21 (GRADIENT)
+ aten::view 21 (GRADIENT) -> 21 (GRADIENT)
+ aten::detach 21 (GRADIENT) -> 21 (GRADIENT)
+ aten::detach 21 (GRADIENT) -> 21 (GRADIENT)
+ aten::mul.Tensor 17 (AUTOGRAD_DETAIL), 1 (INPUT) -> 22 (AUTOGRAD_DETAIL)
aten::sum.dim_IntList 22 (AUTOGRAD_DETAIL) -> 23 (GRADIENT)
aten::view 23 (GRADIENT) -> 23 (GRADIENT)
aten::detach 23 (GRADIENT) -> 23 (GRADIENT)
aten::detach 23 (GRADIENT) -> 23 (GRADIENT)
- aten::mul.Tensor 19 (AUTOGRAD_DETAIL), 1 (INPUT) -> 24 (AUTOGRAD_DETAIL)
- aten::sum.dim_IntList 24 (AUTOGRAD_DETAIL) -> 25 (GRADIENT)
- aten::view 25 (GRADIENT) -> 25 (GRADIENT)
- aten::detach 25 (GRADIENT) -> 25 (GRADIENT)
- aten::detach 25 (GRADIENT) -> 25 (GRADIENT)
-- Optimizer --------------------------------------------------------------------------------------------
- aten::add_.Tensor 3 (PARAMETER), 25 (GRADIENT) -> 3 (PARAMETER)
- aten::add_.Tensor 5 (PARAMETER), 23 (GRADIENT) -> 5 (PARAMETER)""",
+ aten::add_.Tensor 3 (PARAMETER), 23 (GRADIENT) -> 3 (PARAMETER)
+ aten::add_.Tensor 5 (PARAMETER), 21 (GRADIENT) -> 5 (PARAMETER)""",
)
def test_categories_e2e_simple_module_fwd(self) -> None:
diff --git a/torch/_decomp/decompositions.py b/torch/_decomp/decompositions.py
index 54a6fe8..4e98f47 100644
--- a/torch/_decomp/decompositions.py
+++ b/torch/_decomp/decompositions.py
@@ -3821,18 +3821,11 @@
def binary_cross_entropy_with_logits(
self, target, weight=None, pos_weight=None, reduction=Reduction.MEAN.value
):
- max_val = (-self).clamp_min(0)
if pos_weight is not None:
log_weight = (pos_weight - 1) * target + 1
- loss = (1 - target) * self + log_weight * (
- ((-max_val).exp() + (-self - max_val).exp()).log() + max_val
- )
+ loss = (1 - target) * self - (log_weight * F.logsigmoid(self))
else:
- loss = (
- (1 - target) * self
- + max_val
- + ((-max_val).exp() + (-self - max_val).exp()).log()
- )
+ loss = (1 - target) * self - F.logsigmoid(self)
if weight is not None:
loss = loss * weight