Jaliyae/samplers (#13870)

Summary:
Make Samplers optionally accept new size in their reset() method. This helps dataloader or dataset to reset the sampler for an epoch or a chunk of data with different sizes.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/13870

Differential Revision: D13240120

Pulled By: soumith

fbshipit-source-id: 19c53f8be13c0fdcf504f0637b0d3e6009a8e599
diff --git a/test/cpp/api/dataloader.cpp b/test/cpp/api/dataloader.cpp
index 56cee13..0d8d546 100644
--- a/test/cpp/api/dataloader.cpp
+++ b/test/cpp/api/dataloader.cpp
@@ -207,6 +207,19 @@
   ASSERT_FALSE(sampler.next(2).has_value());
 }
 
+TEST(DataTest, SequentialSamplerResetsWithNewSizeWell) {
+  samplers::SequentialSampler sampler(5);
+  ASSERT_EQ(sampler.next(5).value(), std::vector<size_t>({0, 1, 2, 3, 4}));
+  ASSERT_FALSE(sampler.next(2).has_value());
+  sampler.reset(7);
+  ASSERT_EQ(
+      sampler.next(7).value(), std::vector<size_t>({0, 1, 2, 3, 4, 5, 6}));
+  ASSERT_FALSE(sampler.next(2).has_value());
+  sampler.reset(3);
+  ASSERT_EQ(sampler.next(3).value(), std::vector<size_t>({0, 1, 2}));
+  ASSERT_FALSE(sampler.next(2).has_value());
+}
+
 TEST(DataTest, CanSaveAndLoadSequentialSampler) {
   {
     samplers::SequentialSampler a(10);
@@ -272,6 +285,18 @@
   ASSERT_FALSE(sampler.next(2).has_value());
 }
 
+TEST(DataTest, RandomSamplerResetsWithNewSizeWell) {
+  samplers::RandomSampler sampler(5);
+  ASSERT_EQ(sampler.next(5).value().size(), 5);
+  ASSERT_FALSE(sampler.next(2).has_value());
+  sampler.reset(7);
+  ASSERT_EQ(sampler.next(7).value().size(), 7);
+  ASSERT_FALSE(sampler.next(2).has_value());
+  sampler.reset(3);
+  ASSERT_EQ(sampler.next(3).value().size(), 3);
+  ASSERT_FALSE(sampler.next(2).has_value());
+}
+
 TEST(DataTest, SavingAndLoadingRandomSamplerYieldsSameSequence) {
   {
     samplers::RandomSampler a(10);
@@ -320,6 +345,18 @@
   ASSERT_FALSE(sampler.next(2).has_value());
 }
 
+TEST(DataTest, StreamSamplerResetsWithNewSizeWell) {
+  samplers::StreamSampler sampler(/*epoch_size=*/5);
+  ASSERT_EQ(sampler.next(5).value().size(), 5);
+  ASSERT_FALSE(sampler.next(2).has_value());
+  sampler.reset(7);
+  ASSERT_EQ(sampler.next(7).value().size(), 7);
+  ASSERT_FALSE(sampler.next(2).has_value());
+  sampler.reset(3);
+  ASSERT_EQ(sampler.next(3).value().size(), 3);
+  ASSERT_FALSE(sampler.next(2).has_value());
+}
+
 TEST(DataTest, TensorDatasetConstructsFromSingleTensor) {
   datasets::TensorDataset dataset(torch::eye(5));
   ASSERT_TRUE(
@@ -618,7 +655,7 @@
 
 struct TestIndexSampler : public samplers::Sampler<TestIndex> {
   explicit TestIndexSampler(size_t size) : size_(size) {}
-  void reset() override {}
+  void reset(torch::optional<size_t> new_size = torch::nullopt) override {}
   torch::optional<TestIndex> next(size_t batch_size) override {
     if (index_ >= size_) {
       return torch::nullopt;
diff --git a/torch/csrc/api/include/torch/data/samplers/base.h b/torch/csrc/api/include/torch/data/samplers/base.h
index 1767d65..ed2aa30 100644
--- a/torch/csrc/api/include/torch/data/samplers/base.h
+++ b/torch/csrc/api/include/torch/data/samplers/base.h
@@ -1,7 +1,7 @@
 #pragma once
 
-#include <torch/types.h>
 #include <torch/csrc/WindowsTorchApiMacro.h>
+#include <torch/types.h>
 
 #include <cstddef>
 #include <vector>
@@ -27,7 +27,9 @@
 
   /// Resets the `Sampler`'s internal state.
   /// Typically called before a new epoch.
-  TORCH_API virtual void reset() = 0;
+
+  /// Optionally, accepts a new size when reseting the sampler.
+  TORCH_API virtual void reset(optional<size_t> new_size) = 0;
 
   /// Returns the next index if possible, or an empty optional if the
   /// sampler is exhausted for this epoch.
diff --git a/torch/csrc/api/include/torch/data/samplers/random.h b/torch/csrc/api/include/torch/data/samplers/random.h
index b18a36b..f0e6a86 100644
--- a/torch/csrc/api/include/torch/data/samplers/random.h
+++ b/torch/csrc/api/include/torch/data/samplers/random.h
@@ -1,8 +1,8 @@
 #pragma once
 
+#include <torch/csrc/WindowsTorchApiMacro.h>
 #include <torch/data/samplers/base.h>
 #include <torch/types.h>
-#include <torch/csrc/WindowsTorchApiMacro.h>
 
 #include <cstddef>
 #include <vector>
@@ -26,10 +26,12 @@
   /// The constructor will eagerly allocate all required indices, which is the
   /// sequence `0 ... size - 1`. `index_dtype` is the data type of the stored
   /// indices. You can change it to influence memory usage.
-  TORCH_API explicit RandomSampler(int64_t size, Dtype index_dtype = torch::kInt64);
+  TORCH_API explicit RandomSampler(
+      int64_t size,
+      Dtype index_dtype = torch::kInt64);
 
   /// Resets the `RandomSampler` to a new set of indices.
-  TORCH_API void reset() override;
+  TORCH_API void reset(optional<size_t> new_size = nullopt) override;
 
   /// Returns the next batch of indices.
   TORCH_API optional<std::vector<size_t>> next(size_t batch_size) override;
diff --git a/torch/csrc/api/include/torch/data/samplers/sequential.h b/torch/csrc/api/include/torch/data/samplers/sequential.h
index bd14d3b..5f83014 100644
--- a/torch/csrc/api/include/torch/data/samplers/sequential.h
+++ b/torch/csrc/api/include/torch/data/samplers/sequential.h
@@ -1,8 +1,8 @@
 #pragma once
 
+#include <torch/csrc/WindowsTorchApiMacro.h>
 #include <torch/data/samplers/base.h>
 #include <torch/types.h>
-#include <torch/csrc/WindowsTorchApiMacro.h>
 
 #include <cstddef>
 #include <vector>
@@ -26,7 +26,7 @@
   TORCH_API explicit SequentialSampler(size_t size);
 
   /// Resets the `SequentialSampler` to zero.
-  TORCH_API void reset() override;
+  TORCH_API void reset(optional<size_t> new_size = nullopt) override;
 
   /// Returns the next batch of indices.
   TORCH_API optional<std::vector<size_t>> next(size_t batch_size) override;
diff --git a/torch/csrc/api/include/torch/data/samplers/stream.h b/torch/csrc/api/include/torch/data/samplers/stream.h
index fefc301..6f376ac 100644
--- a/torch/csrc/api/include/torch/data/samplers/stream.h
+++ b/torch/csrc/api/include/torch/data/samplers/stream.h
@@ -1,9 +1,9 @@
 #pragma once
 
+#include <torch/csrc/WindowsTorchApiMacro.h>
 #include <torch/data/samplers/base.h>
 #include <torch/data/samplers/custom_batch_request.h>
 #include <torch/types.h>
-#include <torch/csrc/WindowsTorchApiMacro.h>
 
 #include <cstddef>
 
@@ -39,7 +39,7 @@
   TORCH_API explicit StreamSampler(size_t epoch_size);
 
   /// Resets the internal state of the sampler.
-  TORCH_API void reset() override;
+  TORCH_API void reset(optional<size_t> new_size = nullopt) override;
 
   /// Returns a `BatchSize` object with the number of elements to fetch in the
   /// next batch. This number is the minimum of the supplied `batch_size` and
diff --git a/torch/csrc/api/src/data/samplers/random.cpp b/torch/csrc/api/src/data/samplers/random.cpp
index 4ea975b..0edbc8c 100644
--- a/torch/csrc/api/src/data/samplers/random.cpp
+++ b/torch/csrc/api/src/data/samplers/random.cpp
@@ -12,10 +12,11 @@
 RandomSampler::RandomSampler(int64_t size, Dtype index_dtype)
     : indices_(torch::randperm(size, index_dtype)) {}
 
-void RandomSampler::reset() {
+void RandomSampler::reset(optional<size_t> new_size) {
   // This allocates a new chunk of memory every time (just FYI). It should be
   // amortized over the entire epoch hopefully.
-  indices_ = torch::randperm(indices_.numel(), indices_.options());
+  const auto size = new_size.value_or(static_cast<size_t>(indices_.numel()));
+  indices_ = torch::randperm(size, indices_.options());
   index_ = 0;
 }
 
diff --git a/torch/csrc/api/src/data/samplers/sequential.cpp b/torch/csrc/api/src/data/samplers/sequential.cpp
index 3072346..9c294cb 100644
--- a/torch/csrc/api/src/data/samplers/sequential.cpp
+++ b/torch/csrc/api/src/data/samplers/sequential.cpp
@@ -11,7 +11,10 @@
 namespace samplers {
 SequentialSampler::SequentialSampler(size_t size) : size_(size) {}
 
-void SequentialSampler::reset() {
+void SequentialSampler::reset(optional<size_t> new_size) {
+  if (new_size.has_value()) {
+    size_ = *new_size;
+  }
   index_ = 0;
 }
 
diff --git a/torch/csrc/api/src/data/samplers/stream.cpp b/torch/csrc/api/src/data/samplers/stream.cpp
index 2ac1755..6972846 100644
--- a/torch/csrc/api/src/data/samplers/stream.cpp
+++ b/torch/csrc/api/src/data/samplers/stream.cpp
@@ -20,7 +20,10 @@
 
 StreamSampler::StreamSampler(size_t epoch_size) : epoch_size_(epoch_size) {}
 
-void StreamSampler::reset() {
+void StreamSampler::reset(optional<size_t> new_size) {
+  if (new_size.has_value()) {
+    epoch_size_ = *new_size;
+  }
   examples_retrieved_so_far_ = 0;
 }