Skip to content

Commit

Permalink
improv: Use single-lookahead for regen consumption
Browse files Browse the repository at this point in the history
This change introduces a helper method which abstracts the logic of
consuming items one by one in `regen` and also introduces a single
lookahead to ensure that the `__done` property gets set even if the
regen is not fully iterated, fixing an edge case where a repeatable
iterator would get doubled when used as a base for a `regen` instance.
  • Loading branch information
maybe-sybr authored and auvipy committed Jun 16, 2021
1 parent 82f76d9 commit d667f1f
Show file tree
Hide file tree
Showing 3 changed files with 106 additions and 18 deletions.
50 changes: 33 additions & 17 deletions celery/utils/functional.py
Original file line number Diff line number Diff line change
Expand Up @@ -195,7 +195,6 @@ def __init__(self, it):
# UserList creates a new list and sets .data, so we don't
# want to call init here.
self.__it = it
self.__index = 0
self.__consumed = []
self.__done = False

Expand All @@ -205,28 +204,45 @@ def __reduce__(self):
def __length_hint__(self):
return self.__it.__length_hint__()

def __lookahead_consume(self, limit=None):
if not self.__done and (limit is None or limit > 0):
it = iter(self.__it)
try:
now = next(it)
except StopIteration:
return
self.__consumed.append(now)
# Maintain a single look-ahead to ensure we set `__done` when the
# underlying iterator gets exhausted
while not self.__done:
try:
next_ = next(it)
self.__consumed.append(next_)
except StopIteration:
self.__done = True
break
finally:
yield now
now = next_
# We can break out when `limit` is exhausted
if limit is not None:
limit -= 1
if limit <= 0:
break

def __iter__(self):
yield from self.__consumed
if not self.__done:
for x in self.__it:
self.__consumed.append(x)
yield x
self.__done = True
yield from self.__lookahead_consume()

def __getitem__(self, index):
if index < 0:
return self.data[index]
try:
return self.__consumed[index]
except IndexError:
it = iter(self)
try:
for _ in range(self.__index, index + 1):
next(it)
except StopIteration:
raise IndexError(index)
else:
return self.__consumed[index]
# Consume elements up to the desired index prior to attempting to
# access it from within `__consumed`
consume_count = index - len(self.__consumed) + 1
for _ in self.__lookahead_consume(limit=consume_count):
pass
return self.__consumed[index]

def __bool__(self):
if len(self.__consumed):
Expand Down
6 changes: 5 additions & 1 deletion t/unit/tasks/test_canvas.py
Original file line number Diff line number Diff line change
Expand Up @@ -978,11 +978,15 @@ def build_generator():
yield self.add.s(1, 1)
self.second_item_returned = True
yield self.add.s(2, 2)
raise pytest.fail("This should never be reached")

self.second_item_returned = False
c = chord(build_generator(), self.add.s(3))
c.app
assert not self.second_item_returned
# The second task gets returned due to lookahead in `regen()`
assert self.second_item_returned
# Access it again to make sure the generator is not further evaluated
c.app

def test_reverse(self):
x = chord([self.add.s(2, 2), self.add.s(4, 4)], body=self.mul.s(4))
Expand Down
68 changes: 68 additions & 0 deletions t/unit/utils/test_functional.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import collections

import pytest
import pytest_subtests
from kombu.utils.functional import lazy

from celery.utils.functional import (DummyContext, first, firstmethod,
Expand Down Expand Up @@ -206,6 +207,73 @@ def __iter__(self):
# Finally we xfail this test to keep track of it
raise pytest.xfail(reason="#6794")

def test_length_hint_passthrough(self, g):
assert g.__length_hint__() == 10

def test_getitem_repeated(self, g):
halfway_idx = g.__length_hint__() // 2
assert g[halfway_idx] == halfway_idx
# These are now concretised so they should be returned without any work
assert g[halfway_idx] == halfway_idx
for i in range(halfway_idx + 1):
assert g[i] == i
# This should only need to concretise one more element
assert g[halfway_idx + 1] == halfway_idx + 1

def test_done_does_not_lag(self, g):
"""
Don't allow regen to return from `__iter__()` and check `__done`.
"""
# The range we zip with here should ensure that the `regen.__iter__`
# call never gets to return since we never attempt a failing `next()`
len_g = g.__length_hint__()
for i, __ in zip(range(len_g), g):
assert getattr(g, "_regen__done") is (i == len_g - 1)
# Just for sanity, check against a specific `bool` here
assert getattr(g, "_regen__done") is True

def test_lookahead_consume(self, subtests):
"""
Confirm that regen looks ahead by a single item as expected.
"""
def g():
yield from ["foo", "bar"]
raise pytest.fail("This should never be reached")

with subtests.test(msg="bool does not overconsume"):
assert bool(regen(g()))
with subtests.test(msg="getitem 0th does not overconsume"):
assert regen(g())[0] == "foo"
with subtests.test(msg="single iter does not overconsume"):
assert next(iter(regen(g()))) == "foo"

class ExpectedException(BaseException):
pass

def g2():
yield from ["foo", "bar"]
raise ExpectedException()

with subtests.test(msg="getitem 1th does overconsume"):
r = regen(g2())
with pytest.raises(ExpectedException):
r[1]
# Confirm that the item was concretised anyway
assert r[1] == "bar"
with subtests.test(msg="full iter does overconsume"):
r = regen(g2())
with pytest.raises(ExpectedException):
for _ in r:
pass
# Confirm that the items were concretised anyway
assert r == ["foo", "bar"]
with subtests.test(msg="data access does overconsume"):
r = regen(g2())
with pytest.raises(ExpectedException):
r.data
# Confirm that the items were concretised anyway
assert r == ["foo", "bar"]


class test_head_from_fun:

Expand Down

0 comments on commit d667f1f

Please sign in to comment.