Optimize SiLU (Swish) op in PyTorch (#42976)
Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/42976
Optimize SiLU (Swish) op in PyTorch.
Some benchmark result
input = torch.rand(1024, 32768, dtype=torch.float, device="cpu")
forward: 221ms -> 133ms
backward: 600ms -> 170ms
input = torch.rand(1024, 32768, dtype=torch.double, device="cpu")
forward: 479ms -> 297ms
backward: 1438ms -> 387ms
input = torch.rand(8192, 32768, dtype=torch.float, device="cuda")
forward: 24.34ms -> 9.83ms
backward: 97.05ms -> 29.03ms
input = torch.rand(4096, 32768, dtype=torch.double, device="cuda")
forward: 44.24ms -> 30.15ms
backward: 126.21ms -> 49.68ms
Test Plan: buck test mode/dev-nosan //caffe2/test:nn -- "SiLU"
Reviewed By: houseroad
Differential Revision: D23093593
fbshipit-source-id: 1ba7b95d5926c4527216ed211a5ff1cefa3d3bfd
diff --git a/aten/src/ATen/native/Activation.cpp b/aten/src/ATen/native/Activation.cpp
index 0b12568..0e301df 100644
--- a/aten/src/ATen/native/Activation.cpp
+++ b/aten/src/ATen/native/Activation.cpp
@@ -30,6 +30,8 @@
DEFINE_DISPATCH(shrink_backward_stub);
DEFINE_DISPATCH(leaky_relu_stub);
DEFINE_DISPATCH(leaky_relu_backward_stub);
+DEFINE_DISPATCH(silu_stub);
+DEFINE_DISPATCH(silu_backward_stub);
Tensor hardtanh(const Tensor& self, Scalar min, Scalar max) {
return at::clamp(self, min, max);
@@ -195,20 +197,31 @@
}
Tensor silu(const Tensor& self) {
- return self * at::sigmoid(self);
+ Tensor result = at::empty({0}, self.options());
+ at::silu_out(result, self);
+ return result;
}
Tensor& silu_(Tensor& self) {
- return self.mul_(at::sigmoid(self));
+ return at::silu_out(self, self);
}
Tensor& silu_out(Tensor& result, const Tensor& self) {
- return at::mul_out(result, self, at::sigmoid(self));
+ TORCH_CHECK(
+ result.dtype() == self.dtype(),
+ "Output Tensor should have the same type as in Input Tensor.")
+ auto iter = TensorIterator::unary_op(result, self);
+ silu_stub(iter.device_type(), iter);
+ return result;
}
-Tensor silu_backward(const Tensor& grad, const Tensor& self) {
- auto self_sigmoid = at::sigmoid(self);
- return grad * (self_sigmoid * (1 + self * (1 - self_sigmoid)));
+Tensor silu_backward(
+ const Tensor& grad_output,
+ const Tensor& input) {
+ Tensor grad_input = at::empty({0}, input.options());
+ auto iter = TensorIterator::binary_op(grad_input, grad_output, input);
+ silu_backward_stub(iter.device_type(), iter);
+ return grad_input;
}
template <typename scalar_t>
diff --git a/aten/src/ATen/native/Activation.h b/aten/src/ATen/native/Activation.h
index c557835..bebfa67 100644
--- a/aten/src/ATen/native/Activation.h
+++ b/aten/src/ATen/native/Activation.h
@@ -48,6 +48,8 @@
DECLARE_DISPATCH(leaky_relu_backward_fn, leaky_relu_backward_stub);
DECLARE_DISPATCH(activation_fn, glu_stub);
DECLARE_DISPATCH(activation_backward_fn, glu_backward_stub);
+DECLARE_DISPATCH(activation_fn, silu_stub);
+DECLARE_DISPATCH(activation_backward_fn, silu_backward_stub);
} // namespace native
diff --git a/aten/src/ATen/native/cpu/Activation.cpp b/aten/src/ATen/native/cpu/Activation.cpp
index c9a0faa..79de581 100644
--- a/aten/src/ATen/native/cpu/Activation.cpp
+++ b/aten/src/ATen/native/cpu/Activation.cpp
@@ -4,7 +4,8 @@
#include <ATen/native/Activation.h>
-#include <math.h>
+#include <cmath>
+#include <functional>
#include <ATen/ATen.h>
#include <ATen/Config.h>
@@ -589,6 +590,40 @@
});
}
+void silu_kernel(TensorIterator& iter) {
+ AT_DISPATCH_FLOATING_AND_COMPLEX_TYPES_AND1(
+ kBFloat16, iter.dtype(), "silu_cpu", [&]() {
+ const Vec256<scalar_t> kOneVec(scalar_t(1));
+ cpu_kernel_vec(
+ iter,
+ [](scalar_t x) {
+ return x / (scalar_t(1) + std::exp(-x));
+ },
+ [kOneVec](Vec256<scalar_t> x_vec) {
+ return x_vec / (kOneVec + x_vec.neg().exp());
+ });
+ });
+}
+
+void silu_backward_kernel(TensorIterator& iter) {
+ AT_DISPATCH_FLOATING_AND_COMPLEX_TYPES_AND1(
+ kBFloat16, iter.dtype(), "silu_backward_cpu", [&]() {
+ const Vec256<scalar_t> kOneVec(scalar_t(1));
+ cpu_kernel_vec(
+ iter,
+ [](scalar_t dy, scalar_t x) {
+ const scalar_t sigmoid =
+ scalar_t(1) / (scalar_t(1) + std::exp(-x));
+ return dy * sigmoid * (scalar_t(1) + x * (scalar_t(1) - sigmoid));
+ },
+ [kOneVec](Vec256<scalar_t> dy_vec, Vec256<scalar_t> x_vec) {
+ const Vec256<scalar_t> sigmoid =
+ kOneVec / (kOneVec + x_vec.neg().exp());
+ return dy_vec * sigmoid * (kOneVec + x_vec * (kOneVec - sigmoid));
+ });
+ });
+}
+
} // namespace
REGISTER_DISPATCH(log_sigmoid_cpu_stub, &log_sigmoid_cpu_kernel);
@@ -612,6 +647,8 @@
REGISTER_DISPATCH(softplus_backward_stub, &softplus_backward_kernel);
REGISTER_DISPATCH(glu_stub, &glu_kernel);
REGISTER_DISPATCH(glu_backward_stub, &glu_backward_kernel);
+REGISTER_DISPATCH(silu_stub, &silu_kernel);
+REGISTER_DISPATCH(silu_backward_stub, &silu_backward_kernel);
} // namespace native
} // namespace at
diff --git a/aten/src/ATen/native/cuda/Activation.cu b/aten/src/ATen/native/cuda/Activation.cu
index 209bf29..359fd5b 100644
--- a/aten/src/ATen/native/cuda/Activation.cu
+++ b/aten/src/ATen/native/cuda/Activation.cu
@@ -2,7 +2,9 @@
#include <ATen/native/Activation.h>
-#include <math.h>
+#include <cmath>
+
+#include <thrust/tuple.h>
#include <ATen/ATen.h>
#include <ATen/AccumulateType.h>
@@ -15,10 +17,8 @@
#include <ATen/native/cuda/Loops.cuh>
#include <c10/cuda/CUDAMathCompat.h>
-#include <thrust/tuple.h>
-
-
-namespace at { namespace native {
+namespace at {
+namespace native {
// -----------------------------------
// prelu forward
@@ -478,6 +478,43 @@
});
}
+void silu_kernel(TensorIterator& iter) {
+ AT_DISPATCH_FLOATING_TYPES_AND2(
+ at::ScalarType::Half,
+ at::ScalarType::BFloat16,
+ iter.dtype(),
+ "silu_cuda",
+ [&]() {
+ gpu_kernel(
+ iter,
+ [] GPU_LAMBDA(scalar_t x) -> scalar_t {
+ using T_ACC = acc_type<scalar_t, true>;
+ const T_ACC x_acc = static_cast<T_ACC>(x);
+ return x_acc / (T_ACC(1) + c10::cuda::compat::exp(-x_acc));
+ });
+ });
+}
+
+void silu_backward_kernel(TensorIterator& iter) {
+ AT_DISPATCH_FLOATING_TYPES_AND2(
+ at::ScalarType::Half,
+ at::ScalarType::BFloat16,
+ iter.dtype(),
+ "silu_backward_cuda",
+ [&]() {
+ gpu_kernel(
+ iter,
+ [] GPU_LAMBDA(scalar_t dy, scalar_t x) -> scalar_t {
+ using T_ACC = acc_type<scalar_t, true>;
+ const T_ACC dy_acc = static_cast<T_ACC>(dy);
+ const T_ACC x_acc = static_cast<T_ACC>(x);
+ const T_ACC s_acc =
+ T_ACC(1) / (T_ACC(1) + c10::cuda::compat::exp(-x_acc));
+ return dy_acc * s_acc * (T_ACC(1) + x_acc * (T_ACC(1) - s_acc));
+ });
+ });
+}
+
} // namespace
Tensor gelu_cuda(const Tensor& self) {
@@ -540,5 +577,8 @@
REGISTER_DISPATCH(hardsigmoid_backward_stub, &hardsigmoid_backward_kernel);
REGISTER_DISPATCH(softplus_stub, &softplus_kernel);
REGISTER_DISPATCH(softplus_backward_stub, &softplus_backward_kernel);
+REGISTER_DISPATCH(silu_stub, &silu_kernel);
+REGISTER_DISPATCH(silu_backward_stub, &silu_backward_kernel);
-}} // namespace at::native
+} // namespace native
+} // namespace at
diff --git a/docs/source/scripts/build_activation_images.py b/docs/source/scripts/build_activation_images.py
index 1fa9fd2..7274d5c 100644
--- a/docs/source/scripts/build_activation_images.py
+++ b/docs/source/scripts/build_activation_images.py
@@ -36,6 +36,7 @@
'ReLU6',
'RReLU',
'SELU',
+ 'SiLU',
'CELU',
'GELU',
'Sigmoid',
diff --git a/test/cpp_api_parity/parity-tracker.md b/test/cpp_api_parity/parity-tracker.md
index c5a9f0b..b7ec61a 100644
--- a/test/cpp_api_parity/parity-tracker.md
+++ b/test/cpp_api_parity/parity-tracker.md
@@ -57,6 +57,7 @@
torch::nn::SELU|Yes|No
torch::nn::CELU|Yes|No
torch::nn::GELU|Yes|No
+torch::nn::SiLU|Yes|No
torch::nn::Sigmoid|Yes|No
torch::nn::Softplus|Yes|No
torch::nn::Softshrink|Yes|No
@@ -183,6 +184,7 @@
F::rrelu|Yes|No
F::glu|Yes|No
F::gelu|Yes|No
+F::silu|Yes|No
F::logsigmoid|Yes|No
F::hardshrink|Yes|No
F::tanhshrink|Yes|No
diff --git a/test/test_torch.py b/test/test_torch.py
index 6d97c4a..3fd0db4 100644
--- a/test/test_torch.py
+++ b/test/test_torch.py
@@ -16795,22 +16795,35 @@
torch.tensor(expectedOutput, dtype=dtype, device=device),
atol=precision_4dps, rtol=0)
+ @skipIfNoSciPy
@dtypes(torch.float, torch.double)
def test_silu(self, device, dtype):
- inputValues = [-1000, -1, 0, 0.5, 1, 2, 1000]
- expectedOutput = [0.0000, -0.2689, 0, 0.3112, 0.7312, 1.7616, 1000]
- precision_4dps = 0.0002
+ input_np = np.random.randn(5, 8)
+ special_input = [[-1000, -1, -0.1, 0, 0.5, 1, 2, 1000]]
+ input_np = np.concatenate((input_np, special_input), axis=0).astype(
+ torch_to_numpy_dtype_dict[dtype])
+ expected_output_np = input_np * scipy.special.expit(input_np)
- input_tensor = torch.tensor(inputValues, dtype=dtype, device=device)
- expected_output_tensor = torch.tensor(expectedOutput, dtype=dtype, device=device)
+ expected_output = torch.from_numpy(expected_output_np).to(device)
+ expected_output_noncontig = expected_output.transpose(0, 1)
- self.assertEqual(torch.nn.functional.silu(input_tensor),
- expected_output_tensor,
- atol=precision_4dps, rtol=0)
+ atol = 1e-6
+ rtol = 1e-6
- self.assertEqual(torch.nn.functional.silu(input_tensor, inplace=True),
- expected_output_tensor,
- atol=precision_4dps, rtol=0)
+ input = torch.from_numpy(input_np).clone().contiguous().to(device)
+ self.assertEqual(torch.nn.functional.silu(input), expected_output,
+ atol=atol, rtol=rtol)
+ self.assertEqual(torch.nn.functional.silu(input, inplace=True),
+ expected_output, atol=atol, rtol=rtol)
+
+ input = torch.from_numpy(input_np).clone().to(device)
+ input_noncontig = input.transpose(0, 1)
+ self.assertEqual(torch.nn.functional.silu(input_noncontig),
+ expected_output_noncontig, atol=atol, rtol=rtol)
+ self.assertEqual(torch.nn.functional.silu(
+ input_noncontig, inplace=True), expected_output_noncontig,
+ atol=atol, rtol=rtol)
+
@onlyCPU
@dtypes(torch.float)
diff --git a/tools/autograd/derivatives.yaml b/tools/autograd/derivatives.yaml
index cbc9189..1d2c12a 100644
--- a/tools/autograd/derivatives.yaml
+++ b/tools/autograd/derivatives.yaml
@@ -1220,10 +1220,7 @@
self: threshold_backward(grad, result, 0)
- name: silu(Tensor self) -> Tensor
- self: silu_backward(grad, self)
-
-- name: silu_(Tensor(a!) self) -> Tensor(a!)
- self: not_implemented("silu_ cannot compute gradient of inplace version, use silu instead")
+ self: "GradMode::is_enabled() ? infinitely_differentiable_silu_backward(grad, self) : silu_backward(grad, self)"
- name: elu(Tensor self, Scalar alpha=1, Scalar scale=1, Scalar input_scale=1) -> Tensor
self: elu_backward(grad, alpha, scale, input_scale, result)
diff --git a/tools/autograd/templates/Functions.cpp b/tools/autograd/templates/Functions.cpp
index d8bd226..683ed76 100644
--- a/tools/autograd/templates/Functions.cpp
+++ b/tools/autograd/templates/Functions.cpp
@@ -959,6 +959,13 @@
return cdf.addcmul_(self, pdf, kAlpha).mul_(grad);
}
+Tensor infinitely_differentiable_silu_backward(
+ const Tensor& grad_output,
+ const Tensor& input) {
+ const Tensor sigmoid = input.sigmoid();
+ return grad_output * sigmoid * (1.0 + input * (1.0 - sigmoid));
+}
+
Tensor infinitely_differentiable_logit_backward(
const Tensor& grad,
const Tensor& self,
diff --git a/torch/csrc/api/include/torch/enum.h b/torch/csrc/api/include/torch/enum.h
index 80b498f..098d52b 100644
--- a/torch/csrc/api/include/torch/enum.h
+++ b/torch/csrc/api/include/torch/enum.h
@@ -103,6 +103,7 @@
TORCH_ENUM_DECLARE(Tanh)
TORCH_ENUM_DECLARE(ReLU)
TORCH_ENUM_DECLARE(GELU)
+TORCH_ENUM_DECLARE(SiLU)
TORCH_ENUM_DECLARE(LeakyReLU)
TORCH_ENUM_DECLARE(FanIn)
TORCH_ENUM_DECLARE(FanOut)
@@ -143,6 +144,7 @@
TORCH_ENUM_PRETTY_PRINT(Tanh)
TORCH_ENUM_PRETTY_PRINT(ReLU)
TORCH_ENUM_PRETTY_PRINT(GELU)
+ TORCH_ENUM_PRETTY_PRINT(SiLU)
TORCH_ENUM_PRETTY_PRINT(LeakyReLU)
TORCH_ENUM_PRETTY_PRINT(FanIn)
TORCH_ENUM_PRETTY_PRINT(FanOut)
diff --git a/torch/csrc/api/include/torch/nn/functional/activation.h b/torch/csrc/api/include/torch/nn/functional/activation.h
index c69d10b..a13945f 100644
--- a/torch/csrc/api/include/torch/nn/functional/activation.h
+++ b/torch/csrc/api/include/torch/nn/functional/activation.h
@@ -342,6 +342,12 @@
// ============================================================================
+inline Tensor silu(const Tensor& input) {
+ return torch::silu(input);
+}
+
+// ============================================================================
+
inline Tensor prelu(const Tensor& input, const Tensor& weight) {
return torch::prelu(input, weight);
}
diff --git a/torch/csrc/api/include/torch/nn/modules/activation.h b/torch/csrc/api/include/torch/nn/modules/activation.h
index 97542b5..5a4c434 100644
--- a/torch/csrc/api/include/torch/nn/modules/activation.h
+++ b/torch/csrc/api/include/torch/nn/modules/activation.h
@@ -567,6 +567,27 @@
/// module storage semantics.
TORCH_MODULE(GELU);
+// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ SiLU ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
+
+/// Applies silu over a given input.
+/// See https://pytorch.org/docs/master/nn.html#torch.nn.SiLU to learn
+/// about the exact behavior of this module.
+class TORCH_API SiLUImpl : public torch::nn::Cloneable<SiLUImpl> {
+ public:
+ Tensor forward(const Tensor& input);
+
+ void reset() override;
+
+ /// Pretty prints the `SiLU` module into the given `stream`.
+ void pretty_print(std::ostream& stream) const override;
+};
+
+/// A `ModuleHolder` subclass for `SiLUImpl`.
+/// See the documentation for `SiLUImpl` class to learn what methods it
+/// provides, or the documentation for `ModuleHolder` to learn about PyTorch's
+/// module storage semantics.
+TORCH_MODULE(SiLU);
+
// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ Sigmoid ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
/// Applies sigmoid over a given input.
diff --git a/torch/csrc/api/src/nn/modules/activation.cpp b/torch/csrc/api/src/nn/modules/activation.cpp
index 19db63b..6522dc5 100644
--- a/torch/csrc/api/src/nn/modules/activation.cpp
+++ b/torch/csrc/api/src/nn/modules/activation.cpp
@@ -294,6 +294,18 @@
// ============================================================================
+Tensor SiLUImpl::forward(const Tensor& input) {
+ return F::silu(input);
+}
+
+void SiLUImpl::reset() {}
+
+void SiLUImpl::pretty_print(std::ostream& stream) const {
+ stream << "torch::nn::SiLU()";
+}
+
+// ============================================================================
+
Tensor SigmoidImpl::forward(const Tensor& input) {
return torch::sigmoid(input);
}
diff --git a/torch/testing/_internal/common_nn.py b/torch/testing/_internal/common_nn.py
index f13e6d8..85faea1 100644
--- a/torch/testing/_internal/common_nn.py
+++ b/torch/testing/_internal/common_nn.py
@@ -3063,6 +3063,17 @@
reference_fn=lambda x, *_: x * 0.5 * (1.0 + torch.erf(x / math.sqrt(2.0))),
),
dict(
+ module_name='SiLU',
+ input_size=(),
+ desc='scalar',
+ reference_fn=lambda x, *_: x * torch.sigmoid(x),
+ ),
+ dict(
+ module_name='SiLU',
+ input_size=(5, 6, 7),
+ reference_fn=lambda x, *_: x * torch.sigmoid(x),
+ ),
+ dict(
constructor=wrap_functional(F.softmax, dim=-1),
cpp_options_args='F::SoftmaxFuncOptions(-1)',
input_size=(2, 128), # trigger the last-dim algo in CUDA