| /* Copyright 2022 The TensorFlow Authors. All Rights Reserved. |
| |
| Licensed under the Apache License, Version 2.0 (the "License"); |
| you may not use this file except in compliance with the License. |
| You may obtain a copy of the License at |
| |
| http://www.apache.org/licenses/LICENSE-2.0 |
| |
| Unless required by applicable law or agreed to in writing, software |
| distributed under the License is distributed on an "AS IS" BASIS, |
| WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
| See the License for the specific language governing permissions and |
| limitations under the License. |
| ==============================================================================*/ |
| #include "tensorflow/core/data/service/multi_trainer_cache.h" |
| |
| #include <cstdint> |
| #include <memory> |
| #include <string> |
| #include <utility> |
| #include <vector> |
| |
| #include "absl/container/flat_hash_map.h" |
| #include "absl/memory/memory.h" |
| #include "absl/strings/str_cat.h" |
| #include "absl/time/time.h" |
| #include "tensorflow/core/framework/tensor.h" |
| #include "tensorflow/core/framework/tensor_testutil.h" |
| #include "tensorflow/core/lib/core/status_test_util.h" |
| #include "tensorflow/core/lib/monitoring/cell_reader.h" |
| #include "tensorflow/core/platform/env.h" |
| #include "tensorflow/core/platform/errors.h" |
| #include "tensorflow/core/platform/mutex.h" |
| #include "tensorflow/core/platform/random.h" |
| #include "tensorflow/core/platform/status.h" |
| #include "tensorflow/core/platform/status_matchers.h" |
| #include "tensorflow/core/platform/statusor.h" |
| #include "tensorflow/core/platform/test.h" |
| #include "tensorflow/core/protobuf/error_codes.pb.h" |
| |
| namespace tensorflow { |
| namespace data { |
| namespace { |
| |
| using ::tensorflow::monitoring::testing::CellReader; |
| using ::tensorflow::testing::IsOkAndHolds; |
| using ::tensorflow::testing::StatusIs; |
| using ::testing::Gt; |
| using ::testing::HasSubstr; |
| using ::testing::Pointee; |
| using ::testing::UnorderedElementsAreArray; |
| |
| class InfiniteRange : public CachableSequence<int64_t> { |
| public: |
| StatusOr<int64_t> GetNext() override { return next_++; } |
| size_t GetElementSizeBytes(const int64_t& element) const override { |
| return sizeof(element); |
| } |
| |
| private: |
| // No need to guard this variable because only one thread can write the cache. |
| int64_t next_ = 0; |
| }; |
| |
| class TensorDataset : public CachableSequence<Tensor> { |
| public: |
| StatusOr<Tensor> GetNext() override { return Tensor("Test Tensor"); } |
| size_t GetElementSizeBytes(const Tensor& element) const override { |
| return element.TotalBytes(); |
| } |
| }; |
| |
| class SlowDataset : public CachableSequence<Tensor> { |
| public: |
| explicit SlowDataset(absl::Duration delay) : delay_(delay) {} |
| |
| StatusOr<Tensor> GetNext() override { |
| Env::Default()->SleepForMicroseconds(absl::ToInt64Microseconds(delay_)); |
| return Tensor("Test Tensor"); |
| } |
| |
| size_t GetElementSizeBytes(const Tensor& element) const override { |
| return element.TotalBytes(); |
| } |
| |
| private: |
| absl::Duration delay_; |
| }; |
| |
| template <class T> |
| class ElementOrErrorDataset : public CachableSequence<T> { |
| public: |
| explicit ElementOrErrorDataset(const std::vector<StatusOr<T>>& elements) |
| : elements_(elements) {} |
| |
| StatusOr<T> GetNext() override { |
| if (next_ >= elements_.size()) { |
| return errors::OutOfRange("Out of range."); |
| } |
| |
| return elements_[next_++]; |
| } |
| |
| size_t GetElementSizeBytes(const T& element) const override { |
| return sizeof(element); |
| } |
| |
| private: |
| const std::vector<StatusOr<T>> elements_; |
| int64_t next_ = 0; |
| }; |
| |
| template <> |
| size_t ElementOrErrorDataset<std::string>::GetElementSizeBytes( |
| const std::string& element) const { |
| return element.size(); |
| } |
| |
| template <> |
| size_t ElementOrErrorDataset<Tensor>::GetElementSizeBytes( |
| const Tensor& element) const { |
| return element.TotalBytes(); |
| } |
| |
| std::vector<int64_t> GetRange(const size_t range) { |
| std::vector<int64_t> result; |
| for (int64_t i = 0; i < range; ++i) { |
| result.push_back(i); |
| } |
| return result; |
| } |
| |
| bool SequenceIsIncreasing(const std::vector<int64_t> sequence) { |
| for (int i = 1; i < sequence.size(); ++i) { |
| if (sequence[i - 1] > sequence[i - 1]) { |
| return false; |
| } |
| } |
| return true; |
| } |
| |
| TEST(MultiTrainerCacheTest, GetFromOneTrainer) { |
| const size_t num_elements = 10; |
| MultiTrainerCache<int64_t> cache( |
| /*max_cache_size_bytes=*/1024, absl::make_unique<InfiniteRange>()); |
| for (size_t i = 0; i < num_elements; ++i) { |
| EXPECT_THAT(cache.Get("Trainer ID"), IsOkAndHolds(Pointee(i))); |
| } |
| } |
| |
| TEST(MultiTrainerCacheTest, GetFromMultipleTrainers) { |
| const size_t num_elements = 10; |
| const size_t num_trainers = 10; |
| |
| MultiTrainerCache<int64_t> cache( |
| /*max_cache_size_bytes=*/1024, absl::make_unique<InfiniteRange>()); |
| for (size_t i = 0; i < num_elements; ++i) { |
| // All the readers get the same element in one step. |
| for (size_t j = 0; j < num_trainers; ++j) { |
| const std::string trainer_id = absl::StrCat("Trainer ", j); |
| EXPECT_THAT(cache.Get(trainer_id), IsOkAndHolds(Pointee(i))); |
| } |
| } |
| } |
| |
| TEST(MultiTrainerCacheTest, SlowTrainersSkipData) { |
| MultiTrainerCache<int64_t> cache( |
| /*max_cache_size_bytes=*/5 * sizeof(int64_t), |
| absl::make_unique<InfiniteRange>()); |
| EXPECT_THAT(cache.Get("Fast trainer 1"), IsOkAndHolds(Pointee(0))); |
| EXPECT_THAT(cache.Get("Fast trainer 2"), IsOkAndHolds(Pointee(0))); |
| EXPECT_THAT(cache.Get("Slow trainer 1"), IsOkAndHolds(Pointee(0))); |
| EXPECT_THAT(cache.Get("Slow trainer 2"), IsOkAndHolds(Pointee(0))); |
| |
| for (int i = 1; i < 20; ++i) { |
| EXPECT_THAT(cache.Get("Fast trainer 1"), IsOkAndHolds(Pointee(i))); |
| EXPECT_THAT(cache.Get("Fast trainer 2"), IsOkAndHolds(Pointee(i))); |
| } |
| |
| // When 19 is cached, 14 must have been discarded. |
| EXPECT_THAT(cache.Get("Slow trainer 1"), IsOkAndHolds(Pointee(Gt(14)))); |
| EXPECT_THAT(cache.Get("Slow trainer 2"), IsOkAndHolds(Pointee(Gt(14)))); |
| |
| for (int i = 20; i < 100; ++i) { |
| EXPECT_THAT(cache.Get("Fast trainer 1"), IsOkAndHolds(Pointee(i))); |
| EXPECT_THAT(cache.Get("Fast trainer 2"), IsOkAndHolds(Pointee(i))); |
| } |
| |
| // When 99 is cached, 94 must have been discarded. |
| EXPECT_THAT(cache.Get("Slow trainer 1"), IsOkAndHolds(Pointee(Gt(94)))); |
| EXPECT_THAT(cache.Get("Slow trainer 2"), IsOkAndHolds(Pointee(Gt(94)))); |
| } |
| |
| TEST(MultiTrainerCacheTest, NewTrainersStartLate) { |
| MultiTrainerCache<int64_t> cache( |
| /*max_cache_size_bytes=*/5 * sizeof(int64_t), |
| absl::make_unique<InfiniteRange>()); |
| for (int i = 0; i < 100; ++i) { |
| EXPECT_THAT(cache.Get("Old trainer"), IsOkAndHolds(Pointee(i))); |
| } |
| |
| // New trainers start to read after the first trainer has finished. |
| for (int j = 0; j < 100; ++j) { |
| EXPECT_THAT(cache.Get(absl::StrCat("New trainer ", j)), |
| IsOkAndHolds(Pointee(Gt(94)))); |
| } |
| } |
| |
| TEST(MultiTrainerCacheTest, AlternateTrainerExtendsCache) { |
| // The cache size is smaller than one int64_t. |
| MultiTrainerCache<int64_t> cache( |
| /*max_cache_size_bytes=*/sizeof(int64_t), |
| absl::make_unique<InfiniteRange>()); |
| EXPECT_THAT(cache.Get("Trainer 1"), IsOkAndHolds(Pointee(0))); |
| EXPECT_THAT(cache.Get("Trainer 1"), IsOkAndHolds(Pointee(1))); |
| EXPECT_THAT(cache.Get("Trainer 1"), IsOkAndHolds(Pointee(2))); |
| |
| // When 2 is cached, 0 must have been discarded. |
| EXPECT_THAT(cache.Get("Trainer 2"), IsOkAndHolds(Pointee(Gt(0)))); |
| EXPECT_THAT(cache.Get("Trainer 2"), IsOkAndHolds(Pointee(Gt(1)))); |
| EXPECT_THAT(cache.Get("Trainer 2"), IsOkAndHolds(Pointee(Gt(2)))); |
| |
| // When 3 is cached, 1 must have been discarded. |
| EXPECT_THAT(cache.Get("Trainer 1"), IsOkAndHolds(Pointee(Gt(1)))); |
| EXPECT_THAT(cache.Get("Trainer 1"), IsOkAndHolds(Pointee(Gt(2)))); |
| EXPECT_THAT(cache.Get("Trainer 1"), IsOkAndHolds(Pointee(Gt(3)))); |
| |
| // When 4 is cached, 2 must have been discarded. |
| EXPECT_THAT(cache.Get("Trainer 2"), IsOkAndHolds(Pointee(Gt(2)))); |
| EXPECT_THAT(cache.Get("Trainer 2"), IsOkAndHolds(Pointee(Gt(3)))); |
| EXPECT_THAT(cache.Get("Trainer 2"), IsOkAndHolds(Pointee(Gt(4)))); |
| |
| // When 5 is cached, 3 must have been discarded. |
| EXPECT_THAT(cache.Get("Trainer 3"), IsOkAndHolds(Pointee(Gt(3)))); |
| EXPECT_THAT(cache.Get("Trainer 3"), IsOkAndHolds(Pointee(Gt(4)))); |
| EXPECT_THAT(cache.Get("Trainer 3"), IsOkAndHolds(Pointee(Gt(5)))); |
| } |
| |
| TEST(MultiTrainerCacheTest, CacheHitMetrics) { |
| CellReader<int64_t> cell_reader( |
| "/tensorflow/data/service/multi_trainer_cache_queries"); |
| EXPECT_EQ(cell_reader.Delta("true"), 0); |
| EXPECT_EQ(cell_reader.Delta("false"), 0); |
| EXPECT_EQ(cell_reader.Read("true"), 0); |
| EXPECT_EQ(cell_reader.Read("false"), 0); |
| |
| const size_t num_elements = 10; |
| MultiTrainerCache<int64_t> cache( |
| /*max_cache_size_bytes=*/1024, absl::make_unique<InfiniteRange>()); |
| for (size_t i = 0; i < num_elements; ++i) { |
| EXPECT_THAT(cache.Get("Trainer 1"), IsOkAndHolds(Pointee(i))); |
| } |
| EXPECT_EQ(cell_reader.Delta("true"), 0); |
| EXPECT_EQ(cell_reader.Delta("false"), 10); |
| EXPECT_EQ(cell_reader.Read("true"), 0); |
| EXPECT_EQ(cell_reader.Read("false"), 10); |
| |
| for (size_t i = 0; i < num_elements; ++i) { |
| EXPECT_THAT(cache.Get("Trainer 2"), IsOkAndHolds(Pointee(i))); |
| } |
| EXPECT_EQ(cell_reader.Delta("true"), 10); |
| EXPECT_EQ(cell_reader.Delta("false"), 0); |
| EXPECT_EQ(cell_reader.Read("true"), 10); |
| EXPECT_EQ(cell_reader.Read("false"), 10); |
| } |
| |
| TEST(MultiTrainerCacheTest, CacheSizeMetrics) { |
| CellReader<int64_t> cell_reader( |
| "/tensorflow/data/service/multi_trainer_cache_size_bytes"); |
| |
| const size_t num_elements = 5; |
| MultiTrainerCache<int64_t> cache( |
| /*max_cache_size_bytes=*/num_elements * sizeof(int64_t), |
| std::make_unique<InfiniteRange>()); |
| |
| for (size_t i = 0; i < num_elements; ++i) { |
| EXPECT_THAT(cache.Get("Trainer 1"), IsOkAndHolds(Pointee(i))); |
| EXPECT_EQ(cell_reader.Read(), (i + 1) * sizeof(int64_t)); |
| } |
| |
| // The cache size does not increase after reaching `num_elements`. |
| for (size_t i = 0; i < 100; ++i) { |
| EXPECT_THAT(cache.Get("Trainer 1"), |
| IsOkAndHolds(Pointee(num_elements + i))); |
| EXPECT_EQ(cell_reader.Read(), 5 * sizeof(int64_t)); |
| } |
| } |
| |
| TEST(MultiTrainerCacheTest, ConcurrentReaders) { |
| size_t num_trainers = 10; |
| size_t num_elements_to_read = 200; |
| MultiTrainerCache<int64_t> cache( |
| /*max_cache_size_bytes=*/3 * sizeof(int64_t), |
| absl::make_unique<InfiniteRange>()); |
| |
| std::vector<std::vector<int64_t>> results; |
| std::vector<std::unique_ptr<Thread>> reader_threads; |
| results.reserve(num_trainers); |
| for (size_t i = 0; i < num_trainers; ++i) { |
| results.emplace_back(); |
| std::vector<int64_t>& result = results.back(); |
| reader_threads.push_back(absl::WrapUnique(Env::Default()->StartThread( |
| /*thread_options=*/{}, /*name=*/absl::StrCat("Trainer_", i), |
| [&cache, num_elements_to_read, &result]() { |
| for (size_t i = 0; i < num_elements_to_read; ++i) { |
| // Randomly slows down some trainers. |
| if (random::New64() % 5 == 0) { |
| Env::Default()->SleepForMicroseconds(2000); |
| } |
| TF_ASSERT_OK_AND_ASSIGN(std::shared_ptr<const int64_t> next, |
| cache.Get(absl::StrCat("Trainer_", i))); |
| result.push_back(*next); |
| } |
| }))); |
| } |
| reader_threads.clear(); |
| |
| // Verifies all trainers can read `num_elements_to_read` elements. |
| EXPECT_EQ(results.size(), num_trainers); |
| for (const std::vector<int64_t>& result : results) { |
| EXPECT_EQ(result.size(), num_elements_to_read); |
| EXPECT_TRUE(SequenceIsIncreasing(result)); |
| } |
| } |
| |
| TEST(MultiTrainerCacheTest, ConcurrentReadersFromOneTrainer) { |
| size_t num_trainers = 10; |
| size_t num_elements_to_read = 100; |
| MultiTrainerCache<int64_t> cache( |
| /*max_cache_size_bytes=*/3 * sizeof(int64_t), |
| absl::make_unique<InfiniteRange>()); |
| |
| mutex mu; |
| std::vector<int64_t> results; // Guarded by `mu`. |
| std::vector<std::unique_ptr<Thread>> reader_threads; |
| for (size_t i = 0; i < num_trainers; ++i) { |
| reader_threads.push_back(absl::WrapUnique(Env::Default()->StartThread( |
| /*thread_options=*/{}, /*name=*/absl::StrCat("Thread_", i), |
| [&cache, num_elements_to_read, &results, &mu]() { |
| for (size_t i = 0; i < num_elements_to_read; ++i) { |
| // Randomly slows down some trainers. |
| if (random::New64() % 5 == 0) { |
| Env::Default()->SleepForMicroseconds(1000); |
| } |
| TF_ASSERT_OK_AND_ASSIGN(std::shared_ptr<const int64_t> next, |
| cache.Get("Trainer ID")); |
| mutex_lock l(mu); |
| results.push_back(*next); |
| } |
| }))); |
| } |
| reader_threads.clear(); |
| |
| // Verifies the readers have read all elements because they have the same |
| // trainer ID. |
| EXPECT_THAT(results, UnorderedElementsAreArray(GetRange(1000))); |
| } |
| |
| TEST(MultiTrainerCacheTest, Cancel) { |
| size_t num_trainers = 10; |
| MultiTrainerCache<Tensor> cache( |
| /*max_cache_size_bytes=*/1000, absl::make_unique<TensorDataset>()); |
| EXPECT_FALSE(cache.IsCancelled()); |
| |
| mutex mu; |
| Status status; // Guarded by `mu`. |
| std::vector<std::unique_ptr<Thread>> reader_threads; |
| for (size_t i = 0; i < num_trainers; ++i) { |
| reader_threads.push_back(absl::WrapUnique(Env::Default()->StartThread( |
| /*thread_options=*/{}, /*name=*/absl::StrCat("Trainer_", i), |
| [&cache, &status, &mu]() { |
| for (int j = 0; true; ++j) { |
| StatusOr<std::shared_ptr<const Tensor>> tensor = |
| cache.Get(absl::StrCat("Trainer_", j % 1000)); |
| { |
| mutex_lock l(mu); |
| status = tensor.status(); |
| } |
| if (!tensor.status().ok()) { |
| return; |
| } |
| test::ExpectEqual(*tensor.ValueOrDie(), Tensor("Test Tensor")); |
| } |
| }))); |
| } |
| |
| Env::Default()->SleepForMicroseconds(1000000); |
| cache.Cancel(errors::Cancelled("Cancelled")); |
| reader_threads.clear(); |
| |
| mutex_lock l(mu); |
| EXPECT_THAT(status, StatusIs(error::CANCELLED)); |
| EXPECT_THAT(cache.Get("New trainer"), StatusIs(error::CANCELLED)); |
| EXPECT_TRUE(cache.IsCancelled()); |
| } |
| |
| TEST(MultiTrainerCacheTest, Errors) { |
| auto elements = absl::make_unique<ElementOrErrorDataset<std::string>>( |
| std::vector<StatusOr<std::string>>{ |
| std::string("First element"), |
| errors::Cancelled("Cancelled"), |
| std::string("Second element"), |
| errors::InvalidArgument("InvalidArgument"), |
| std::string("Third element"), |
| errors::Unavailable("Unavailable"), |
| }); |
| MultiTrainerCache<std::string> cache( |
| /*max_cache_size_bytes=*/1000, std::move(elements)); |
| |
| EXPECT_THAT(cache.Get("Trainer ID"), |
| IsOkAndHolds(Pointee(std::string("First element")))); |
| EXPECT_THAT(cache.Get("Trainer ID"), StatusIs(error::CANCELLED)); |
| EXPECT_THAT(cache.Get("Trainer ID"), |
| IsOkAndHolds(Pointee(std::string("Second element")))); |
| EXPECT_THAT(cache.Get("Trainer ID"), StatusIs(error::INVALID_ARGUMENT)); |
| EXPECT_THAT(cache.Get("Trainer ID"), |
| IsOkAndHolds(Pointee(std::string("Third element")))); |
| EXPECT_THAT(cache.Get("Trainer ID"), StatusIs(error::UNAVAILABLE)); |
| |
| // Errors are not stored in the cache. |
| EXPECT_THAT(cache.Get("New Trainer"), |
| IsOkAndHolds(Pointee(std::string("First element")))); |
| EXPECT_THAT(cache.Get("New Trainer"), |
| IsOkAndHolds(Pointee(std::string("Second element")))); |
| EXPECT_THAT(cache.Get("New Trainer"), |
| IsOkAndHolds(Pointee(std::string("Third element")))); |
| } |
| |
| TEST(MultiTrainerCacheTest, CacheSizeIsTooSmall) { |
| // The cache size is smaller than one int64_t. |
| MultiTrainerCache<Tensor> cache( |
| /*max_cache_size_bytes=*/1, absl::make_unique<TensorDataset>()); |
| EXPECT_THAT(cache.Get("Trainer ID"), |
| StatusIs(error::INVALID_ARGUMENT, |
| HasSubstr("tf.data service element size is larger than " |
| "cache size in bytes."))); |
| } |
| |
| TEST(MultiTrainerCacheTest, TrainerIDMustBeNonEmpty) { |
| MultiTrainerCache<Tensor> cache( |
| /*max_cache_size_bytes=*/1000, absl::make_unique<TensorDataset>()); |
| EXPECT_THAT( |
| cache.Get(""), |
| StatusIs( |
| error::INTERNAL, |
| "tf.data service multi-trainer cache trainer ID must be non-empty.")); |
| } |
| |
| } // namespace |
| } // namespace data |
| } // namespace tensorflow |