From 455e0a0e86679eaaba9f0da533066627b1d79296 Mon Sep 17 00:00:00 2001 From: maybe-sybr <58414429+maybe-sybr@users.noreply.github.com> Date: Sun, 19 Jul 2020 14:51:18 +1000 Subject: [PATCH] Preserve order of group results with Redis result backend (#6218) * Preserve order of return values from groups Fixes #3781. * Update for zadd arguments changed in redis-py 3 * Use more explicit loop variable name * Handle group_index not set * Use zrange instead of zrangebyscore * test: Fix Redis sorted set mocks in backend tests * test: Make canvas integration tests use `zrange()` The test suite still uses `lrange()` and `rpush()` to implement its `redis-echo` task chain integration tests, but these are unrelated to the handling of group results and remain unchanged. * test: Add unit tests for `group_index` handling * fix: Add `group_index` to `Context`, chord uplift * test: Sanity check `Request.group_index` property This adds a test to make sure the property exists and also changes the property to use the private `_request_dict` rather than the public property. Co-authored-by: Leo Singer --- celery/app/amqp.py | 6 ++-- celery/app/base.py | 5 +-- celery/app/task.py | 4 +++ celery/backends/redis.py | 11 +++--- celery/canvas.py | 27 ++++++++++----- celery/utils/abstract.py | 3 +- celery/worker/request.py | 5 +++ t/integration/test_canvas.py | 4 +-- t/unit/backends/test_redis.py | 65 +++++++++++++++++++++-------------- t/unit/tasks/test_canvas.py | 17 +++++++++ t/unit/worker/test_request.py | 5 +++ 11 files changed, 107 insertions(+), 45 deletions(-) diff --git a/celery/app/amqp.py b/celery/app/amqp.py index 2bf8c1d8de7..537ebcf8166 100644 --- a/celery/app/amqp.py +++ b/celery/app/amqp.py @@ -305,7 +305,7 @@ def TaskConsumer(self, channel, queues=None, accept=None, **kw): ) def as_task_v2(self, task_id, name, args=None, kwargs=None, - countdown=None, eta=None, group_id=None, + countdown=None, eta=None, group_id=None, group_index=None, expires=None, retries=0, chord=None, callbacks=None, errbacks=None, reply_to=None, time_limit=None, soft_time_limit=None, @@ -363,6 +363,7 @@ def as_task_v2(self, task_id, name, args=None, kwargs=None, 'eta': eta, 'expires': expires, 'group': group_id, + 'group_index': group_index, 'retries': retries, 'timelimit': [time_limit, soft_time_limit], 'root_id': root_id, @@ -397,7 +398,7 @@ def as_task_v2(self, task_id, name, args=None, kwargs=None, ) def as_task_v1(self, task_id, name, args=None, kwargs=None, - countdown=None, eta=None, group_id=None, + countdown=None, eta=None, group_id=None, group_index=None, expires=None, retries=0, chord=None, callbacks=None, errbacks=None, reply_to=None, time_limit=None, soft_time_limit=None, @@ -442,6 +443,7 @@ def as_task_v1(self, task_id, name, args=None, kwargs=None, 'args': args, 'kwargs': kwargs, 'group': group_id, + 'group_index': group_index, 'retries': retries, 'eta': eta, 'expires': expires, diff --git a/celery/app/base.py b/celery/app/base.py index b04dd9e2435..3ced1af7a34 100644 --- a/celery/app/base.py +++ b/celery/app/base.py @@ -680,7 +680,8 @@ def send_task(self, name, args=None, kwargs=None, countdown=None, eta=None, task_id=None, producer=None, connection=None, router=None, result_cls=None, expires=None, publisher=None, link=None, link_error=None, - add_to_parent=True, group_id=None, retries=0, chord=None, + add_to_parent=True, group_id=None, group_index=None, + retries=0, chord=None, reply_to=None, time_limit=None, soft_time_limit=None, root_id=None, parent_id=None, route_name=None, shadow=None, chain=None, task_type=None, **options): @@ -720,7 +721,7 @@ def send_task(self, name, args=None, kwargs=None, countdown=None, parent.request.delivery_info.get('priority')) message = amqp.create_task_message( - task_id, name, args, kwargs, countdown, eta, group_id, + task_id, name, args, kwargs, countdown, eta, group_id, group_index, expires, retries, chord, maybe_list(link), maybe_list(link_error), reply_to or self.oid, time_limit, soft_time_limit, diff --git a/celery/app/task.py b/celery/app/task.py index ffb6d83e110..073b41c3091 100644 --- a/celery/app/task.py +++ b/celery/app/task.py @@ -83,6 +83,7 @@ class Context(object): correlation_id = None taskset = None # compat alias to group group = None + group_index = None chord = None chain = None utc = None @@ -116,6 +117,7 @@ def as_execution_options(self): 'root_id': self.root_id, 'parent_id': self.parent_id, 'group_id': self.group, + 'group_index': self.group_index, 'chord': self.chord, 'chain': self.chain, 'link': self.callbacks, @@ -891,6 +893,7 @@ def replace(self, sig): sig.set( chord=chord, group_id=self.request.group, + group_index=self.request.group_index, root_id=self.request.root_id, ) sig.freeze(self.request.id) @@ -917,6 +920,7 @@ def add_to_chord(self, sig, lazy=False): raise ValueError('Current task is not member of any chord') sig.set( group_id=self.request.group, + group_index=self.request.group_index, chord=self.request.chord, root_id=self.request.root_id, ) diff --git a/celery/backends/redis.py b/celery/backends/redis.py index aec18284780..9c635ccde0c 100644 --- a/celery/backends/redis.py +++ b/celery/backends/redis.py @@ -413,9 +413,11 @@ def apply_chord(self, header_result, body, **kwargs): def on_chord_part_return(self, request, state, result, propagate=None, **kwargs): app = self.app - tid, gid = request.id, request.group + tid, gid, group_index = request.id, request.group, request.group_index if not gid or not tid: return + if group_index is None: + group_index = '+inf' client = self.client jkey = self.get_key_for_group(gid, '.j') @@ -423,8 +425,9 @@ def on_chord_part_return(self, request, state, result, result = self.encode_result(result, state) with client.pipeline() as pipe: pipeline = pipe \ - .rpush(jkey, self.encode([1, tid, state, result])) \ - .llen(jkey) \ + .zadd(jkey, + {self.encode([1, tid, state, result]): group_index}) \ + .zcount(jkey, '-inf', '+inf') \ .get(tkey) if self.expires is not None: @@ -443,7 +446,7 @@ def on_chord_part_return(self, request, state, result, decode, unpack = self.decode, self._unpack_chord_result with client.pipeline() as pipe: resl, = pipe \ - .lrange(jkey, 0, total) \ + .zrange(jkey, 0, -1) \ .execute() try: callback.delay([unpack(tup, decode) for tup in resl]) diff --git a/celery/canvas.py b/celery/canvas.py index 6a060e08806..cb4ac1ab76d 100644 --- a/celery/canvas.py +++ b/celery/canvas.py @@ -276,7 +276,7 @@ def clone(self, args=None, kwargs=None, **opts): partial = clone def freeze(self, _id=None, group_id=None, chord=None, - root_id=None, parent_id=None): + root_id=None, parent_id=None, group_index=None): """Finalize the signature by adding a concrete task id. The task won't be called and you shouldn't call the signature @@ -303,6 +303,8 @@ def freeze(self, _id=None, group_id=None, chord=None, opts['group_id'] = group_id if chord: opts['chord'] = chord + if group_index is not None: + opts['group_index'] = group_index # pylint: disable=too-many-function-args # Borks on this, as it's a property. return self.AsyncResult(tid) @@ -674,19 +676,21 @@ def run(self, args=None, kwargs=None, group_id=None, chord=None, return results[0] def freeze(self, _id=None, group_id=None, chord=None, - root_id=None, parent_id=None): + root_id=None, parent_id=None, group_index=None): # pylint: disable=redefined-outer-name # XXX chord is also a class in outer scope. _, results = self._frozen = self.prepare_steps( self.args, self.kwargs, self.tasks, root_id, parent_id, None, self.app, _id, group_id, chord, clone=False, + group_index=group_index, ) return results[0] def prepare_steps(self, args, kwargs, tasks, root_id=None, parent_id=None, link_error=None, app=None, last_task_id=None, group_id=None, chord_body=None, - clone=True, from_dict=Signature.from_dict): + clone=True, from_dict=Signature.from_dict, + group_index=None): app = app or self.app # use chain message field for protocol 2 and later. # this avoids pickle blowing the stack on the recursion @@ -763,6 +767,7 @@ def prepare_steps(self, args, kwargs, tasks, res = task.freeze( last_task_id, root_id=root_id, group_id=group_id, chord=chord_body, + group_index=group_index, ) else: res = task.freeze(root_id=root_id) @@ -1189,7 +1194,7 @@ def _freeze_gid(self, options): return options, group_id, options.get('root_id') def freeze(self, _id=None, group_id=None, chord=None, - root_id=None, parent_id=None): + root_id=None, parent_id=None, group_index=None): # pylint: disable=redefined-outer-name # XXX chord is also a class in outer scope. opts = self.options @@ -1201,6 +1206,8 @@ def freeze(self, _id=None, group_id=None, chord=None, opts['group_id'] = group_id if chord: opts['chord'] = chord + if group_index is not None: + opts['group_index'] = group_index root_id = opts.setdefault('root_id', root_id) parent_id = opts.setdefault('parent_id', parent_id) new_tasks = [] @@ -1221,6 +1228,7 @@ def _freeze_unroll(self, new_tasks, group_id, chord, root_id, parent_id): # pylint: disable=redefined-outer-name # XXX chord is also a class in outer scope. stack = deque(self.tasks) + group_index = 0 while stack: task = maybe_signature(stack.popleft(), app=self._app).clone() if isinstance(task, group): @@ -1229,7 +1237,9 @@ def _freeze_unroll(self, new_tasks, group_id, chord, root_id, parent_id): new_tasks.append(task) yield task.freeze(group_id=group_id, chord=chord, root_id=root_id, - parent_id=parent_id) + parent_id=parent_id, + group_index=group_index) + group_index += 1 def __repr__(self): if self.tasks: @@ -1308,17 +1318,16 @@ def __call__(self, body=None, **options): return self.apply_async((), {'body': body} if body else {}, **options) def freeze(self, _id=None, group_id=None, chord=None, - root_id=None, parent_id=None): + root_id=None, parent_id=None, group_index=None): # pylint: disable=redefined-outer-name # XXX chord is also a class in outer scope. if not isinstance(self.tasks, group): self.tasks = group(self.tasks, app=self.app) header_result = self.tasks.freeze( parent_id=parent_id, root_id=root_id, chord=self.body) - body_result = self.body.freeze( - _id, root_id=root_id, chord=chord, group_id=group_id) - + _id, root_id=root_id, chord=chord, group_id=group_id, + group_index=group_index) # we need to link the body result back to the group result, # but the body may actually be a chain, # so find the first result without a parent diff --git a/celery/utils/abstract.py b/celery/utils/abstract.py index 3dfb3d5e067..8465a2a5efd 100644 --- a/celery/utils/abstract.py +++ b/celery/utils/abstract.py @@ -118,7 +118,8 @@ def clone(self, args=None, kwargs=None): pass @abstractmethod - def freeze(self, id=None, group_id=None, chord=None, root_id=None): + def freeze(self, id=None, group_id=None, chord=None, root_id=None, + group_index=None): pass @abstractmethod diff --git a/celery/worker/request.py b/celery/worker/request.py index 73f7e227b7b..8f1b07cc548 100644 --- a/celery/worker/request.py +++ b/celery/worker/request.py @@ -626,6 +626,11 @@ def _context(self): request.update(**embed or {}) return Context(request) + @cached_property + def group_index(self): + # used by backend.on_chord_part_return to order return values in group + return self._request_dict.get('group_index') + def create_request_cls(base, task, pool, hostname, eventer, ref=ref, revoked_tasks=revoked_tasks, diff --git a/t/integration/test_canvas.py b/t/integration/test_canvas.py index da6c99294a7..2fe8ffbb384 100644 --- a/t/integration/test_canvas.py +++ b/t/integration/test_canvas.py @@ -618,7 +618,7 @@ def test_add_to_chord(self, manager): c = group([add_to_all_to_chord.s([1, 2, 3], 4)]) | identity.s() res = c() - assert res.get() == [0, 5, 6, 7] + assert sorted(res.get()) == [0, 5, 6, 7] @pytest.mark.flaky(reruns=5, reruns_delay=1, cause=is_retryable_exception) def test_add_chord_to_chord(self, manager): @@ -857,7 +857,7 @@ def test_chord_on_error(self, manager): j_key = backend.get_key_for_group(original_group_id, '.j') redis_connection = get_redis_connection() chord_results = [backend.decode(t) for t in - redis_connection.lrange(j_key, 0, 3)] + redis_connection.zrange(j_key, 0, 3)] # Validate group result assert [cr[3] for cr in chord_results if cr[2] == states.SUCCESS] == \ diff --git a/t/unit/backends/test_redis.py b/t/unit/backends/test_redis.py index bcb1800344b..8f088d445b5 100644 --- a/t/unit/backends/test_redis.py +++ b/t/unit/backends/test_redis.py @@ -114,21 +114,33 @@ def delete(self, key): def pipeline(self): return self.Pipeline(self) - def _get_list(self, key): - try: - return self.keyspace[key] - except KeyError: - l = self.keyspace[key] = [] - return l + def _get_sorted_set(self, key): + return self.keyspace.setdefault(key, []) + + def zadd(self, key, mapping): + # Store elements as 2-tuples with the score first so we can sort it + # once the new items have been inserted + fake_sorted_set = self._get_sorted_set(key) + fake_sorted_set.extend( + (score, value) for value, score in mapping.items() + ) + fake_sorted_set.sort() - def rpush(self, key, value): - self._get_list(key).append(value) + def zrange(self, key, start, stop): + # `stop` is inclusive in Redis so we use `stop + 1` unless that would + # cause us to move from negative (right-most) indicies to positive + stop = stop + 1 if stop != -1 else None + return [e[1] for e in self._get_sorted_set(key)[start:stop]] - def lrange(self, key, start, stop): - return self._get_list(key)[start:stop] + def zrangebyscore(self, key, min_, max_): + return [ + e[1] for e in self._get_sorted_set(key) + if (min_ == "-inf" or e[0] >= min_) and + (max_ == "+inf" or e[1] <= max_) + ] - def llen(self, key): - return len(self.keyspace.get(key) or []) + def zcount(self, key, min_, max_): + return len(self.zrangebyscore(key, min_, max_)) class Sentinel(mock.MockCallbacks): @@ -540,7 +552,7 @@ def test_unpack_chord_result(self): def test_on_chord_part_return_no_gid_or_tid(self): request = Mock(name='request') - request.id = request.group = None + request.id = request.group = request.group_index = None assert self.b.on_chord_part_return(request, 'SUCCESS', 10) is None def test_ConnectionPool(self): @@ -580,7 +592,7 @@ def test_set_no_expire(self): self.b.expires = None self.b._set_with_state('foo', 'bar', states.SUCCESS) - def create_task(self): + def create_task(self, i): tid = uuid() task = Mock(name='task-{0}'.format(tid)) task.name = 'foobarbaz' @@ -589,17 +601,19 @@ def create_task(self): task.request.id = tid task.request.chord['chord_size'] = 10 task.request.group = 'group_id' + task.request.group_index = i return task @patch('celery.result.GroupResult.restore') def test_on_chord_part_return(self, restore): - tasks = [self.create_task() for i in range(10)] + tasks = [self.create_task(i) for i in range(10)] + random.shuffle(tasks) for i in range(10): self.b.on_chord_part_return(tasks[i].request, states.SUCCESS, i) - assert self.b.client.rpush.call_count - self.b.client.rpush.reset_mock() - assert self.b.client.lrange.call_count + assert self.b.client.zadd.call_count + self.b.client.zadd.reset_mock() + assert self.b.client.zrangebyscore.call_count jkey = self.b.get_key_for_group('group_id', '.j') tkey = self.b.get_key_for_group('group_id', '.t') self.b.client.delete.assert_has_calls([call(jkey), call(tkey)]) @@ -611,13 +625,13 @@ def test_on_chord_part_return(self, restore): def test_on_chord_part_return_no_expiry(self, restore): old_expires = self.b.expires self.b.expires = None - tasks = [self.create_task() for i in range(10)] + tasks = [self.create_task(i) for i in range(10)] for i in range(10): self.b.on_chord_part_return(tasks[i].request, states.SUCCESS, i) - assert self.b.client.rpush.call_count - self.b.client.rpush.reset_mock() - assert self.b.client.lrange.call_count + assert self.b.client.zadd.call_count + self.b.client.zadd.reset_mock() + assert self.b.client.zrangebyscore.call_count jkey = self.b.get_key_for_group('group_id', '.j') tkey = self.b.get_key_for_group('group_id', '.t') self.b.client.delete.assert_has_calls([call(jkey), call(tkey)]) @@ -645,7 +659,7 @@ def test_on_chord_part_return__ChordError(self): with self.chord_context(1) as (_, request, callback): self.b.client.pipeline = ContextMock() raise_on_second_call(self.b.client.pipeline, ChordError()) - self.b.client.pipeline.return_value.rpush().llen().get().expire( + self.b.client.pipeline.return_value.zadd().zcount().get().expire( ).expire().execute.return_value = (1, 1, 0, 4, 5) task = self.app._tasks['add'] = Mock(name='add_task') self.b.on_chord_part_return(request, states.SUCCESS, 10) @@ -657,7 +671,7 @@ def test_on_chord_part_return__other_error(self): with self.chord_context(1) as (_, request, callback): self.b.client.pipeline = ContextMock() raise_on_second_call(self.b.client.pipeline, RuntimeError()) - self.b.client.pipeline.return_value.rpush().llen().get().expire( + self.b.client.pipeline.return_value.zadd().zcount().get().expire( ).expire().execute.return_value = (1, 1, 0, 4, 5) task = self.app._tasks['add'] = Mock(name='add_task') self.b.on_chord_part_return(request, states.SUCCESS, 10) @@ -668,10 +682,11 @@ def test_on_chord_part_return__other_error(self): @contextmanager def chord_context(self, size=1): with patch('celery.backends.redis.maybe_signature') as ms: - tasks = [self.create_task() for i in range(size)] + tasks = [self.create_task(i) for i in range(size)] request = Mock(name='request') request.id = 'id1' request.group = 'gid1' + request.group_index = None callback = ms.return_value = Signature('add') callback.id = 'id1' callback['chord_size'] = size diff --git a/t/unit/tasks/test_canvas.py b/t/unit/tasks/test_canvas.py index b7224bda2e0..967fd284df2 100644 --- a/t/unit/tasks/test_canvas.py +++ b/t/unit/tasks/test_canvas.py @@ -776,6 +776,23 @@ def test_repr(self): x.kwargs['body'] = None assert 'without body' in repr(x) + def test_freeze_tasks_body_is_group(self): + # Confirm that `group index` is passed from a chord to elements of its + # body when the chord itself is encapsulated in a group + body_elem = self.add.s() + chord_body = group([body_elem]) + chord_obj = chord(self.add.s(), body=chord_body) + top_group = group([chord_obj]) + # We expect the body to be the signature we passed in before we freeze + (embedded_body_elem, ) = chord_obj.body.tasks + assert embedded_body_elem is body_elem + assert embedded_body_elem.options == dict() + # When we freeze the chord, its body will be clones and options set + top_group.freeze() + (embedded_body_elem, ) = chord_obj.body.tasks + assert embedded_body_elem is not body_elem + assert embedded_body_elem.options["group_index"] == 0 # 0th task + def test_freeze_tasks_is_not_group(self): x = chord([self.add.s(2, 2)], body=self.add.s(), app=self.app) x.freeze() diff --git a/t/unit/worker/test_request.py b/t/unit/worker/test_request.py index b002197ebeb..2f0d0cac2cb 100644 --- a/t/unit/worker/test_request.py +++ b/t/unit/worker/test_request.py @@ -1067,6 +1067,11 @@ def test_group(self): job = self.xRequest(id=uuid(), group=gid) assert job.group == gid + def test_group_index(self): + group_index = 42 + job = self.xRequest(id=uuid(), group_index=group_index) + assert job.group_index == group_index + class test_create_request_class(RequestCase):