[SR] Give VarStackNodeWrapper an iterator (#69922)
Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/69922
D32596934 (https://github.com/pytorch/pytorch/commit/65f54bc000c4824a4e999ebfb6a27b252b696b0d) made the serial stack implementation a bit brittle. It introduced a new container type: `VarStackNodeWrapper`. This type was used as a template parameter in the serial stack implementation.
The other type used in the serial stack implementation is `at::ArrayRef<at::Tensor>`. Ideally, the interface of `VarStackNodeWrapper` should be as close as possible to this other type. However, because the new container type did not have an iterator, expressions like this would fail to compile:
```
for (const auto& tensor : tensors) {
// do something
}
```
Introducing this iterator will make the code easier to maintain going forward.
Test Plan:
`buck test caffe2/benchmarks/static_runtime:static_runtime_cpptest -- Stack`
I consider this a `VarStack` implementation detail, so I'd prefer not to test it directly. We can test it implicitly by adding some code to the serial stack implementation that uses the iterator.
Reviewed By: swolchok
Differential Revision: D33101489
fbshipit-source-id: 7cf44c072d230c41bd9113cf2393bc6a6645a5b5
diff --git a/aten/src/ATen/native/cpu/SerialStackImpl.h b/aten/src/ATen/native/cpu/SerialStackImpl.h
index 8c30c49..b35e5b0 100644
--- a/aten/src/ATen/native/cpu/SerialStackImpl.h
+++ b/aten/src/ATen/native/cpu/SerialStackImpl.h
@@ -35,8 +35,7 @@
int64_t ninputs = tensors.size();
std::vector<InputMeta> inputs;
inputs.reserve(ninputs);
- for (const auto i : c10::irange(tensors.size())) {
- auto& tensor = tensors[i];
+ for (const auto& tensor : tensors) {
inputs.emplace_back(tensor, dim, tensor.strides()[dim]);
}
diff --git a/torch/csrc/jit/runtime/static/ops.cpp b/torch/csrc/jit/runtime/static/ops.cpp
index c4a28f9..96e55c0 100644
--- a/torch/csrc/jit/runtime/static/ops.cpp
+++ b/torch/csrc/jit/runtime/static/ops.cpp
@@ -35,6 +35,7 @@
#include <torch/csrc/jit/tensorexpr/ir_simplifier.h>
#include <torch/csrc/jit/tensorexpr/llvm_codegen.h>
#include <torch/csrc/jit/tensorexpr/loopnest.h>
+#include <iterator>
#include <mutex>
#include <unordered_map>
@@ -681,6 +682,68 @@
class VarStackNodeWrapper {
public:
+ class VarStackNodeWrapperIter {
+ public:
+ using iterator_category = std::forward_iterator_tag;
+ using value_type = at::Tensor;
+ using difference_type = size_t;
+ using pointer = const at::Tensor*;
+ using reference = const at::Tensor&;
+
+ VarStackNodeWrapperIter() = default;
+
+ VarStackNodeWrapperIter(
+ const VarStackNodeWrapper* container,
+ size_t start_idx)
+ : container_(container), idx_(start_idx) {}
+
+ VarStackNodeWrapperIter& operator++() {
+ DCHECK_NE(idx_, container_->size());
+ ++idx_;
+ return *this;
+ }
+
+ VarStackNodeWrapperIter operator++(int) {
+ VarStackNodeWrapperIter old = *this;
+ ++(*this);
+ return old;
+ }
+
+ reference operator*() const {
+ TORCH_CHECK(container_ != nullptr);
+ return (*container_)[idx_];
+ }
+
+ pointer operator->() const {
+ TORCH_CHECK(container_ != nullptr);
+ return &(*container_)[idx_];
+ }
+
+ friend bool operator==(
+ VarStackNodeWrapperIter lhs,
+ VarStackNodeWrapperIter rhs) {
+ DCHECK_EQ(lhs.container_, rhs.container_);
+ return lhs.idx_ == rhs.idx_;
+ }
+
+ friend bool operator!=(
+ VarStackNodeWrapperIter lhs,
+ VarStackNodeWrapperIter rhs) {
+ return !(lhs == rhs);
+ }
+
+ private:
+ const VarStackNodeWrapper* container_ = nullptr;
+ size_t idx_ = 0;
+ };
+
+ // NB: to mimic the behavior of at::ArrayRef, both iterators are
+ // the const version.
+ using iterator = VarStackNodeWrapperIter;
+ using const_iterator = VarStackNodeWrapperIter;
+ using size_type = size_t;
+ using value_type = at::Tensor;
+
explicit VarStackNodeWrapper(const ProcessedNode& pnode) : pnode_(pnode) {}
const at::Tensor& operator[](size_t idx) const {
@@ -692,6 +755,43 @@
return pnode_.num_inputs() - 1;
}
+ iterator begin() {
+ return VarStackNodeWrapperIter(this, 0);
+ }
+ iterator end() {
+ return VarStackNodeWrapperIter(this, size());
+ }
+
+ const_iterator begin() const {
+ return VarStackNodeWrapperIter(this, 0);
+ }
+ const_iterator end() const {
+ return VarStackNodeWrapperIter(this, size());
+ }
+
+ const_iterator cbegin() const {
+ return VarStackNodeWrapperIter(this, 0);
+ }
+ const_iterator cend() const {
+ return VarStackNodeWrapperIter(this, size());
+ }
+
+ bool empty() const {
+ return size() == 0;
+ }
+
+ const at::Tensor& front() const {
+ TORCH_CHECK(
+ !empty(), "Attempted to access front() of empty VarStackNodeWrapper");
+ return pnode_.Input(0).toTensor();
+ }
+
+ const at::Tensor& back() const {
+ TORCH_CHECK(
+ !empty(), "Attempted to access back() of empty VarStackNodeWrapper");
+ return pnode_.Input(size() - 1).toTensor();
+ }
+
private:
const ProcessedNode& pnode_;
};