blob: 89153399044a09bc64afbb2cc8fdcaf5d6b31942 [file] [edit]
import threading
import time
import unittest
from os import environ
from typing import Any
from cachetools import LRUCache, cached, cachedmethod
count = 0
@cached(cache=LRUCache(1), condition=threading.Condition(), info=True)
def func():
global count
time.sleep(1.0)
count += 1
return count
@unittest.skipUnless(environ.get("THREADING_TESTS", False), "THREADING_TESTS not set")
class ThreadingTest(unittest.TestCase):
NTHREADS = 10
cache: LRUCache[Any, int] = LRUCache(1)
cond = threading.Condition()
count = 0
@cachedmethod(
cache=lambda self: self.cache, condition=lambda self: self.cond, info=True
)
def meth(self):
time.sleep(1.0)
self.count += 1
return 42
def test_cached_stampede(self):
threads = [threading.Thread(target=func) for i in range(0, self.NTHREADS)]
for t in threads:
t.start()
for t in threads:
t.join()
self.assertEqual(count, 1)
info = func.cache_info()
self.assertEqual(info.hits, self.NTHREADS - 1)
self.assertEqual(info.misses, 1)
def test_cachedmethod_stampede(self):
threads = [threading.Thread(target=self.meth) for i in range(0, self.NTHREADS)]
for t in threads:
t.start()
for t in threads:
t.join()
self.assertEqual(self.count, 1)
info = self.meth.cache_info()
self.assertEqual(info.hits, self.NTHREADS - 1)
self.assertEqual(info.misses, 1)