Port softplus activation to Aten(CPU+CUDA) (#30504)
Summary:
VitalyFedyunin, This PR is about port Softplus activation to Aten:
**Test script:**
```
import torch
import torch.nn as nn
import time
torch.manual_seed(0)
def _time():
if torch.cuda.is_available():
torch.cuda.synchronize()
return time.time()
device = "cpu"
m = nn.Softplus()
if torch.cuda.is_available():
device = "cuda"
m = m.cuda()
#warm up
for n in [100, 10000]:
input = torch.randn(128, n, requires_grad=True, device=device)
grad_output = torch.ones(128, n, device=device)
for i in range(1000):
output = m(input)
output.backward(grad_output)
for n in [100, 10000]:
input = torch.randn(128, n, requires_grad=True, device=device)
grad_output = torch.ones(128, n, device=device)
fwd_t = 0
bwd_t = 0
for i in range(10000):
t1 = _time()
output = m(input)
t2 = _time()
output.backward(grad_output)
t3 = _time()
fwd_t = fwd_t + (t2 -t1)
bwd_t = bwd_t + (t3 - t2)
fwd_avg = fwd_t / 10000 * 1000
bwd_avg = bwd_t / 10000 * 1000
print("input size(128, %d) forward time is %.2f (ms); backwad avg time is %.2f (ms)."
% (n, fwd_avg, bwd_avg))
```
Test Device: CPU: skx-8180, GPU: Tesla P40.
Perfromance:
Before:
```
GPU:
input size(128, 100) forward time is 0.06 (ms); backwad avg time is 0.12 (ms).
input size(128, 10000) forward time is 0.06 (ms); backwad avg time is 0.18 (ms).
CPU:
input size(128, 100) forward time is 1.16 (ms); backwad avg time is 0.69 (ms).
input size(128, 10000) forward time is 60.19 (ms); backwad avg time is 31.86 (ms).
```
After:
```
GPU:
input size(128, 100) forward time is 0.05 (ms); backwad avg time is 0.11 (ms).
input size(128, 10000) forward time is 0.06 (ms); backwad avg time is 0.17 (ms).
CPU:
input size(128, 100) forward time is 0.43 (ms); backwad avg time is 0.16 (ms).
input size(128, 10000) forward time is 1.65 (ms); backwad avg time is 0.83 (ms).
```
`OMP_NUM_THREADS=1:`
```
Before:
input size(128, 100) forward time is 0.53 (ms); backwad avg time is 0.28 (ms).
input size(128, 10000) forward time is 51.33 (ms); backwad avg time is 25.48 (ms).
After:
input size(128, 100) forward time is 0.44 (ms); backwad avg time is 0.16 (ms).
input size(128, 10000) forward time is 42.05 (ms); backwad avg time is 13.97 (ms).
```
Fix https://github.com/pytorch/pytorch/issues/24633, https://github.com/pytorch/pytorch/issues/24634, https://github.com/pytorch/pytorch/issues/24766, https://github.com/pytorch/pytorch/issues/24767.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/30504
Differential Revision: D19274913
Pulled By: ezyang
fbshipit-source-id: 21b29e8459dcba5a040cc68333887b45a858328e
diff --git a/aten/src/ATen/native/Activation.cpp b/aten/src/ATen/native/Activation.cpp
index 9f0b524..95e5f55 100644
--- a/aten/src/ATen/native/Activation.cpp
+++ b/aten/src/ATen/native/Activation.cpp
@@ -15,6 +15,8 @@
DEFINE_DISPATCH(elu_stub);
DEFINE_DISPATCH(elu_backward_stub);
+DEFINE_DISPATCH(softplus_stub);
+DEFINE_DISPATCH(softplus_backward_stub);
DEFINE_DISPATCH(threshold_stub);
DEFINE_DISPATCH(hardtanh_backward_stub);
DEFINE_DISPATCH(hardshrink_stub);
@@ -251,6 +253,43 @@
return at::rrelu_with_noise_(self, at::empty_like(self, LEGACY_CONTIGUOUS_MEMORY_FORMAT), lower, upper, training, generator);
}
+Tensor & softplus_out(Tensor& result, const Tensor& self, Scalar beta, Scalar threshold) {
+ auto iter = TensorIterator::unary_op(result, self);
+ softplus_stub(iter.device_type(), iter, beta, threshold);
+ return result;
+}
+
+Tensor softplus(const Tensor& self, Scalar beta, Scalar threshold) {
+ Tensor result;
+ auto iter = TensorIterator::unary_op(result, self);
+ softplus_stub(iter.device_type(), iter, beta, threshold);
+ return iter.output();
+}
+
+Tensor & softplus_backward_out(
+ Tensor& grad_input,
+ const Tensor& grad_output,
+ const Tensor& self,
+ Scalar beta,
+ Scalar threshold,
+ const Tensor& output) {
+ auto iter = TensorIterator::binary_op(grad_input, grad_output, output);
+ softplus_backward_stub(iter.device_type(), iter, beta, threshold);
+ return grad_input;
+}
+
+Tensor softplus_backward(
+ const Tensor& grad_output,
+ const Tensor& self,
+ Scalar beta,
+ Scalar threshold,
+ const Tensor& output) {
+ Tensor grad_input;
+ auto iter = TensorIterator::binary_op(grad_input, grad_output, output);
+ softplus_backward_stub(iter.device_type(), iter, beta, threshold);
+ return iter.output();
+}
+
// computes `result = self <= threshold ? value : other`
// other is `self` in threshold() and `grad` in threshold_backward()
static Tensor threshold_out(
diff --git a/aten/src/ATen/native/Activation.h b/aten/src/ATen/native/Activation.h
index 77e29da..bf407cb 100644
--- a/aten/src/ATen/native/Activation.h
+++ b/aten/src/ATen/native/Activation.h
@@ -12,6 +12,8 @@
using activation_fn = void (*)(TensorIterator&);
using activation_backward_fn = void (*)(TensorIterator&);
+using softplus_fn = void (*)(TensorIterator&, Scalar, Scalar);
+using softplus_backward_fn = void (*)(TensorIterator&, Scalar, Scalar);
using threshold_fn = void (*)(TensorIterator&, Scalar, Scalar);
using hardtanh_backward_fn = void (*)(TensorIterator&, Scalar, Scalar);
using shrink_fn = void (*)(TensorIterator&, Scalar);
@@ -22,6 +24,8 @@
DECLARE_DISPATCH(elu_fn, elu_stub);
DECLARE_DISPATCH(elu_fn, elu_backward_stub);
+DECLARE_DISPATCH(softplus_fn, softplus_stub);
+DECLARE_DISPATCH(softplus_backward_fn, softplus_backward_stub);
DECLARE_DISPATCH(threshold_fn, threshold_stub);
DECLARE_DISPATCH(activation_fn, GeluKernel);
DECLARE_DISPATCH(activation_backward_fn, GeluBackwardKernel);
diff --git a/aten/src/ATen/native/cpu/Activation.cpp b/aten/src/ATen/native/cpu/Activation.cpp
index aa9aade..6262373 100644
--- a/aten/src/ATen/native/cpu/Activation.cpp
+++ b/aten/src/ATen/native/cpu/Activation.cpp
@@ -311,6 +311,28 @@
});
}
+void softplus_kernel(TensorIterator& iter, Scalar beta_, Scalar threshold_) {
+ AT_DISPATCH_FLOATING_TYPES_AND_HALF(iter.dtype(), "softplus_cpu", [&]() {
+ auto beta = beta_.to<scalar_t>();
+ auto threshold = threshold_.to<scalar_t>();
+ cpu_kernel(iter, [=](scalar_t a) -> scalar_t {
+ return (a * beta) > threshold ? a
+ : static_cast<scalar_t>(std::log1p(std::exp(a * beta))) / beta;
+ });
+ });
+}
+
+void softplus_backward_kernel(TensorIterator& iter, Scalar beta_, Scalar threshold_) {
+ AT_DISPATCH_FLOATING_TYPES_AND_HALF(iter.dtype(), "softplus_backward_cpu", [&]() {
+ auto beta = beta_.to<scalar_t>();
+ auto threshold = threshold_.to<scalar_t>();
+ cpu_kernel(iter, [=](scalar_t a, scalar_t b) -> scalar_t {
+ scalar_t z = std::exp(b * beta);
+ return (b * beta) > threshold ? a : a * (z - scalar_t(1.)) / z;
+ });
+ });
+}
+
} // namespace
REGISTER_DISPATCH(threshold_stub, &threshold_kernel);
@@ -324,6 +346,8 @@
REGISTER_DISPATCH(shrink_backward_stub, &shrink_backward_kernel);
REGISTER_DISPATCH(leaky_relu_stub, &leaky_relu_kernel);
REGISTER_DISPATCH(leaky_relu_backward_stub, &leaky_relu_backward_kernel);
+REGISTER_DISPATCH(softplus_stub, &softplus_kernel);
+REGISTER_DISPATCH(softplus_backward_stub, &softplus_backward_kernel);
} // namespace native
} // namespace at
diff --git a/aten/src/ATen/native/cuda/Activation.cu b/aten/src/ATen/native/cuda/Activation.cu
index 5a12ffd..f20b83a 100644
--- a/aten/src/ATen/native/cuda/Activation.cu
+++ b/aten/src/ATen/native/cuda/Activation.cu
@@ -280,6 +280,27 @@
});
}
+void softplus_kernel(TensorIterator& iter, Scalar beta_, Scalar threshold_) {
+ AT_DISPATCH_FLOATING_TYPES_AND_HALF(iter.dtype(), "softplus_cuda", [&]() {
+ auto beta = beta_.to<scalar_t>();
+ auto threshold = threshold_.to<scalar_t>();
+ gpu_kernel(iter, [beta, threshold]GPU_LAMBDA(scalar_t a) -> scalar_t {
+ return (a * beta) > threshold ? a : static_cast<scalar_t>(::log1p(std::exp(a * beta))) / beta;
+ });
+ });
+}
+
+void softplus_backward_kernel(TensorIterator& iter, Scalar beta_, Scalar threshold_) {
+ AT_DISPATCH_FLOATING_TYPES_AND_HALF(iter.dtype(), "softplus_backward_cuda", [&]() {
+ auto beta = beta_.to<scalar_t>();
+ auto threshold = threshold_.to<scalar_t>();
+ gpu_kernel(iter, [beta, threshold]GPU_LAMBDA(scalar_t a, scalar_t b) -> scalar_t {
+ scalar_t z = std::exp(b * beta);
+ return (b * beta) > threshold ? a : a * (z - scalar_t(1.)) / z;
+ });
+ });
+}
+
template <typename scalar_t>
void threshold_kernel_impl(TensorIterator& iter, scalar_t threshold, scalar_t value) {
gpu_kernel_with_scalars(iter, [=]GPU_LAMBDA(scalar_t x, scalar_t other) -> scalar_t {
@@ -418,5 +439,7 @@
REGISTER_DISPATCH(elu_backward_stub, &elu_backward_kernel);
REGISTER_DISPATCH(leaky_relu_stub, &leaky_relu_kernel);
REGISTER_DISPATCH(leaky_relu_backward_stub, &leaky_relu_backward_kernel);
+REGISTER_DISPATCH(softplus_stub, &softplus_kernel);
+REGISTER_DISPATCH(softplus_backward_stub, &softplus_backward_kernel);
}} // namespace at::native
diff --git a/aten/src/ATen/native/native_functions.yaml b/aten/src/ATen/native/native_functions.yaml
index f598b49..e84fea5 100644
--- a/aten/src/ATen/native/native_functions.yaml
+++ b/aten/src/ATen/native/native_functions.yaml
@@ -5618,29 +5618,20 @@
- func: softplus.out(Tensor self, Scalar beta=1, Scalar threshold=20, *, Tensor(a!) out) -> Tensor(a!)
python_module: nn
- dispatch:
- CPU: legacy::cpu::_thnn_softplus_forward_out
- CUDA: legacy::cuda::_thnn_softplus_forward_out
- func: softplus(Tensor self, Scalar beta=1, Scalar threshold=20) -> Tensor
use_c10_dispatcher: full
python_module: nn
- dispatch:
- CPU: legacy::cpu::_thnn_softplus_forward
- CUDA: legacy::cuda::_thnn_softplus_forward
- func: softplus_backward.grad_input(Tensor grad_output, Tensor self, Scalar beta, Scalar threshold, Tensor output, *, Tensor(a!) grad_input) -> Tensor(a!)
python_module: nn
dispatch:
- CPU: legacy::cpu::_thnn_softplus_backward_out
- CUDA: legacy::cuda::_thnn_softplus_backward_out
+ CPU: softplus_backward_out
+ CUDA: softplus_backward_out
- func: softplus_backward(Tensor grad_output, Tensor self, Scalar beta, Scalar threshold, Tensor output) -> Tensor
use_c10_dispatcher: full
python_module: nn
- dispatch:
- CPU: legacy::cpu::_thnn_softplus_backward
- CUDA: legacy::cuda::_thnn_softplus_backward
- func: softshrink.out(Tensor self, Scalar lambd=0.5, *, Tensor(a!) out) -> Tensor(a!)
python_module: nn
diff --git a/aten/src/ATen/nn.yaml b/aten/src/ATen/nn.yaml
index e394549..fc9769b 100644
--- a/aten/src/ATen/nn.yaml
+++ b/aten/src/ATen/nn.yaml
@@ -47,9 +47,6 @@
cname: RReLU
has_inplace: True
-- name: _thnn_softplus(Tensor self, Scalar beta, Scalar threshold)
- cname: SoftPlus
-
# Convolutions
- name: _thnn_conv2d(Tensor self, Tensor weight, IntArrayRef[2] kernel_size, Tensor? bias, IntArrayRef[2] stride, IntArrayRef[2] padding)
diff --git a/aten/src/THCUNN/CMakeLists.txt b/aten/src/THCUNN/CMakeLists.txt
index 79b1f38..67e3402 100644
--- a/aten/src/THCUNN/CMakeLists.txt
+++ b/aten/src/THCUNN/CMakeLists.txt
@@ -6,7 +6,6 @@
${CMAKE_CURRENT_SOURCE_DIR}/MultiLabelMarginCriterion.cu
${CMAKE_CURRENT_SOURCE_DIR}/MultiMarginCriterion.cu
${CMAKE_CURRENT_SOURCE_DIR}/RReLU.cu
-${CMAKE_CURRENT_SOURCE_DIR}/SoftPlus.cu
${CMAKE_CURRENT_SOURCE_DIR}/SpatialClassNLLCriterion.cu
${CMAKE_CURRENT_SOURCE_DIR}/SpatialConvolutionMM.cu
${CMAKE_CURRENT_SOURCE_DIR}/SpatialDepthwiseConvolution.cu
diff --git a/aten/src/THCUNN/SoftPlus.cu b/aten/src/THCUNN/SoftPlus.cu
deleted file mode 100644
index c8a13bc..0000000
--- a/aten/src/THCUNN/SoftPlus.cu
+++ /dev/null
@@ -1,43 +0,0 @@
-#include <THCUNN/THCUNN.h>
-#include <TH/THHalf.h>
-#include <THCUNN/THCHalfAutoNumerics.cuh>
-#include <THC/THCApply.cuh>
-
-template <typename T>
-struct softPlusupdateOutput_functor
-{
- const T threshold;
- const T beta;
-
- softPlusupdateOutput_functor(T threshold_, T beta_)
- : threshold(threshold_)
- , beta(beta_)
- {}
-
- __device__ void operator()(T *output, const T *input) const {
- T betain = beta * (*input);
- *output = ((betain) > threshold) ? *input : (1/beta) * static_cast<T>(log1p(exp(betain)));
- }
-};
-
-template <typename T>
-struct softPlusupdateGradInput_functor
-{
- const T threshold;
- const T beta;
-
- softPlusupdateGradInput_functor(T threshold_, T beta_)
- : threshold(threshold_)
- , beta(beta_)
- {}
-
- __device__ void operator()(T *gradInput, const T *output, const T *gradOutput) const
- {
- T betaout = beta * (*output);
- T exp_bo = exp(betaout);
- *gradInput = ((betaout) > threshold) ? *gradOutput : *gradOutput * (exp_bo - 1) / exp_bo;
- }
-};
-
-#include <THCUNN/generic/SoftPlus.cu>
-#include <THC/THCGenerateFloatTypes.h>
diff --git a/aten/src/THCUNN/generic/SoftPlus.cu b/aten/src/THCUNN/generic/SoftPlus.cu
deleted file mode 100644
index 913a7a1..0000000
--- a/aten/src/THCUNN/generic/SoftPlus.cu
+++ /dev/null
@@ -1,38 +0,0 @@
-#ifndef THC_GENERIC_FILE
-#define THC_GENERIC_FILE "THCUNN/generic/SoftPlus.cu"
-#else
-
-#include <THCUNN/common.h>
-
-void THNN_(SoftPlus_updateOutput)(
- THCState *state,
- THCTensor *input,
- THCTensor *output,
- accreal beta_,
- accreal threshold_)
-{
- scalar_t beta = ScalarConvert<accreal, scalar_t>::to(beta_);
- scalar_t threshold = ScalarConvert<accreal, scalar_t>::to(threshold_);
- THCUNN_assertSameGPU(state, 2, input, output);
- THCTensor_(resizeAs)(state, output, input);
- THC_pointwiseApply2<scalar_t, scalar_t>(state, output, input, softPlusupdateOutput_functor<scalar_t>(threshold, beta));
-}
-
-void THNN_(SoftPlus_updateGradInput)(
- THCState *state,
- THCTensor *input,
- THCTensor *gradOutput,
- THCTensor *gradInput,
- THCTensor *output,
- accreal beta_,
- accreal threshold_)
-{
- scalar_t beta = ScalarConvert<accreal, scalar_t>::to(beta_);
- scalar_t threshold = ScalarConvert<accreal, scalar_t>::to(threshold_);
- THCUNN_check_nElement(state, input, gradOutput);
- THCUNN_assertSameGPU(state, 4, input, output, gradOutput, gradInput);
- THCTensor_(resizeAs)(state, gradInput, output);
- THC_pointwiseApply3<scalar_t, scalar_t, scalar_t>(state, gradInput, output, gradOutput, softPlusupdateGradInput_functor<scalar_t>(threshold, beta));
-}
-
-#endif
diff --git a/aten/src/THCUNN/generic/THCUNN.h b/aten/src/THCUNN/generic/THCUNN.h
index 7f3d31b..215f34b 100644
--- a/aten/src/THCUNN/generic/THCUNN.h
+++ b/aten/src/THCUNN/generic/THCUNN.h
@@ -217,32 +217,4 @@
double upper,
bool train,
bool inplace);
-
-THC_API void THNN_(SoftPlus_updateOutput)(
- THCState *state,
- THCTensor *input,
- THCTensor *output,
- accreal beta,
- accreal threshold);
-
-THC_API void THNN_(SoftPlus_updateGradInput)(
- THCState *state,
- THCTensor *input,
- THCTensor *gradOutput,
- THCTensor *gradInput,
- THCTensor *output,
- accreal beta,
- accreal threshold);
-
-THC_API void THNN_(Tanh_updateOutput)(
- THCState *state,
- THCTensor *input,
- THCTensor *output);
-
-THC_API void THNN_(Tanh_updateGradInput)(
- THCState *state,
- THCTensor *gradOutput,
- THCTensor *gradInput,
- THCTensor *output);
-
#endif
diff --git a/aten/src/THNN/generic/SoftPlus.c b/aten/src/THNN/generic/SoftPlus.c
deleted file mode 100644
index a880b46..0000000
--- a/aten/src/THNN/generic/SoftPlus.c
+++ /dev/null
@@ -1,49 +0,0 @@
-#ifndef TH_GENERIC_FILE
-#define TH_GENERIC_FILE "THNN/generic/SoftPlus.c"
-#else
-
-#include <c10/util/math_compat.h>
-
-void THNN_(SoftPlus_updateOutput)(
- THNNState *state,
- THTensor *input,
- THTensor *output,
- accreal beta_,
- accreal threshold_)
-{
- scalar_t beta = TH_CONVERT_ACCREAL_TO_REAL(beta_);
- scalar_t threshold = TH_CONVERT_ACCREAL_TO_REAL(threshold_);
- THTensor_(resizeAs)(output, input);
-
- // f(x) = 1/beta * log(1 + exp(beta * x))
- TH_TENSOR_APPLY2(scalar_t, output, scalar_t, input, \
- *output_data = (*input_data * beta) > threshold ? *input_data : std::log1p(exp(*input_data * beta)) / beta;
- );
-}
-
-void THNN_(SoftPlus_updateGradInput)(
- THNNState *state,
- THTensor *input,
- THTensor *gradOutput,
- THTensor *gradInput,
- THTensor *output,
- accreal beta_,
- accreal threshold_)
-{
- scalar_t beta = TH_CONVERT_ACCREAL_TO_REAL(beta_);
- scalar_t threshold = TH_CONVERT_ACCREAL_TO_REAL(threshold_);
- THNN_CHECK_NELEMENT(input, gradOutput);
- THTensor_(resizeAs)(gradInput, output);
-
- // d/dx[log(1+exp(k*x))/k] = exp(kx) / (exp(kx) + 1)
- // SINCE
- // y = (1/k)*log(1+exp(k*x)) --> x = (1/k)*log(exp(k*y)-1)
- // THEREFORE:
- // d/dx(f(x)) = (exp(k*y) - 1) / exp(k*y)
- TH_TENSOR_APPLY3(scalar_t, gradInput, scalar_t, gradOutput, scalar_t, output,
- scalar_t z = exp(*output_data * beta);
- *gradInput_data = (*output_data * beta) > threshold ? *gradOutput_data : *gradOutput_data * (z - 1.)/z;
- );
-}
-
-#endif
diff --git a/aten/src/THNN/generic/THNN.h b/aten/src/THNN/generic/THNN.h
index 825091d..573a713 100644
--- a/aten/src/THNN/generic/THNN.h
+++ b/aten/src/THNN/generic/THNN.h
@@ -48,19 +48,5 @@
THTensor *gradInput, // [OUT] gradient w.r.t. input
THTensor *buffer); // [BUFFER]
-TH_API void THNN_(SoftPlus_updateOutput)(
- THNNState *state,
- THTensor *input, THTensor *output,
- accreal beta,
- accreal threshold);
-TH_API void THNN_(SoftPlus_updateGradInput)(
- THNNState *state,
- THTensor *input,
- THTensor *gradOutput,
- THTensor *gradInput,
- THTensor *output,
- accreal beta,
- accreal threshold);
-
#endif
#endif
diff --git a/aten/src/THNN/init.cpp b/aten/src/THNN/init.cpp
index 234e1ed..90fb2a7 100644
--- a/aten/src/THNN/init.cpp
+++ b/aten/src/THNN/init.cpp
@@ -69,6 +69,3 @@
#include <THNN/generic/LogSigmoid.c>
#include <TH/THGenerateFloatTypes.h>
-
-#include <THNN/generic/SoftPlus.c>
-#include <TH/THGenerateFloatTypes.h>