[CUDNN][CUDNN V8 API] LRU Cache for cuDNN frontend `ExecutionPlan` (#104369)
Adds LRU functionality to the cuDNN frontend `ExecutionPlan` cache to address high memory usage as observed in #98688, #104122 via the `TORCH_CUDNN_V8_LRU_CACHE_LIMIT` environment variable. By default this limit is set to 10000, which corresponds to about 2GiB of host memory usage as observed empirically. Note that we are still following up with cuDNN to see if the size of an `ExecutionPlan` can be reduced, as it appears to currently be around 200KiB (!!) for a single plan.
This implementation is a bit heavy on the internal asserts for now as it's a bit difficult to directly test the state of the cache without instrumenting it explicitly in tests. Once we are confident that the implementation is stable, we can remove the asserts.
CC @malfet who @ptrblck mentioned may have also been looking into this
CC @colesbury
Pull Request resolved: https://github.com/pytorch/pytorch/pull/104369
Approved by: https://github.com/malfet
diff --git a/aten/src/ATen/native/cudnn/Conv_v8.cpp b/aten/src/ATen/native/cudnn/Conv_v8.cpp
index 600aa02..434b2cd 100644
--- a/aten/src/ATen/native/cudnn/Conv_v8.cpp
+++ b/aten/src/ATen/native/cudnn/Conv_v8.cpp
@@ -32,6 +32,7 @@
#include <c10/cuda/CUDACachingAllocator.h>
#include <unordered_map>
+#include <list>
#ifndef AT_PER_OPERATOR_HEADERS
#include <ATen/Functions.h>
@@ -180,23 +181,95 @@
}
};
+static int getLRUCacheLimit() {
+ constexpr int DEFAULT_LIMIT = 10000; // roughly corresponds to 2GiB assuming 200KiB per ExecutionPlan
+ // 0 is used to indicate no limit
+ // negative values are used to indicate no caching
+ static int limit = [&] {
+ const char * val = getenv("TORCH_CUDNN_V8_API_LRU_CACHE_LIMIT");
+ if (!val) {
+ return DEFAULT_LIMIT;
+ }
+ try {
+ return std::stoi(val);
+ } catch(std::invalid_argument const& e) {
+ TORCH_WARN("invalid TORCH_CUDNN_V8_API_LRU_CACHE_LIMIT,",
+ " using default LRU cache limit of ", DEFAULT_LIMIT, " entries.");
+ } catch(std::out_of_range const& e) {
+ TORCH_WARN("invalid TORCH_CUDNN_V8_API_LRU_CACHE_LIMIT,",
+ " using default LRU cache limit of ", DEFAULT_LIMIT, " entries.");
+ }
+ return DEFAULT_LIMIT;
+ } ();
+ return limit;
+}
+
template <typename T, typename KeyType>
struct BenchmarkCache {
-std::unordered_map<KeyType, cudnn_frontend::ExecutionPlan, ParamsWrapperHash<KeyType>> engine_cache;
+std::list<KeyType> engine_cache_order;
+std::unordered_map<KeyType, std::pair<cudnn_frontend::ExecutionPlan, typename std::list<KeyType>::iterator>, ParamsWrapperHash<KeyType>> engine_cache;
// no mutexes here as caches are now thread local for v8, can also return a pointer
// to the Execution Plan if we know it will not be invalidated by another thread
cudnn_frontend::ExecutionPlan* find(const KeyType& key) {
+ const int lru_cache_limit = getLRUCacheLimit();
+ if (lru_cache_limit < 0) {
+ return nullptr;
+ }
auto it = engine_cache.find(key);
if (it == engine_cache.end()) {
return nullptr;
}
- return &(it->second);
+ if (lru_cache_limit) {
+ TORCH_INTERNAL_ASSERT(*(it->second.second) == key, "CUDNN V8 LRU Cache Corrupted (found key mismatches list). Please report a bug to PyTorch.");
+ auto engine_cache_order_size = engine_cache_order.size();
+ auto engine_cache_size = engine_cache.size();
+ TORCH_INTERNAL_ASSERT(engine_cache_order_size == engine_cache_size, "CUDNN V8 LRU Cache Corrupted (found list vs. map size mismatch). Please report a bug to PyTorch.");
+ // update most recently accessed
+ auto plan = it->second.first;
+ engine_cache_order.erase(it->second.second);
+ engine_cache_order.push_back(key);
+ engine_cache.erase(key);
+ engine_cache.emplace(key, std::make_pair(plan, --engine_cache_order.end()));
+ // iterator was invalidated by the erase, so we grab it again
+ it = engine_cache.find(key);
+ TORCH_INTERNAL_ASSERT(it->first == *(it->second.second), "CUDNN V8 LRU Cache Corrupted (refresh list vs. map key mismatch). Please report a bug to PyTorch.");
+ TORCH_INTERNAL_ASSERT((long) engine_cache_order.size() <= lru_cache_limit, "CUDNN V8 LRU Cache Corrupted (refresh size exceeds limit: ", lru_cache_limit, " please report a bug to PyTorch.");
+ TORCH_INTERNAL_ASSERT(engine_cache_order.size() == engine_cache_order_size, "CUDNN V8 LRU Cache Corrupted (list size unexpectedly changed). Please report a bug to PyTorch.");
+ TORCH_INTERNAL_ASSERT(engine_cache.size() == engine_cache.size(), "CUDNN V8 LRU Cache Corrupted (cache size unexpectedly changed). Please report a bug to PyTorch.");
+ }
+ return &(it->second.first);
}
void update(const KeyType& key, T& results) {
- engine_cache.erase(key);
- engine_cache.emplace(key, std::move(results));
+ int lru_cache_limit = getLRUCacheLimit();
+ if (lru_cache_limit < 0) {
+ return;
+ } else if (lru_cache_limit) {
+ auto it = engine_cache.find(key);
+ if (it == engine_cache.end()) {
+ auto engine_cache_order_size = engine_cache_order.size();
+ auto engine_cache_size = engine_cache.size();
+ TORCH_INTERNAL_ASSERT(engine_cache_order_size == engine_cache_size, "CUDNN V8 LRU Cache Corrupted (list vs. map size mismatch). Please report a bug to PyTorch.");
+ if ((long) engine_cache_order_size >= lru_cache_limit) {
+ // need to perform eviction
+ TORCH_INTERNAL_ASSERT(engine_cache.find(engine_cache_order.front()) != engine_cache.end(), "CUDNN V8 LRU Cache Corrupted (eviction key not in map). Please report a bug to PyTorch.");
+ engine_cache.erase(engine_cache_order.front());
+ engine_cache_order.pop_front();
+ }
+ } else {
+ TORCH_INTERNAL_ASSERT(*(it->second.second) == key, "CUDNN V8 LRU Cache Corrupted (list iterator key mismatch). Please report a bug to PyTorch.");
+ engine_cache_order.erase(it->second.second);
+ }
+ engine_cache_order.push_back(key);
+ engine_cache.erase(key);
+ engine_cache.emplace(key, std::make_pair(results, --engine_cache_order.end()));
+ TORCH_INTERNAL_ASSERT(engine_cache.find(key)->first == *(engine_cache.find(key)->second.second), "CUDNN V8 LRU Cache Corrupted (updated list vs. map key mismatch). Please report a bug to PyTorch.");
+ TORCH_INTERNAL_ASSERT((long) engine_cache_order.size() <= lru_cache_limit, "CUDNN V8 LRU Cache Corrupted (updated size exceeds limit: ", lru_cache_limit, " please report a bug to PyTorch.");
+ } else {
+ engine_cache.erase(key);
+ engine_cache.emplace(key, std::make_pair(results, engine_cache_order.end())); // dummy iterator
+ }
}
};