Skip dispatch for is_signed (#53847)
Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/53847
Test Plan: Imported from OSS
Reviewed By: anjali411
Differential Revision: D26994937
Pulled By: carolineechen
fbshipit-source-id: 8af25ecdade0b31d29fac27de6ee5f704353af10
diff --git a/aten/src/ATen/native/TypeProperties.cpp b/aten/src/ATen/native/TypeProperties.cpp
index f277f67..b483d68 100644
--- a/aten/src/ATen/native/TypeProperties.cpp
+++ b/aten/src/ATen/native/TypeProperties.cpp
@@ -23,7 +23,7 @@
}
bool is_signed(const Tensor &self) {
- return at::isSignedType(self.scalar_type());
+ return self.is_signed();
}
bool is_sparse(const Tensor& self) {
diff --git a/aten/src/ATen/native/native_functions.yaml b/aten/src/ATen/native/native_functions.yaml
index a05b4e1..a2205e3 100644
--- a/aten/src/ATen/native/native_functions.yaml
+++ b/aten/src/ATen/native/native_functions.yaml
@@ -2092,6 +2092,7 @@
- func: is_signed(Tensor self) -> bool
variants: function, method
device_guard: False
+ manual_cpp_binding: True
- func: kl_div(Tensor self, Tensor target, int reduction=Mean, *, bool log_target=False) -> Tensor
dispatch:
diff --git a/aten/src/ATen/templates/Functions.h b/aten/src/ATen/templates/Functions.h
index d38a524..0c82fdf 100644
--- a/aten/src/ATen/templates/Functions.h
+++ b/aten/src/ATen/templates/Functions.h
@@ -147,7 +147,11 @@
}
inline bool is_floating_point(const Tensor& tensor) {
- return tensor.is_floating_point();
+ return tensor.is_floating_point();
+}
+
+inline bool is_signed(const Tensor& tensor) {
+ return tensor.is_signed();
}
}
diff --git a/aten/src/ATen/templates/TensorBody.h b/aten/src/ATen/templates/TensorBody.h
index c6efc90..9bec71e 100644
--- a/aten/src/ATen/templates/TensorBody.h
+++ b/aten/src/ATen/templates/TensorBody.h
@@ -149,7 +149,11 @@
}
bool is_floating_point() const {
- return at::isFloatingType(this->scalar_type());
+ return at::isFloatingType(this->scalar_type());
+ }
+
+ bool is_signed() const {
+ return at::isSignedType(this->scalar_type());
}
int64_t size(int64_t dim) const {