Skip to content

Commit

Permalink
fix: Count chord "final elements" correctly
Browse files Browse the repository at this point in the history
This change amends the implementation of `chord.__length_hint__()` to
ensure that all child task types are correctly counted. Specifically:

 * all sub-tasks of a group are counted recursively
 * the final task of a chain is counted recursively
 * the body of a chord is counted recursively
 * all other simple signatures count as a single "final element"

There is also a deserialisation step if a `dict` is seen while counting
the final elements in a chord, however this should become less important
with the merge of celery#6342 which ensures that tasks are recursively
deserialized by `.from_dict()`.
  • Loading branch information
maybe-sybr authored and auvipy committed Oct 14, 2020
1 parent f1dbf3f commit a7af4b2
Show file tree
Hide file tree
Showing 3 changed files with 218 additions and 20 deletions.
35 changes: 22 additions & 13 deletions celery/canvas.py
Original file line number Diff line number Diff line change
Expand Up @@ -1383,21 +1383,30 @@ def apply(self, args=None, kwargs=None,
args=(tasks.apply(args, kwargs).get(propagate=propagate),),
)

def _traverse_tasks(self, tasks, value=None):
stack = deque(tasks)
while stack:
task = stack.popleft()
if isinstance(task, group):
stack.extend(task.tasks)
elif isinstance(task, _chain) and isinstance(task.tasks[-1], group):
stack.extend(task.tasks[-1].tasks)
else:
yield task if value is None else value
@classmethod
def __descend(cls, sig_obj):
# Sometimes serialized signatures might make their way here
if not isinstance(sig_obj, Signature) and isinstance(sig_obj, dict):
sig_obj = Signature.from_dict(sig_obj)
if isinstance(sig_obj, group):
# Each task in a group counts toward this chord
subtasks = getattr(sig_obj.tasks, "tasks", sig_obj.tasks)
return sum(cls.__descend(task) for task in subtasks)
elif isinstance(sig_obj, _chain):
# The last element in a chain counts toward this chord
return cls.__descend(sig_obj.tasks[-1])
elif isinstance(sig_obj, chord):
# The child chord's body counts toward this chord
return cls.__descend(sig_obj.body)
elif isinstance(sig_obj, Signature):
# Each simple signature counts as 1 completion for this chord
return 1
# Any other types are assumed to be iterables of simple signatures
return len(sig_obj)

def __length_hint__(self):
tasks = (self.tasks.tasks if isinstance(self.tasks, group)
else self.tasks)
return sum(self._traverse_tasks(tasks, 1))
tasks = getattr(self.tasks, "tasks", self.tasks)
return sum(self.__descend(task) for task in tasks)

def run(self, header, body, partial_args, app=None, interval=None,
countdown=1, max_retries=None, eager=False,
Expand Down
22 changes: 22 additions & 0 deletions t/integration/test_canvas.py
Original file line number Diff line number Diff line change
Expand Up @@ -1009,3 +1009,25 @@ def test_priority_chain(self, manager):
c = return_priority.signature(priority=3) | return_priority.signature(
priority=5)
assert c().get(timeout=TIMEOUT) == "Priority: 5"

def test_nested_chord_group_chain_group_tail(self, manager):
"""
Sanity check that a deeply nested group is completed as expected.
Groups at the end of chains nested in chords have had issues and this
simple test sanity check that such a tsk structure can be completed.
"""
try:
manager.app.backend.ensure_chords_allowed()
except NotImplementedError as e:
raise pytest.skip(e.args[0])

sig = chord(group(chain(
identity.s(42), # -> 42
group(
identity.s(), # -> 42
identity.s(), # -> 42
), # [42, 42]
)), identity.s()) # [[42, 42]]
res = sig.delay()
assert res.get(timeout=TIMEOUT) == [[42, 42]]
181 changes: 174 additions & 7 deletions t/unit/tasks/test_canvas.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
import json
from unittest.mock import MagicMock, Mock, patch, sentinel
from unittest.mock import MagicMock, Mock, call, patch, sentinel

import pytest

Expand Down Expand Up @@ -808,12 +808,179 @@ def test_app_fallback_to_current(self):
x = chord([t1], body=t1)
assert x.app is current_app

def test_chord_size_with_groups(self):
x = chord([
self.add.s(2, 2) | group([self.add.si(2, 2), self.add.si(2, 2)]),
self.add.s(2, 2) | group([self.add.si(2, 2), self.add.si(2, 2)]),
], body=self.add.si(2, 2))
assert x.__length_hint__() == 4
def test_chord_size_simple(self):
sig = chord(self.add.s())
assert sig.__length_hint__() == 1

def test_chord_size_with_body(self):
sig = chord(self.add.s(), self.add.s())
assert sig.__length_hint__() == 1

def test_chord_size_explicit_group_single(self):
sig = chord(group(self.add.s()))
assert sig.__length_hint__() == 1

def test_chord_size_explicit_group_many(self):
sig = chord(group([self.add.s()] * 42))
assert sig.__length_hint__() == 42

def test_chord_size_implicit_group_single(self):
sig = chord([self.add.s()])
assert sig.__length_hint__() == 1

def test_chord_size_implicit_group_many(self):
sig = chord([self.add.s()] * 42)
assert sig.__length_hint__() == 42

def test_chord_size_chain_single(self):
sig = chord(chain(self.add.s()))
assert sig.__length_hint__() == 1

def test_chord_size_chain_many(self):
# Chains get flattened into the encapsulating chord so even though the
# chain would only count for 1, the tasks we pulled into the chord's
# header and are counted as a bunch of simple signature objects
sig = chord(chain([self.add.s()] * 42))
assert sig.__length_hint__() == 42

def test_chord_size_nested_chain_chain_single(self):
sig = chord(chain(chain(self.add.s())))
assert sig.__length_hint__() == 1

def test_chord_size_nested_chain_chain_many(self):
# The outer chain will be pulled up into the chord but the lower one
# remains and will only count as a single final element
sig = chord(chain(chain([self.add.s()] * 42)))
assert sig.__length_hint__() == 1

def test_chord_size_implicit_chain_single(self):
sig = chord([self.add.s()])
assert sig.__length_hint__() == 1

def test_chord_size_implicit_chain_many(self):
# This isn't a chain object so the `tasks` attribute can't be lifted
# into the chord - this isn't actually valid and would blow up we tried
# to run it but it sanity checks our recursion
sig = chord([[self.add.s()] * 42])
assert sig.__length_hint__() == 1

def test_chord_size_nested_implicit_chain_chain_single(self):
sig = chord([chain(self.add.s())])
assert sig.__length_hint__() == 1

def test_chord_size_nested_implicit_chain_chain_many(self):
sig = chord([chain([self.add.s()] * 42)])
assert sig.__length_hint__() == 1

def test_chord_size_nested_chord_body_simple(self):
sig = chord(chord(tuple(), self.add.s()))
assert sig.__length_hint__() == 1

def test_chord_size_nested_chord_body_implicit_group_single(self):
sig = chord(chord(tuple(), [self.add.s()]))
assert sig.__length_hint__() == 1

def test_chord_size_nested_chord_body_implicit_group_many(self):
sig = chord(chord(tuple(), [self.add.s()] * 42))
assert sig.__length_hint__() == 42

# Nested groups in a chain only affect the chord size if they are the last
# element in the chain - in that case each group element is counted
def test_chord_size_nested_group_chain_group_head_single(self):
x = chord(
group(
[group(self.add.s()) | self.add.s()] * 42
),
body=self.add.s()
)
assert x.__length_hint__() == 42

def test_chord_size_nested_group_chain_group_head_many(self):
x = chord(
group(
[group([self.add.s()] * 4) | self.add.s()] * 2
),
body=self.add.s()
)
assert x.__length_hint__() == 2

def test_chord_size_nested_group_chain_group_mid_single(self):
x = chord(
group(
[self.add.s() | group(self.add.s()) | self.add.s()] * 42
),
body=self.add.s()
)
assert x.__length_hint__() == 42

def test_chord_size_nested_group_chain_group_mid_many(self):
x = chord(
group(
[self.add.s() | group([self.add.s()] * 4) | self.add.s()] * 2
),
body=self.add.s()
)
assert x.__length_hint__() == 2

def test_chord_size_nested_group_chain_group_tail_single(self):
x = chord(
group(
[self.add.s() | group(self.add.s())] * 42
),
body=self.add.s()
)
assert x.__length_hint__() == 42

def test_chord_size_nested_group_chain_group_tail_many(self):
x = chord(
group(
[self.add.s() | group([self.add.s()] * 4)] * 2
),
body=self.add.s()
)
assert x.__length_hint__() == 4 * 2

def test_chord_size_nested_implicit_group_chain_group_tail_single(self):
x = chord(
[self.add.s() | group(self.add.s())] * 42,
body=self.add.s()
)
assert x.__length_hint__() == 42

def test_chord_size_nested_implicit_group_chain_group_tail_many(self):
x = chord(
[self.add.s() | group([self.add.s()] * 4)] * 2,
body=self.add.s()
)
assert x.__length_hint__() == 4 * 2

def test_chord_size_deserialized_element_single(self):
child_sig = self.add.s()
deserialized_child_sig = json.loads(json.dumps(child_sig))
# We have to break in to be sure that a child remains as a `dict` so we
# can confirm that the length hint will instantiate a `Signature`
# object and then descend as expected
chord_sig = chord(tuple())
chord_sig.tasks = [deserialized_child_sig]
with patch(
"celery.canvas.Signature.from_dict", return_value=child_sig
) as mock_from_dict:
assert chord_sig. __length_hint__() == 1
mock_from_dict.assert_called_once_with(deserialized_child_sig)

def test_chord_size_deserialized_element_many(self):
child_sig = self.add.s()
deserialized_child_sig = json.loads(json.dumps(child_sig))
# We have to break in to be sure that a child remains as a `dict` so we
# can confirm that the length hint will instantiate a `Signature`
# object and then descend as expected
chord_sig = chord(tuple())
chord_sig.tasks = [deserialized_child_sig] * 42
with patch(
"celery.canvas.Signature.from_dict", return_value=child_sig
) as mock_from_dict:
assert chord_sig. __length_hint__() == 42
mock_from_dict.assert_has_calls([call(deserialized_child_sig)] * 42)

def test_set_immutable(self):
x = chord([Mock(name='t1'), Mock(name='t2')], app=self.app)
Expand Down

0 comments on commit a7af4b2

Please sign in to comment.