blob: d4d2b58a88c95d55ae8c91439798d9a091a4b874 [file] [log] [blame]
/* Copyright 2022 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#ifndef TENSORFLOW_CORE_DATA_SERVICE_MULTI_TRAINER_CACHE_H_
#define TENSORFLOW_CORE_DATA_SERVICE_MULTI_TRAINER_CACHE_H_
#include <cstddef>
#include <deque>
#include <functional>
#include <limits>
#include <memory>
#include <string>
#include <utility>
#include "absl/container/flat_hash_map.h"
#include "tensorflow/core/data/service/logging_utils.h"
#include "tensorflow/core/framework/metrics.h"
#include "tensorflow/core/platform/env.h"
#include "tensorflow/core/platform/errors.h"
#include "tensorflow/core/platform/mutex.h"
#include "tensorflow/core/platform/status.h"
#include "tensorflow/core/platform/statusor.h"
#include "tensorflow/core/platform/thread_annotations.h"
namespace tensorflow {
namespace data {
// Sliding-window cache shared across concurrent trainers. Readers call `Get` to
// read elements they haven't read. After a trainer reads an element, it remains
// in the cache and the data is shared with other trainers. This is useful for
// datasets involving expensive computation, and multiple models use the same
// data for training. For example, for hyperparameter tuning.
//
// The cache progresses when a trainer that has consumed all elements in the
// cache requests additional data. It has a bounded size. Elements are garbage
// collected when the cache becomes full. Consequently, trainers read from a
// sliding window through the dataset and may not read the full dataset.
//
// The `MultiTrainerCache` class is thread-safe.
//
// Example usage:
//
// // `InfiniteRange` returns 1, 2, 3, ... in the `GetNext` calls.
// class InfiniteRange : public CachableSequence<int64_t> {
// public:
// StatusOr<int64_t> GetNext() override {
// return next_++;
// }
//
// size_t GetElementSizeBytes(const int64_t& element) const override {
// return sizeof(element);
// }
//
// private:
// int64_t next_ = 1;
// };
//
// MultiTrainerCache<int64_t> cache(
// /*max_cache_size_bytes=*/10 * (size_t{1} << 30), // 10GB
// absl::make_unique<InfiniteRange>());
//
// std::shared_ptr<int64_t> next;
// TF_ASSIGN_OR_RETURN(next, cache.Get("Trainer 1")); // Returns 1
// TF_ASSIGN_OR_RETURN(next, cache.Get("Trainer 2")); // Returns 1
// TF_ASSIGN_OR_RETURN(next, cache.Get("Trainer 1")); // Returns 2
// TF_ASSIGN_OR_RETURN(next, cache.Get("Trainer 2")); // Returns 2
// To use the cache, the user needs to define a `CachableSequence` to generate
// an infinite sequence of data. It should implement a `GetNext` method to
// produce elements, and a `GetElementSizeBytes` method to estimate the element
// size in bytes.
template <class ElementType>
class CachableSequence {
public:
virtual ~CachableSequence() = default;
// Returns the next element to be cached.
virtual StatusOr<ElementType> GetNext() = 0;
// Returns the estimated size of the element in bytes.
virtual size_t GetElementSizeBytes(const ElementType&) const = 0;
};
// Sliding-window cache shared across concurrent trainers.
template <class ElementType>
class MultiTrainerCache {
public:
// Creates a `MultiTrainerCache` with `max_cache_size_bytes` of memory budget.
// The cache should be able to hold at least one element, i.e.:
// REQUIRES: `max_cache_size_bytes >= max(GetElementSizeBytes(*))`
explicit MultiTrainerCache(
size_t max_cache_size_bytes,
std::unique_ptr<CachableSequence<ElementType>> cachable_sequence);
virtual ~MultiTrainerCache() = default;
// Gets the next element for a trainer. A `trainer_id` identifies the trainer
// reading from the cache. A trainer reads the next element it hasn't read
// before. After a trainer reads data, the data is cached and reused by other
// trainers.
StatusOr<std::shared_ptr<const ElementType>> Get(
const std::string& trainer_id);
// Cancels the cache with `status` and notifies the readers. After cancelling,
// all `Get` calls will return `status`.
// REQUIRES: !status.ok()
void Cancel(Status status);
// Returns true if the cache has been cancelled.
bool IsCancelled() const;
private:
struct CacheQueryResult {
std::shared_ptr<const ElementType> element;
bool cache_hit;
};
// Returns the next element and metrics about this query.
StatusOr<CacheQueryResult> GetCacheQueryResult(const std::string& trainer_id);
// Returns true if element is ready for `trainer_id`. An element is ready if
// other trainers have read the data and the data remains in the cache. If the
// data is not ready, one of the trainers need to extend the cache.
bool IsElementReady(const std::string& trainer_id);
// Returns the absolute element index relative to the dataset (not relative to
// the cached elements).
size_t GetElementIndex(const std::string& trainer_id);
// Returns the next element for `trainer_id`.
StatusOr<std::shared_ptr<const ElementType>> GetElement(
const std::string& trainer_id);
// Reads a new element and writes it into the cache.
Status ExtendCache();
// Frees old elements to keep the cache size below `max_cache_size_bytes_`.
// `new_element_size_bytes` is the size of the new element being inserted.
void FreeSpace(size_t new_element_size_bytes);
// Records the cache hit rate and cache size.
void RecordMetrics(const CacheQueryResult& result);
// Maximum cache size in bytes.
const size_t max_cache_size_bytes_;
// The element sequence over which the sliding window cache operates.
std::unique_ptr<CachableSequence<ElementType>> cachable_sequence_;
mutable mutex mu_;
mutable condition_variable cv_;
// If `status_` is non-OK, the cache is cancelled, and all method calls will
// return this status.
Status status_ TF_GUARDED_BY(mu_) = Status::OK();
// `cache_` stores the cached elements.
std::deque<std::shared_ptr<const ElementType>> cache_ TF_GUARDED_BY(mu_);
size_t cache_size_bytes_ TF_GUARDED_BY(mu_) = 0;
size_t cache_start_index_ TF_GUARDED_BY(mu_) = 0;
// True if one thread is extending the cache.
bool extending_cache_ TF_GUARDED_BY(mu_) = false;
// Maps trainer IDs to element indices. The indices are absolute indices
// within the dataset. The actual index to use with `cache_` would be
// `trainer_to_element_index_map_[trainer_id] - cache_start_index_`.
absl::flat_hash_map<std::string, size_t> trainer_to_element_index_map_
TF_GUARDED_BY(mu_);
};
template <class ElementType>
MultiTrainerCache<ElementType>::MultiTrainerCache(
size_t max_cache_size_bytes,
std::unique_ptr<CachableSequence<ElementType>> cachable_sequence)
: max_cache_size_bytes_(max_cache_size_bytes),
cachable_sequence_(std::move(cachable_sequence)) {
DCHECK_GT(max_cache_size_bytes, 0)
<< "MultiTrainerCache size must be greater than 0.";
VLOG(2) << "Initialized tf.data service multi-trainer cache with "
<< FormatBytes(max_cache_size_bytes) << " of memory.";
}
template <class ElementType>
StatusOr<std::shared_ptr<const ElementType>>
MultiTrainerCache<ElementType>::Get(const std::string& trainer_id)
TF_LOCKS_EXCLUDED(mu_) {
if (trainer_id.empty()) {
return errors::Internal(
"tf.data service multi-trainer cache trainer ID must be non-empty.");
}
TF_ASSIGN_OR_RETURN(CacheQueryResult result, GetCacheQueryResult(trainer_id));
RecordMetrics(result);
return result.element;
}
template <class ElementType>
StatusOr<typename MultiTrainerCache<ElementType>::CacheQueryResult>
MultiTrainerCache<ElementType>::GetCacheQueryResult(
const std::string& trainer_id) {
bool should_extend_cache = false;
while (true) {
{
mutex_lock l(mu_);
TF_RETURN_IF_ERROR(status_);
if (IsElementReady(trainer_id)) {
TF_ASSIGN_OR_RETURN(std::shared_ptr<const ElementType> element,
GetElement(trainer_id));
return CacheQueryResult{element,
/*is_cache_hit=*/!should_extend_cache};
}
// Extends the cache or waits for another thread to extend the cache. When
// concurrent trainers wait for the next element, only one of them should
// extend the cache.
if (extending_cache_) {
should_extend_cache = false;
cv_.wait(l);
} else {
should_extend_cache = true;
extending_cache_ = true;
}
}
if (should_extend_cache) {
Status s = ExtendCache();
mutex_lock l(mu_);
extending_cache_ = false;
cv_.notify_all();
TF_RETURN_IF_ERROR(s);
}
}
}
template <class ElementType>
bool MultiTrainerCache<ElementType>::IsElementReady(
const std::string& trainer_id) TF_EXCLUSIVE_LOCKS_REQUIRED(mu_) {
return GetElementIndex(trainer_id) < cache_start_index_ + cache_.size();
}
template <class ElementType>
StatusOr<std::shared_ptr<const ElementType>>
MultiTrainerCache<ElementType>::GetElement(const std::string& trainer_id)
TF_EXCLUSIVE_LOCKS_REQUIRED(mu_) {
size_t element_index = GetElementIndex(trainer_id);
if (element_index >= std::numeric_limits<size_t>::max()) {
return errors::Internal(
"tf.data service caching element index exceeds integer limit. Got ",
element_index);
}
std::shared_ptr<const ElementType> result =
cache_[element_index - cache_start_index_];
trainer_to_element_index_map_[trainer_id] = element_index + 1;
return result;
}
template <class ElementType>
size_t MultiTrainerCache<ElementType>::GetElementIndex(
const std::string& trainer_id) TF_EXCLUSIVE_LOCKS_REQUIRED(mu_) {
size_t element_index = trainer_to_element_index_map_[trainer_id];
if (element_index < cache_start_index_) {
element_index = cache_start_index_;
}
return element_index;
}
template <class ElementType>
Status MultiTrainerCache<ElementType>::ExtendCache() TF_LOCKS_EXCLUDED(mu_) {
TF_ASSIGN_OR_RETURN(ElementType element, cachable_sequence_->GetNext());
size_t new_element_size_bytes =
cachable_sequence_->GetElementSizeBytes(element);
if (new_element_size_bytes > max_cache_size_bytes_) {
return errors::InvalidArgument(
"tf.data service element size is larger than cache size in bytes. Got ",
"element size: ", new_element_size_bytes,
" and cache size: ", max_cache_size_bytes_);
}
mutex_lock l(mu_);
TF_RETURN_IF_ERROR(status_);
FreeSpace(new_element_size_bytes);
cache_.push_back(std::make_shared<ElementType>(std::move(element)));
cache_size_bytes_ += new_element_size_bytes;
return Status::OK();
}
template <class ElementType>
void MultiTrainerCache<ElementType>::FreeSpace(size_t new_element_size_bytes)
TF_EXCLUSIVE_LOCKS_REQUIRED(mu_) {
size_t num_elements_discarded = 0;
while (!cache_.empty() &&
cache_size_bytes_ + new_element_size_bytes > max_cache_size_bytes_) {
size_t free_bytes =
cachable_sequence_->GetElementSizeBytes(*cache_.front());
cache_.pop_front();
cache_size_bytes_ -= free_bytes;
++cache_start_index_;
++num_elements_discarded;
}
VLOG(3) << "Freed " << num_elements_discarded << " element(s) from "
<< "tf.data service multi-trainer cache. Memory usage: "
<< FormatBytes(cache_size_bytes_) << ".";
}
template <class ElementType>
void MultiTrainerCache<ElementType>::Cancel(Status status)
TF_LOCKS_EXCLUDED(mu_) {
DCHECK(!status.ok())
<< "Cancelling MultiTrainerCache requires a non-OK status. Got "
<< status;
VLOG(2) << "Cancel tf.data service multi-trainer cache with status "
<< status;
mutex_lock l(mu_);
status_ = std::move(status);
cv_.notify_all();
}
template <class ElementType>
bool MultiTrainerCache<ElementType>::IsCancelled() const
TF_LOCKS_EXCLUDED(mu_) {
mutex_lock l(mu_);
return !status_.ok();
}
template <class ElementType>
void MultiTrainerCache<ElementType>::RecordMetrics(
const CacheQueryResult& result) {
metrics::RecordTFDataServiceMultiTrainerCacheQuery(result.cache_hit);
size_t cache_size_bytes = 0;
{
mutex_lock l(mu_);
cache_size_bytes = cache_size_bytes_;
}
metrics::RecordTFDataServiceMultiTrainerCacheSizeBytes(cache_size_bytes);
}
} // namespace data
} // namespace tensorflow
#endif // TENSORFLOW_CORE_DATA_SERVICE_MULTI_CLIENT_CACHE_H_