blob: 358ef0a622f7d93e3ca07334c575fb2922810170 [file] [log] [blame]
/* Copyright 2021 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#ifndef TENSORFLOW_COMPILER_XLA_PJRT_LRU_CACHE_H_
#define TENSORFLOW_COMPILER_XLA_PJRT_LRU_CACHE_H_
#include "absl/container/node_hash_map.h"
#include "absl/types/optional.h"
#include "tensorflow/core/platform/logging.h"
namespace xla {
// A simple LRU cache. Not thread-safe.
// Value must be copyable and moveable. The intent is that Value is typically
// a smart-pointer type.
template <typename Key, typename Value,
typename Hash = typename absl::node_hash_map<Key, Value>::hasher,
typename Eq = typename absl::node_hash_map<Key, Value>::key_equal>
class LRUCache {
private:
struct LRUListEntry {
LRUListEntry* next;
LRUListEntry* prev;
};
public:
// Multiple LRUCaches can share a LRU list, meaning that the capacity and
// eviction policy is shared. The user provides an LRU list
// to the cache constructor, and must ensure that it remains alive as long
// as the cache does.
class LRUList {
public:
explicit LRUList(int capacity) : capacity_(capacity) {
head_.next = &head_;
head_.prev = &head_;
}
~LRUList() {
CHECK(head_.next == &head_);
CHECK(head_.prev == &head_);
}
LRUList(const LRUList&) = delete;
LRUList(LRUList&&) = delete;
LRUList& operator=(const LRUList&) = delete;
LRUList& operator=(LRUList&&) = delete;
int Capacity() const { return capacity_; }
int Size() const { return size_; }
void Clear();
private:
friend class LRUCache;
int capacity_;
int size_ = 0;
// Root of a circular doubly-linked list of entries, in order from least
// recently used to most recently used. An "empty" cache always contains
// this element in the LRU list.
LRUListEntry head_;
};
explicit LRUCache(LRUList* lru_list) : lru_list_(lru_list) {}
~LRUCache();
LRUCache(const LRUCache&) = delete;
LRUCache(LRUCache&&) = delete;
LRUCache& operator=(const LRUCache&) = delete;
LRUCache& operator=(LRUCache&&) = delete;
// Returns the `value` associated with `key`. Creates a value with `factory`
// and inserts it if absent.
Value GetOrCreateIfAbsent(const Key& key,
const std::function<Value(const Key&)>& factory);
// Removes all entries from the cache.
void Clear();
int Size() const { return entries_.size(); }
int Capacity() const { return lru_list_->Capacity(); }
private:
LRUList* lru_list_;
struct Entry : public LRUListEntry {
Entry() = default;
// Pointer to the key in `entries_`. absl::node_hash_map<> promises
// pointer stability for keys.
const Key* key;
LRUCache* container;
std::optional<Value> value;
};
// We use `node_hash_map` because we want to guarantee pointer stability for
// keys and values.
absl::node_hash_map<Key, Entry, Hash, Eq> entries_;
};
template <typename Key, typename Value, typename Hash, typename Eq>
void LRUCache<Key, Value, Hash, Eq>::LRUList::Clear() {
while (head_.next != &head_) {
static_cast<Entry*>(head_.next)->container->Clear();
}
size_ = 0;
}
template <typename Key, typename Value, typename Hash, typename Eq>
void LRUCache<Key, Value, Hash, Eq>::Clear() {
for (auto& e : entries_) {
LRUListEntry* l = &e.second;
l->next->prev = l->prev;
l->prev->next = l->next;
--lru_list_->size_;
}
entries_.clear();
}
template <typename Key, typename Value, typename Hash, typename Eq>
LRUCache<Key, Value, Hash, Eq>::~LRUCache() {
Clear();
}
template <typename Key, typename Value, typename Hash, typename Eq>
Value LRUCache<Key, Value, Hash, Eq>::GetOrCreateIfAbsent(
const Key& key, const std::function<Value(const Key&)>& factory) {
typename absl::node_hash_map<Key, Entry, Hash, Eq>::iterator it;
bool inserted;
std::tie(it, inserted) = entries_.try_emplace(key);
Entry& entry = it->second;
if (inserted) {
entry.key = &it->first;
entry.value = factory(*entry.key);
++lru_list_->size_;
} else {
// Removes the entry from the LRU list, in preparation for adding it
// to the back of the list.
entry.prev->next = entry.next;
entry.next->prev = entry.prev;
}
// (Re-)adds entry to the back of the LRU list. Since it is now the
// most recently used element, it goes at the back.
LRUListEntry& lru_head = lru_list_->head_;
entry.container = this;
entry.prev = lru_head.prev;
entry.next = &lru_head;
lru_head.prev->next = &entry;
lru_head.prev = &entry;
Value v = *entry.value;
// Evict an LRU entry if we are over capacity.
if (lru_list_->size_ > lru_list_->capacity_) {
Entry* to_remove = static_cast<Entry*>(lru_head.next);
to_remove->next->prev = &lru_head;
lru_head.next = to_remove->next;
to_remove->container->entries_.erase(*to_remove->key);
--lru_list_->size_;
}
return v;
}
} // namespace xla
#endif // TENSORFLOW_COMPILER_XLA_PJRT_LRU_CACHE_H_