Port softplus activation to Aten(CPU+CUDA) (#30504)

Summary:
VitalyFedyunin, This PR is about port Softplus activation to Aten:
**Test script:**
```
import torch
import torch.nn as nn
import time

torch.manual_seed(0)
def _time():
    if torch.cuda.is_available():
        torch.cuda.synchronize()
    return time.time()

device = "cpu"
m = nn.Softplus()
if torch.cuda.is_available():
    device = "cuda"
    m = m.cuda()

#warm up
for n in [100, 10000]:
    input = torch.randn(128, n, requires_grad=True, device=device)
    grad_output = torch.ones(128, n, device=device)
    for i in range(1000):
        output = m(input)
        output.backward(grad_output)

for n in [100, 10000]:
    input = torch.randn(128, n, requires_grad=True, device=device)
    grad_output = torch.ones(128, n, device=device)
    fwd_t = 0
    bwd_t = 0
    for i in range(10000):
        t1 = _time()
        output = m(input)
        t2 = _time()
        output.backward(grad_output)
        t3 = _time()
        fwd_t = fwd_t + (t2 -t1)
        bwd_t = bwd_t + (t3 - t2)
    fwd_avg = fwd_t / 10000 * 1000
    bwd_avg = bwd_t / 10000 * 1000
    print("input size(128, %d) forward time is %.2f (ms); backwad avg time is %.2f (ms)."
          % (n, fwd_avg, bwd_avg))
```
Test Device: CPU: skx-8180, GPU: Tesla P40.
Perfromance:
Before:
```
GPU:
input size(128, 100) forward time is 0.06 (ms); backwad avg time is 0.12 (ms).
input size(128, 10000) forward time is 0.06 (ms); backwad avg time is 0.18 (ms).
CPU:
input size(128, 100) forward time is 1.16 (ms); backwad avg time is 0.69 (ms).
input size(128, 10000) forward time is 60.19 (ms); backwad avg time is 31.86 (ms).
```
After:
```
GPU:
input size(128, 100) forward time is 0.05 (ms); backwad avg time is 0.11 (ms).
input size(128, 10000) forward time is 0.06 (ms); backwad avg time is 0.17 (ms).
CPU:
input size(128, 100) forward time is 0.43 (ms); backwad avg time is 0.16 (ms).
input size(128, 10000) forward time is 1.65 (ms); backwad avg time is 0.83 (ms).
```
`OMP_NUM_THREADS=1:`
```
Before:
input size(128, 100) forward time is 0.53 (ms); backwad avg time is 0.28 (ms).
input size(128, 10000) forward time is 51.33 (ms); backwad avg time is 25.48 (ms).
After:
input size(128, 100) forward time is 0.44 (ms); backwad avg time is 0.16 (ms).
input size(128, 10000) forward time is 42.05 (ms); backwad avg time is 13.97 (ms).
```

Fix https://github.com/pytorch/pytorch/issues/24633, https://github.com/pytorch/pytorch/issues/24634, https://github.com/pytorch/pytorch/issues/24766, https://github.com/pytorch/pytorch/issues/24767.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/30504

Differential Revision: D19274913

Pulled By: ezyang

fbshipit-source-id: 21b29e8459dcba5a040cc68333887b45a858328e
diff --git a/aten/src/ATen/native/Activation.cpp b/aten/src/ATen/native/Activation.cpp
index 9f0b524..95e5f55 100644
--- a/aten/src/ATen/native/Activation.cpp
+++ b/aten/src/ATen/native/Activation.cpp
@@ -15,6 +15,8 @@
 
 DEFINE_DISPATCH(elu_stub);
 DEFINE_DISPATCH(elu_backward_stub);
+DEFINE_DISPATCH(softplus_stub);
+DEFINE_DISPATCH(softplus_backward_stub);
 DEFINE_DISPATCH(threshold_stub);
 DEFINE_DISPATCH(hardtanh_backward_stub);
 DEFINE_DISPATCH(hardshrink_stub);
@@ -251,6 +253,43 @@
   return at::rrelu_with_noise_(self, at::empty_like(self, LEGACY_CONTIGUOUS_MEMORY_FORMAT), lower, upper, training, generator);
 }
 
+Tensor & softplus_out(Tensor& result, const Tensor& self, Scalar beta, Scalar threshold) {
+  auto iter = TensorIterator::unary_op(result, self);
+  softplus_stub(iter.device_type(), iter, beta, threshold);
+  return result;
+}
+
+Tensor softplus(const Tensor& self, Scalar beta, Scalar threshold) {
+  Tensor result;
+  auto iter = TensorIterator::unary_op(result, self);
+  softplus_stub(iter.device_type(), iter, beta, threshold);
+  return iter.output();
+}
+
+Tensor & softplus_backward_out(
+    Tensor& grad_input,
+    const Tensor& grad_output,
+    const Tensor& self,
+    Scalar beta,
+    Scalar threshold,
+    const Tensor& output) {
+  auto iter = TensorIterator::binary_op(grad_input, grad_output, output);
+  softplus_backward_stub(iter.device_type(), iter, beta, threshold);
+  return grad_input;
+}
+
+Tensor softplus_backward(
+    const Tensor& grad_output,
+    const Tensor& self,
+    Scalar beta,
+    Scalar threshold,
+    const Tensor& output) {
+  Tensor grad_input;
+  auto iter = TensorIterator::binary_op(grad_input, grad_output, output);
+  softplus_backward_stub(iter.device_type(), iter, beta, threshold);
+  return iter.output();
+}
+
 // computes `result = self <= threshold ? value : other`
 // other is `self` in threshold() and `grad` in threshold_backward()
 static Tensor threshold_out(
diff --git a/aten/src/ATen/native/Activation.h b/aten/src/ATen/native/Activation.h
index 77e29da..bf407cb 100644
--- a/aten/src/ATen/native/Activation.h
+++ b/aten/src/ATen/native/Activation.h
@@ -12,6 +12,8 @@
 
 using activation_fn = void (*)(TensorIterator&);
 using activation_backward_fn = void (*)(TensorIterator&);
+using softplus_fn = void (*)(TensorIterator&, Scalar, Scalar);
+using softplus_backward_fn = void (*)(TensorIterator&, Scalar, Scalar);
 using threshold_fn = void (*)(TensorIterator&, Scalar, Scalar);
 using hardtanh_backward_fn = void (*)(TensorIterator&, Scalar, Scalar);
 using shrink_fn = void (*)(TensorIterator&, Scalar);
@@ -22,6 +24,8 @@
 
 DECLARE_DISPATCH(elu_fn, elu_stub);
 DECLARE_DISPATCH(elu_fn, elu_backward_stub);
+DECLARE_DISPATCH(softplus_fn, softplus_stub);
+DECLARE_DISPATCH(softplus_backward_fn, softplus_backward_stub);
 DECLARE_DISPATCH(threshold_fn, threshold_stub);
 DECLARE_DISPATCH(activation_fn, GeluKernel);
 DECLARE_DISPATCH(activation_backward_fn, GeluBackwardKernel);
diff --git a/aten/src/ATen/native/cpu/Activation.cpp b/aten/src/ATen/native/cpu/Activation.cpp
index aa9aade..6262373 100644
--- a/aten/src/ATen/native/cpu/Activation.cpp
+++ b/aten/src/ATen/native/cpu/Activation.cpp
@@ -311,6 +311,28 @@
   });
 }
 
+void softplus_kernel(TensorIterator& iter, Scalar beta_, Scalar threshold_) {
+  AT_DISPATCH_FLOATING_TYPES_AND_HALF(iter.dtype(), "softplus_cpu", [&]() {
+    auto beta = beta_.to<scalar_t>();
+    auto threshold = threshold_.to<scalar_t>();
+    cpu_kernel(iter, [=](scalar_t a) -> scalar_t {
+      return (a * beta) > threshold ? a
+          : static_cast<scalar_t>(std::log1p(std::exp(a * beta))) / beta;
+    });
+  });
+}
+
+void softplus_backward_kernel(TensorIterator& iter, Scalar beta_, Scalar threshold_) {
+  AT_DISPATCH_FLOATING_TYPES_AND_HALF(iter.dtype(), "softplus_backward_cpu", [&]() {
+    auto beta = beta_.to<scalar_t>();
+    auto threshold = threshold_.to<scalar_t>();
+    cpu_kernel(iter, [=](scalar_t a, scalar_t b) -> scalar_t {
+      scalar_t z = std::exp(b * beta);
+      return (b * beta) > threshold ? a : a * (z - scalar_t(1.)) / z;
+    });
+  });
+}
+
 } // namespace
 
 REGISTER_DISPATCH(threshold_stub, &threshold_kernel);
@@ -324,6 +346,8 @@
 REGISTER_DISPATCH(shrink_backward_stub, &shrink_backward_kernel);
 REGISTER_DISPATCH(leaky_relu_stub, &leaky_relu_kernel);
 REGISTER_DISPATCH(leaky_relu_backward_stub, &leaky_relu_backward_kernel);
+REGISTER_DISPATCH(softplus_stub, &softplus_kernel);
+REGISTER_DISPATCH(softplus_backward_stub, &softplus_backward_kernel);
 
 } // namespace native
 } // namespace at
diff --git a/aten/src/ATen/native/cuda/Activation.cu b/aten/src/ATen/native/cuda/Activation.cu
index 5a12ffd..f20b83a 100644
--- a/aten/src/ATen/native/cuda/Activation.cu
+++ b/aten/src/ATen/native/cuda/Activation.cu
@@ -280,6 +280,27 @@
   });
 }
 
+void softplus_kernel(TensorIterator& iter, Scalar beta_, Scalar threshold_) {
+  AT_DISPATCH_FLOATING_TYPES_AND_HALF(iter.dtype(), "softplus_cuda", [&]() {
+    auto beta = beta_.to<scalar_t>();
+    auto threshold = threshold_.to<scalar_t>();
+    gpu_kernel(iter, [beta, threshold]GPU_LAMBDA(scalar_t a) -> scalar_t {
+      return (a * beta) > threshold ? a : static_cast<scalar_t>(::log1p(std::exp(a * beta))) / beta;
+    });
+  });
+}
+
+void softplus_backward_kernel(TensorIterator& iter, Scalar beta_, Scalar threshold_) {
+  AT_DISPATCH_FLOATING_TYPES_AND_HALF(iter.dtype(), "softplus_backward_cuda", [&]() {
+    auto beta = beta_.to<scalar_t>();
+    auto threshold = threshold_.to<scalar_t>();
+    gpu_kernel(iter, [beta, threshold]GPU_LAMBDA(scalar_t a, scalar_t b) -> scalar_t {
+      scalar_t z = std::exp(b * beta);
+      return (b * beta) > threshold ? a : a * (z - scalar_t(1.)) / z;
+    });
+  });
+}
+
 template <typename scalar_t>
 void threshold_kernel_impl(TensorIterator& iter, scalar_t threshold, scalar_t value) {
   gpu_kernel_with_scalars(iter, [=]GPU_LAMBDA(scalar_t x, scalar_t other) -> scalar_t {
@@ -418,5 +439,7 @@
 REGISTER_DISPATCH(elu_backward_stub, &elu_backward_kernel);
 REGISTER_DISPATCH(leaky_relu_stub, &leaky_relu_kernel);
 REGISTER_DISPATCH(leaky_relu_backward_stub, &leaky_relu_backward_kernel);
+REGISTER_DISPATCH(softplus_stub, &softplus_kernel);
+REGISTER_DISPATCH(softplus_backward_stub, &softplus_backward_kernel);
 
 }}  // namespace at::native
diff --git a/aten/src/ATen/native/native_functions.yaml b/aten/src/ATen/native/native_functions.yaml
index f598b49..e84fea5 100644
--- a/aten/src/ATen/native/native_functions.yaml
+++ b/aten/src/ATen/native/native_functions.yaml
@@ -5618,29 +5618,20 @@
 
 - func: softplus.out(Tensor self, Scalar beta=1, Scalar threshold=20, *, Tensor(a!) out) -> Tensor(a!)
   python_module: nn
-  dispatch:
-    CPU: legacy::cpu::_thnn_softplus_forward_out
-    CUDA: legacy::cuda::_thnn_softplus_forward_out
 
 - func: softplus(Tensor self, Scalar beta=1, Scalar threshold=20) -> Tensor
   use_c10_dispatcher: full
   python_module: nn
-  dispatch:
-    CPU: legacy::cpu::_thnn_softplus_forward
-    CUDA: legacy::cuda::_thnn_softplus_forward
 
 - func: softplus_backward.grad_input(Tensor grad_output, Tensor self, Scalar beta, Scalar threshold, Tensor output, *, Tensor(a!) grad_input) -> Tensor(a!)
   python_module: nn
   dispatch:
-    CPU: legacy::cpu::_thnn_softplus_backward_out
-    CUDA: legacy::cuda::_thnn_softplus_backward_out
+    CPU: softplus_backward_out
+    CUDA: softplus_backward_out
 
 - func: softplus_backward(Tensor grad_output, Tensor self, Scalar beta, Scalar threshold, Tensor output) -> Tensor
   use_c10_dispatcher: full
   python_module: nn
-  dispatch:
-    CPU: legacy::cpu::_thnn_softplus_backward
-    CUDA: legacy::cuda::_thnn_softplus_backward
 
 - func: softshrink.out(Tensor self, Scalar lambd=0.5, *, Tensor(a!) out) -> Tensor(a!)
   python_module: nn
diff --git a/aten/src/ATen/nn.yaml b/aten/src/ATen/nn.yaml
index e394549..fc9769b 100644
--- a/aten/src/ATen/nn.yaml
+++ b/aten/src/ATen/nn.yaml
@@ -47,9 +47,6 @@
   cname: RReLU
   has_inplace: True
 
-- name: _thnn_softplus(Tensor self, Scalar beta, Scalar threshold)
-  cname: SoftPlus
-
 # Convolutions
 
 - name: _thnn_conv2d(Tensor self, Tensor weight, IntArrayRef[2] kernel_size, Tensor? bias, IntArrayRef[2] stride, IntArrayRef[2] padding)
diff --git a/aten/src/THCUNN/CMakeLists.txt b/aten/src/THCUNN/CMakeLists.txt
index 79b1f38..67e3402 100644
--- a/aten/src/THCUNN/CMakeLists.txt
+++ b/aten/src/THCUNN/CMakeLists.txt
@@ -6,7 +6,6 @@
 ${CMAKE_CURRENT_SOURCE_DIR}/MultiLabelMarginCriterion.cu
 ${CMAKE_CURRENT_SOURCE_DIR}/MultiMarginCriterion.cu
 ${CMAKE_CURRENT_SOURCE_DIR}/RReLU.cu
-${CMAKE_CURRENT_SOURCE_DIR}/SoftPlus.cu
 ${CMAKE_CURRENT_SOURCE_DIR}/SpatialClassNLLCriterion.cu
 ${CMAKE_CURRENT_SOURCE_DIR}/SpatialConvolutionMM.cu
 ${CMAKE_CURRENT_SOURCE_DIR}/SpatialDepthwiseConvolution.cu
diff --git a/aten/src/THCUNN/SoftPlus.cu b/aten/src/THCUNN/SoftPlus.cu
deleted file mode 100644
index c8a13bc..0000000
--- a/aten/src/THCUNN/SoftPlus.cu
+++ /dev/null
@@ -1,43 +0,0 @@
-#include <THCUNN/THCUNN.h>
-#include <TH/THHalf.h>
-#include <THCUNN/THCHalfAutoNumerics.cuh>
-#include <THC/THCApply.cuh>
-
-template <typename T>
-struct softPlusupdateOutput_functor
-{
-  const T threshold;
-  const T beta;
-
-  softPlusupdateOutput_functor(T threshold_, T beta_)
-    : threshold(threshold_)
-    , beta(beta_)
-  {}
-
-  __device__ void operator()(T *output, const T *input) const {
-    T betain = beta * (*input);
-    *output = ((betain) > threshold) ? *input : (1/beta) * static_cast<T>(log1p(exp(betain)));
-  }
-};
-
-template <typename T>
-struct softPlusupdateGradInput_functor
-{
-  const T threshold;
-  const T beta;
-
-  softPlusupdateGradInput_functor(T threshold_, T beta_)
-    : threshold(threshold_)
-    , beta(beta_)
-  {}
-
-  __device__ void operator()(T *gradInput, const T *output, const T *gradOutput) const
-  {
-    T betaout = beta * (*output);
-    T exp_bo = exp(betaout);
-    *gradInput = ((betaout) > threshold) ? *gradOutput : *gradOutput * (exp_bo - 1) / exp_bo;
-  }
-};
-
-#include <THCUNN/generic/SoftPlus.cu>
-#include <THC/THCGenerateFloatTypes.h>
diff --git a/aten/src/THCUNN/generic/SoftPlus.cu b/aten/src/THCUNN/generic/SoftPlus.cu
deleted file mode 100644
index 913a7a1..0000000
--- a/aten/src/THCUNN/generic/SoftPlus.cu
+++ /dev/null
@@ -1,38 +0,0 @@
-#ifndef THC_GENERIC_FILE
-#define THC_GENERIC_FILE "THCUNN/generic/SoftPlus.cu"
-#else
-
-#include <THCUNN/common.h>
-
-void THNN_(SoftPlus_updateOutput)(
-           THCState *state,
-           THCTensor *input,
-           THCTensor *output,
-           accreal beta_,
-           accreal threshold_)
-{
-  scalar_t beta = ScalarConvert<accreal, scalar_t>::to(beta_);
-  scalar_t threshold = ScalarConvert<accreal, scalar_t>::to(threshold_);
-  THCUNN_assertSameGPU(state, 2, input, output);
-  THCTensor_(resizeAs)(state, output, input);
-  THC_pointwiseApply2<scalar_t, scalar_t>(state, output, input, softPlusupdateOutput_functor<scalar_t>(threshold, beta));
-}
-
-void THNN_(SoftPlus_updateGradInput)(
-           THCState *state,
-           THCTensor *input,
-           THCTensor *gradOutput,
-           THCTensor *gradInput,
-           THCTensor *output,
-           accreal beta_,
-           accreal threshold_)
-{
-  scalar_t beta = ScalarConvert<accreal, scalar_t>::to(beta_);
-  scalar_t threshold = ScalarConvert<accreal, scalar_t>::to(threshold_);
-  THCUNN_check_nElement(state, input, gradOutput);
-  THCUNN_assertSameGPU(state, 4, input, output, gradOutput, gradInput);
-  THCTensor_(resizeAs)(state, gradInput, output);
-  THC_pointwiseApply3<scalar_t, scalar_t, scalar_t>(state, gradInput, output, gradOutput, softPlusupdateGradInput_functor<scalar_t>(threshold, beta));
-}
-
-#endif
diff --git a/aten/src/THCUNN/generic/THCUNN.h b/aten/src/THCUNN/generic/THCUNN.h
index 7f3d31b..215f34b 100644
--- a/aten/src/THCUNN/generic/THCUNN.h
+++ b/aten/src/THCUNN/generic/THCUNN.h
@@ -217,32 +217,4 @@
                   double upper,
                   bool train,
                   bool inplace);
-
-THC_API void THNN_(SoftPlus_updateOutput)(
-                  THCState *state,
-                  THCTensor *input,
-                  THCTensor *output,
-                  accreal beta,
-                  accreal threshold);
-
-THC_API void THNN_(SoftPlus_updateGradInput)(
-                  THCState *state,
-                  THCTensor *input,
-                  THCTensor *gradOutput,
-                  THCTensor *gradInput,
-                  THCTensor *output,
-                  accreal beta,
-                  accreal threshold);
-
-THC_API void THNN_(Tanh_updateOutput)(
-                  THCState *state,
-                  THCTensor *input,
-                  THCTensor *output);
-
-THC_API void THNN_(Tanh_updateGradInput)(
-                  THCState *state,
-                  THCTensor *gradOutput,
-                  THCTensor *gradInput,
-                  THCTensor *output);
-
 #endif
diff --git a/aten/src/THNN/generic/SoftPlus.c b/aten/src/THNN/generic/SoftPlus.c
deleted file mode 100644
index a880b46..0000000
--- a/aten/src/THNN/generic/SoftPlus.c
+++ /dev/null
@@ -1,49 +0,0 @@
-#ifndef TH_GENERIC_FILE
-#define TH_GENERIC_FILE "THNN/generic/SoftPlus.c"
-#else
-
-#include <c10/util/math_compat.h>
-
-void THNN_(SoftPlus_updateOutput)(
-          THNNState *state,
-          THTensor *input,
-          THTensor *output,
-          accreal beta_,
-          accreal threshold_)
-{
-  scalar_t beta = TH_CONVERT_ACCREAL_TO_REAL(beta_);
-  scalar_t threshold = TH_CONVERT_ACCREAL_TO_REAL(threshold_);
-  THTensor_(resizeAs)(output, input);
-
-  // f(x) = 1/beta * log(1 + exp(beta * x))
-  TH_TENSOR_APPLY2(scalar_t, output, scalar_t, input,               \
-    *output_data = (*input_data * beta) > threshold ? *input_data : std::log1p(exp(*input_data * beta)) / beta;
-  );
-}
-
-void THNN_(SoftPlus_updateGradInput)(
-          THNNState *state,
-          THTensor *input,
-          THTensor *gradOutput,
-          THTensor *gradInput,
-          THTensor *output,
-          accreal beta_,
-          accreal threshold_)
-{
-  scalar_t beta = TH_CONVERT_ACCREAL_TO_REAL(beta_);
-  scalar_t threshold = TH_CONVERT_ACCREAL_TO_REAL(threshold_);
-  THNN_CHECK_NELEMENT(input, gradOutput);
-  THTensor_(resizeAs)(gradInput, output);
-
-  // d/dx[log(1+exp(k*x))/k] = exp(kx) / (exp(kx) + 1)
-  // SINCE
-  // y = (1/k)*log(1+exp(k*x)) --> x = (1/k)*log(exp(k*y)-1)
-  // THEREFORE:
-  // d/dx(f(x)) = (exp(k*y) - 1) / exp(k*y)
-  TH_TENSOR_APPLY3(scalar_t, gradInput, scalar_t, gradOutput, scalar_t, output,
-    scalar_t z = exp(*output_data * beta);
-    *gradInput_data = (*output_data * beta) > threshold ? *gradOutput_data : *gradOutput_data * (z - 1.)/z;
-  );
-}
-
-#endif
diff --git a/aten/src/THNN/generic/THNN.h b/aten/src/THNN/generic/THNN.h
index 825091d..573a713 100644
--- a/aten/src/THNN/generic/THNN.h
+++ b/aten/src/THNN/generic/THNN.h
@@ -48,19 +48,5 @@
           THTensor *gradInput,         // [OUT] gradient w.r.t. input
           THTensor *buffer);           // [BUFFER]
 
-TH_API void THNN_(SoftPlus_updateOutput)(
-          THNNState *state,
-          THTensor *input, THTensor *output,
-          accreal beta,
-          accreal threshold);
-TH_API void THNN_(SoftPlus_updateGradInput)(
-          THNNState *state,
-          THTensor *input,
-          THTensor *gradOutput,
-          THTensor *gradInput,
-          THTensor *output,
-          accreal beta,
-          accreal threshold);
-
 #endif
 #endif
diff --git a/aten/src/THNN/init.cpp b/aten/src/THNN/init.cpp
index 234e1ed..90fb2a7 100644
--- a/aten/src/THNN/init.cpp
+++ b/aten/src/THNN/init.cpp
@@ -69,6 +69,3 @@
 
 #include <THNN/generic/LogSigmoid.c>
 #include <TH/THGenerateFloatTypes.h>
-
-#include <THNN/generic/SoftPlus.c>
-#include <TH/THGenerateFloatTypes.h>