[PyTorch] Fix many Tuple::elements() callsites (#64065)
Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/64065
It is only safe to mutate Tuple elements if you are the sole owner
of the tuple. The most efficient way to do this, then, is
`std::move(*std::move(tupleIValue).toTuple()).elements()` (the
innermost move allows `IValue::toTuple()` to avoid a refcount bump and
the outermost move allows the element vector to be moved out of the
tuple), but many callsites write simply
`tupleIValue.toTuple().elements()`, which incurs many extra refcount
bumps.
ghstack-source-id: 139468088
Test Plan: CI
Reviewed By: ezyang
Differential Revision: D30592621
fbshipit-source-id: e8312de866de09b9ea2a62e5128cbf403ee16f09
diff --git a/aten/src/ATen/core/ivalue_inl.h b/aten/src/ATen/core/ivalue_inl.h
index a308536..afd9e4f 100644
--- a/aten/src/ATen/core/ivalue_inl.h
+++ b/aten/src/ATen/core/ivalue_inl.h
@@ -278,13 +278,22 @@
return elements_;
}
- std::vector<IValue>& elements() & {
- return elements_;
- }
-
- std::vector<IValue>&& elements() && {
+ std::vector<IValue> elements() && {
return std::move(elements_);
}
+
+ void setElements(std::vector<IValue>&& elements) {
+ elements_ = std::move(elements);
+ }
+
+ void unsafeSetElement(size_t idx, const IValue& element) {
+ elements_[idx] = element;
+ }
+
+ void unsafeSetElement(size_t idx, IValue&& element) {
+ elements_[idx] = std::move(element);
+ }
+
std::shared_ptr<TupleType> type() const;
static size_t hash(const Tuple& t) {
@@ -1306,7 +1315,7 @@
guts::negation<std::is_constructible<IValue, Args>>...>::value,
std::nullptr_t> = nullptr>
std::tuple<Args...> generic_to(IValue ivalue, _fake_type<std::tuple<Args...>>) {
- auto vals = ivalue.toTuple()->elements();
+ const auto& vals = ivalue.toTuple()->elements();
TORCH_CHECK(vals.size() == sizeof...(Args));
return detail::generic_to_tuple_impl<std::tuple<Args...>>(vals, Indices{});
}
diff --git a/aten/src/ATen/native/quantized/cpu/conv_serialization.h b/aten/src/ATen/native/quantized/cpu/conv_serialization.h
index 549da6e2..d2ad042 100644
--- a/aten/src/ATen/native/quantized/cpu/conv_serialization.h
+++ b/aten/src/ATen/native/quantized/cpu/conv_serialization.h
@@ -80,7 +80,7 @@
// determine the version based on IValue contents
int version = -1;
if (v.isTuple()) {
- auto elements = v.toTuple()->elements();
+ const auto& elements = v.toTuple()->elements();
if (elements.size() > 0) {
auto firstElement = elements[0];
if (firstElement.isTensor()) {
@@ -105,7 +105,7 @@
if (version == 1) {
// version 1 - convert to version 3 manually
- auto elements = v.toTuple()->elements();
+ const auto& elements = v.toTuple()->elements();
at::Tensor weight = elements[0].toTensor();
c10::optional<at::Tensor> bias = elements[1].toOptional<at::Tensor>();
@@ -149,7 +149,7 @@
return std::tie(version, config_vals, tensors);
} else if (version == 2) {
// version 2
- auto elements = v.toTuple()->elements();
+ const auto& elements = v.toTuple()->elements();
std::vector<at::Tensor> non_optional = elements[1].toTensorList().vec();
std::vector<c10::optional<at::Tensor>> optional;
diff --git a/torch/csrc/distributed/autograd/rpc_messages/propagate_gradients_req.cpp b/torch/csrc/distributed/autograd/rpc_messages/propagate_gradients_req.cpp
index 13f6614..f8c8f1c 100644
--- a/torch/csrc/distributed/autograd/rpc_messages/propagate_gradients_req.cpp
+++ b/torch/csrc/distributed/autograd/rpc_messages/propagate_gradients_req.cpp
@@ -55,28 +55,25 @@
payload_size,
*rpc::RpcAgent::getCurrentRpcAgent()->getTypeResolver(),
message.tensors());
- std::vector<at::IValue> tupleElements = tuple.toTuple()->elements();
+ const auto& tupleElements = tuple.toTuple()->elements();
// Build PropagateGradientsReq.
TORCH_INTERNAL_ASSERT(tupleElements.size() >= 3);
// Retrieve retainGraph.
bool retainGraph = tupleElements.back().toBool();
- tupleElements.pop_back();
// Build AutogradMetadata.
// NOLINTNEXTLINE(cppcoreguidelines-init-variables)
int64_t autogradContextId, autogradMessageId;
- autogradMessageId = tupleElements.back().toInt();
- tupleElements.pop_back();
- autogradContextId = tupleElements.back().toInt();
- tupleElements.pop_back();
+ autogradMessageId = tupleElements[tupleElements.size() - 2].toInt();
+ autogradContextId = tupleElements[tupleElements.size() - 3].toInt();
AutogradMetadata autogradMetadata(autogradContextId, autogradMessageId);
// Retrieve the gradient tensors.
- std::vector<Variable> grads(tupleElements.size());
- for(const auto i : c10::irange(tupleElements.size())) {
+ std::vector<Variable> grads(tupleElements.size() - 3);
+ for(const auto i : c10::irange(tupleElements.size() - 3)) {
grads[i] = tupleElements[i].toTensor();
}
diff --git a/torch/csrc/distributed/autograd/rpc_messages/rref_backward_req.cpp b/torch/csrc/distributed/autograd/rpc_messages/rref_backward_req.cpp
index aaa06ba..46ca618 100644
--- a/torch/csrc/distributed/autograd/rpc_messages/rref_backward_req.cpp
+++ b/torch/csrc/distributed/autograd/rpc_messages/rref_backward_req.cpp
@@ -46,7 +46,7 @@
payload_size,
*rpc::RpcAgent::getCurrentRpcAgent()->getTypeResolver(),
message.tensors());
- std::vector<at::IValue> tupleElements = tuple.toTuple()->elements();
+ const auto& tupleElements = std::move(*std::move(tuple).toTuple()).elements();
// Build RRefBackwardReq.
TORCH_INTERNAL_ASSERT(tupleElements.size() == 3);
diff --git a/torch/csrc/distributed/rpc/rref_proto.cpp b/torch/csrc/distributed/rpc/rref_proto.cpp
index 49e3287..9e44630 100644
--- a/torch/csrc/distributed/rpc/rref_proto.cpp
+++ b/torch/csrc/distributed/rpc/rref_proto.cpp
@@ -25,7 +25,7 @@
payload_size,
*RpcAgent::getCurrentRpcAgent()->getTypeResolver(),
message.tensors());
- return value.toTuple()->elements();
+ return std::move(*std::move(value).toTuple()).elements();
}
c10::intrusive_ptr<Message> fromIValues(
diff --git a/torch/csrc/distributed/rpc/types.cpp b/torch/csrc/distributed/rpc/types.cpp
index 3430fef..8c50e0f 100644
--- a/torch/csrc/distributed/rpc/types.cpp
+++ b/torch/csrc/distributed/rpc/types.cpp
@@ -60,7 +60,7 @@
TORCH_INTERNAL_ASSERT(
ivalue.isTuple(),
"GloballyUniqueId::fromIValue expected ivalue to be a tuple.");
- auto ivalues = ivalue.toTuple()->elements();
+ const auto& ivalues = ivalue.toTuple()->elements();
TORCH_CHECK(
ivalues.size() == 2,
"Constructing GloballyUniqueId from ivalue "
diff --git a/torch/csrc/jit/frontend/tracer.cpp b/torch/csrc/jit/frontend/tracer.cpp
index 22b8959..d0e5edd 100644
--- a/torch/csrc/jit/frontend/tracer.cpp
+++ b/torch/csrc/jit/frontend/tracer.cpp
@@ -352,7 +352,8 @@
AT_ASSERT(
elem_values.size() == num_elems && elem_types.size() == num_elems);
for (const auto i : c10::irange(num_elems)) {
- elems[i] = addInput(state, elems.at(i), elem_types[i], elem_values[i]);
+ tuple->unsafeSetElement(
+ i, addInput(state, elems.at(i), elem_types[i], elem_values[i]));
}
return tuple;
} else if (auto dict_type = type->cast<DictType>()) {
@@ -546,7 +547,7 @@
setValue(outputs.get(i), unpack_node->outputs()[i]);
}
} else if (v.isTuple()) {
- auto outputs = v.toTuple()->elements();
+ const auto& outputs = v.toTuple()->elements();
Node* unpack_node = graph->insertNode(graph->createTupleUnpack(value));
for (const auto i : c10::irange(outputs.size())) {
setValue(outputs[i], unpack_node->outputs()[i]);
diff --git a/torch/csrc/jit/mobile/backport_manager.cpp b/torch/csrc/jit/mobile/backport_manager.cpp
index a0e6544..2cb2d94 100644
--- a/torch/csrc/jit/mobile/backport_manager.cpp
+++ b/torch/csrc/jit/mobile/backport_manager.cpp
@@ -189,7 +189,8 @@
const int64_t to_version) {
PyTorchStreamReader reader_bytecode(&input_model);
std::vector<IValue> constants_values =
- readArchive(kArchiveNameConstants, reader_bytecode).toTuple()->elements();
+ std::move(*readArchive(kArchiveNameConstants, reader_bytecode).toTuple())
+ .elements();
std::vector<IValue> bytecode_values = get_bytecode_ivalues(reader_bytecode);
std::unordered_set<std::string> excluded_files{
@@ -310,7 +311,8 @@
PyTorchStreamReader reader(&input_model_stream);
std::vector<IValue> bytecode_values = get_bytecode_ivalues(reader);
std::vector<IValue> constants_values =
- readArchive(kArchiveNameConstants, reader).toTuple()->elements();
+ std::move(*readArchive(kArchiveNameConstants, reader).toTuple())
+ .elements();
// 2) Copy everything to new output, except some specific files and dirs
// (usually version, bytecode.pkl and bytecode folder are skipped)
@@ -474,7 +476,8 @@
std::make_shared<IStreamAdapter>(&input_model_stream);
auto reader = std::make_shared<PyTorchStreamReader>(rai);
std::vector<IValue> constants_values =
- readArchive(kArchiveNameConstants, *reader.get()).toTuple()->elements();
+ std::move(*readArchive(kArchiveNameConstants, *reader.get()).toTuple())
+ .elements();
// If there are debug info files in the original model file, it should also
// show up in the backported model
diff --git a/torch/csrc/jit/mobile/import.cpp b/torch/csrc/jit/mobile/import.cpp
index ba54365..9697c1b 100644
--- a/torch/csrc/jit/mobile/import.cpp
+++ b/torch/csrc/jit/mobile/import.cpp
@@ -427,7 +427,8 @@
bool has_debug_handles{false};
if (reader_->hasRecord("mobile_debug_handles.pkl")) {
debug_handles =
- readArchive("mobile_debug_handles", mcu).toTuple()->elements();
+ std::move(*readArchive("mobile_debug_handles", mcu).toTuple())
+ .elements();
has_debug_handles = true;
}
parseMethods(std::move(bvals), std::move(debug_handles), *mcu);
diff --git a/torch/csrc/jit/mobile/model_compatibility.cpp b/torch/csrc/jit/mobile/model_compatibility.cpp
index 3ee165b..10046cc 100644
--- a/torch/csrc/jit/mobile/model_compatibility.cpp
+++ b/torch/csrc/jit/mobile/model_compatibility.cpp
@@ -52,7 +52,8 @@
std::vector<IValue> get_bytecode_ivalues(PyTorchStreamReader& reader) {
std::vector<IValue> bytecode_values;
- bytecode_values = readArchive("bytecode", reader).toTuple()->elements();
+ bytecode_values =
+ std::move(*readArchive("bytecode", reader).toTuple()).elements();
return bytecode_values;
}
@@ -154,11 +155,11 @@
// loop over all the functions in the bytecode
for (const auto i : c10::irange(1, bytecode_ivalues.size())) {
// descend to the operators list
- auto method_tuple = bytecode_ivalues.at(i).toTuple()->elements();
+ const auto& method_tuple = bytecode_ivalues.at(i).toTuple()->elements();
auto operators_tuple = method_tuple.at(1).toTuple()->elements()[1];
auto operators = operators_tuple.toTuple()->elements()[1];
for (auto& op_tuple : operators.toTuple()->elements()) {
- auto op = op_tuple.toTuple()->elements();
+ const auto& op = op_tuple.toTuple()->elements();
// grab name
std::string op_name = op.at(0).toStringRef();
diff --git a/torch/csrc/jit/mobile/parse_bytecode.cpp b/torch/csrc/jit/mobile/parse_bytecode.cpp
index 1cb7bfe..2e8c7db 100644
--- a/torch/csrc/jit/mobile/parse_bytecode.cpp
+++ b/torch/csrc/jit/mobile/parse_bytecode.cpp
@@ -79,14 +79,15 @@
debug_info_function_name == function_name,
"The function names in the bytecode table and the debug info table do not match.");
IValue& debug_handles_table = debug_handles_m_tuple[1];
- debug_handles_list =
- (expect_field(
- std::move(debug_handles_table).toTuple()->elements(),
- "function_debug_handles",
- BYTECODE_INDEX_MODULE_DEBUG_HANDLES)
- .toTuple()
- ->elements())[0]
- .toIntList();
+ auto debugHandlesElements =
+ std::move(*std::move(debug_handles_table).toTuple()).elements();
+ debug_handles_list = (expect_field(
+ debugHandlesElements,
+ "function_debug_handles",
+ BYTECODE_INDEX_MODULE_DEBUG_HANDLES)
+ .toTuple()
+ ->elements())[0]
+ .toIntList();
TORCH_CHECK(
debug_handles_list.size() == ins_list.size(),
"The numbers of instructions and debug handles strings do not match.");
diff --git a/torch/csrc/jit/passes/freeze_module.cpp b/torch/csrc/jit/passes/freeze_module.cpp
index a272332..d7f89e3 100644
--- a/torch/csrc/jit/passes/freeze_module.cpp
+++ b/torch/csrc/jit/passes/freeze_module.cpp
@@ -308,12 +308,11 @@
}
} else if (attr.isTuple()) {
auto tuple = std::move(attr).toTuple();
- std::vector<IValue>& elems = tuple->elements();
- for (auto& elem : elems) {
- elem = overrideGradient(elem);
+ const auto& elems = tuple->elements();
+ for (const auto idx : c10::irange(elems.size())) {
+ tuple->unsafeSetElement(idx, overrideGradient(elems[idx]));
}
attr = std::move(tuple);
-
} else if (attr.isList()) {
c10::List<IValue> elems = std::move(attr).toList();
for (const auto i : c10::irange(elems.size())) {
diff --git a/torch/csrc/jit/passes/lower_tuples.cpp b/torch/csrc/jit/passes/lower_tuples.cpp
index a2c8a6c..91969e6 100644
--- a/torch/csrc/jit/passes/lower_tuples.cpp
+++ b/torch/csrc/jit/passes/lower_tuples.cpp
@@ -157,7 +157,8 @@
}
auto g = n->owningGraph();
- auto tuple_elements = toIValue(n->output()).value().toTuple()->elements();
+ auto tuple = toIValue(n->output()).value().toTuple();
+ const auto& tuple_elements = tuple->elements();
WithInsertPoint insert(n);
std::vector<Value*> elements;
for (const auto& elem : tuple_elements) {
diff --git a/torch/csrc/jit/runtime/register_prim_ops.cpp b/torch/csrc/jit/runtime/register_prim_ops.cpp
index 05bd5b7..8ab012c 100644
--- a/torch/csrc/jit/runtime/register_prim_ops.cpp
+++ b/torch/csrc/jit/runtime/register_prim_ops.cpp
@@ -1295,7 +1295,7 @@
tup_type->elements().at(0), tup_type->elements().at(1));
dict.reserve(list.size());
for (IValue input : list) {
- const auto tup = input.toTuple()->elements();
+ const auto& tup = input.toTuple()->elements();
dict.insert_or_assign(tup[0], tup[1]);
}
push(stack, dict);
diff --git a/torch/csrc/jit/runtime/vararg_functions.cpp b/torch/csrc/jit/runtime/vararg_functions.cpp
index 3fe5514..e9df698 100644
--- a/torch/csrc/jit/runtime/vararg_functions.cpp
+++ b/torch/csrc/jit/runtime/vararg_functions.cpp
@@ -364,7 +364,7 @@
auto iv = pop(stack);
if (iv.isTuple()) {
auto tuple = iv.toTuple();
- auto elems = tuple->elements();
+ const auto& elems = tuple->elements();
std::vector<IValue> output_elems;
output_elems.reserve(elems.size());
for (const auto& elem : elems) {
diff --git a/torch/csrc/jit/serialization/callstack_debug_info_serialization.cpp b/torch/csrc/jit/serialization/callstack_debug_info_serialization.cpp
index 93da38a..2500a84 100644
--- a/torch/csrc/jit/serialization/callstack_debug_info_serialization.cpp
+++ b/torch/csrc/jit/serialization/callstack_debug_info_serialization.cpp
@@ -137,7 +137,7 @@
return it->second;
}
- auto tup_elems = tup->elements();
+ const auto& tup_elems = tup->elements();
TORCH_INTERNAL_ASSERT(tup_elems.size() == 4);
// {IValue(module_instance_info), source_range_tag, IValue(InlinedCallStack),
// function name}
@@ -186,7 +186,7 @@
if (it != cached_module_instance_info_.end()) {
return it->second;
}
- auto tup_elems = iv.toTuple()->elements();
+ const auto& tup_elems = iv.toTuple()->elements();
TORCH_CHECK(tup_elems.size() == 2);
std::string type_name = tup_elems[0].toString()->string();
std::string instance_name = tup_elems[1].toString()->string();
@@ -221,9 +221,10 @@
const std::shared_ptr<CompilationUnit>& cu) {
auto ival = jit::unpickle(reinterpret_cast<const char*>(data.get()), size);
ska::flat_hash_map<int64_t, DebugInfoTuple> callstack_ptrs;
- auto& ivalues = ival.toTuple()->elements();
+ std::vector<IValue> ivalues =
+ std::move(*std::move(ival).toTuple()).elements();
for (auto& val : ivalues) {
- const auto tup_elems = val.toTuple()->elements();
+ const auto& tup_elems = val.toTuple()->elements();
TORCH_CHECK(
tup_elems.size() == 4,
"Pickled map must have four elements: "
diff --git a/torch/csrc/jit/serialization/export_module.cpp b/torch/csrc/jit/serialization/export_module.cpp
index fbd2546..859d1e4 100644
--- a/torch/csrc/jit/serialization/export_module.cpp
+++ b/torch/csrc/jit/serialization/export_module.cpp
@@ -843,7 +843,7 @@
row->elements().at(0).toStringRef());
const auto& ops_list = row->elements().at(1).toTuple()->elements();
for (const auto& op : ops_list) {
- auto op_item = op.toTuple()->elements();
+ const auto& op_item = op.toTuple()->elements();
TORCH_CHECK(
op_item.size() >= 2,
"There should be either two parts (name and overload name), ",
diff --git a/torch/csrc/jit/serialization/import_legacy.cpp b/torch/csrc/jit/serialization/import_legacy.cpp
index fbb393a..972f50c 100644
--- a/torch/csrc/jit/serialization/import_legacy.cpp
+++ b/torch/csrc/jit/serialization/import_legacy.cpp
@@ -130,7 +130,8 @@
LEGACY_pickled_ivalues_.end(), list.begin(), list.end());
} else if (proto_version >= 3) {
LEGACY_pickled_ivalues_ =
- LEGACY_loadPickleArchive("attributes.pkl").toTuple()->elements();
+ std::move(*LEGACY_loadPickleArchive("attributes.pkl").toTuple())
+ .elements();
}
LEGACY_moduleStack_.emplace_back("__torch__");
const auto& module_def = model_def.main_module();
diff --git a/torch/csrc/jit/serialization/source_range_serialization.cpp b/torch/csrc/jit/serialization/source_range_serialization.cpp
index 1e69cfa..79c9501 100644
--- a/torch/csrc/jit/serialization/source_range_serialization.cpp
+++ b/torch/csrc/jit/serialization/source_range_serialization.cpp
@@ -112,13 +112,13 @@
return;
}
- auto ivalues = jit::unpickle(reinterpret_cast<const char*>(data.get()), size)
- .toTuple()
- ->elements();
+ auto ivaluesTuple =
+ jit::unpickle(reinterpret_cast<const char*>(data.get()), size).toTuple();
+ const auto& ivalues = ivaluesTuple->elements();
unpickled_records = std::make_shared<SourceRangeRecords>();
for (auto& val : ivalues) {
- auto tup_elems = val.toTuple()->elements();
+ const auto& tup_elems = val.toTuple()->elements();
int64_t offset = tup_elems[kByteOffsetIndex].toInt();
auto source_range = deserializer->deserialize(tup_elems[kSourceRangeIndex]);
unpickled_records->emplace_back(offset, std::move(source_range));
diff --git a/torch/csrc/jit/serialization/unpickler.cpp b/torch/csrc/jit/serialization/unpickler.cpp
index e0e556e..25aa1c2 100644
--- a/torch/csrc/jit/serialization/unpickler.cpp
+++ b/torch/csrc/jit/serialization/unpickler.cpp
@@ -163,7 +163,7 @@
}
}
-void restoreContainerTypeTags(IValue& ivalue, const TypePtr& type) {
+void restoreContainerTypeTags(const IValue& ivalue, const TypePtr& type) {
if (auto dict_type = type->cast<DictType>()) {
auto dict = ivalue.toGenericDict();
dict.unsafeSetKeyType(dict_type->getKeyType());
@@ -327,14 +327,14 @@
case PickleOpCode::TUPLE: {
size_t start = marks_.back();
marks_.pop_back();
- auto tuple = c10::ivalue::Tuple::create({});
- tuple->elements().reserve(stack_.size() - start);
+ std::vector<IValue> elements;
+ elements.reserve(stack_.size() - start);
auto start_it = stack_.begin() + start;
for (auto it = start_it; it != stack_.end(); ++it) {
- tuple->elements().emplace_back(std::move(*it));
+ elements.emplace_back(std::move(*it));
}
stack_.erase(start_it, stack_.end());
- stack_.emplace_back(std::move(tuple));
+ stack_.emplace_back(c10::ivalue::Tuple::create(std::move(elements)));
} break;
case PickleOpCode::TUPLE1: {
stack_.emplace_back(c10::ivalue::Tuple::create(pop(stack_, 1)));
@@ -409,7 +409,8 @@
globals_.at(idx)();
} break;
case PickleOpCode::BINPERSID: {
- auto args = pop(stack_).toTuple()->elements();
+ auto tuple = pop(stack_).toTuple();
+ const auto& args = tuple->elements();
AT_ASSERT(
args.at(0).toStringRef() == "storage",
"unknown PERSID key ",
@@ -512,7 +513,8 @@
});
} else if (class_name == "restore_type_tag") {
globals_.emplace_back([this] {
- auto data = stack_.back().toTuple()->elements();
+ auto tuple = stack_.back().toTuple();
+ const auto& data = tuple->elements();
auto type_str = data.at(1).toStringRef();
stack_.pop_back();
TypePtr type = nullptr;
@@ -568,7 +570,8 @@
rebuildSparseTensor();
} else if (module_name == "builtins" && class_name == "complex") {
globals_.emplace_back([this] {
- auto elems = pop(stack_).toTuple()->elements();
+ auto tuple = pop(stack_).toTuple();
+ const auto& elems = tuple->elements();
AT_ASSERT(elems.size() == 2);
auto complex =
c10::complex<double>(elems.at(0).toDouble(), elems.at(1).toDouble());
@@ -750,7 +753,8 @@
globals_.emplace_back([this] {
// It is the same as how rref is unpickled in python,
// see PyRRef::unpickle
- auto args = stack_.back().toTuple()->elements();
+ auto tuple = std::move(stack_.back()).toTuple();
+ const auto& args = tuple->elements();
stack_.pop_back();
TORCH_INTERNAL_ASSERT(
args.size() == distributed::rpc::RFD_TUPLE_SIZE,