[PyTorch] Remove ArrayRefTensor::numel_ (#124516)
ArrayRefTensor::numel_ is redundant with the size of the contained MiniArrayRef. Reclaiming the space entirely would break ABI compatibility, but at least we have 4-8 bytes for future expansion.
Differential Revision: [D56366829](https://our.internmc.facebook.com/intern/diff/D56366829/)
**NOTE FOR REVIEWERS**: This PR has internal Meta-specific changes or comments, please review them on [Phabricator](https://our.internmc.facebook.com/intern/diff/D56366829/)!
Pull Request resolved: https://github.com/pytorch/pytorch/pull/124516
Approved by: https://github.com/chenyang78, https://github.com/desertfire
diff --git a/torch/csrc/inductor/aoti_runtime/arrayref_tensor.h b/torch/csrc/inductor/aoti_runtime/arrayref_tensor.h
index a864dbf..436ed3f 100644
--- a/torch/csrc/inductor/aoti_runtime/arrayref_tensor.h
+++ b/torch/csrc/inductor/aoti_runtime/arrayref_tensor.h
@@ -154,6 +154,10 @@
using MiniIntArrayRef = MiniArrayRef<int64_t>;
+static_assert(
+ sizeof(MiniIntArrayRef) == sizeof(void*) + sizeof(size_t),
+ "changing the size of MiniArrayRef breaks ABI compatibility!");
+
inline bool is_contiguous_strides_for_shape(
int64_t ndim,
const int64_t* strides_ptr,
@@ -189,8 +193,7 @@
sizes_(sizes),
strides_(strides),
device_type_(device_type),
- device_idx_(device_idx),
- numel_(arr.size()) {
+ device_idx_(device_idx) {
assert(sizes.size() == strides.size());
assert(is_contiguous_strides_for_shape(
sizes.size(), strides.data(), sizes.data()));
@@ -242,7 +245,7 @@
}
auto numel() const {
- return numel_;
+ return arrayRef_.size();
}
void set_arrayref(MiniArrayRef<T> new_arrayref) {
@@ -257,9 +260,17 @@
MiniArrayRef<const int64_t> strides_;
int32_t device_type_ = 0;
int32_t device_idx_ = 0;
- int32_t numel_ = 0;
+ // We continue to zero-initialize this field in case we repurpose
+ // the space later; having predictable contents can only help.
+ int32_t unusedDoNotRemoveForABICompatibility_ = 0;
};
+static_assert(
+ sizeof(ArrayRefTensor<int>) ==
+ 3 * sizeof(MiniIntArrayRef) + 3 * sizeof(int32_t) +
+ (alignof(ArrayRefTensor<int>) > 4 ? sizeof(int32_t) : 0),
+ "changing the size of ArrayRefTensor breaks ABI compatibility!");
+
inline AtenTensorHandle reinterpret_tensor_wrapper(
AtenTensorHandle self,
int64_t ndim,