Simplify operator `sign` using the helper.
Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/25592
Test Plan: Imported from OSS
Differential Revision: D17552470
Pulled By: VitalyFedyunin
fbshipit-source-id: 6c8cc4f46dd390c231b2d0aac664ad2a6ac8876e
diff --git a/aten/src/ATen/native/UnaryOps.cpp b/aten/src/ATen/native/UnaryOps.cpp
index a8a5c19..586f5d5 100644
--- a/aten/src/ATen/native/UnaryOps.cpp
+++ b/aten/src/ATen/native/UnaryOps.cpp
@@ -81,6 +81,10 @@
Tensor rsqrt(const Tensor& self) { return unary_op_impl(self, at::rsqrt_out); }
Tensor& rsqrt_(Tensor& self) { return unary_op_impl_(self, at::rsqrt_out); }
+Tensor& sign_out(Tensor& result, const Tensor& self) { return unary_op_impl_out(result, self, sign_stub); }
+Tensor sign(const Tensor& self) { return unary_op_impl(self, at::sign_out); }
+Tensor& sign_(Tensor& self) { return unary_op_impl_(self, at::sign_out); }
+
Tensor& trunc_out(Tensor& result, const Tensor& self) { return unary_op_impl_out(result, self, trunc_stub); }
Tensor trunc(const Tensor& self) { return unary_op_impl(self, at::trunc_out); }
Tensor& trunc_(Tensor& self) { return unary_op_impl_(self, at::trunc_out); }
@@ -193,23 +197,6 @@
return result;
}
-Tensor sign(const Tensor& self) {
- Tensor result = at::empty({0}, self.options());
- return at::sign_out(result, self);
-}
-
-Tensor& sign_(Tensor& self) {
- return at::sign_out(self, self);
-}
-
-Tensor& sign_out(Tensor& result, const Tensor& self) {
- checkBackend("sign", result, self.type().backend());
- auto iter = TensorIterator::unary_op(result, self,
- /*check_internal_overlap=*/true);
- sign_stub(iter.device_type(), iter);
- return result;
-}
-
Tensor mvlgamma(const Tensor& self, int64_t p) {
TORCH_CHECK(at::isFloatingType(self.scalar_type()),
"mvlgamma is not implemented for ", self.type());