[SR] Give VarStackNodeWrapper an iterator (#69922)

Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/69922

D32596934 (https://github.com/pytorch/pytorch/commit/65f54bc000c4824a4e999ebfb6a27b252b696b0d) made the serial stack implementation a bit brittle. It introduced a new container type: `VarStackNodeWrapper`. This type was used as a template parameter in the serial stack implementation.

The other type used in the serial stack implementation is `at::ArrayRef<at::Tensor>`. Ideally, the interface of `VarStackNodeWrapper` should be as close as possible to this other type. However, because the new container type did not have an iterator, expressions like this would fail to compile:
```
for (const auto& tensor : tensors) {
  // do something
}
```
Introducing this iterator will make the code easier to maintain going forward.

Test Plan:
`buck test caffe2/benchmarks/static_runtime:static_runtime_cpptest -- Stack`

I consider this a `VarStack` implementation detail, so I'd prefer not to test it directly. We can test it implicitly by adding some code to the serial stack implementation that uses the iterator.

Reviewed By: swolchok

Differential Revision: D33101489

fbshipit-source-id: 7cf44c072d230c41bd9113cf2393bc6a6645a5b5
diff --git a/aten/src/ATen/native/cpu/SerialStackImpl.h b/aten/src/ATen/native/cpu/SerialStackImpl.h
index 8c30c49..b35e5b0 100644
--- a/aten/src/ATen/native/cpu/SerialStackImpl.h
+++ b/aten/src/ATen/native/cpu/SerialStackImpl.h
@@ -35,8 +35,7 @@
   int64_t ninputs = tensors.size();
   std::vector<InputMeta> inputs;
   inputs.reserve(ninputs);
-  for (const auto i : c10::irange(tensors.size())) {
-    auto& tensor = tensors[i];
+  for (const auto& tensor : tensors) {
     inputs.emplace_back(tensor, dim, tensor.strides()[dim]);
   }
 
diff --git a/torch/csrc/jit/runtime/static/ops.cpp b/torch/csrc/jit/runtime/static/ops.cpp
index c4a28f9..96e55c0 100644
--- a/torch/csrc/jit/runtime/static/ops.cpp
+++ b/torch/csrc/jit/runtime/static/ops.cpp
@@ -35,6 +35,7 @@
 #include <torch/csrc/jit/tensorexpr/ir_simplifier.h>
 #include <torch/csrc/jit/tensorexpr/llvm_codegen.h>
 #include <torch/csrc/jit/tensorexpr/loopnest.h>
+#include <iterator>
 #include <mutex>
 #include <unordered_map>
 
@@ -681,6 +682,68 @@
 
 class VarStackNodeWrapper {
  public:
+  class VarStackNodeWrapperIter {
+   public:
+    using iterator_category = std::forward_iterator_tag;
+    using value_type = at::Tensor;
+    using difference_type = size_t;
+    using pointer = const at::Tensor*;
+    using reference = const at::Tensor&;
+
+    VarStackNodeWrapperIter() = default;
+
+    VarStackNodeWrapperIter(
+        const VarStackNodeWrapper* container,
+        size_t start_idx)
+        : container_(container), idx_(start_idx) {}
+
+    VarStackNodeWrapperIter& operator++() {
+      DCHECK_NE(idx_, container_->size());
+      ++idx_;
+      return *this;
+    }
+
+    VarStackNodeWrapperIter operator++(int) {
+      VarStackNodeWrapperIter old = *this;
+      ++(*this);
+      return old;
+    }
+
+    reference operator*() const {
+      TORCH_CHECK(container_ != nullptr);
+      return (*container_)[idx_];
+    }
+
+    pointer operator->() const {
+      TORCH_CHECK(container_ != nullptr);
+      return &(*container_)[idx_];
+    }
+
+    friend bool operator==(
+        VarStackNodeWrapperIter lhs,
+        VarStackNodeWrapperIter rhs) {
+      DCHECK_EQ(lhs.container_, rhs.container_);
+      return lhs.idx_ == rhs.idx_;
+    }
+
+    friend bool operator!=(
+        VarStackNodeWrapperIter lhs,
+        VarStackNodeWrapperIter rhs) {
+      return !(lhs == rhs);
+    }
+
+   private:
+    const VarStackNodeWrapper* container_ = nullptr;
+    size_t idx_ = 0;
+  };
+
+  // NB: to mimic the behavior of at::ArrayRef, both iterators are
+  // the const version.
+  using iterator = VarStackNodeWrapperIter;
+  using const_iterator = VarStackNodeWrapperIter;
+  using size_type = size_t;
+  using value_type = at::Tensor;
+
   explicit VarStackNodeWrapper(const ProcessedNode& pnode) : pnode_(pnode) {}
 
   const at::Tensor& operator[](size_t idx) const {
@@ -692,6 +755,43 @@
     return pnode_.num_inputs() - 1;
   }
 
+  iterator begin() {
+    return VarStackNodeWrapperIter(this, 0);
+  }
+  iterator end() {
+    return VarStackNodeWrapperIter(this, size());
+  }
+
+  const_iterator begin() const {
+    return VarStackNodeWrapperIter(this, 0);
+  }
+  const_iterator end() const {
+    return VarStackNodeWrapperIter(this, size());
+  }
+
+  const_iterator cbegin() const {
+    return VarStackNodeWrapperIter(this, 0);
+  }
+  const_iterator cend() const {
+    return VarStackNodeWrapperIter(this, size());
+  }
+
+  bool empty() const {
+    return size() == 0;
+  }
+
+  const at::Tensor& front() const {
+    TORCH_CHECK(
+        !empty(), "Attempted to access front() of empty VarStackNodeWrapper");
+    return pnode_.Input(0).toTensor();
+  }
+
+  const at::Tensor& back() const {
+    TORCH_CHECK(
+        !empty(), "Attempted to access back() of empty VarStackNodeWrapper");
+    return pnode_.Input(size() - 1).toTensor();
+  }
+
  private:
   const ProcessedNode& pnode_;
 };