diff --git a/celery/utils/functional.py b/celery/utils/functional.py index a82991b2437..2878bc15ea0 100644 --- a/celery/utils/functional.py +++ b/celery/utils/functional.py @@ -195,7 +195,6 @@ def __init__(self, it): # UserList creates a new list and sets .data, so we don't # want to call init here. self.__it = it - self.__index = 0 self.__consumed = [] self.__done = False @@ -205,28 +204,45 @@ def __reduce__(self): def __length_hint__(self): return self.__it.__length_hint__() + def __lookahead_consume(self, limit=None): + if not self.__done and (limit is None or limit > 0): + it = iter(self.__it) + try: + now = next(it) + except StopIteration: + return + self.__consumed.append(now) + # Maintain a single look-ahead to ensure we set `__done` when the + # underlying iterator gets exhausted + while not self.__done: + try: + next_ = next(it) + self.__consumed.append(next_) + except StopIteration: + self.__done = True + break + finally: + yield now + now = next_ + # We can break out when `limit` is exhausted + if limit is not None: + limit -= 1 + if limit <= 0: + break + def __iter__(self): yield from self.__consumed - if not self.__done: - for x in self.__it: - self.__consumed.append(x) - yield x - self.__done = True + yield from self.__lookahead_consume() def __getitem__(self, index): if index < 0: return self.data[index] - try: - return self.__consumed[index] - except IndexError: - it = iter(self) - try: - for _ in range(self.__index, index + 1): - next(it) - except StopIteration: - raise IndexError(index) - else: - return self.__consumed[index] + # Consume elements up to the desired index prior to attempting to + # access it from within `__consumed` + consume_count = index - len(self.__consumed) + 1 + for _ in self.__lookahead_consume(limit=consume_count): + pass + return self.__consumed[index] def __bool__(self): if len(self.__consumed): diff --git a/t/unit/tasks/test_canvas.py b/t/unit/tasks/test_canvas.py index 1b6064f0db5..487e3b1d6fe 100644 --- a/t/unit/tasks/test_canvas.py +++ b/t/unit/tasks/test_canvas.py @@ -978,11 +978,15 @@ def build_generator(): yield self.add.s(1, 1) self.second_item_returned = True yield self.add.s(2, 2) + raise pytest.fail("This should never be reached") self.second_item_returned = False c = chord(build_generator(), self.add.s(3)) c.app - assert not self.second_item_returned + # The second task gets returned due to lookahead in `regen()` + assert self.second_item_returned + # Access it again to make sure the generator is not further evaluated + c.app def test_reverse(self): x = chord([self.add.s(2, 2), self.add.s(4, 4)], body=self.mul.s(4)) diff --git a/t/unit/utils/test_functional.py b/t/unit/utils/test_functional.py index d7e8b686f5e..fe12f426462 100644 --- a/t/unit/utils/test_functional.py +++ b/t/unit/utils/test_functional.py @@ -1,6 +1,7 @@ import collections import pytest +import pytest_subtests from kombu.utils.functional import lazy from celery.utils.functional import (DummyContext, first, firstmethod, @@ -206,6 +207,73 @@ def __iter__(self): # Finally we xfail this test to keep track of it raise pytest.xfail(reason="#6794") + def test_length_hint_passthrough(self, g): + assert g.__length_hint__() == 10 + + def test_getitem_repeated(self, g): + halfway_idx = g.__length_hint__() // 2 + assert g[halfway_idx] == halfway_idx + # These are now concretised so they should be returned without any work + assert g[halfway_idx] == halfway_idx + for i in range(halfway_idx + 1): + assert g[i] == i + # This should only need to concretise one more element + assert g[halfway_idx + 1] == halfway_idx + 1 + + def test_done_does_not_lag(self, g): + """ + Don't allow regen to return from `__iter__()` and check `__done`. + """ + # The range we zip with here should ensure that the `regen.__iter__` + # call never gets to return since we never attempt a failing `next()` + len_g = g.__length_hint__() + for i, __ in zip(range(len_g), g): + assert getattr(g, "_regen__done") is (i == len_g - 1) + # Just for sanity, check against a specific `bool` here + assert getattr(g, "_regen__done") is True + + def test_lookahead_consume(self, subtests): + """ + Confirm that regen looks ahead by a single item as expected. + """ + def g(): + yield from ["foo", "bar"] + raise pytest.fail("This should never be reached") + + with subtests.test(msg="bool does not overconsume"): + assert bool(regen(g())) + with subtests.test(msg="getitem 0th does not overconsume"): + assert regen(g())[0] == "foo" + with subtests.test(msg="single iter does not overconsume"): + assert next(iter(regen(g()))) == "foo" + + class ExpectedException(BaseException): + pass + + def g2(): + yield from ["foo", "bar"] + raise ExpectedException() + + with subtests.test(msg="getitem 1th does overconsume"): + r = regen(g2()) + with pytest.raises(ExpectedException): + r[1] + # Confirm that the item was concretised anyway + assert r[1] == "bar" + with subtests.test(msg="full iter does overconsume"): + r = regen(g2()) + with pytest.raises(ExpectedException): + for _ in r: + pass + # Confirm that the items were concretised anyway + assert r == ["foo", "bar"] + with subtests.test(msg="data access does overconsume"): + r = regen(g2()) + with pytest.raises(ExpectedException): + r.data + # Confirm that the items were concretised anyway + assert r == ["foo", "bar"] + class test_head_from_fun: