Added support for async loop context
diff --git a/jinja2/asyncsupport.py b/jinja2/asyncsupport.py
index 0a19b6a..0c4fecf 100644
--- a/jinja2/asyncsupport.py
+++ b/jinja2/asyncsupport.py
@@ -4,6 +4,7 @@
from jinja2.utils import concat, internalcode, concat, Markup
from jinja2.environment import TemplateModule
+from jinja2.runtime import LoopContextBase, _last_iteration
async def concat_async(async_gen):
@@ -144,3 +145,45 @@
return
for item in iterable:
yield item
+
+
+class AsyncLoopContext(LoopContextBase):
+
+ def __init__(self, async_iterator, iterable, after, recurse=None, depth0=0):
+ self._async_iterator = async_iterator
+ LoopContextBase.__init__(self, iterable, recurse, depth0)
+ self._after = after
+
+ def __aiter__(self):
+ return AsyncLoopContextIterator(self)
+
+
+class AsyncLoopContextIterator(object):
+ __slots__ = ('context',)
+
+ def __init__(self, context):
+ self.context = context
+
+ def __aiter__(self):
+ return self
+
+ async def __anext__(self):
+ ctx = self.context
+ ctx.index0 += 1
+ if ctx._after is _last_iteration:
+ raise StopAsyncIteration()
+ next_elem = ctx._after
+ try:
+ ctx._after = await ctx._async_iterator.__anext__()
+ except StopAsyncIteration:
+ ctx._after = _last_iteration
+ return next_elem, ctx
+
+
+async def make_async_loop_context(iterable, recurse=None, depth0=0):
+ async_iterator = auto_iter(iterable)
+ try:
+ after = await async_iterator.__anext__()
+ except StopAsyncIteration:
+ after = _last_iteration
+ return AsyncLoopContext(async_iterator, iterable, after, recurse, depth0)
diff --git a/jinja2/compiler.py b/jinja2/compiler.py
index f2714c6..667fec5 100644
--- a/jinja2/compiler.py
+++ b/jinja2/compiler.py
@@ -783,7 +783,8 @@
self.writeline('dummy = lambda *x: None')
if self.environment._async:
- self.writeline('from jinja2.asyncsupport import auto_await, auto_iter')
+ self.writeline('from jinja2.asyncsupport import auto_await, '
+ 'auto_iter, make_async_loop_context')
# if we want a deferred initialization we cannot move the
# environment into a local name
@@ -1132,7 +1133,13 @@
self.writeline(self.environment._async and 'async for ' or 'for ', node)
self.visit(node.target, loop_frame)
- self.write(extended_loop and ', l_loop in LoopContext(' or ' in ')
+ if extended_loop:
+ if self.environment._async:
+ self.write(', l_loop in await make_async_loop_context(')
+ else:
+ self.write(', l_loop in LoopContext(')
+ else:
+ self.write(' in ')
# if we have an extened loop and a node test, we filter in the
# "outer frame".
@@ -1158,10 +1165,10 @@
elif node.recursive:
self.write('reciter')
else:
- if self.environment._async:
+ if self.environment._async and not extended_loop:
self.write('auto_iter(')
self.visit(node.iter, loop_frame)
- if self.environment._async:
+ if self.environment._async and not extended_loop:
self.write(')')
if node.recursive:
diff --git a/jinja2/runtime.py b/jinja2/runtime.py
index 685a12d..e0df9b7 100644
--- a/jinja2/runtime.py
+++ b/jinja2/runtime.py
@@ -280,13 +280,14 @@
return rv
-class LoopContext(object):
+class LoopContextBase(object):
"""A loop context for dynamic iteration."""
+ _after = _last_iteration
+ _length = None
+
def __init__(self, iterable, recurse=None, depth0=0):
- self._iterator = iter(iterable)
self._recurse = recurse
- self._after = self._safe_next()
self.index0 = -1
self.depth0 = depth0
@@ -315,15 +316,6 @@
def __len__(self):
return self.length
- def __iter__(self):
- return LoopContextIterator(self)
-
- def _safe_next(self):
- try:
- return next(self._iterator)
- except StopIteration:
- return _last_iteration
-
@internalcode
def loop(self, iterable):
if self._recurse is None:
@@ -357,6 +349,23 @@
)
+class LoopContext(LoopContextBase):
+
+ def __init__(self, iterable, recurse=None, depth0=0):
+ self._iterator = iter(iterable)
+ LoopContextBase.__init__(self, iterable, recurse, depth0)
+ self._after = self._safe_next()
+
+ def __iter__(self):
+ return LoopContextIterator(self)
+
+ def _safe_next(self):
+ try:
+ return next(self._iterator)
+ except StopIteration:
+ return _last_iteration
+
+
@implements_iterator
class LoopContextIterator(object):
"""The iterator for a loop context."""
diff --git a/tests/test_async.py b/tests/test_async.py
index 44463d1..fff732b 100644
--- a/tests/test_async.py
+++ b/tests/test_async.py
@@ -97,7 +97,7 @@
@pytest.mark.skipif(not have_async_gen, reason='No async generators')
-def test_async_iteration_in_tmeplates():
+def test_async_iteration_in_templates():
t = Template('{% for x in rng %}{{ x }}{% endfor %}',
enable_async=True)
async def async_iterator():
@@ -107,6 +107,17 @@
assert rv == ['1', '2', '3']
+@pytest.mark.skipif(not have_async_gen, reason='No async generators')
+def test_async_iteration_in_templates_extended():
+ t = Template('{% for x in rng %}{{ loop.index0 }}/{{ x }}{% endfor %}',
+ enable_async=True)
+ async def async_iterator():
+ for item in [1, 2, 3]:
+ yield item
+ rv = list(t.generate(rng=async_iterator()))
+ assert rv == ['0/1', '1/2', '2/3']
+
+
@pytest.fixture
def test_env_async():
env = Environment(loader=DictLoader(dict(