Jaliyae/samplers (#13870)
Summary:
Make Samplers optionally accept new size in their reset() method. This helps dataloader or dataset to reset the sampler for an epoch or a chunk of data with different sizes.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/13870
Differential Revision: D13240120
Pulled By: soumith
fbshipit-source-id: 19c53f8be13c0fdcf504f0637b0d3e6009a8e599
diff --git a/test/cpp/api/dataloader.cpp b/test/cpp/api/dataloader.cpp
index 56cee13..0d8d546 100644
--- a/test/cpp/api/dataloader.cpp
+++ b/test/cpp/api/dataloader.cpp
@@ -207,6 +207,19 @@
ASSERT_FALSE(sampler.next(2).has_value());
}
+TEST(DataTest, SequentialSamplerResetsWithNewSizeWell) {
+ samplers::SequentialSampler sampler(5);
+ ASSERT_EQ(sampler.next(5).value(), std::vector<size_t>({0, 1, 2, 3, 4}));
+ ASSERT_FALSE(sampler.next(2).has_value());
+ sampler.reset(7);
+ ASSERT_EQ(
+ sampler.next(7).value(), std::vector<size_t>({0, 1, 2, 3, 4, 5, 6}));
+ ASSERT_FALSE(sampler.next(2).has_value());
+ sampler.reset(3);
+ ASSERT_EQ(sampler.next(3).value(), std::vector<size_t>({0, 1, 2}));
+ ASSERT_FALSE(sampler.next(2).has_value());
+}
+
TEST(DataTest, CanSaveAndLoadSequentialSampler) {
{
samplers::SequentialSampler a(10);
@@ -272,6 +285,18 @@
ASSERT_FALSE(sampler.next(2).has_value());
}
+TEST(DataTest, RandomSamplerResetsWithNewSizeWell) {
+ samplers::RandomSampler sampler(5);
+ ASSERT_EQ(sampler.next(5).value().size(), 5);
+ ASSERT_FALSE(sampler.next(2).has_value());
+ sampler.reset(7);
+ ASSERT_EQ(sampler.next(7).value().size(), 7);
+ ASSERT_FALSE(sampler.next(2).has_value());
+ sampler.reset(3);
+ ASSERT_EQ(sampler.next(3).value().size(), 3);
+ ASSERT_FALSE(sampler.next(2).has_value());
+}
+
TEST(DataTest, SavingAndLoadingRandomSamplerYieldsSameSequence) {
{
samplers::RandomSampler a(10);
@@ -320,6 +345,18 @@
ASSERT_FALSE(sampler.next(2).has_value());
}
+TEST(DataTest, StreamSamplerResetsWithNewSizeWell) {
+ samplers::StreamSampler sampler(/*epoch_size=*/5);
+ ASSERT_EQ(sampler.next(5).value().size(), 5);
+ ASSERT_FALSE(sampler.next(2).has_value());
+ sampler.reset(7);
+ ASSERT_EQ(sampler.next(7).value().size(), 7);
+ ASSERT_FALSE(sampler.next(2).has_value());
+ sampler.reset(3);
+ ASSERT_EQ(sampler.next(3).value().size(), 3);
+ ASSERT_FALSE(sampler.next(2).has_value());
+}
+
TEST(DataTest, TensorDatasetConstructsFromSingleTensor) {
datasets::TensorDataset dataset(torch::eye(5));
ASSERT_TRUE(
@@ -618,7 +655,7 @@
struct TestIndexSampler : public samplers::Sampler<TestIndex> {
explicit TestIndexSampler(size_t size) : size_(size) {}
- void reset() override {}
+ void reset(torch::optional<size_t> new_size = torch::nullopt) override {}
torch::optional<TestIndex> next(size_t batch_size) override {
if (index_ >= size_) {
return torch::nullopt;
diff --git a/torch/csrc/api/include/torch/data/samplers/base.h b/torch/csrc/api/include/torch/data/samplers/base.h
index 1767d65..ed2aa30 100644
--- a/torch/csrc/api/include/torch/data/samplers/base.h
+++ b/torch/csrc/api/include/torch/data/samplers/base.h
@@ -1,7 +1,7 @@
#pragma once
-#include <torch/types.h>
#include <torch/csrc/WindowsTorchApiMacro.h>
+#include <torch/types.h>
#include <cstddef>
#include <vector>
@@ -27,7 +27,9 @@
/// Resets the `Sampler`'s internal state.
/// Typically called before a new epoch.
- TORCH_API virtual void reset() = 0;
+
+ /// Optionally, accepts a new size when reseting the sampler.
+ TORCH_API virtual void reset(optional<size_t> new_size) = 0;
/// Returns the next index if possible, or an empty optional if the
/// sampler is exhausted for this epoch.
diff --git a/torch/csrc/api/include/torch/data/samplers/random.h b/torch/csrc/api/include/torch/data/samplers/random.h
index b18a36b..f0e6a86 100644
--- a/torch/csrc/api/include/torch/data/samplers/random.h
+++ b/torch/csrc/api/include/torch/data/samplers/random.h
@@ -1,8 +1,8 @@
#pragma once
+#include <torch/csrc/WindowsTorchApiMacro.h>
#include <torch/data/samplers/base.h>
#include <torch/types.h>
-#include <torch/csrc/WindowsTorchApiMacro.h>
#include <cstddef>
#include <vector>
@@ -26,10 +26,12 @@
/// The constructor will eagerly allocate all required indices, which is the
/// sequence `0 ... size - 1`. `index_dtype` is the data type of the stored
/// indices. You can change it to influence memory usage.
- TORCH_API explicit RandomSampler(int64_t size, Dtype index_dtype = torch::kInt64);
+ TORCH_API explicit RandomSampler(
+ int64_t size,
+ Dtype index_dtype = torch::kInt64);
/// Resets the `RandomSampler` to a new set of indices.
- TORCH_API void reset() override;
+ TORCH_API void reset(optional<size_t> new_size = nullopt) override;
/// Returns the next batch of indices.
TORCH_API optional<std::vector<size_t>> next(size_t batch_size) override;
diff --git a/torch/csrc/api/include/torch/data/samplers/sequential.h b/torch/csrc/api/include/torch/data/samplers/sequential.h
index bd14d3b..5f83014 100644
--- a/torch/csrc/api/include/torch/data/samplers/sequential.h
+++ b/torch/csrc/api/include/torch/data/samplers/sequential.h
@@ -1,8 +1,8 @@
#pragma once
+#include <torch/csrc/WindowsTorchApiMacro.h>
#include <torch/data/samplers/base.h>
#include <torch/types.h>
-#include <torch/csrc/WindowsTorchApiMacro.h>
#include <cstddef>
#include <vector>
@@ -26,7 +26,7 @@
TORCH_API explicit SequentialSampler(size_t size);
/// Resets the `SequentialSampler` to zero.
- TORCH_API void reset() override;
+ TORCH_API void reset(optional<size_t> new_size = nullopt) override;
/// Returns the next batch of indices.
TORCH_API optional<std::vector<size_t>> next(size_t batch_size) override;
diff --git a/torch/csrc/api/include/torch/data/samplers/stream.h b/torch/csrc/api/include/torch/data/samplers/stream.h
index fefc301..6f376ac 100644
--- a/torch/csrc/api/include/torch/data/samplers/stream.h
+++ b/torch/csrc/api/include/torch/data/samplers/stream.h
@@ -1,9 +1,9 @@
#pragma once
+#include <torch/csrc/WindowsTorchApiMacro.h>
#include <torch/data/samplers/base.h>
#include <torch/data/samplers/custom_batch_request.h>
#include <torch/types.h>
-#include <torch/csrc/WindowsTorchApiMacro.h>
#include <cstddef>
@@ -39,7 +39,7 @@
TORCH_API explicit StreamSampler(size_t epoch_size);
/// Resets the internal state of the sampler.
- TORCH_API void reset() override;
+ TORCH_API void reset(optional<size_t> new_size = nullopt) override;
/// Returns a `BatchSize` object with the number of elements to fetch in the
/// next batch. This number is the minimum of the supplied `batch_size` and
diff --git a/torch/csrc/api/src/data/samplers/random.cpp b/torch/csrc/api/src/data/samplers/random.cpp
index 4ea975b..0edbc8c 100644
--- a/torch/csrc/api/src/data/samplers/random.cpp
+++ b/torch/csrc/api/src/data/samplers/random.cpp
@@ -12,10 +12,11 @@
RandomSampler::RandomSampler(int64_t size, Dtype index_dtype)
: indices_(torch::randperm(size, index_dtype)) {}
-void RandomSampler::reset() {
+void RandomSampler::reset(optional<size_t> new_size) {
// This allocates a new chunk of memory every time (just FYI). It should be
// amortized over the entire epoch hopefully.
- indices_ = torch::randperm(indices_.numel(), indices_.options());
+ const auto size = new_size.value_or(static_cast<size_t>(indices_.numel()));
+ indices_ = torch::randperm(size, indices_.options());
index_ = 0;
}
diff --git a/torch/csrc/api/src/data/samplers/sequential.cpp b/torch/csrc/api/src/data/samplers/sequential.cpp
index 3072346..9c294cb 100644
--- a/torch/csrc/api/src/data/samplers/sequential.cpp
+++ b/torch/csrc/api/src/data/samplers/sequential.cpp
@@ -11,7 +11,10 @@
namespace samplers {
SequentialSampler::SequentialSampler(size_t size) : size_(size) {}
-void SequentialSampler::reset() {
+void SequentialSampler::reset(optional<size_t> new_size) {
+ if (new_size.has_value()) {
+ size_ = *new_size;
+ }
index_ = 0;
}
diff --git a/torch/csrc/api/src/data/samplers/stream.cpp b/torch/csrc/api/src/data/samplers/stream.cpp
index 2ac1755..6972846 100644
--- a/torch/csrc/api/src/data/samplers/stream.cpp
+++ b/torch/csrc/api/src/data/samplers/stream.cpp
@@ -20,7 +20,10 @@
StreamSampler::StreamSampler(size_t epoch_size) : epoch_size_(epoch_size) {}
-void StreamSampler::reset() {
+void StreamSampler::reset(optional<size_t> new_size) {
+ if (new_size.has_value()) {
+ epoch_size_ = *new_size;
+ }
examples_retrieved_so_far_ = 0;
}