[PyTorch] Fix many Tuple::elements() callsites (#64065)

Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/64065

It is only safe to mutate Tuple elements if you are the sole owner
of the tuple. The most efficient way to do this, then, is
`std::move(*std::move(tupleIValue).toTuple()).elements()` (the
innermost move allows `IValue::toTuple()` to avoid a refcount bump and
the outermost move allows the element vector to be moved out of the
tuple), but many callsites write simply
`tupleIValue.toTuple().elements()`, which incurs many extra refcount
bumps.

ghstack-source-id: 139468088

Test Plan: CI

Reviewed By: ezyang

Differential Revision: D30592621

fbshipit-source-id: e8312de866de09b9ea2a62e5128cbf403ee16f09
diff --git a/aten/src/ATen/core/ivalue_inl.h b/aten/src/ATen/core/ivalue_inl.h
index a308536..afd9e4f 100644
--- a/aten/src/ATen/core/ivalue_inl.h
+++ b/aten/src/ATen/core/ivalue_inl.h
@@ -278,13 +278,22 @@
     return elements_;
   }
 
-  std::vector<IValue>& elements() & {
-    return elements_;
-  }
-
-  std::vector<IValue>&& elements() && {
+  std::vector<IValue> elements() && {
     return std::move(elements_);
   }
+
+  void setElements(std::vector<IValue>&& elements) {
+    elements_ = std::move(elements);
+  }
+
+  void unsafeSetElement(size_t idx, const IValue& element) {
+    elements_[idx] = element;
+  }
+
+  void unsafeSetElement(size_t idx, IValue&& element) {
+    elements_[idx] = std::move(element);
+  }
+
   std::shared_ptr<TupleType> type() const;
 
   static size_t hash(const Tuple& t) {
@@ -1306,7 +1315,7 @@
             guts::negation<std::is_constructible<IValue, Args>>...>::value,
         std::nullptr_t> = nullptr>
 std::tuple<Args...> generic_to(IValue ivalue, _fake_type<std::tuple<Args...>>) {
-  auto vals = ivalue.toTuple()->elements();
+  const auto& vals = ivalue.toTuple()->elements();
   TORCH_CHECK(vals.size() == sizeof...(Args));
   return detail::generic_to_tuple_impl<std::tuple<Args...>>(vals, Indices{});
 }
diff --git a/aten/src/ATen/native/quantized/cpu/conv_serialization.h b/aten/src/ATen/native/quantized/cpu/conv_serialization.h
index 549da6e2..d2ad042 100644
--- a/aten/src/ATen/native/quantized/cpu/conv_serialization.h
+++ b/aten/src/ATen/native/quantized/cpu/conv_serialization.h
@@ -80,7 +80,7 @@
   // determine the version based on IValue contents
   int version = -1;
   if (v.isTuple()) {
-    auto elements = v.toTuple()->elements();
+    const auto& elements = v.toTuple()->elements();
     if (elements.size() > 0) {
       auto firstElement = elements[0];
       if (firstElement.isTensor()) {
@@ -105,7 +105,7 @@
   if (version == 1) {
     // version 1 - convert to version 3 manually
 
-    auto elements = v.toTuple()->elements();
+    const auto& elements = v.toTuple()->elements();
 
     at::Tensor weight = elements[0].toTensor();
     c10::optional<at::Tensor> bias = elements[1].toOptional<at::Tensor>();
@@ -149,7 +149,7 @@
     return std::tie(version, config_vals, tensors);
   } else if (version == 2) {
     // version 2
-    auto elements = v.toTuple()->elements();
+    const auto& elements = v.toTuple()->elements();
     std::vector<at::Tensor> non_optional = elements[1].toTensorList().vec();
     std::vector<c10::optional<at::Tensor>> optional;
 
diff --git a/torch/csrc/distributed/autograd/rpc_messages/propagate_gradients_req.cpp b/torch/csrc/distributed/autograd/rpc_messages/propagate_gradients_req.cpp
index 13f6614..f8c8f1c 100644
--- a/torch/csrc/distributed/autograd/rpc_messages/propagate_gradients_req.cpp
+++ b/torch/csrc/distributed/autograd/rpc_messages/propagate_gradients_req.cpp
@@ -55,28 +55,25 @@
       payload_size,
       *rpc::RpcAgent::getCurrentRpcAgent()->getTypeResolver(),
       message.tensors());
-  std::vector<at::IValue> tupleElements = tuple.toTuple()->elements();
+  const auto& tupleElements = tuple.toTuple()->elements();
 
   // Build PropagateGradientsReq.
   TORCH_INTERNAL_ASSERT(tupleElements.size() >= 3);
 
   // Retrieve retainGraph.
   bool retainGraph = tupleElements.back().toBool();
-  tupleElements.pop_back();
 
   // Build AutogradMetadata.
   // NOLINTNEXTLINE(cppcoreguidelines-init-variables)
   int64_t autogradContextId, autogradMessageId;
-  autogradMessageId = tupleElements.back().toInt();
-  tupleElements.pop_back();
-  autogradContextId = tupleElements.back().toInt();
-  tupleElements.pop_back();
+  autogradMessageId = tupleElements[tupleElements.size() - 2].toInt();
+  autogradContextId = tupleElements[tupleElements.size() - 3].toInt();
 
   AutogradMetadata autogradMetadata(autogradContextId, autogradMessageId);
 
   // Retrieve the gradient tensors.
-  std::vector<Variable> grads(tupleElements.size());
-  for(const auto i : c10::irange(tupleElements.size())) {
+  std::vector<Variable> grads(tupleElements.size() - 3);
+  for(const auto i : c10::irange(tupleElements.size() - 3)) {
     grads[i] = tupleElements[i].toTensor();
   }
 
diff --git a/torch/csrc/distributed/autograd/rpc_messages/rref_backward_req.cpp b/torch/csrc/distributed/autograd/rpc_messages/rref_backward_req.cpp
index aaa06ba..46ca618 100644
--- a/torch/csrc/distributed/autograd/rpc_messages/rref_backward_req.cpp
+++ b/torch/csrc/distributed/autograd/rpc_messages/rref_backward_req.cpp
@@ -46,7 +46,7 @@
       payload_size,
       *rpc::RpcAgent::getCurrentRpcAgent()->getTypeResolver(),
       message.tensors());
-  std::vector<at::IValue> tupleElements = tuple.toTuple()->elements();
+  const auto& tupleElements = std::move(*std::move(tuple).toTuple()).elements();
 
   // Build RRefBackwardReq.
   TORCH_INTERNAL_ASSERT(tupleElements.size() == 3);
diff --git a/torch/csrc/distributed/rpc/rref_proto.cpp b/torch/csrc/distributed/rpc/rref_proto.cpp
index 49e3287..9e44630 100644
--- a/torch/csrc/distributed/rpc/rref_proto.cpp
+++ b/torch/csrc/distributed/rpc/rref_proto.cpp
@@ -25,7 +25,7 @@
       payload_size,
       *RpcAgent::getCurrentRpcAgent()->getTypeResolver(),
       message.tensors());
-  return value.toTuple()->elements();
+  return std::move(*std::move(value).toTuple()).elements();
 }
 
 c10::intrusive_ptr<Message> fromIValues(
diff --git a/torch/csrc/distributed/rpc/types.cpp b/torch/csrc/distributed/rpc/types.cpp
index 3430fef..8c50e0f 100644
--- a/torch/csrc/distributed/rpc/types.cpp
+++ b/torch/csrc/distributed/rpc/types.cpp
@@ -60,7 +60,7 @@
   TORCH_INTERNAL_ASSERT(
       ivalue.isTuple(),
       "GloballyUniqueId::fromIValue expected ivalue to be a tuple.");
-  auto ivalues = ivalue.toTuple()->elements();
+  const auto& ivalues = ivalue.toTuple()->elements();
   TORCH_CHECK(
       ivalues.size() == 2,
       "Constructing GloballyUniqueId from ivalue "
diff --git a/torch/csrc/jit/frontend/tracer.cpp b/torch/csrc/jit/frontend/tracer.cpp
index 22b8959..d0e5edd 100644
--- a/torch/csrc/jit/frontend/tracer.cpp
+++ b/torch/csrc/jit/frontend/tracer.cpp
@@ -352,7 +352,8 @@
     AT_ASSERT(
         elem_values.size() == num_elems && elem_types.size() == num_elems);
     for (const auto i : c10::irange(num_elems)) {
-      elems[i] = addInput(state, elems.at(i), elem_types[i], elem_values[i]);
+      tuple->unsafeSetElement(
+          i, addInput(state, elems.at(i), elem_types[i], elem_values[i]));
     }
     return tuple;
   } else if (auto dict_type = type->cast<DictType>()) {
@@ -546,7 +547,7 @@
       setValue(outputs.get(i), unpack_node->outputs()[i]);
     }
   } else if (v.isTuple()) {
-    auto outputs = v.toTuple()->elements();
+    const auto& outputs = v.toTuple()->elements();
     Node* unpack_node = graph->insertNode(graph->createTupleUnpack(value));
     for (const auto i : c10::irange(outputs.size())) {
       setValue(outputs[i], unpack_node->outputs()[i]);
diff --git a/torch/csrc/jit/mobile/backport_manager.cpp b/torch/csrc/jit/mobile/backport_manager.cpp
index a0e6544..2cb2d94 100644
--- a/torch/csrc/jit/mobile/backport_manager.cpp
+++ b/torch/csrc/jit/mobile/backport_manager.cpp
@@ -189,7 +189,8 @@
     const int64_t to_version) {
   PyTorchStreamReader reader_bytecode(&input_model);
   std::vector<IValue> constants_values =
-      readArchive(kArchiveNameConstants, reader_bytecode).toTuple()->elements();
+      std::move(*readArchive(kArchiveNameConstants, reader_bytecode).toTuple())
+          .elements();
 
   std::vector<IValue> bytecode_values = get_bytecode_ivalues(reader_bytecode);
   std::unordered_set<std::string> excluded_files{
@@ -310,7 +311,8 @@
   PyTorchStreamReader reader(&input_model_stream);
   std::vector<IValue> bytecode_values = get_bytecode_ivalues(reader);
   std::vector<IValue> constants_values =
-      readArchive(kArchiveNameConstants, reader).toTuple()->elements();
+      std::move(*readArchive(kArchiveNameConstants, reader).toTuple())
+          .elements();
 
   // 2) Copy everything to new output, except some specific files and dirs
   // (usually version, bytecode.pkl and bytecode folder are skipped)
@@ -474,7 +476,8 @@
       std::make_shared<IStreamAdapter>(&input_model_stream);
   auto reader = std::make_shared<PyTorchStreamReader>(rai);
   std::vector<IValue> constants_values =
-      readArchive(kArchiveNameConstants, *reader.get()).toTuple()->elements();
+      std::move(*readArchive(kArchiveNameConstants, *reader.get()).toTuple())
+          .elements();
 
   // If there are debug info files in the original model file, it should also
   // show up in the backported model
diff --git a/torch/csrc/jit/mobile/import.cpp b/torch/csrc/jit/mobile/import.cpp
index ba54365..9697c1b 100644
--- a/torch/csrc/jit/mobile/import.cpp
+++ b/torch/csrc/jit/mobile/import.cpp
@@ -427,7 +427,8 @@
   bool has_debug_handles{false};
   if (reader_->hasRecord("mobile_debug_handles.pkl")) {
     debug_handles =
-        readArchive("mobile_debug_handles", mcu).toTuple()->elements();
+        std::move(*readArchive("mobile_debug_handles", mcu).toTuple())
+            .elements();
     has_debug_handles = true;
   }
   parseMethods(std::move(bvals), std::move(debug_handles), *mcu);
diff --git a/torch/csrc/jit/mobile/model_compatibility.cpp b/torch/csrc/jit/mobile/model_compatibility.cpp
index 3ee165b..10046cc 100644
--- a/torch/csrc/jit/mobile/model_compatibility.cpp
+++ b/torch/csrc/jit/mobile/model_compatibility.cpp
@@ -52,7 +52,8 @@
 
 std::vector<IValue> get_bytecode_ivalues(PyTorchStreamReader& reader) {
   std::vector<IValue> bytecode_values;
-  bytecode_values = readArchive("bytecode", reader).toTuple()->elements();
+  bytecode_values =
+      std::move(*readArchive("bytecode", reader).toTuple()).elements();
   return bytecode_values;
 }
 
@@ -154,11 +155,11 @@
   // loop over all the functions in the bytecode
   for (const auto i : c10::irange(1, bytecode_ivalues.size())) {
     // descend to the operators list
-    auto method_tuple = bytecode_ivalues.at(i).toTuple()->elements();
+    const auto& method_tuple = bytecode_ivalues.at(i).toTuple()->elements();
     auto operators_tuple = method_tuple.at(1).toTuple()->elements()[1];
     auto operators = operators_tuple.toTuple()->elements()[1];
     for (auto& op_tuple : operators.toTuple()->elements()) {
-      auto op = op_tuple.toTuple()->elements();
+      const auto& op = op_tuple.toTuple()->elements();
 
       // grab name
       std::string op_name = op.at(0).toStringRef();
diff --git a/torch/csrc/jit/mobile/parse_bytecode.cpp b/torch/csrc/jit/mobile/parse_bytecode.cpp
index 1cb7bfe..2e8c7db 100644
--- a/torch/csrc/jit/mobile/parse_bytecode.cpp
+++ b/torch/csrc/jit/mobile/parse_bytecode.cpp
@@ -79,14 +79,15 @@
         debug_info_function_name == function_name,
         "The function names in the bytecode table and the debug info table do not match.");
     IValue& debug_handles_table = debug_handles_m_tuple[1];
-    debug_handles_list =
-        (expect_field(
-             std::move(debug_handles_table).toTuple()->elements(),
-             "function_debug_handles",
-             BYTECODE_INDEX_MODULE_DEBUG_HANDLES)
-             .toTuple()
-             ->elements())[0]
-            .toIntList();
+    auto debugHandlesElements =
+        std::move(*std::move(debug_handles_table).toTuple()).elements();
+    debug_handles_list = (expect_field(
+                              debugHandlesElements,
+                              "function_debug_handles",
+                              BYTECODE_INDEX_MODULE_DEBUG_HANDLES)
+                              .toTuple()
+                              ->elements())[0]
+                             .toIntList();
     TORCH_CHECK(
         debug_handles_list.size() == ins_list.size(),
         "The numbers of instructions and debug handles strings do not match.");
diff --git a/torch/csrc/jit/passes/freeze_module.cpp b/torch/csrc/jit/passes/freeze_module.cpp
index a272332..d7f89e3 100644
--- a/torch/csrc/jit/passes/freeze_module.cpp
+++ b/torch/csrc/jit/passes/freeze_module.cpp
@@ -308,12 +308,11 @@
       }
     } else if (attr.isTuple()) {
       auto tuple = std::move(attr).toTuple();
-      std::vector<IValue>& elems = tuple->elements();
-      for (auto& elem : elems) {
-        elem = overrideGradient(elem);
+      const auto& elems = tuple->elements();
+      for (const auto idx : c10::irange(elems.size())) {
+        tuple->unsafeSetElement(idx, overrideGradient(elems[idx]));
       }
       attr = std::move(tuple);
-
     } else if (attr.isList()) {
       c10::List<IValue> elems = std::move(attr).toList();
       for (const auto i : c10::irange(elems.size())) {
diff --git a/torch/csrc/jit/passes/lower_tuples.cpp b/torch/csrc/jit/passes/lower_tuples.cpp
index a2c8a6c..91969e6 100644
--- a/torch/csrc/jit/passes/lower_tuples.cpp
+++ b/torch/csrc/jit/passes/lower_tuples.cpp
@@ -157,7 +157,8 @@
   }
 
   auto g = n->owningGraph();
-  auto tuple_elements = toIValue(n->output()).value().toTuple()->elements();
+  auto tuple = toIValue(n->output()).value().toTuple();
+  const auto& tuple_elements = tuple->elements();
   WithInsertPoint insert(n);
   std::vector<Value*> elements;
   for (const auto& elem : tuple_elements) {
diff --git a/torch/csrc/jit/runtime/register_prim_ops.cpp b/torch/csrc/jit/runtime/register_prim_ops.cpp
index 05bd5b7..8ab012c 100644
--- a/torch/csrc/jit/runtime/register_prim_ops.cpp
+++ b/torch/csrc/jit/runtime/register_prim_ops.cpp
@@ -1295,7 +1295,7 @@
       tup_type->elements().at(0), tup_type->elements().at(1));
   dict.reserve(list.size());
   for (IValue input : list) {
-    const auto tup = input.toTuple()->elements();
+    const auto& tup = input.toTuple()->elements();
     dict.insert_or_assign(tup[0], tup[1]);
   }
   push(stack, dict);
diff --git a/torch/csrc/jit/runtime/vararg_functions.cpp b/torch/csrc/jit/runtime/vararg_functions.cpp
index 3fe5514..e9df698 100644
--- a/torch/csrc/jit/runtime/vararg_functions.cpp
+++ b/torch/csrc/jit/runtime/vararg_functions.cpp
@@ -364,7 +364,7 @@
   auto iv = pop(stack);
   if (iv.isTuple()) {
     auto tuple = iv.toTuple();
-    auto elems = tuple->elements();
+    const auto& elems = tuple->elements();
     std::vector<IValue> output_elems;
     output_elems.reserve(elems.size());
     for (const auto& elem : elems) {
diff --git a/torch/csrc/jit/serialization/callstack_debug_info_serialization.cpp b/torch/csrc/jit/serialization/callstack_debug_info_serialization.cpp
index 93da38a..2500a84 100644
--- a/torch/csrc/jit/serialization/callstack_debug_info_serialization.cpp
+++ b/torch/csrc/jit/serialization/callstack_debug_info_serialization.cpp
@@ -137,7 +137,7 @@
     return it->second;
   }
 
-  auto tup_elems = tup->elements();
+  const auto& tup_elems = tup->elements();
   TORCH_INTERNAL_ASSERT(tup_elems.size() == 4);
   // {IValue(module_instance_info), source_range_tag, IValue(InlinedCallStack),
   // function name}
@@ -186,7 +186,7 @@
   if (it != cached_module_instance_info_.end()) {
     return it->second;
   }
-  auto tup_elems = iv.toTuple()->elements();
+  const auto& tup_elems = iv.toTuple()->elements();
   TORCH_CHECK(tup_elems.size() == 2);
   std::string type_name = tup_elems[0].toString()->string();
   std::string instance_name = tup_elems[1].toString()->string();
@@ -221,9 +221,10 @@
         const std::shared_ptr<CompilationUnit>& cu) {
   auto ival = jit::unpickle(reinterpret_cast<const char*>(data.get()), size);
   ska::flat_hash_map<int64_t, DebugInfoTuple> callstack_ptrs;
-  auto& ivalues = ival.toTuple()->elements();
+  std::vector<IValue> ivalues =
+      std::move(*std::move(ival).toTuple()).elements();
   for (auto& val : ivalues) {
-    const auto tup_elems = val.toTuple()->elements();
+    const auto& tup_elems = val.toTuple()->elements();
     TORCH_CHECK(
         tup_elems.size() == 4,
         "Pickled map must have four elements: "
diff --git a/torch/csrc/jit/serialization/export_module.cpp b/torch/csrc/jit/serialization/export_module.cpp
index fbd2546..859d1e4 100644
--- a/torch/csrc/jit/serialization/export_module.cpp
+++ b/torch/csrc/jit/serialization/export_module.cpp
@@ -843,7 +843,7 @@
         row->elements().at(0).toStringRef());
     const auto& ops_list = row->elements().at(1).toTuple()->elements();
     for (const auto& op : ops_list) {
-      auto op_item = op.toTuple()->elements();
+      const auto& op_item = op.toTuple()->elements();
       TORCH_CHECK(
           op_item.size() >= 2,
           "There should be either two parts (name and overload name), ",
diff --git a/torch/csrc/jit/serialization/import_legacy.cpp b/torch/csrc/jit/serialization/import_legacy.cpp
index fbb393a..972f50c 100644
--- a/torch/csrc/jit/serialization/import_legacy.cpp
+++ b/torch/csrc/jit/serialization/import_legacy.cpp
@@ -130,7 +130,8 @@
         LEGACY_pickled_ivalues_.end(), list.begin(), list.end());
   } else if (proto_version >= 3) {
     LEGACY_pickled_ivalues_ =
-        LEGACY_loadPickleArchive("attributes.pkl").toTuple()->elements();
+        std::move(*LEGACY_loadPickleArchive("attributes.pkl").toTuple())
+            .elements();
   }
   LEGACY_moduleStack_.emplace_back("__torch__");
   const auto& module_def = model_def.main_module();
diff --git a/torch/csrc/jit/serialization/source_range_serialization.cpp b/torch/csrc/jit/serialization/source_range_serialization.cpp
index 1e69cfa..79c9501 100644
--- a/torch/csrc/jit/serialization/source_range_serialization.cpp
+++ b/torch/csrc/jit/serialization/source_range_serialization.cpp
@@ -112,13 +112,13 @@
     return;
   }
 
-  auto ivalues = jit::unpickle(reinterpret_cast<const char*>(data.get()), size)
-                     .toTuple()
-                     ->elements();
+  auto ivaluesTuple =
+      jit::unpickle(reinterpret_cast<const char*>(data.get()), size).toTuple();
+  const auto& ivalues = ivaluesTuple->elements();
 
   unpickled_records = std::make_shared<SourceRangeRecords>();
   for (auto& val : ivalues) {
-    auto tup_elems = val.toTuple()->elements();
+    const auto& tup_elems = val.toTuple()->elements();
     int64_t offset = tup_elems[kByteOffsetIndex].toInt();
     auto source_range = deserializer->deserialize(tup_elems[kSourceRangeIndex]);
     unpickled_records->emplace_back(offset, std::move(source_range));
diff --git a/torch/csrc/jit/serialization/unpickler.cpp b/torch/csrc/jit/serialization/unpickler.cpp
index e0e556e..25aa1c2 100644
--- a/torch/csrc/jit/serialization/unpickler.cpp
+++ b/torch/csrc/jit/serialization/unpickler.cpp
@@ -163,7 +163,7 @@
   }
 }
 
-void restoreContainerTypeTags(IValue& ivalue, const TypePtr& type) {
+void restoreContainerTypeTags(const IValue& ivalue, const TypePtr& type) {
   if (auto dict_type = type->cast<DictType>()) {
     auto dict = ivalue.toGenericDict();
     dict.unsafeSetKeyType(dict_type->getKeyType());
@@ -327,14 +327,14 @@
     case PickleOpCode::TUPLE: {
       size_t start = marks_.back();
       marks_.pop_back();
-      auto tuple = c10::ivalue::Tuple::create({});
-      tuple->elements().reserve(stack_.size() - start);
+      std::vector<IValue> elements;
+      elements.reserve(stack_.size() - start);
       auto start_it = stack_.begin() + start;
       for (auto it = start_it; it != stack_.end(); ++it) {
-        tuple->elements().emplace_back(std::move(*it));
+        elements.emplace_back(std::move(*it));
       }
       stack_.erase(start_it, stack_.end());
-      stack_.emplace_back(std::move(tuple));
+      stack_.emplace_back(c10::ivalue::Tuple::create(std::move(elements)));
     } break;
     case PickleOpCode::TUPLE1: {
       stack_.emplace_back(c10::ivalue::Tuple::create(pop(stack_, 1)));
@@ -409,7 +409,8 @@
       globals_.at(idx)();
     } break;
     case PickleOpCode::BINPERSID: {
-      auto args = pop(stack_).toTuple()->elements();
+      auto tuple = pop(stack_).toTuple();
+      const auto& args = tuple->elements();
       AT_ASSERT(
           args.at(0).toStringRef() == "storage",
           "unknown PERSID key ",
@@ -512,7 +513,8 @@
       });
     } else if (class_name == "restore_type_tag") {
       globals_.emplace_back([this] {
-        auto data = stack_.back().toTuple()->elements();
+        auto tuple = stack_.back().toTuple();
+        const auto& data = tuple->elements();
         auto type_str = data.at(1).toStringRef();
         stack_.pop_back();
         TypePtr type = nullptr;
@@ -568,7 +570,8 @@
     rebuildSparseTensor();
   } else if (module_name == "builtins" && class_name == "complex") {
     globals_.emplace_back([this] {
-      auto elems = pop(stack_).toTuple()->elements();
+      auto tuple = pop(stack_).toTuple();
+      const auto& elems = tuple->elements();
       AT_ASSERT(elems.size() == 2);
       auto complex =
           c10::complex<double>(elems.at(0).toDouble(), elems.at(1).toDouble());
@@ -750,7 +753,8 @@
   globals_.emplace_back([this] {
     // It is the same as how rref is unpickled in python,
     // see PyRRef::unpickle
-    auto args = stack_.back().toTuple()->elements();
+    auto tuple = std::move(stack_.back()).toTuple();
+    const auto& args = tuple->elements();
     stack_.pop_back();
     TORCH_INTERNAL_ASSERT(
         args.size() == distributed::rpc::RFD_TUPLE_SIZE,