Skip to content

Commit

Permalink
Preserve order of group results with Redis result backend (celery#6218)
Browse files Browse the repository at this point in the history
* Preserve order of return values from groups

Fixes celery#3781.

* Update for zadd arguments changed in redis-py 3

* Use more explicit loop variable name

* Handle group_index not set

* Use zrange instead of zrangebyscore

* test: Fix Redis sorted set mocks in backend tests

* test: Make canvas integration tests use `zrange()`

The test suite still uses `lrange()` and `rpush()` to implement its
`redis-echo` task chain integration tests, but these are unrelated to
the handling of group results and remain unchanged.

* test: Add unit tests for `group_index` handling

* fix: Add `group_index` to `Context`, chord uplift

* test: Sanity check `Request.group_index` property

This adds a test to make sure the property exists and also changes the
property to use the private `_request_dict` rather than the public
property.

Co-authored-by: Leo Singer <[email protected]>
  • Loading branch information
maybe-sybr and lpsinger authored Jul 19, 2020
1 parent 9dddf8c commit 455e0a0
Show file tree
Hide file tree
Showing 11 changed files with 107 additions and 45 deletions.
6 changes: 4 additions & 2 deletions celery/app/amqp.py
Original file line number Diff line number Diff line change
Expand Up @@ -305,7 +305,7 @@ def TaskConsumer(self, channel, queues=None, accept=None, **kw):
)

def as_task_v2(self, task_id, name, args=None, kwargs=None,
countdown=None, eta=None, group_id=None,
countdown=None, eta=None, group_id=None, group_index=None,
expires=None, retries=0, chord=None,
callbacks=None, errbacks=None, reply_to=None,
time_limit=None, soft_time_limit=None,
Expand Down Expand Up @@ -363,6 +363,7 @@ def as_task_v2(self, task_id, name, args=None, kwargs=None,
'eta': eta,
'expires': expires,
'group': group_id,
'group_index': group_index,
'retries': retries,
'timelimit': [time_limit, soft_time_limit],
'root_id': root_id,
Expand Down Expand Up @@ -397,7 +398,7 @@ def as_task_v2(self, task_id, name, args=None, kwargs=None,
)

def as_task_v1(self, task_id, name, args=None, kwargs=None,
countdown=None, eta=None, group_id=None,
countdown=None, eta=None, group_id=None, group_index=None,
expires=None, retries=0,
chord=None, callbacks=None, errbacks=None, reply_to=None,
time_limit=None, soft_time_limit=None,
Expand Down Expand Up @@ -442,6 +443,7 @@ def as_task_v1(self, task_id, name, args=None, kwargs=None,
'args': args,
'kwargs': kwargs,
'group': group_id,
'group_index': group_index,
'retries': retries,
'eta': eta,
'expires': expires,
Expand Down
5 changes: 3 additions & 2 deletions celery/app/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -680,7 +680,8 @@ def send_task(self, name, args=None, kwargs=None, countdown=None,
eta=None, task_id=None, producer=None, connection=None,
router=None, result_cls=None, expires=None,
publisher=None, link=None, link_error=None,
add_to_parent=True, group_id=None, retries=0, chord=None,
add_to_parent=True, group_id=None, group_index=None,
retries=0, chord=None,
reply_to=None, time_limit=None, soft_time_limit=None,
root_id=None, parent_id=None, route_name=None,
shadow=None, chain=None, task_type=None, **options):
Expand Down Expand Up @@ -720,7 +721,7 @@ def send_task(self, name, args=None, kwargs=None, countdown=None,
parent.request.delivery_info.get('priority'))

message = amqp.create_task_message(
task_id, name, args, kwargs, countdown, eta, group_id,
task_id, name, args, kwargs, countdown, eta, group_id, group_index,
expires, retries, chord,
maybe_list(link), maybe_list(link_error),
reply_to or self.oid, time_limit, soft_time_limit,
Expand Down
4 changes: 4 additions & 0 deletions celery/app/task.py
Original file line number Diff line number Diff line change
Expand Up @@ -83,6 +83,7 @@ class Context(object):
correlation_id = None
taskset = None # compat alias to group
group = None
group_index = None
chord = None
chain = None
utc = None
Expand Down Expand Up @@ -116,6 +117,7 @@ def as_execution_options(self):
'root_id': self.root_id,
'parent_id': self.parent_id,
'group_id': self.group,
'group_index': self.group_index,
'chord': self.chord,
'chain': self.chain,
'link': self.callbacks,
Expand Down Expand Up @@ -891,6 +893,7 @@ def replace(self, sig):
sig.set(
chord=chord,
group_id=self.request.group,
group_index=self.request.group_index,
root_id=self.request.root_id,
)
sig.freeze(self.request.id)
Expand All @@ -917,6 +920,7 @@ def add_to_chord(self, sig, lazy=False):
raise ValueError('Current task is not member of any chord')
sig.set(
group_id=self.request.group,
group_index=self.request.group_index,
chord=self.request.chord,
root_id=self.request.root_id,
)
Expand Down
11 changes: 7 additions & 4 deletions celery/backends/redis.py
Original file line number Diff line number Diff line change
Expand Up @@ -413,18 +413,21 @@ def apply_chord(self, header_result, body, **kwargs):
def on_chord_part_return(self, request, state, result,
propagate=None, **kwargs):
app = self.app
tid, gid = request.id, request.group
tid, gid, group_index = request.id, request.group, request.group_index
if not gid or not tid:
return
if group_index is None:
group_index = '+inf'

client = self.client
jkey = self.get_key_for_group(gid, '.j')
tkey = self.get_key_for_group(gid, '.t')
result = self.encode_result(result, state)
with client.pipeline() as pipe:
pipeline = pipe \
.rpush(jkey, self.encode([1, tid, state, result])) \
.llen(jkey) \
.zadd(jkey,
{self.encode([1, tid, state, result]): group_index}) \
.zcount(jkey, '-inf', '+inf') \
.get(tkey)

if self.expires is not None:
Expand All @@ -443,7 +446,7 @@ def on_chord_part_return(self, request, state, result,
decode, unpack = self.decode, self._unpack_chord_result
with client.pipeline() as pipe:
resl, = pipe \
.lrange(jkey, 0, total) \
.zrange(jkey, 0, -1) \
.execute()
try:
callback.delay([unpack(tup, decode) for tup in resl])
Expand Down
27 changes: 18 additions & 9 deletions celery/canvas.py
Original file line number Diff line number Diff line change
Expand Up @@ -276,7 +276,7 @@ def clone(self, args=None, kwargs=None, **opts):
partial = clone

def freeze(self, _id=None, group_id=None, chord=None,
root_id=None, parent_id=None):
root_id=None, parent_id=None, group_index=None):
"""Finalize the signature by adding a concrete task id.
The task won't be called and you shouldn't call the signature
Expand All @@ -303,6 +303,8 @@ def freeze(self, _id=None, group_id=None, chord=None,
opts['group_id'] = group_id
if chord:
opts['chord'] = chord
if group_index is not None:
opts['group_index'] = group_index
# pylint: disable=too-many-function-args
# Borks on this, as it's a property.
return self.AsyncResult(tid)
Expand Down Expand Up @@ -674,19 +676,21 @@ def run(self, args=None, kwargs=None, group_id=None, chord=None,
return results[0]

def freeze(self, _id=None, group_id=None, chord=None,
root_id=None, parent_id=None):
root_id=None, parent_id=None, group_index=None):
# pylint: disable=redefined-outer-name
# XXX chord is also a class in outer scope.
_, results = self._frozen = self.prepare_steps(
self.args, self.kwargs, self.tasks, root_id, parent_id, None,
self.app, _id, group_id, chord, clone=False,
group_index=group_index,
)
return results[0]

def prepare_steps(self, args, kwargs, tasks,
root_id=None, parent_id=None, link_error=None, app=None,
last_task_id=None, group_id=None, chord_body=None,
clone=True, from_dict=Signature.from_dict):
clone=True, from_dict=Signature.from_dict,
group_index=None):
app = app or self.app
# use chain message field for protocol 2 and later.
# this avoids pickle blowing the stack on the recursion
Expand Down Expand Up @@ -763,6 +767,7 @@ def prepare_steps(self, args, kwargs, tasks,
res = task.freeze(
last_task_id,
root_id=root_id, group_id=group_id, chord=chord_body,
group_index=group_index,
)
else:
res = task.freeze(root_id=root_id)
Expand Down Expand Up @@ -1189,7 +1194,7 @@ def _freeze_gid(self, options):
return options, group_id, options.get('root_id')

def freeze(self, _id=None, group_id=None, chord=None,
root_id=None, parent_id=None):
root_id=None, parent_id=None, group_index=None):
# pylint: disable=redefined-outer-name
# XXX chord is also a class in outer scope.
opts = self.options
Expand All @@ -1201,6 +1206,8 @@ def freeze(self, _id=None, group_id=None, chord=None,
opts['group_id'] = group_id
if chord:
opts['chord'] = chord
if group_index is not None:
opts['group_index'] = group_index
root_id = opts.setdefault('root_id', root_id)
parent_id = opts.setdefault('parent_id', parent_id)
new_tasks = []
Expand All @@ -1221,6 +1228,7 @@ def _freeze_unroll(self, new_tasks, group_id, chord, root_id, parent_id):
# pylint: disable=redefined-outer-name
# XXX chord is also a class in outer scope.
stack = deque(self.tasks)
group_index = 0
while stack:
task = maybe_signature(stack.popleft(), app=self._app).clone()
if isinstance(task, group):
Expand All @@ -1229,7 +1237,9 @@ def _freeze_unroll(self, new_tasks, group_id, chord, root_id, parent_id):
new_tasks.append(task)
yield task.freeze(group_id=group_id,
chord=chord, root_id=root_id,
parent_id=parent_id)
parent_id=parent_id,
group_index=group_index)
group_index += 1

def __repr__(self):
if self.tasks:
Expand Down Expand Up @@ -1308,17 +1318,16 @@ def __call__(self, body=None, **options):
return self.apply_async((), {'body': body} if body else {}, **options)

def freeze(self, _id=None, group_id=None, chord=None,
root_id=None, parent_id=None):
root_id=None, parent_id=None, group_index=None):
# pylint: disable=redefined-outer-name
# XXX chord is also a class in outer scope.
if not isinstance(self.tasks, group):
self.tasks = group(self.tasks, app=self.app)
header_result = self.tasks.freeze(
parent_id=parent_id, root_id=root_id, chord=self.body)

body_result = self.body.freeze(
_id, root_id=root_id, chord=chord, group_id=group_id)

_id, root_id=root_id, chord=chord, group_id=group_id,
group_index=group_index)
# we need to link the body result back to the group result,
# but the body may actually be a chain,
# so find the first result without a parent
Expand Down
3 changes: 2 additions & 1 deletion celery/utils/abstract.py
Original file line number Diff line number Diff line change
Expand Up @@ -118,7 +118,8 @@ def clone(self, args=None, kwargs=None):
pass

@abstractmethod
def freeze(self, id=None, group_id=None, chord=None, root_id=None):
def freeze(self, id=None, group_id=None, chord=None, root_id=None,
group_index=None):
pass

@abstractmethod
Expand Down
5 changes: 5 additions & 0 deletions celery/worker/request.py
Original file line number Diff line number Diff line change
Expand Up @@ -626,6 +626,11 @@ def _context(self):
request.update(**embed or {})
return Context(request)

@cached_property
def group_index(self):
# used by backend.on_chord_part_return to order return values in group
return self._request_dict.get('group_index')


def create_request_cls(base, task, pool, hostname, eventer,
ref=ref, revoked_tasks=revoked_tasks,
Expand Down
4 changes: 2 additions & 2 deletions t/integration/test_canvas.py
Original file line number Diff line number Diff line change
Expand Up @@ -618,7 +618,7 @@ def test_add_to_chord(self, manager):

c = group([add_to_all_to_chord.s([1, 2, 3], 4)]) | identity.s()
res = c()
assert res.get() == [0, 5, 6, 7]
assert sorted(res.get()) == [0, 5, 6, 7]

@pytest.mark.flaky(reruns=5, reruns_delay=1, cause=is_retryable_exception)
def test_add_chord_to_chord(self, manager):
Expand Down Expand Up @@ -857,7 +857,7 @@ def test_chord_on_error(self, manager):
j_key = backend.get_key_for_group(original_group_id, '.j')
redis_connection = get_redis_connection()
chord_results = [backend.decode(t) for t in
redis_connection.lrange(j_key, 0, 3)]
redis_connection.zrange(j_key, 0, 3)]

# Validate group result
assert [cr[3] for cr in chord_results if cr[2] == states.SUCCESS] == \
Expand Down
Loading

0 comments on commit 455e0a0

Please sign in to comment.