Skip to content

Commit

Permalink
Redis Result Consumer: unsubscribe on message success (celery#4666)
Browse files Browse the repository at this point in the history
* Add manager assertion which checks AsyncResult state

* Redis Result Consumer: unsubscribe on message success

- Use on_after_fork in consumer to reset PubSub and connection pool
  internal states.
- Improve Canvas integration test.
  • Loading branch information
georgepsarakis authored and Omer Katz committed Apr 28, 2018
1 parent 0c277bb commit a035680
Show file tree
Hide file tree
Showing 4 changed files with 130 additions and 12 deletions.
14 changes: 14 additions & 0 deletions celery/backends/redis.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,6 +58,20 @@ def __init__(self, *args, **kwargs):
self._decode_result = self.backend.decode_result
self.subscribed_to = set()

def on_after_fork(self):
self.backend.client.connection_pool.reset()
if self._pubsub is not None:
self._pubsub.close()
super(ResultConsumer, self).on_after_fork()

def _maybe_cancel_ready_task(self, meta):
if meta['status'] in states.READY_STATES:
self.cancel_for(meta['task_id'])

def on_state_change(self, meta, message):
super(ResultConsumer, self).on_state_change(meta, message)
self._maybe_cancel_ready_task(meta)

def start(self, initial_task_id, **kwargs):
self._pubsub = self.backend.client.pubsub(
ignore_subscribe_messages=True,
Expand Down
26 changes: 26 additions & 0 deletions celery/contrib/testing/manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@

from kombu.utils.functional import retry_over_time

from celery import states
from celery.exceptions import TimeoutError
from celery.five import items
from celery.result import ResultSet
Expand Down Expand Up @@ -145,6 +146,31 @@ def assert_received(self, ids, interval=0.5,
self.is_accepted, ids, interval=interval, desc=desc, **policy
)

def assert_result_tasks_in_progress_or_completed(
self,
async_results,
interval=0.5,
desc='waiting for tasks to be started or completed',
**policy
):
return self.assert_task_state_from_result(
self.is_result_task_in_progress,
async_results,
interval=interval, desc=desc, **policy
)

def assert_task_state_from_result(self, fun, results,
interval=0.5, **policy):
return self.wait_for(
partial(self.true_or_raise, fun, results, timeout=interval),
(Sentinel,), **policy
)

@staticmethod
def is_result_task_in_progress(results, **kwargs):
possible_states = (states.STARTED, states.SUCCESS)
return all(result.state in possible_states for result in results)

def assert_task_worker_state(self, fun, ids, interval=0.5, **policy):
return self.wait_for(
partial(self.true_or_raise, fun, ids, timeout=interval),
Expand Down
54 changes: 42 additions & 12 deletions t/integration/test_canvas.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
from __future__ import absolute_import, unicode_literals

from datetime import datetime, timedelta
from time import sleep

import pytest

Expand Down Expand Up @@ -257,23 +256,54 @@ def assert_ids(r, expected_value, expected_root_id, expected_parent_id):

class test_chord:

@staticmethod
def _get_active_redis_channels(client):
return client.execute_command('PUBSUB CHANNELS')

@flaky
def test_redis_subscribed_channels_leak(self, manager):
if not manager.app.conf.result_backend.startswith('redis'):
raise pytest.skip('Requires redis result backend.')

redis_client = get_redis_connection()
async_result = chord([add.s(5, 6), add.s(6, 7)])(delayed_sum.s())
for _ in range(TIMEOUT):
if async_result.state == 'STARTED':
break
sleep(0.2)
channels_before = \
len(redis_client.execute_command('PUBSUB CHANNELS'))
assert async_result.get(timeout=TIMEOUT) == 24
channels_after = \
len(redis_client.execute_command('PUBSUB CHANNELS'))
assert channels_after < channels_before

manager.app.backend.result_consumer.on_after_fork()
initial_channels = self._get_active_redis_channels(redis_client)
initial_channels_count = len(initial_channels)

total_chords = 10
async_results = [
chord([add.s(5, 6), add.s(6, 7)])(delayed_sum.s())
for _ in range(total_chords)
]

manager.assert_result_tasks_in_progress_or_completed(async_results)

channels_before = self._get_active_redis_channels(redis_client)
channels_before_count = len(channels_before)

assert set(channels_before) != set(initial_channels)
assert channels_before_count > initial_channels_count

# The total number of active Redis channels at this point
# is the number of chord header tasks multiplied by the
# total chord tasks, plus the initial channels
# (existing from previous tests).
chord_header_task_count = 2
assert channels_before_count == \
chord_header_task_count * total_chords + initial_channels_count

result_values = [
result.get(timeout=TIMEOUT)
for result in async_results
]
assert result_values == [24] * total_chords

channels_after = self._get_active_redis_channels(redis_client)
channels_after_count = len(channels_after)

assert channels_after_count == initial_channels_count
assert set(channels_after) == set(initial_channels)

@flaky
def test_replaced_nested_chord(self, manager):
Expand Down
48 changes: 48 additions & 0 deletions t/unit/backends/test_redis.py
Original file line number Diff line number Diff line change
Expand Up @@ -134,6 +134,54 @@ class sentinel(object):
Sentinel = Sentinel


class test_RedisResultConsumer:
def get_backend(self):
from celery.backends.redis import RedisBackend

class _RedisBackend(RedisBackend):
redis = redis

return _RedisBackend(app=self.app)

def get_consumer(self):
return self.get_backend().result_consumer

@patch('celery.backends.async.BaseResultConsumer.on_after_fork')
def test_on_after_fork(self, parent_method):
consumer = self.get_consumer()
consumer.start('none')
consumer.on_after_fork()
parent_method.assert_called_once()
consumer.backend.client.connection_pool.reset.assert_called_once()
consumer._pubsub.close.assert_called_once()
# PubSub instance not initialized - exception would be raised
# when calling .close()
consumer._pubsub = None
parent_method.reset_mock()
consumer.backend.client.connection_pool.reset.reset_mock()
consumer.on_after_fork()
parent_method.assert_called_once()
consumer.backend.client.connection_pool.reset.assert_called_once()

@patch('celery.backends.redis.ResultConsumer.cancel_for')
@patch('celery.backends.async.BaseResultConsumer.on_state_change')
def test_on_state_change(self, parent_method, cancel_for):
consumer = self.get_consumer()
meta = {'task_id': 'testing', 'status': states.SUCCESS}
message = 'hello'
consumer.on_state_change(meta, message)
parent_method.assert_called_once_with(meta, message)
cancel_for.assert_called_once_with(meta['task_id'])

# Does not call cancel_for for other states
meta = {'task_id': 'testing2', 'status': states.PENDING}
parent_method.reset_mock()
cancel_for.reset_mock()
consumer.on_state_change(meta, message)
parent_method.assert_called_once_with(meta, message)
cancel_for.assert_not_called()


class test_RedisBackend:
def get_backend(self):
from celery.backends.redis import RedisBackend
Expand Down

0 comments on commit a035680

Please sign in to comment.