[SR] Add StorageGroup abstraction (#68279)

Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/68279

While reworking the liveness analysis, I noticed that using `std::pair<size_t, std::vector<Tensor*>>` to represent storage groups made things quite unreadable.

Add a simple class to wrap a `std::vector<at::Tensor*>` and store a `size` attribute

Test Plan:
`buck test caffe2/benchmarks/static_runtime/...`

Also ran inline_cvr benchmarks, did not see any errors

Reviewed By: swolchok

Differential Revision: D32369447

fbshipit-source-id: e0b562aa7eefd738b1a34f1f37eb7bc95d71a257
diff --git a/torch/csrc/jit/runtime/static/memory_planner.cpp b/torch/csrc/jit/runtime/static/memory_planner.cpp
index d654472..643cc21 100644
--- a/torch/csrc/jit/runtime/static/memory_planner.cpp
+++ b/torch/csrc/jit/runtime/static/memory_planner.cpp
@@ -11,7 +11,7 @@
     const FastSet<const Value*>& managed_tensor_values,
     const FastMap<const Value*, std::vector<const Value*>>&
         value_to_same_storage_values,
-    std::vector<std::pair<size_t, std::vector<at::Tensor*>>>& managed_tensors) {
+    std::vector<StorageGroup>& managed_tensors) {
   // map Value to index to managed_storage, where multiple values can
   // map to the same index (i.e., sharing the same storage)
   FastMap<const Value*, size_t> value_to_storage_idx;
@@ -27,11 +27,9 @@
         auto f = value_to_storage_idx.find(val);
         if (f != value_to_storage_idx.end()) {
           auto storage_idx = f->second;
-          managed_tensors[storage_idx].second.emplace_back(tensor);
+          managed_tensors[storage_idx].addTensor(tensor);
         } else {
-          auto p =
-              std::make_pair<size_t, std::vector<at::Tensor*>>(0, {tensor});
-          managed_tensors.emplace_back(std::move(p));
+          managed_tensors.emplace_back(tensor);
           // first of a group, update the value_to_storage_idx map with the
           // index
           auto f = value_to_same_storage_values.find(val);
@@ -193,7 +191,7 @@
 
   num_managed_tensors_ = 0;
   for (const auto& ms : managed_tensors_) {
-    num_managed_tensors_ += ms.second.size();
+    num_managed_tensors_ += ms.numManagedTensors();
   }
 }
 
@@ -234,13 +232,13 @@
     void* src = static_cast<void*>(start + offset);
 
 #ifndef NDEBUG
-    DCHECK_EQ(tensor_size, managed_tensors_[group_idx].first);
-    for (auto* tensor : managed_tensors_[group_idx].second) {
+    DCHECK_EQ(tensor_size, managed_tensors_[group_idx].maxTensorSize());
+    for (auto* tensor : managed_tensors_[group_idx].group()) {
       DCHECK_EQ(storageImpl, tensor->storage().unsafeGetStorageImpl());
     }
 #endif
-    DCHECK_NE(managed_tensors_[group_idx].second.size(), 0);
-    reused_tensors_ += managed_tensors_[group_idx].second.size() - 1;
+    DCHECK_NE(managed_tensors_[group_idx].numManagedTensors(), 0);
+    reused_tensors_ += managed_tensors_[group_idx].numManagedTensors() - 1;
     storageImpl->set_data_ptr_noswap(
         at::DataPtr(src, src, nullptr, c10::Device(c10::DeviceType::CPU)));
     storageImpl->set_nbytes(tensor_size);
@@ -314,8 +312,8 @@
     managed_tensor_storage_impls_.reserve(managed_tensors_.size());
   }
   for (auto& ms : managed_tensors_) {
-    const auto& tensors = ms.second;
-    size_t max = ms.first;
+    const auto& tensors = ms.group();
+    size_t max = ms.maxTensorSize();
     auto tensor_idx = 0;
     for (auto& tensor : tensors) {
       const auto& storage = tensor->storage();
@@ -371,7 +369,8 @@
     // run (following C2 tradition), exploiting the fact that tensor storage
     // size does not have to match that of real tensor size. The following logic
     // records the tensor storage size for the next run.
-    managed_tensor_storage_impls_[group_idx++].first = ms.first = max;
+    managed_tensor_storage_impls_[group_idx++].first = max;
+    ms.setMaxTensorSize(max);
     managed_bytes_ += max;
   }
 
diff --git a/torch/csrc/jit/runtime/static/memory_planner.h b/torch/csrc/jit/runtime/static/memory_planner.h
index 3d3c9cd..6717c35 100644
--- a/torch/csrc/jit/runtime/static/memory_planner.h
+++ b/torch/csrc/jit/runtime/static/memory_planner.h
@@ -5,6 +5,40 @@
 namespace torch {
 namespace jit {
 
+// A StorageGroup represents a collection of tensors that share backing storage.
+class StorageGroup {
+ public:
+  // Every storage group must contain at least one tensor.
+  explicit StorageGroup(at::Tensor* tensor) : group_{tensor} {}
+
+  void addTensor(at::Tensor* tensor) {
+    group_.push_back(tensor);
+  }
+
+  const std::vector<at::Tensor*>& group() const {
+    return group_;
+  }
+
+  size_t maxTensorSize() const {
+    return max_tensor_size_;
+  }
+
+  void setMaxTensorSize(size_t new_size) {
+    max_tensor_size_ = new_size;
+  }
+
+  size_t numManagedTensors() const {
+    return group_.size();
+  }
+
+ private:
+  // The size attribute represents the amount of memory that will be
+  // allocated for all tensors in this storage group. Initially it
+  // is zero, eventually it gets updated by the MemoryPlanner.
+  size_t max_tensor_size_ = 0;
+  std::vector<at::Tensor*> group_{};
+};
+
 /// There are three types of ops in a processed graph in Static Runtime:
 ///   1. op with _out variant
 ///   2. view producing op
@@ -154,7 +188,7 @@
   // We don't have any guarantee that the model doesn't change the
   // Storage for managed tensors out from under us during execution,
   // so we have to check the StorageImpls each time we deallocate.
-  std::vector<std::pair<size_t, std::vector<at::Tensor*>>> managed_tensors_;
+  std::vector<StorageGroup> managed_tensors_{};
   at::DataPtr buffer_; // allocated each time we call Run()
   uint8_t* buffer_start_{nullptr};
   uint8_t* buffer_end_{nullptr};