[torchbind] Improve IValue custom class API and remove most Capsule stuff (#34848)

Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/34848

Test Plan: Imported from OSS

Differential Revision: D20480514

Pulled By: jamesr66a

fbshipit-source-id: 1c595faf34e00aab0a6202a8902426bd310551c3
diff --git a/aten/src/ATen/core/ivalue.cpp b/aten/src/ATen/core/ivalue.cpp
index 7816df1..20d6ce3 100644
--- a/aten/src/ATen/core/ivalue.cpp
+++ b/aten/src/ATen/core/ivalue.cpp
@@ -8,6 +8,19 @@
 namespace c10 {
 namespace ivalue {
 
+// This is in ivalue.cpp because we need to access Type::python_str, which
+// is declared in jit_type.h
+void checkCustomClassType(TypePtr expected_type, TypePtr actual_type) {
+  // NB: doing pointer comparison here
+  // If in the future there ever arises a need to call operator== on custom class
+  // Type's, this needs to be changed!
+  TORCH_CHECK(actual_type == expected_type,
+              "Tried to convert an IValue of type ",
+              actual_type->python_str(),
+              " to custom class type ",
+              expected_type->python_str());
+}
+
 CAFFE2_API c10::intrusive_ptr<ConstantString> ConstantString::create(
     std::string str_) {
   return c10::make_intrusive<ConstantString>(std::move(str_));
diff --git a/aten/src/ATen/core/ivalue.h b/aten/src/ATen/core/ivalue.h
index 19358809..32ca9d7 100644
--- a/aten/src/ATen/core/ivalue.h
+++ b/aten/src/ATen/core/ivalue.h
@@ -207,14 +207,25 @@
   /// @private [doxygen private]
   c10::intrusive_ptr<caffe2::Blob> toBlob() const &;
 
-  // Capsule
-  IValue(intrusive_ptr<torch::CustomClassHolder> blob);
+  // Capsule. Capsule is an internal implementation detail
+  // of custom C++ classes. No new callsites of these APIs should
+  // be introduced.
+  static inline IValue make_capsule(intrusive_ptr<torch::CustomClassHolder> blob);
   bool isCapsule() const {
     return Tag::Capsule == tag;
   }
   c10::intrusive_ptr<torch::CustomClassHolder> toCapsule() &&;
   c10::intrusive_ptr<torch::CustomClassHolder> toCapsule() const &;
 
+  // Custom C++ classes
+  template <typename T, std::enable_if_t<std::is_base_of<torch::CustomClassHolder, T>::value, int> = 0>
+  IValue(intrusive_ptr<T> custom_class);
+  bool isCustomClass() const;
+  template <typename T>
+  c10::intrusive_ptr<T> toCustomClass() &&;
+  template <typename T>
+  c10::intrusive_ptr<T> toCustomClass() const &;
+
   // Tuple
   IValue(c10::intrusive_ptr<ivalue::Tuple> v);
 
diff --git a/aten/src/ATen/core/ivalue_inl.h b/aten/src/ATen/core/ivalue_inl.h
index 680e3e2..3937150 100644
--- a/aten/src/ATen/core/ivalue_inl.h
+++ b/aten/src/ATen/core/ivalue_inl.h
@@ -17,6 +17,7 @@
 struct Function;
 struct CompilationUnit;
 } // namespace jit
+TORCH_API bool isCustomClass(const c10::IValue& v);
 } // namespace torch
 namespace c10 {
 struct IValue;
@@ -129,6 +130,8 @@
 
 namespace ivalue {
 
+void CAFFE2_API checkCustomClassType(TypePtr expected_type, TypePtr actual_type);
+
 template <typename T>
 using Shared = c10::intrusive_ptr<T>;
 
@@ -522,13 +525,41 @@
 }
 
 template <typename T>
+c10::intrusive_ptr<T> IValue::toCustomClass() && {
+  static_assert(std::is_base_of<torch::CustomClassHolder, T>::value == true,
+    "toCustomClass requires that template parameter T must inherit "
+    "from torch::CustomClassHolder");
+  auto obj = toObject();
+  TORCH_CHECK(obj->slots().size() == 1,
+              "Tried to cast IValue to custom class but it did "
+              "not contain a custom class!");
+  auto expected_type = c10::getCustomClassType<c10::intrusive_ptr<T>>();
+  ivalue::checkCustomClassType(expected_type, type());
+  auto userObj = c10::static_intrusive_pointer_cast<T>(obj->getSlot(0).toCapsule());
+  return userObj;
+}
+
+template <typename T>
+c10::intrusive_ptr<T> IValue::toCustomClass() const & {
+  static_assert(std::is_base_of<torch::CustomClassHolder, T>::value == true,
+    "toCustomClass requires that template parameter T must inherit "
+    "from torch::CustomClassHolder");
+  auto obj = toObject();
+  TORCH_CHECK(obj->slots().size() == 1,
+              "Tried to cast IValue to custom class but it did "
+              "not contain a custom class!");
+  auto expected_type = c10::getCustomClassType<c10::intrusive_ptr<T>>();
+  ivalue::checkCustomClassType(expected_type, type());
+  auto userObj = c10::static_intrusive_pointer_cast<T>(obj->getSlot(0).toCapsule());
+  return userObj;
+}
+
+template <typename T>
 T generic_to(
     IValue ivalue,
     _fake_type<T>) {
     using ElemType = typename std::remove_pointer<T>::type::element_type;
-    auto obj = std::move(ivalue).toObject();
-    auto capsule = obj->getSlot(0);
-    return c10::static_intrusive_pointer_cast<ElemType>(capsule.toCapsule());
+    return std::move(ivalue).toCustomClass<ElemType>();
 }
 
 template <typename T>
@@ -776,10 +807,30 @@
 : tag(Tag::PyObject), is_intrusive_ptr(true) {
   payload.as_intrusive_ptr = v.release();
 }
-inline IValue::IValue(c10::intrusive_ptr<torch::CustomClassHolder> v)
-: tag(Tag::Capsule), is_intrusive_ptr(true) {
-  payload.as_intrusive_ptr = v.release();
+inline IValue IValue::make_capsule(intrusive_ptr<torch::CustomClassHolder> blob) {
+  IValue iv;
+  iv.tag = Tag::Capsule;
+  iv.is_intrusive_ptr = true;
+  iv.payload.as_intrusive_ptr = blob.release();
+  return iv;
 }
+
+template <typename T, std::enable_if_t<std::is_base_of<torch::CustomClassHolder, T>::value, int>>
+IValue::IValue(c10::intrusive_ptr<T> custom_class) {
+  if (!c10::isCustomClassRegistered<c10::intrusive_ptr<T>>()) {
+    throw c10::Error(
+        "Trying to instantiate a class that isn't a registered custom class.",
+        "");
+  }
+  auto classType = c10::getCustomClassType<c10::intrusive_ptr<T>>();
+  auto ivalue_obj = c10::ivalue::Object::create(
+      c10::StrongTypePtr(nullptr, classType), /*num_slots=*/1);
+  ivalue_obj->setSlot(0, IValue::make_capsule(std::move(custom_class)));
+  payload.as_intrusive_ptr = ivalue_obj.release();
+  tag = Tag::Object;
+  is_intrusive_ptr = true;
+}
+
 inline IValue::IValue(c10::intrusive_ptr<ivalue::Future> v)
 : tag(Tag::Future), is_intrusive_ptr(true) {
   payload.as_intrusive_ptr = v.release();
@@ -804,6 +855,10 @@
   return this->to<T>();
 }
 
+inline bool IValue::isCustomClass() const {
+  return torch::isCustomClass(*this);
+}
+
 inline bool IValue::isSameIdentity(const IValue& rhs) const {
   // We choose to not use memcmp for payload check due to potential random padding characters on union type
 
@@ -865,17 +920,7 @@
   if (!isCustomClassRegistered<inputType>()) {
     throw c10::Error("Trying to return a class that we don't support and isn't a registered custom class.", "");
   }
-  auto res = getCustomClassType<inputType>();
-  auto retObject = ivalue::Object::create(
-    StrongTypePtr(
-      std::shared_ptr<torch::jit::CompilationUnit>(),
-      std::move(res)),
-    1);
-  auto objPtr = c10::static_intrusive_pointer_cast<torch::CustomClassHolder>(std::move(x));
-
-  retObject->setSlot(0, IValue(std::move(objPtr)));
-  auto resIVal = IValue(std::move(retObject));
-  return resIVal;
+  return IValue(x);
 }
 template <typename T>
 IValue from_(T x, std::false_type) {
diff --git a/test/cpp/jit/test_custom_class.cpp b/test/cpp/jit/test_custom_class.cpp
index 47e2136..497a8e1 100644
--- a/test/cpp/jit/test_custom_class.cpp
+++ b/test/cpp/jit/test_custom_class.cpp
@@ -156,13 +156,20 @@
     AT_ASSERT(tup->elements().size() == 2);
     auto str = tup->elements()[0].toStringRef();
     auto other_obj =
-        tup->elements()[1].to<c10::intrusive_ptr<MyStackClass<std::string>>>();
+        tup->elements()[1].toCustomClass<MyStackClass<std::string>>();
     AT_ASSERT(str == expected);
-    auto ref_obj = obj.to<c10::intrusive_ptr<MyStackClass<std::string>>>();
+    auto ref_obj = obj.toCustomClass<MyStackClass<std::string>>();
     AT_ASSERT(other_obj.get() == ref_obj.get());
   };
 
   test_with_obj(custom_class_obj, "bar");
+
+  // test IValue() API
+  auto my_new_stack = c10::make_intrusive<MyStackClass<std::string>>(
+      std::vector<std::string>{"baz", "boo"});
+  auto new_stack_ivalue = c10::IValue(my_new_stack);
+
+  test_with_obj(new_stack_ivalue, "boo");
 }
 
 } // namespace jit
diff --git a/torch/csrc/jit/python/pybind_utils.h b/torch/csrc/jit/python/pybind_utils.h
index 3f9ab85..a8d8324 100644
--- a/torch/csrc/jit/python/pybind_utils.h
+++ b/torch/csrc/jit/python/pybind_utils.h
@@ -586,7 +586,8 @@
       return c10::ivalue::ConcretePyObjectHolder::create(obj.cast<py::object>());
 
     case TypeKind::CapsuleType: {
-      return py::cast<c10::intrusive_ptr<CustomClassHolder>>(obj);
+      return IValue::make_capsule(
+          py::cast<c10::intrusive_ptr<CustomClassHolder>>(obj));
     } break;
     case TypeKind::AnyType:
       return toTypeInferredIValue(obj);
diff --git a/torch/custom_class.h b/torch/custom_class.h
index 70d7cb4..35ebf71 100644
--- a/torch/custom_class.h
+++ b/torch/custom_class.h
@@ -74,10 +74,8 @@
                                                // torch::init<...>()
     auto func = [](c10::tagged_capsule<CurClass> self, Types... args) {
       auto classObj = c10::make_intrusive<CurClass>(args...);
-      auto genericPtr = c10::static_intrusive_pointer_cast<torch::CustomClassHolder>(std::move(classObj));
-      auto capsule = c10::IValue(std::move(genericPtr));
-      auto object = std::move(self.ivalue).toObject();
-      object->setSlot(0, std::move(capsule));
+      auto object = self.ivalue.toObject();
+      object->setSlot(0, c10::IValue::make_capsule(std::move(classObj)));
     };
 
     defineMethod("__init__", std::move(func));
@@ -113,12 +111,8 @@
                                 SetStateArg&& arg) {
       c10::intrusive_ptr<CurClass> classObj =
           at::guts::invoke(set_state, std::forward<SetStateArg>(arg));
-      auto genericPtr =
-          c10::static_intrusive_pointer_cast<torch::CustomClassHolder>(
-              classObj);
-      auto capsule = c10::IValue(genericPtr);
       auto object = self.ivalue.toObject();
-      object->setSlot(0, capsule);
+      object->setSlot(0, c10::IValue::make_capsule(classObj));
     };
     defineMethod(
         "__setstate__",
@@ -192,16 +186,8 @@
         "Trying to instantiate a class that isn't a registered custom class.",
         "");
   }
-  auto classType = c10::getCustomClassType<c10::intrusive_ptr<CurClass>>();
-  auto ivalue_obj = c10::ivalue::Object::create(
-      c10::StrongTypePtr(nullptr, classType), /*num_slots=*/1);
-  auto userClassInstance =
-      c10::make_intrusive<CurClass>(std::forward<CtorArgs...>(args)...);
-  ivalue_obj->setAttr(
-      "capsule",
-      c10::static_intrusive_pointer_cast<torch::jit::CustomClassHolder>(
-          userClassInstance));
-  return ivalue_obj;
+  auto userClassInstance = c10::make_intrusive<CurClass>(std::forward<CtorArgs>(args)...);
+  return c10::IValue(std::move(userClassInstance));
 }
 
 // jit namespace for backward-compatibility