[SR] Add StorageGroup abstraction (#68279)
Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/68279
While reworking the liveness analysis, I noticed that using `std::pair<size_t, std::vector<Tensor*>>` to represent storage groups made things quite unreadable.
Add a simple class to wrap a `std::vector<at::Tensor*>` and store a `size` attribute
Test Plan:
`buck test caffe2/benchmarks/static_runtime/...`
Also ran inline_cvr benchmarks, did not see any errors
Reviewed By: swolchok
Differential Revision: D32369447
fbshipit-source-id: e0b562aa7eefd738b1a34f1f37eb7bc95d71a257
diff --git a/torch/csrc/jit/runtime/static/memory_planner.cpp b/torch/csrc/jit/runtime/static/memory_planner.cpp
index d654472..643cc21 100644
--- a/torch/csrc/jit/runtime/static/memory_planner.cpp
+++ b/torch/csrc/jit/runtime/static/memory_planner.cpp
@@ -11,7 +11,7 @@
const FastSet<const Value*>& managed_tensor_values,
const FastMap<const Value*, std::vector<const Value*>>&
value_to_same_storage_values,
- std::vector<std::pair<size_t, std::vector<at::Tensor*>>>& managed_tensors) {
+ std::vector<StorageGroup>& managed_tensors) {
// map Value to index to managed_storage, where multiple values can
// map to the same index (i.e., sharing the same storage)
FastMap<const Value*, size_t> value_to_storage_idx;
@@ -27,11 +27,9 @@
auto f = value_to_storage_idx.find(val);
if (f != value_to_storage_idx.end()) {
auto storage_idx = f->second;
- managed_tensors[storage_idx].second.emplace_back(tensor);
+ managed_tensors[storage_idx].addTensor(tensor);
} else {
- auto p =
- std::make_pair<size_t, std::vector<at::Tensor*>>(0, {tensor});
- managed_tensors.emplace_back(std::move(p));
+ managed_tensors.emplace_back(tensor);
// first of a group, update the value_to_storage_idx map with the
// index
auto f = value_to_same_storage_values.find(val);
@@ -193,7 +191,7 @@
num_managed_tensors_ = 0;
for (const auto& ms : managed_tensors_) {
- num_managed_tensors_ += ms.second.size();
+ num_managed_tensors_ += ms.numManagedTensors();
}
}
@@ -234,13 +232,13 @@
void* src = static_cast<void*>(start + offset);
#ifndef NDEBUG
- DCHECK_EQ(tensor_size, managed_tensors_[group_idx].first);
- for (auto* tensor : managed_tensors_[group_idx].second) {
+ DCHECK_EQ(tensor_size, managed_tensors_[group_idx].maxTensorSize());
+ for (auto* tensor : managed_tensors_[group_idx].group()) {
DCHECK_EQ(storageImpl, tensor->storage().unsafeGetStorageImpl());
}
#endif
- DCHECK_NE(managed_tensors_[group_idx].second.size(), 0);
- reused_tensors_ += managed_tensors_[group_idx].second.size() - 1;
+ DCHECK_NE(managed_tensors_[group_idx].numManagedTensors(), 0);
+ reused_tensors_ += managed_tensors_[group_idx].numManagedTensors() - 1;
storageImpl->set_data_ptr_noswap(
at::DataPtr(src, src, nullptr, c10::Device(c10::DeviceType::CPU)));
storageImpl->set_nbytes(tensor_size);
@@ -314,8 +312,8 @@
managed_tensor_storage_impls_.reserve(managed_tensors_.size());
}
for (auto& ms : managed_tensors_) {
- const auto& tensors = ms.second;
- size_t max = ms.first;
+ const auto& tensors = ms.group();
+ size_t max = ms.maxTensorSize();
auto tensor_idx = 0;
for (auto& tensor : tensors) {
const auto& storage = tensor->storage();
@@ -371,7 +369,8 @@
// run (following C2 tradition), exploiting the fact that tensor storage
// size does not have to match that of real tensor size. The following logic
// records the tensor storage size for the next run.
- managed_tensor_storage_impls_[group_idx++].first = ms.first = max;
+ managed_tensor_storage_impls_[group_idx++].first = max;
+ ms.setMaxTensorSize(max);
managed_bytes_ += max;
}
diff --git a/torch/csrc/jit/runtime/static/memory_planner.h b/torch/csrc/jit/runtime/static/memory_planner.h
index 3d3c9cd..6717c35 100644
--- a/torch/csrc/jit/runtime/static/memory_planner.h
+++ b/torch/csrc/jit/runtime/static/memory_planner.h
@@ -5,6 +5,40 @@
namespace torch {
namespace jit {
+// A StorageGroup represents a collection of tensors that share backing storage.
+class StorageGroup {
+ public:
+ // Every storage group must contain at least one tensor.
+ explicit StorageGroup(at::Tensor* tensor) : group_{tensor} {}
+
+ void addTensor(at::Tensor* tensor) {
+ group_.push_back(tensor);
+ }
+
+ const std::vector<at::Tensor*>& group() const {
+ return group_;
+ }
+
+ size_t maxTensorSize() const {
+ return max_tensor_size_;
+ }
+
+ void setMaxTensorSize(size_t new_size) {
+ max_tensor_size_ = new_size;
+ }
+
+ size_t numManagedTensors() const {
+ return group_.size();
+ }
+
+ private:
+ // The size attribute represents the amount of memory that will be
+ // allocated for all tensors in this storage group. Initially it
+ // is zero, eventually it gets updated by the MemoryPlanner.
+ size_t max_tensor_size_ = 0;
+ std::vector<at::Tensor*> group_{};
+};
+
/// There are three types of ops in a processed graph in Static Runtime:
/// 1. op with _out variant
/// 2. view producing op
@@ -154,7 +188,7 @@
// We don't have any guarantee that the model doesn't change the
// Storage for managed tensors out from under us during execution,
// so we have to check the StorageImpls each time we deallocate.
- std::vector<std::pair<size_t, std::vector<at::Tensor*>>> managed_tensors_;
+ std::vector<StorageGroup> managed_tensors_{};
at::DataPtr buffer_; // allocated each time we call Run()
uint8_t* buffer_start_{nullptr};
uint8_t* buffer_end_{nullptr};