Optimize SiLU (Swish) op in PyTorch (#42976)

Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/42976

Optimize SiLU (Swish) op in PyTorch.

Some benchmark result

input = torch.rand(1024, 32768, dtype=torch.float, device="cpu")
forward: 221ms -> 133ms
backward: 600ms -> 170ms

input = torch.rand(1024, 32768, dtype=torch.double, device="cpu")
forward: 479ms -> 297ms
backward: 1438ms -> 387ms

input = torch.rand(8192, 32768, dtype=torch.float, device="cuda")
forward: 24.34ms -> 9.83ms
backward: 97.05ms -> 29.03ms

input = torch.rand(4096, 32768, dtype=torch.double, device="cuda")
forward: 44.24ms -> 30.15ms
backward: 126.21ms -> 49.68ms

Test Plan: buck test mode/dev-nosan //caffe2/test:nn -- "SiLU"

Reviewed By: houseroad

Differential Revision: D23093593

fbshipit-source-id: 1ba7b95d5926c4527216ed211a5ff1cefa3d3bfd
diff --git a/aten/src/ATen/native/Activation.cpp b/aten/src/ATen/native/Activation.cpp
index 0b12568..0e301df 100644
--- a/aten/src/ATen/native/Activation.cpp
+++ b/aten/src/ATen/native/Activation.cpp
@@ -30,6 +30,8 @@
 DEFINE_DISPATCH(shrink_backward_stub);
 DEFINE_DISPATCH(leaky_relu_stub);
 DEFINE_DISPATCH(leaky_relu_backward_stub);
+DEFINE_DISPATCH(silu_stub);
+DEFINE_DISPATCH(silu_backward_stub);
 
 Tensor hardtanh(const Tensor& self, Scalar min, Scalar max) {
   return at::clamp(self, min, max);
@@ -195,20 +197,31 @@
 }
 
 Tensor silu(const Tensor& self) {
-  return self * at::sigmoid(self);
+  Tensor result = at::empty({0}, self.options());
+  at::silu_out(result, self);
+  return result;
 }
 
 Tensor& silu_(Tensor& self) {
-  return self.mul_(at::sigmoid(self));
+  return at::silu_out(self, self);
 }
 
 Tensor& silu_out(Tensor& result, const Tensor& self) {
-  return at::mul_out(result, self, at::sigmoid(self));
+  TORCH_CHECK(
+      result.dtype() == self.dtype(),
+      "Output Tensor should have the same type as in Input Tensor.")
+  auto iter = TensorIterator::unary_op(result, self);
+  silu_stub(iter.device_type(), iter);
+  return result;
 }
 
-Tensor silu_backward(const Tensor& grad, const Tensor& self) {
-  auto self_sigmoid = at::sigmoid(self);
-  return grad * (self_sigmoid * (1 + self * (1 - self_sigmoid)));
+Tensor silu_backward(
+    const Tensor& grad_output,
+    const Tensor& input) {
+  Tensor grad_input = at::empty({0}, input.options());
+  auto iter = TensorIterator::binary_op(grad_input, grad_output, input);
+  silu_backward_stub(iter.device_type(), iter);
+  return grad_input;
 }
 
 template <typename scalar_t>
diff --git a/aten/src/ATen/native/Activation.h b/aten/src/ATen/native/Activation.h
index c557835..bebfa67 100644
--- a/aten/src/ATen/native/Activation.h
+++ b/aten/src/ATen/native/Activation.h
@@ -48,6 +48,8 @@
 DECLARE_DISPATCH(leaky_relu_backward_fn, leaky_relu_backward_stub);
 DECLARE_DISPATCH(activation_fn, glu_stub);
 DECLARE_DISPATCH(activation_backward_fn, glu_backward_stub);
+DECLARE_DISPATCH(activation_fn, silu_stub);
+DECLARE_DISPATCH(activation_backward_fn, silu_backward_stub);
 
 } // namespace native
 
diff --git a/aten/src/ATen/native/cpu/Activation.cpp b/aten/src/ATen/native/cpu/Activation.cpp
index c9a0faa..79de581 100644
--- a/aten/src/ATen/native/cpu/Activation.cpp
+++ b/aten/src/ATen/native/cpu/Activation.cpp
@@ -4,7 +4,8 @@
 
 #include <ATen/native/Activation.h>
 
-#include <math.h>
+#include <cmath>
+#include <functional>
 
 #include <ATen/ATen.h>
 #include <ATen/Config.h>
@@ -589,6 +590,40 @@
   });
 }
 
+void silu_kernel(TensorIterator& iter) {
+  AT_DISPATCH_FLOATING_AND_COMPLEX_TYPES_AND1(
+      kBFloat16, iter.dtype(), "silu_cpu", [&]() {
+        const Vec256<scalar_t> kOneVec(scalar_t(1));
+        cpu_kernel_vec(
+            iter,
+            [](scalar_t x) {
+              return x / (scalar_t(1) + std::exp(-x));
+            },
+            [kOneVec](Vec256<scalar_t> x_vec) {
+              return x_vec / (kOneVec + x_vec.neg().exp());
+            });
+      });
+}
+
+void silu_backward_kernel(TensorIterator& iter) {
+  AT_DISPATCH_FLOATING_AND_COMPLEX_TYPES_AND1(
+      kBFloat16, iter.dtype(), "silu_backward_cpu", [&]() {
+        const Vec256<scalar_t> kOneVec(scalar_t(1));
+        cpu_kernel_vec(
+            iter,
+            [](scalar_t dy, scalar_t x) {
+              const scalar_t sigmoid =
+                  scalar_t(1) / (scalar_t(1) + std::exp(-x));
+              return dy * sigmoid * (scalar_t(1) + x * (scalar_t(1) - sigmoid));
+            },
+            [kOneVec](Vec256<scalar_t> dy_vec, Vec256<scalar_t> x_vec) {
+              const Vec256<scalar_t> sigmoid =
+                  kOneVec / (kOneVec + x_vec.neg().exp());
+              return dy_vec * sigmoid * (kOneVec + x_vec * (kOneVec - sigmoid));
+            });
+      });
+}
+
 } // namespace
 
 REGISTER_DISPATCH(log_sigmoid_cpu_stub, &log_sigmoid_cpu_kernel);
@@ -612,6 +647,8 @@
 REGISTER_DISPATCH(softplus_backward_stub, &softplus_backward_kernel);
 REGISTER_DISPATCH(glu_stub, &glu_kernel);
 REGISTER_DISPATCH(glu_backward_stub, &glu_backward_kernel);
+REGISTER_DISPATCH(silu_stub, &silu_kernel);
+REGISTER_DISPATCH(silu_backward_stub, &silu_backward_kernel);
 
 } // namespace native
 } // namespace at
diff --git a/aten/src/ATen/native/cuda/Activation.cu b/aten/src/ATen/native/cuda/Activation.cu
index 209bf29..359fd5b 100644
--- a/aten/src/ATen/native/cuda/Activation.cu
+++ b/aten/src/ATen/native/cuda/Activation.cu
@@ -2,7 +2,9 @@
 
 #include <ATen/native/Activation.h>
 
-#include <math.h>
+#include <cmath>
+
+#include <thrust/tuple.h>
 
 #include <ATen/ATen.h>
 #include <ATen/AccumulateType.h>
@@ -15,10 +17,8 @@
 #include <ATen/native/cuda/Loops.cuh>
 #include <c10/cuda/CUDAMathCompat.h>
 
-#include <thrust/tuple.h>
-
-
-namespace at { namespace native {
+namespace at {
+namespace native {
 
 // -----------------------------------
 // prelu forward
@@ -478,6 +478,43 @@
   });
 }
 
+void silu_kernel(TensorIterator& iter) {
+  AT_DISPATCH_FLOATING_TYPES_AND2(
+      at::ScalarType::Half,
+      at::ScalarType::BFloat16,
+      iter.dtype(),
+      "silu_cuda",
+      [&]() {
+        gpu_kernel(
+            iter,
+            [] GPU_LAMBDA(scalar_t x) -> scalar_t {
+              using T_ACC = acc_type<scalar_t, true>;
+              const T_ACC x_acc = static_cast<T_ACC>(x);
+              return x_acc / (T_ACC(1) + c10::cuda::compat::exp(-x_acc));
+            });
+      });
+}
+
+void silu_backward_kernel(TensorIterator& iter) {
+  AT_DISPATCH_FLOATING_TYPES_AND2(
+      at::ScalarType::Half,
+      at::ScalarType::BFloat16,
+      iter.dtype(),
+      "silu_backward_cuda",
+      [&]() {
+        gpu_kernel(
+            iter,
+            [] GPU_LAMBDA(scalar_t dy, scalar_t x) -> scalar_t {
+              using T_ACC = acc_type<scalar_t, true>;
+              const T_ACC dy_acc = static_cast<T_ACC>(dy);
+              const T_ACC x_acc = static_cast<T_ACC>(x);
+              const T_ACC s_acc =
+                  T_ACC(1) / (T_ACC(1) + c10::cuda::compat::exp(-x_acc));
+              return dy_acc * s_acc * (T_ACC(1) + x_acc * (T_ACC(1) - s_acc));
+            });
+      });
+}
+
 } // namespace
 
 Tensor gelu_cuda(const Tensor& self) {
@@ -540,5 +577,8 @@
 REGISTER_DISPATCH(hardsigmoid_backward_stub, &hardsigmoid_backward_kernel);
 REGISTER_DISPATCH(softplus_stub, &softplus_kernel);
 REGISTER_DISPATCH(softplus_backward_stub, &softplus_backward_kernel);
+REGISTER_DISPATCH(silu_stub, &silu_kernel);
+REGISTER_DISPATCH(silu_backward_stub, &silu_backward_kernel);
 
-}}  // namespace at::native
+} // namespace native
+} // namespace at
diff --git a/docs/source/scripts/build_activation_images.py b/docs/source/scripts/build_activation_images.py
index 1fa9fd2..7274d5c 100644
--- a/docs/source/scripts/build_activation_images.py
+++ b/docs/source/scripts/build_activation_images.py
@@ -36,6 +36,7 @@
     'ReLU6',
     'RReLU',
     'SELU',
+    'SiLU',
     'CELU',
     'GELU',
     'Sigmoid',
diff --git a/test/cpp_api_parity/parity-tracker.md b/test/cpp_api_parity/parity-tracker.md
index c5a9f0b..b7ec61a 100644
--- a/test/cpp_api_parity/parity-tracker.md
+++ b/test/cpp_api_parity/parity-tracker.md
@@ -57,6 +57,7 @@
 torch::nn::SELU|Yes|No
 torch::nn::CELU|Yes|No
 torch::nn::GELU|Yes|No
+torch::nn::SiLU|Yes|No
 torch::nn::Sigmoid|Yes|No
 torch::nn::Softplus|Yes|No
 torch::nn::Softshrink|Yes|No
@@ -183,6 +184,7 @@
 F::rrelu|Yes|No
 F::glu|Yes|No
 F::gelu|Yes|No
+F::silu|Yes|No
 F::logsigmoid|Yes|No
 F::hardshrink|Yes|No
 F::tanhshrink|Yes|No
diff --git a/test/test_torch.py b/test/test_torch.py
index 6d97c4a..3fd0db4 100644
--- a/test/test_torch.py
+++ b/test/test_torch.py
@@ -16795,22 +16795,35 @@
                          torch.tensor(expectedOutput, dtype=dtype, device=device),
                          atol=precision_4dps, rtol=0)
 
+    @skipIfNoSciPy
     @dtypes(torch.float, torch.double)
     def test_silu(self, device, dtype):
-        inputValues = [-1000, -1, 0, 0.5, 1, 2, 1000]
-        expectedOutput = [0.0000, -0.2689, 0, 0.3112, 0.7312, 1.7616, 1000]
-        precision_4dps = 0.0002
+        input_np = np.random.randn(5, 8)
+        special_input = [[-1000, -1, -0.1, 0, 0.5, 1, 2, 1000]]
+        input_np = np.concatenate((input_np, special_input), axis=0).astype(
+            torch_to_numpy_dtype_dict[dtype])
+        expected_output_np = input_np * scipy.special.expit(input_np)
 
-        input_tensor = torch.tensor(inputValues, dtype=dtype, device=device)
-        expected_output_tensor = torch.tensor(expectedOutput, dtype=dtype, device=device)
+        expected_output = torch.from_numpy(expected_output_np).to(device)
+        expected_output_noncontig = expected_output.transpose(0, 1)
 
-        self.assertEqual(torch.nn.functional.silu(input_tensor),
-                         expected_output_tensor,
-                         atol=precision_4dps, rtol=0)
+        atol = 1e-6
+        rtol = 1e-6
 
-        self.assertEqual(torch.nn.functional.silu(input_tensor, inplace=True),
-                         expected_output_tensor,
-                         atol=precision_4dps, rtol=0)
+        input = torch.from_numpy(input_np).clone().contiguous().to(device)
+        self.assertEqual(torch.nn.functional.silu(input), expected_output,
+                         atol=atol, rtol=rtol)
+        self.assertEqual(torch.nn.functional.silu(input, inplace=True),
+                         expected_output, atol=atol, rtol=rtol)
+
+        input = torch.from_numpy(input_np).clone().to(device)
+        input_noncontig = input.transpose(0, 1)
+        self.assertEqual(torch.nn.functional.silu(input_noncontig),
+                         expected_output_noncontig, atol=atol, rtol=rtol)
+        self.assertEqual(torch.nn.functional.silu(
+            input_noncontig, inplace=True), expected_output_noncontig,
+            atol=atol, rtol=rtol)
+
 
     @onlyCPU
     @dtypes(torch.float)
diff --git a/tools/autograd/derivatives.yaml b/tools/autograd/derivatives.yaml
index cbc9189..1d2c12a 100644
--- a/tools/autograd/derivatives.yaml
+++ b/tools/autograd/derivatives.yaml
@@ -1220,10 +1220,7 @@
   self: threshold_backward(grad, result, 0)
 
 - name: silu(Tensor self) -> Tensor
-  self: silu_backward(grad, self)
-
-- name: silu_(Tensor(a!) self) -> Tensor(a!)
-  self: not_implemented("silu_ cannot compute gradient of inplace version, use silu instead")
+  self: "GradMode::is_enabled() ? infinitely_differentiable_silu_backward(grad, self) : silu_backward(grad, self)"
 
 - name: elu(Tensor self, Scalar alpha=1, Scalar scale=1, Scalar input_scale=1) -> Tensor
   self: elu_backward(grad, alpha, scale, input_scale, result)
diff --git a/tools/autograd/templates/Functions.cpp b/tools/autograd/templates/Functions.cpp
index d8bd226..683ed76 100644
--- a/tools/autograd/templates/Functions.cpp
+++ b/tools/autograd/templates/Functions.cpp
@@ -959,6 +959,13 @@
   return cdf.addcmul_(self, pdf, kAlpha).mul_(grad);
 }
 
+Tensor infinitely_differentiable_silu_backward(
+    const Tensor& grad_output,
+    const Tensor& input) {
+  const Tensor sigmoid = input.sigmoid();
+  return grad_output * sigmoid * (1.0 + input * (1.0 - sigmoid));
+}
+
 Tensor infinitely_differentiable_logit_backward(
     const Tensor& grad,
     const Tensor& self,
diff --git a/torch/csrc/api/include/torch/enum.h b/torch/csrc/api/include/torch/enum.h
index 80b498f..098d52b 100644
--- a/torch/csrc/api/include/torch/enum.h
+++ b/torch/csrc/api/include/torch/enum.h
@@ -103,6 +103,7 @@
 TORCH_ENUM_DECLARE(Tanh)
 TORCH_ENUM_DECLARE(ReLU)
 TORCH_ENUM_DECLARE(GELU)
+TORCH_ENUM_DECLARE(SiLU)
 TORCH_ENUM_DECLARE(LeakyReLU)
 TORCH_ENUM_DECLARE(FanIn)
 TORCH_ENUM_DECLARE(FanOut)
@@ -143,6 +144,7 @@
   TORCH_ENUM_PRETTY_PRINT(Tanh)
   TORCH_ENUM_PRETTY_PRINT(ReLU)
   TORCH_ENUM_PRETTY_PRINT(GELU)
+  TORCH_ENUM_PRETTY_PRINT(SiLU)
   TORCH_ENUM_PRETTY_PRINT(LeakyReLU)
   TORCH_ENUM_PRETTY_PRINT(FanIn)
   TORCH_ENUM_PRETTY_PRINT(FanOut)
diff --git a/torch/csrc/api/include/torch/nn/functional/activation.h b/torch/csrc/api/include/torch/nn/functional/activation.h
index c69d10b..a13945f 100644
--- a/torch/csrc/api/include/torch/nn/functional/activation.h
+++ b/torch/csrc/api/include/torch/nn/functional/activation.h
@@ -342,6 +342,12 @@
 
 // ============================================================================
 
+inline Tensor silu(const Tensor& input) {
+  return torch::silu(input);
+}
+
+// ============================================================================
+
 inline Tensor prelu(const Tensor& input, const Tensor& weight) {
   return torch::prelu(input, weight);
 }
diff --git a/torch/csrc/api/include/torch/nn/modules/activation.h b/torch/csrc/api/include/torch/nn/modules/activation.h
index 97542b5..5a4c434 100644
--- a/torch/csrc/api/include/torch/nn/modules/activation.h
+++ b/torch/csrc/api/include/torch/nn/modules/activation.h
@@ -567,6 +567,27 @@
 /// module storage semantics.
 TORCH_MODULE(GELU);
 
+// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ SiLU ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
+
+/// Applies silu over a given input.
+/// See https://pytorch.org/docs/master/nn.html#torch.nn.SiLU to learn
+/// about the exact behavior of this module.
+class TORCH_API SiLUImpl : public torch::nn::Cloneable<SiLUImpl> {
+ public:
+  Tensor forward(const Tensor& input);
+
+  void reset() override;
+
+  /// Pretty prints the `SiLU` module into the given `stream`.
+  void pretty_print(std::ostream& stream) const override;
+};
+
+/// A `ModuleHolder` subclass for `SiLUImpl`.
+/// See the documentation for `SiLUImpl` class to learn what methods it
+/// provides, or the documentation for `ModuleHolder` to learn about PyTorch's
+/// module storage semantics.
+TORCH_MODULE(SiLU);
+
 // ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ Sigmoid ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
 
 /// Applies sigmoid over a given input.
diff --git a/torch/csrc/api/src/nn/modules/activation.cpp b/torch/csrc/api/src/nn/modules/activation.cpp
index 19db63b..6522dc5 100644
--- a/torch/csrc/api/src/nn/modules/activation.cpp
+++ b/torch/csrc/api/src/nn/modules/activation.cpp
@@ -294,6 +294,18 @@
 
 // ============================================================================
 
+Tensor SiLUImpl::forward(const Tensor& input) {
+  return F::silu(input);
+}
+
+void SiLUImpl::reset() {}
+
+void SiLUImpl::pretty_print(std::ostream& stream) const {
+  stream << "torch::nn::SiLU()";
+}
+
+// ============================================================================
+
 Tensor SigmoidImpl::forward(const Tensor& input) {
   return torch::sigmoid(input);
 }
diff --git a/torch/testing/_internal/common_nn.py b/torch/testing/_internal/common_nn.py
index f13e6d8..85faea1 100644
--- a/torch/testing/_internal/common_nn.py
+++ b/torch/testing/_internal/common_nn.py
@@ -3063,6 +3063,17 @@
         reference_fn=lambda x, *_: x * 0.5 * (1.0 + torch.erf(x / math.sqrt(2.0))),
     ),
     dict(
+        module_name='SiLU',
+        input_size=(),
+        desc='scalar',
+        reference_fn=lambda x, *_: x * torch.sigmoid(x),
+    ),
+    dict(
+        module_name='SiLU',
+        input_size=(5, 6, 7),
+        reference_fn=lambda x, *_: x * torch.sigmoid(x),
+    ),
+    dict(
         constructor=wrap_functional(F.softmax, dim=-1),
         cpp_options_args='F::SoftmaxFuncOptions(-1)',
         input_size=(2, 128),  # trigger the last-dim algo in CUDA