Add a bunch of extra functionality to SymFloat (#86046)

- SymInt to SymFloat conversion
- All the basic arithmetic operators on c10::SymFloat

Signed-off-by: Edward Z. Yang <ezyang@fb.com>
Pull Request resolved: https://github.com/pytorch/pytorch/pull/86046
Approved by: https://github.com/wconstab
diff --git a/c10/core/SymFloat.cpp b/c10/core/SymFloat.cpp
index 19321bd..0ba980a 100644
--- a/c10/core/SymFloat.cpp
+++ b/c10/core/SymFloat.cpp
@@ -9,6 +9,60 @@
   return SymFloatNode::reclaim_copy(toSymFloatNodeImplUnowned());
 }
 
+static std::array<SymFloatNode, 2> normalize_symfloats(
+    SymFloat a_,
+    SymFloat b_) {
+  SymFloatNode a, b;
+  if (a_.is_symbolic())
+    a = a_.toSymFloatNodeImpl();
+  if (b_.is_symbolic())
+    b = b_.toSymFloatNodeImpl();
+
+  SymFloatNodeImpl* common = a ? a.get() : b.get();
+  // TODO: technically we need to check that the classes match
+  if (!a) {
+    a = common->wrap(a_.as_float_unchecked());
+    a_.toSymFloat(a); //
+  }
+  if (!b) {
+    b = common->wrap(b_.as_float_unchecked());
+    b_.toSymFloat(b);
+  }
+  return {a, b};
+}
+
+SymFloat SymFloat::operator+(SymFloat sci) const {
+  if (!is_symbolic() && !sci.is_symbolic()) {
+    return SymFloat(data_ + sci.data_);
+  }
+  auto res = normalize_symfloats(*this, sci);
+  return SymFloat::toSymFloat(res[0]->add(res[1]));
+}
+
+SymFloat SymFloat::operator-(SymFloat sci) const {
+  if (!is_symbolic() && !sci.is_symbolic()) {
+    return SymFloat(data_ - sci.data_);
+  }
+  auto res = normalize_symfloats(*this, sci);
+  return SymFloat::toSymFloat(res[0]->sub(res[1]));
+}
+
+SymFloat SymFloat::operator*(SymFloat sci) const {
+  if (!is_symbolic() && !sci.is_symbolic()) {
+    return SymFloat(data_ * sci.data_);
+  }
+  auto res = normalize_symfloats(*this, sci);
+  return SymFloat::toSymFloat(res[0]->mul(res[1]));
+}
+
+SymFloat SymFloat::operator/(SymFloat sci) const {
+  if (!is_symbolic() && !sci.is_symbolic()) {
+    return SymFloat(data_ / sci.data_);
+  }
+  auto res = normalize_symfloats(*this, sci);
+  return SymFloat::toSymFloat(res[0]->truediv(res[1]));
+}
+
 c10::SymFloat SymFloat::toSymFloat(SymFloatNode sin_sp) {
   return c10::SymFloat(std::move(sin_sp));
 }
diff --git a/c10/core/SymFloat.h b/c10/core/SymFloat.h
index 0e32458..02c64eb 100644
--- a/c10/core/SymFloat.h
+++ b/c10/core/SymFloat.h
@@ -34,6 +34,11 @@
     return data_;
   }
 
+  SymFloat operator+(SymFloat) const;
+  SymFloat operator-(SymFloat) const;
+  SymFloat operator*(SymFloat) const;
+  SymFloat operator/(SymFloat) const;
+
   // N.B. It's important to keep this definition in the header
   // as we expect if checks to be folded for mobile builds
   // where `is_symbolic` is always false
diff --git a/c10/core/SymInt.cpp b/c10/core/SymInt.cpp
index 095a331..3d20363 100644
--- a/c10/core/SymInt.cpp
+++ b/c10/core/SymInt.cpp
@@ -1,3 +1,4 @@
+#include <c10/core/SymFloat.h>
 #include <c10/core/SymInt.h>
 #include <c10/core/SymIntNodeImpl.h>
 #include <array>
@@ -60,6 +61,13 @@
   return a->guard_int(file, line);
 }
 
+SymInt::operator SymFloat() const {
+  if (!is_symbolic()) {
+    return SymFloat(double(data_));
+  }
+  return SymFloat::toSymFloat(toSymIntNodeImpl()->sym_float());
+}
+
 SymInt SymInt::operator+(SymInt sci) const {
   if (!is_symbolic() && !sci.is_symbolic()) {
     return SymInt(data_ + sci.data_);
diff --git a/c10/core/SymInt.h b/c10/core/SymInt.h
index de4866f..00c51c8 100644
--- a/c10/core/SymInt.h
+++ b/c10/core/SymInt.h
@@ -10,6 +10,8 @@
 
 namespace c10 {
 
+class SymFloat;
+
 // `SymInt` is a C++ wrapper class around int64_t data_ which  and is used to
 // represent concrete dimension values.
 //
@@ -188,6 +190,8 @@
   bool operator>(int64_t sci) const;
   bool operator>=(int64_t sci) const;
 
+  operator SymFloat() const;
+
   int64_t as_int_unchecked() const {
     return data_;
   }
diff --git a/c10/core/SymIntNodeImpl.h b/c10/core/SymIntNodeImpl.h
index 1e9aa4c..03a5068 100644
--- a/c10/core/SymIntNodeImpl.h
+++ b/c10/core/SymIntNodeImpl.h
@@ -64,6 +64,9 @@
   virtual SymIntNode clone() {
     TORCH_CHECK(false, "NYI");
   };
+  virtual SymFloatNode sym_float() {
+    TORCH_CHECK(false, "NYI");
+  }
   virtual SymIntNode wrap(int64_t num) {
     TORCH_CHECK(false, "NYI");
   };
diff --git a/torch/csrc/jit/python/init.cpp b/torch/csrc/jit/python/init.cpp
index 113ff1d..d0d810c 100644
--- a/torch/csrc/jit/python/init.cpp
+++ b/torch/csrc/jit/python/init.cpp
@@ -177,8 +177,7 @@
     return getPyObj().attr("__int__")().cast<int64_t>();
   }
 
-  // TODO: virtualize
-  SymFloat sym_float();
+  SymFloatNode sym_float() override;
 
   virtual std::string str() override {
     py::gil_scoped_acquire acquire;
@@ -299,11 +298,10 @@
   return c10::make_intrusive<PythonSymFloatNodeImpl>(r);
 }
 
-SymFloat PythonSymIntNodeImpl::sym_float() {
+SymFloatNode PythonSymIntNodeImpl::sym_float() {
   py::gil_scoped_acquire acquire;
   return c10::make_intrusive<PythonSymFloatNodeImpl>(
-             getPyObj().attr("__sym_float__")())
-      ->toSymFloat();
+      getPyObj().attr("__sym_float__")());
 }
 
 namespace {