Add a bunch of extra functionality to SymFloat (#86046)
- SymInt to SymFloat conversion
- All the basic arithmetic operators on c10::SymFloat
Signed-off-by: Edward Z. Yang <ezyang@fb.com>
Pull Request resolved: https://github.com/pytorch/pytorch/pull/86046
Approved by: https://github.com/wconstab
diff --git a/c10/core/SymFloat.cpp b/c10/core/SymFloat.cpp
index 19321bd..0ba980a 100644
--- a/c10/core/SymFloat.cpp
+++ b/c10/core/SymFloat.cpp
@@ -9,6 +9,60 @@
return SymFloatNode::reclaim_copy(toSymFloatNodeImplUnowned());
}
+static std::array<SymFloatNode, 2> normalize_symfloats(
+ SymFloat a_,
+ SymFloat b_) {
+ SymFloatNode a, b;
+ if (a_.is_symbolic())
+ a = a_.toSymFloatNodeImpl();
+ if (b_.is_symbolic())
+ b = b_.toSymFloatNodeImpl();
+
+ SymFloatNodeImpl* common = a ? a.get() : b.get();
+ // TODO: technically we need to check that the classes match
+ if (!a) {
+ a = common->wrap(a_.as_float_unchecked());
+ a_.toSymFloat(a); //
+ }
+ if (!b) {
+ b = common->wrap(b_.as_float_unchecked());
+ b_.toSymFloat(b);
+ }
+ return {a, b};
+}
+
+SymFloat SymFloat::operator+(SymFloat sci) const {
+ if (!is_symbolic() && !sci.is_symbolic()) {
+ return SymFloat(data_ + sci.data_);
+ }
+ auto res = normalize_symfloats(*this, sci);
+ return SymFloat::toSymFloat(res[0]->add(res[1]));
+}
+
+SymFloat SymFloat::operator-(SymFloat sci) const {
+ if (!is_symbolic() && !sci.is_symbolic()) {
+ return SymFloat(data_ - sci.data_);
+ }
+ auto res = normalize_symfloats(*this, sci);
+ return SymFloat::toSymFloat(res[0]->sub(res[1]));
+}
+
+SymFloat SymFloat::operator*(SymFloat sci) const {
+ if (!is_symbolic() && !sci.is_symbolic()) {
+ return SymFloat(data_ * sci.data_);
+ }
+ auto res = normalize_symfloats(*this, sci);
+ return SymFloat::toSymFloat(res[0]->mul(res[1]));
+}
+
+SymFloat SymFloat::operator/(SymFloat sci) const {
+ if (!is_symbolic() && !sci.is_symbolic()) {
+ return SymFloat(data_ / sci.data_);
+ }
+ auto res = normalize_symfloats(*this, sci);
+ return SymFloat::toSymFloat(res[0]->truediv(res[1]));
+}
+
c10::SymFloat SymFloat::toSymFloat(SymFloatNode sin_sp) {
return c10::SymFloat(std::move(sin_sp));
}
diff --git a/c10/core/SymFloat.h b/c10/core/SymFloat.h
index 0e32458..02c64eb 100644
--- a/c10/core/SymFloat.h
+++ b/c10/core/SymFloat.h
@@ -34,6 +34,11 @@
return data_;
}
+ SymFloat operator+(SymFloat) const;
+ SymFloat operator-(SymFloat) const;
+ SymFloat operator*(SymFloat) const;
+ SymFloat operator/(SymFloat) const;
+
// N.B. It's important to keep this definition in the header
// as we expect if checks to be folded for mobile builds
// where `is_symbolic` is always false
diff --git a/c10/core/SymInt.cpp b/c10/core/SymInt.cpp
index 095a331..3d20363 100644
--- a/c10/core/SymInt.cpp
+++ b/c10/core/SymInt.cpp
@@ -1,3 +1,4 @@
+#include <c10/core/SymFloat.h>
#include <c10/core/SymInt.h>
#include <c10/core/SymIntNodeImpl.h>
#include <array>
@@ -60,6 +61,13 @@
return a->guard_int(file, line);
}
+SymInt::operator SymFloat() const {
+ if (!is_symbolic()) {
+ return SymFloat(double(data_));
+ }
+ return SymFloat::toSymFloat(toSymIntNodeImpl()->sym_float());
+}
+
SymInt SymInt::operator+(SymInt sci) const {
if (!is_symbolic() && !sci.is_symbolic()) {
return SymInt(data_ + sci.data_);
diff --git a/c10/core/SymInt.h b/c10/core/SymInt.h
index de4866f..00c51c8 100644
--- a/c10/core/SymInt.h
+++ b/c10/core/SymInt.h
@@ -10,6 +10,8 @@
namespace c10 {
+class SymFloat;
+
// `SymInt` is a C++ wrapper class around int64_t data_ which and is used to
// represent concrete dimension values.
//
@@ -188,6 +190,8 @@
bool operator>(int64_t sci) const;
bool operator>=(int64_t sci) const;
+ operator SymFloat() const;
+
int64_t as_int_unchecked() const {
return data_;
}
diff --git a/c10/core/SymIntNodeImpl.h b/c10/core/SymIntNodeImpl.h
index 1e9aa4c..03a5068 100644
--- a/c10/core/SymIntNodeImpl.h
+++ b/c10/core/SymIntNodeImpl.h
@@ -64,6 +64,9 @@
virtual SymIntNode clone() {
TORCH_CHECK(false, "NYI");
};
+ virtual SymFloatNode sym_float() {
+ TORCH_CHECK(false, "NYI");
+ }
virtual SymIntNode wrap(int64_t num) {
TORCH_CHECK(false, "NYI");
};
diff --git a/torch/csrc/jit/python/init.cpp b/torch/csrc/jit/python/init.cpp
index 113ff1d..d0d810c 100644
--- a/torch/csrc/jit/python/init.cpp
+++ b/torch/csrc/jit/python/init.cpp
@@ -177,8 +177,7 @@
return getPyObj().attr("__int__")().cast<int64_t>();
}
- // TODO: virtualize
- SymFloat sym_float();
+ SymFloatNode sym_float() override;
virtual std::string str() override {
py::gil_scoped_acquire acquire;
@@ -299,11 +298,10 @@
return c10::make_intrusive<PythonSymFloatNodeImpl>(r);
}
-SymFloat PythonSymIntNodeImpl::sym_float() {
+SymFloatNode PythonSymIntNodeImpl::sym_float() {
py::gil_scoped_acquire acquire;
return c10::make_intrusive<PythonSymFloatNodeImpl>(
- getPyObj().attr("__sym_float__")())
- ->toSymFloat();
+ getPyObj().attr("__sym_float__")());
}
namespace {