Allow Future::then to return pre-extracted DataPtrs (#58424)
Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/58424
In CUDA mode, Future must inspect its value and extract DataPtrs. However some types are not supported, for example the C++/JIT custom classes, which include Message, which is widely used in RPC. Hence for these scenarios we allow the user to perform the custom DataPtr extraction on their own, and pass the pre-extracted DataPtrs.
Note that `markCompleted` already allowed users to pass in pre-extracted DataPtrs, hence this PR simply extends this possibility to the `then` method too.
ghstack-source-id: 129567044
Test Plan: Used in next PR.
Reviewed By: mrshenli
Differential Revision: D28474880
fbshipit-source-id: 91a0dde5e29d1afac55650c5dfb306873188d785
diff --git a/aten/src/ATen/core/ivalue_inl.h b/aten/src/ATen/core/ivalue_inl.h
index 10d0a7a..48b83c5 100644
--- a/aten/src/ATen/core/ivalue_inl.h
+++ b/aten/src/ATen/core/ivalue_inl.h
@@ -518,20 +518,36 @@
*/
template <typename T>
c10::intrusive_ptr<Future> then(T callback, TypePtr type) {
+ using IValueWithDataPtrs = std::
+ tuple<IValue, std::vector<std::reference_wrapper<const at::DataPtr>>>;
#if __cpp_lib_is_invocable >= 201703
static_assert(
- std::is_invocable_r<IValue, T, Future&>::value,
- "The callback must have signature IValue(Future&)");
+ guts::disjunction<
+ std::is_invocable_r<IValue, T, Future&>,
+ std::is_invocable_r<IValueWithDataPtrs, T, Future&>>::value,
+ "The callback must have signature IValue(Future&) or "
+ "std::tuple<IValue, std::vector<std::reference_wrapper<const DataPtr>>>(Future&)");
#endif
auto childFut = createInstance(std::move(type));
- addCallback(
- [childFut, cb = std::move(callback)](Future& parentFut) mutable {
- try {
- childFut->markCompleted(cb(parentFut));
- } catch (std::exception&) {
- childFut->setError(std::current_exception());
- }
- });
+ addCallback([childFut,
+ cb = std::move(callback)](Future& parentFut) mutable {
+ try {
+ guts::if_constexpr<std::is_convertible<
+ typename std::result_of<T && (Future&)>::type,
+ IValueWithDataPtrs>::value>(
+ [&](auto identity) {
+ IValue value;
+ std::vector<std::reference_wrapper<const at::DataPtr>> dataPtrs;
+ std::tie(value, dataPtrs) = identity(cb)(parentFut);
+ childFut->markCompleted(std::move(value), std::move(dataPtrs));
+ },
+ [&](auto identity) {
+ childFut->markCompleted(identity(cb)(parentFut));
+ });
+ } catch (std::exception&) {
+ childFut->setError(std::current_exception());
+ }
+ });
return childFut;
}