[torchbind] Improve IValue custom class API and remove most Capsule stuff (#34848)
Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/34848
Test Plan: Imported from OSS
Differential Revision: D20480514
Pulled By: jamesr66a
fbshipit-source-id: 1c595faf34e00aab0a6202a8902426bd310551c3
diff --git a/aten/src/ATen/core/ivalue.cpp b/aten/src/ATen/core/ivalue.cpp
index 7816df1..20d6ce3 100644
--- a/aten/src/ATen/core/ivalue.cpp
+++ b/aten/src/ATen/core/ivalue.cpp
@@ -8,6 +8,19 @@
namespace c10 {
namespace ivalue {
+// This is in ivalue.cpp because we need to access Type::python_str, which
+// is declared in jit_type.h
+void checkCustomClassType(TypePtr expected_type, TypePtr actual_type) {
+ // NB: doing pointer comparison here
+ // If in the future there ever arises a need to call operator== on custom class
+ // Type's, this needs to be changed!
+ TORCH_CHECK(actual_type == expected_type,
+ "Tried to convert an IValue of type ",
+ actual_type->python_str(),
+ " to custom class type ",
+ expected_type->python_str());
+}
+
CAFFE2_API c10::intrusive_ptr<ConstantString> ConstantString::create(
std::string str_) {
return c10::make_intrusive<ConstantString>(std::move(str_));
diff --git a/aten/src/ATen/core/ivalue.h b/aten/src/ATen/core/ivalue.h
index 19358809..32ca9d7 100644
--- a/aten/src/ATen/core/ivalue.h
+++ b/aten/src/ATen/core/ivalue.h
@@ -207,14 +207,25 @@
/// @private [doxygen private]
c10::intrusive_ptr<caffe2::Blob> toBlob() const &;
- // Capsule
- IValue(intrusive_ptr<torch::CustomClassHolder> blob);
+ // Capsule. Capsule is an internal implementation detail
+ // of custom C++ classes. No new callsites of these APIs should
+ // be introduced.
+ static inline IValue make_capsule(intrusive_ptr<torch::CustomClassHolder> blob);
bool isCapsule() const {
return Tag::Capsule == tag;
}
c10::intrusive_ptr<torch::CustomClassHolder> toCapsule() &&;
c10::intrusive_ptr<torch::CustomClassHolder> toCapsule() const &;
+ // Custom C++ classes
+ template <typename T, std::enable_if_t<std::is_base_of<torch::CustomClassHolder, T>::value, int> = 0>
+ IValue(intrusive_ptr<T> custom_class);
+ bool isCustomClass() const;
+ template <typename T>
+ c10::intrusive_ptr<T> toCustomClass() &&;
+ template <typename T>
+ c10::intrusive_ptr<T> toCustomClass() const &;
+
// Tuple
IValue(c10::intrusive_ptr<ivalue::Tuple> v);
diff --git a/aten/src/ATen/core/ivalue_inl.h b/aten/src/ATen/core/ivalue_inl.h
index 680e3e2..3937150 100644
--- a/aten/src/ATen/core/ivalue_inl.h
+++ b/aten/src/ATen/core/ivalue_inl.h
@@ -17,6 +17,7 @@
struct Function;
struct CompilationUnit;
} // namespace jit
+TORCH_API bool isCustomClass(const c10::IValue& v);
} // namespace torch
namespace c10 {
struct IValue;
@@ -129,6 +130,8 @@
namespace ivalue {
+void CAFFE2_API checkCustomClassType(TypePtr expected_type, TypePtr actual_type);
+
template <typename T>
using Shared = c10::intrusive_ptr<T>;
@@ -522,13 +525,41 @@
}
template <typename T>
+c10::intrusive_ptr<T> IValue::toCustomClass() && {
+ static_assert(std::is_base_of<torch::CustomClassHolder, T>::value == true,
+ "toCustomClass requires that template parameter T must inherit "
+ "from torch::CustomClassHolder");
+ auto obj = toObject();
+ TORCH_CHECK(obj->slots().size() == 1,
+ "Tried to cast IValue to custom class but it did "
+ "not contain a custom class!");
+ auto expected_type = c10::getCustomClassType<c10::intrusive_ptr<T>>();
+ ivalue::checkCustomClassType(expected_type, type());
+ auto userObj = c10::static_intrusive_pointer_cast<T>(obj->getSlot(0).toCapsule());
+ return userObj;
+}
+
+template <typename T>
+c10::intrusive_ptr<T> IValue::toCustomClass() const & {
+ static_assert(std::is_base_of<torch::CustomClassHolder, T>::value == true,
+ "toCustomClass requires that template parameter T must inherit "
+ "from torch::CustomClassHolder");
+ auto obj = toObject();
+ TORCH_CHECK(obj->slots().size() == 1,
+ "Tried to cast IValue to custom class but it did "
+ "not contain a custom class!");
+ auto expected_type = c10::getCustomClassType<c10::intrusive_ptr<T>>();
+ ivalue::checkCustomClassType(expected_type, type());
+ auto userObj = c10::static_intrusive_pointer_cast<T>(obj->getSlot(0).toCapsule());
+ return userObj;
+}
+
+template <typename T>
T generic_to(
IValue ivalue,
_fake_type<T>) {
using ElemType = typename std::remove_pointer<T>::type::element_type;
- auto obj = std::move(ivalue).toObject();
- auto capsule = obj->getSlot(0);
- return c10::static_intrusive_pointer_cast<ElemType>(capsule.toCapsule());
+ return std::move(ivalue).toCustomClass<ElemType>();
}
template <typename T>
@@ -776,10 +807,30 @@
: tag(Tag::PyObject), is_intrusive_ptr(true) {
payload.as_intrusive_ptr = v.release();
}
-inline IValue::IValue(c10::intrusive_ptr<torch::CustomClassHolder> v)
-: tag(Tag::Capsule), is_intrusive_ptr(true) {
- payload.as_intrusive_ptr = v.release();
+inline IValue IValue::make_capsule(intrusive_ptr<torch::CustomClassHolder> blob) {
+ IValue iv;
+ iv.tag = Tag::Capsule;
+ iv.is_intrusive_ptr = true;
+ iv.payload.as_intrusive_ptr = blob.release();
+ return iv;
}
+
+template <typename T, std::enable_if_t<std::is_base_of<torch::CustomClassHolder, T>::value, int>>
+IValue::IValue(c10::intrusive_ptr<T> custom_class) {
+ if (!c10::isCustomClassRegistered<c10::intrusive_ptr<T>>()) {
+ throw c10::Error(
+ "Trying to instantiate a class that isn't a registered custom class.",
+ "");
+ }
+ auto classType = c10::getCustomClassType<c10::intrusive_ptr<T>>();
+ auto ivalue_obj = c10::ivalue::Object::create(
+ c10::StrongTypePtr(nullptr, classType), /*num_slots=*/1);
+ ivalue_obj->setSlot(0, IValue::make_capsule(std::move(custom_class)));
+ payload.as_intrusive_ptr = ivalue_obj.release();
+ tag = Tag::Object;
+ is_intrusive_ptr = true;
+}
+
inline IValue::IValue(c10::intrusive_ptr<ivalue::Future> v)
: tag(Tag::Future), is_intrusive_ptr(true) {
payload.as_intrusive_ptr = v.release();
@@ -804,6 +855,10 @@
return this->to<T>();
}
+inline bool IValue::isCustomClass() const {
+ return torch::isCustomClass(*this);
+}
+
inline bool IValue::isSameIdentity(const IValue& rhs) const {
// We choose to not use memcmp for payload check due to potential random padding characters on union type
@@ -865,17 +920,7 @@
if (!isCustomClassRegistered<inputType>()) {
throw c10::Error("Trying to return a class that we don't support and isn't a registered custom class.", "");
}
- auto res = getCustomClassType<inputType>();
- auto retObject = ivalue::Object::create(
- StrongTypePtr(
- std::shared_ptr<torch::jit::CompilationUnit>(),
- std::move(res)),
- 1);
- auto objPtr = c10::static_intrusive_pointer_cast<torch::CustomClassHolder>(std::move(x));
-
- retObject->setSlot(0, IValue(std::move(objPtr)));
- auto resIVal = IValue(std::move(retObject));
- return resIVal;
+ return IValue(x);
}
template <typename T>
IValue from_(T x, std::false_type) {
diff --git a/test/cpp/jit/test_custom_class.cpp b/test/cpp/jit/test_custom_class.cpp
index 47e2136..497a8e1 100644
--- a/test/cpp/jit/test_custom_class.cpp
+++ b/test/cpp/jit/test_custom_class.cpp
@@ -156,13 +156,20 @@
AT_ASSERT(tup->elements().size() == 2);
auto str = tup->elements()[0].toStringRef();
auto other_obj =
- tup->elements()[1].to<c10::intrusive_ptr<MyStackClass<std::string>>>();
+ tup->elements()[1].toCustomClass<MyStackClass<std::string>>();
AT_ASSERT(str == expected);
- auto ref_obj = obj.to<c10::intrusive_ptr<MyStackClass<std::string>>>();
+ auto ref_obj = obj.toCustomClass<MyStackClass<std::string>>();
AT_ASSERT(other_obj.get() == ref_obj.get());
};
test_with_obj(custom_class_obj, "bar");
+
+ // test IValue() API
+ auto my_new_stack = c10::make_intrusive<MyStackClass<std::string>>(
+ std::vector<std::string>{"baz", "boo"});
+ auto new_stack_ivalue = c10::IValue(my_new_stack);
+
+ test_with_obj(new_stack_ivalue, "boo");
}
} // namespace jit
diff --git a/torch/csrc/jit/python/pybind_utils.h b/torch/csrc/jit/python/pybind_utils.h
index 3f9ab85..a8d8324 100644
--- a/torch/csrc/jit/python/pybind_utils.h
+++ b/torch/csrc/jit/python/pybind_utils.h
@@ -586,7 +586,8 @@
return c10::ivalue::ConcretePyObjectHolder::create(obj.cast<py::object>());
case TypeKind::CapsuleType: {
- return py::cast<c10::intrusive_ptr<CustomClassHolder>>(obj);
+ return IValue::make_capsule(
+ py::cast<c10::intrusive_ptr<CustomClassHolder>>(obj));
} break;
case TypeKind::AnyType:
return toTypeInferredIValue(obj);
diff --git a/torch/custom_class.h b/torch/custom_class.h
index 70d7cb4..35ebf71 100644
--- a/torch/custom_class.h
+++ b/torch/custom_class.h
@@ -74,10 +74,8 @@
// torch::init<...>()
auto func = [](c10::tagged_capsule<CurClass> self, Types... args) {
auto classObj = c10::make_intrusive<CurClass>(args...);
- auto genericPtr = c10::static_intrusive_pointer_cast<torch::CustomClassHolder>(std::move(classObj));
- auto capsule = c10::IValue(std::move(genericPtr));
- auto object = std::move(self.ivalue).toObject();
- object->setSlot(0, std::move(capsule));
+ auto object = self.ivalue.toObject();
+ object->setSlot(0, c10::IValue::make_capsule(std::move(classObj)));
};
defineMethod("__init__", std::move(func));
@@ -113,12 +111,8 @@
SetStateArg&& arg) {
c10::intrusive_ptr<CurClass> classObj =
at::guts::invoke(set_state, std::forward<SetStateArg>(arg));
- auto genericPtr =
- c10::static_intrusive_pointer_cast<torch::CustomClassHolder>(
- classObj);
- auto capsule = c10::IValue(genericPtr);
auto object = self.ivalue.toObject();
- object->setSlot(0, capsule);
+ object->setSlot(0, c10::IValue::make_capsule(classObj));
};
defineMethod(
"__setstate__",
@@ -192,16 +186,8 @@
"Trying to instantiate a class that isn't a registered custom class.",
"");
}
- auto classType = c10::getCustomClassType<c10::intrusive_ptr<CurClass>>();
- auto ivalue_obj = c10::ivalue::Object::create(
- c10::StrongTypePtr(nullptr, classType), /*num_slots=*/1);
- auto userClassInstance =
- c10::make_intrusive<CurClass>(std::forward<CtorArgs...>(args)...);
- ivalue_obj->setAttr(
- "capsule",
- c10::static_intrusive_pointer_cast<torch::jit::CustomClassHolder>(
- userClassInstance));
- return ivalue_obj;
+ auto userClassInstance = c10::make_intrusive<CurClass>(std::forward<CtorArgs>(args)...);
+ return c10::IValue(std::move(userClassInstance));
}
// jit namespace for backward-compatibility