Skip to content

Commit

Permalink
Tests passing
Browse files Browse the repository at this point in the history
  • Loading branch information
ask committed May 21, 2012
1 parent 632d80f commit 73ea354
Show file tree
Hide file tree
Showing 6 changed files with 207 additions and 80 deletions.
33 changes: 33 additions & 0 deletions celery/app/task.py
Original file line number Diff line number Diff line change
Expand Up @@ -778,6 +778,21 @@ def update_state(self, task_id=None, state=None, meta=None):
task_id = self.request.id
self.backend.store_result(task_id, meta, state)

def on_success(self, retval, task_id, args, kwargs):
"""Success handler.
Run by the worker if the task executes successfully.
:param retval: The return value of the task.
:param task_id: Unique id of the executed task.
:param args: Original arguments for the executed task.
:param kwargs: Original keyword arguments for the executed task.
The return value of this handler is ignored.
"""
pass

def on_retry(self, exc, task_id, args, kwargs, einfo):
"""Retry handler.
Expand Down Expand Up @@ -815,6 +830,24 @@ def on_failure(self, exc, task_id, args, kwargs, einfo):
"""
pass

def after_return(self, status, retval, task_id, args, kwargs, einfo):
"""Handler called after the task returns.
:param status: Current task state.
:param retval: Task return value/exception.
:param task_id: Unique id of the task.
:param args: Original arguments for the task that failed.
:param kwargs: Original keyword arguments for the task
that failed.
:keyword einfo: :class:`~celery.datastructures.ExceptionInfo`
instance, containing the traceback (if any).
The return value of this handler is ignored.
"""
pass

def send_error_email(self, context, exc, **kwargs):
if self.send_error_emails and not self.disable_error_emails:
self.ErrorMail(self, **kwargs).send(context, exc)
Expand Down
16 changes: 10 additions & 6 deletions celery/task/trace.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,10 +65,10 @@ def mro_lookup(cls, attr, stop=()):
return node


def defines_custom_call(task):
def task_has_custom(task, attr):
"""Returns true if the task or one of its bases
defines __call__ (excluding the one in BaseTask)."""
return mro_lookup(task.__class__, "__call__", stop=(BaseTask, object))
defines ``attr`` (excluding the one in BaseTask)."""
return mro_lookup(task.__class__, attr, stop=(BaseTask, object))


class TraceInfo(object):
Expand Down Expand Up @@ -157,7 +157,7 @@ def build_tracer(name, task, loader=None, hostname=None, store_errors=True,
# If the task doesn't define a custom __call__ method
# we optimize it away by simply calling the run method directly,
# saving the extra method call and a line less in the stack trace.
fun = task if defines_custom_call(task) else task.run
fun = task if task_has_custom(task, "__call__") else task.run

loader = loader or current_app.loader
backend = task.backend
Expand All @@ -170,8 +170,12 @@ def build_tracer(name, task, loader=None, hostname=None, store_errors=True,
loader_task_init = loader.on_task_init
loader_cleanup = loader.on_process_cleanup

task_on_success = getattr(task, "on_success", None)
task_after_return = getattr(task, "after_return", None)
task_on_success = None
task_after_return = None
if task_has_custom(task, "on_success"):
task_on_success = task.on_success
if task_has_custom(task, "after_return"):
task_after_return = task.after_return

store_result = backend.store_result
backend_cleanup = backend.process_cleanup
Expand Down
178 changes: 128 additions & 50 deletions celery/tests/bin/test_celeryd.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@
from celery import current_app
from celery.apps import worker as cd
from celery.bin.celeryd import WorkerCommand, main as celeryd_main
from celery.exceptions import ImproperlyConfigured
from celery.exceptions import ImproperlyConfigured, SystemTerminate
from celery.utils.log import ensure_process_aware_logger
from celery.worker import state

Expand All @@ -32,12 +32,17 @@ def disable_stdouts(fun):

@wraps(fun)
def disable(*args, **kwargs):
sys.stdout, sys.stderr = WhateverIO(), WhateverIO()
prev_out, prev_err = sys.stdout, sys.stderr
prev_rout, prev_rerr = sys.__stdout__, sys.__stderr__
sys.stdout = sys.__stdout__ = WhateverIO()
sys.stderr = sys.__stderr__ = WhateverIO()
try:
return fun(*args, **kwargs)
finally:
sys.stdout = sys.__stdout__
sys.stderr = sys.__stderr__
sys.stdout = prev_out
sys.stderr = prev_err
sys.__stdout__ = prev_rout
sys.__stderr__ = prev_rerr

return disable

Expand All @@ -58,6 +63,9 @@ class Worker(cd.Worker):
class test_Worker(AppCase):
Worker = Worker

def teardown(self):
self.app.conf.CELERY_INCLUDE = ()

@disable_stdouts
def test_queues_string(self):
celery = Celery(set_as_current=False)
Expand Down Expand Up @@ -402,19 +410,33 @@ class Signals(platforms.Signals):
def __setitem__(self, sig, handler):
next_handlers[sig] = handler

p, platforms.signals = platforms.signals, Signals()
try:
handlers["SIGINT"]("SIGINT", object())
self.assertTrue(state.should_stop)
finally:
platforms.signals = p
state.should_stop = False
with patch("celery.apps.worker.active_thread_count") as c:
c.return_value = 3
p, platforms.signals = platforms.signals, Signals()
try:
handlers["SIGINT"]("SIGINT", object())
self.assertTrue(state.should_stop)
finally:
platforms.signals = p
state.should_stop = False

try:
next_handlers["SIGINT"]("SIGINT", object())
self.assertTrue(state.should_terminate)
finally:
state.should_terminate = False
try:
next_handlers["SIGINT"]("SIGINT", object())
self.assertTrue(state.should_terminate)
finally:
state.should_terminate = False

with patch("celery.apps.worker.active_thread_count") as c:
c.return_value = 1
p, platforms.signals = platforms.signals, Signals()
try:
with self.assertRaises(SystemExit):
handlers["SIGINT"]("SIGINT", object())
finally:
platforms.signals = p

with self.assertRaises(SystemTerminate):
next_handlers["SIGINT"]("SIGINT", object())

@disable_stdouts
def test_worker_int_handler_only_stop_MainProcess(self):
Expand All @@ -424,14 +446,27 @@ def test_worker_int_handler_only_stop_MainProcess(self):
raise SkipTest("only relevant for multiprocessing")
process = current_process()
name, process.name = process.name, "OtherProcess"
try:
worker = self._Worker()
handlers = self.psig(cd.install_worker_int_handler, worker)
handlers["SIGINT"]("SIGINT", object())
self.assertTrue(state.should_stop)
finally:
process.name = name
state.should_stop = False
with patch("celery.apps.worker.active_thread_count") as c:
c.return_value = 3
try:
worker = self._Worker()
handlers = self.psig(cd.install_worker_int_handler, worker)
handlers["SIGINT"]("SIGINT", object())
self.assertTrue(state.should_stop)
finally:
process.name = name
state.should_stop = False

with patch("celery.apps.worker.active_thread_count") as c:
c.return_value = 1
try:
worker = self._Worker()
handlers = self.psig(cd.install_worker_int_handler, worker)
with self.assertRaises(SystemExit):
handlers["SIGINT"]("SIGINT", object())
finally:
process.name = name
state.should_stop = False

@disable_stdouts
def test_install_HUP_not_supported_handler(self):
Expand All @@ -448,25 +483,49 @@ def test_worker_term_hard_handler_only_stop_MainProcess(self):
process = current_process()
name, process.name = process.name, "OtherProcess"
try:
with patch("celery.apps.worker.active_thread_count") as c:
c.return_value = 3
worker = self._Worker()
handlers = self.psig(
cd.install_worker_term_hard_handler, worker)
try:
handlers["SIGQUIT"]("SIGQUIT", object())
self.assertTrue(state.should_terminate)
finally:
state.should_terminate = False
with patch("celery.apps.worker.active_thread_count") as c:
c.return_value = 1
worker = self._Worker()
handlers = self.psig(
cd.install_worker_term_hard_handler, worker)
with self.assertRaises(SystemTerminate):
handlers["SIGQUIT"]("SIGQUIT", object())
finally:
process.name = name

@disable_stdouts
def test_worker_term_handler_when_threads(self):
with patch("celery.apps.worker.active_thread_count") as c:
c.return_value = 3
worker = self._Worker()
handlers = self.psig(cd.install_worker_term_hard_handler, worker)
handlers = self.psig(cd.install_worker_term_handler, worker)
try:
handlers["SIGQUIT"]("SIGQUIT", object())
self.assertTrue(state.should_terminate)
handlers["SIGTERM"]("SIGTERM", object())
self.assertTrue(state.should_stop)
finally:
state.should_terminate = False
finally:
process.name = name
state.should_stop = False

@disable_stdouts
def test_worker_term_handler(self):
worker = self._Worker()
handlers = self.psig(cd.install_worker_term_handler, worker)
try:
handlers["SIGTERM"]("SIGTERM", object())
self.assertTrue(state.should_stop)
finally:
state.should_stop = False
def test_worker_term_handler_when_single_thread(self):
with patch("celery.apps.worker.active_thread_count") as c:
c.return_value = 1
worker = self._Worker()
handlers = self.psig(cd.install_worker_term_handler, worker)
try:
with self.assertRaises(SystemExit):
handlers["SIGTERM"]("SIGTERM", object())
finally:
state.should_stop = False

@patch("sys.__stderr__")
def test_worker_cry_handler(self, stderr):
Expand All @@ -490,10 +549,18 @@ def test_worker_term_handler_only_stop_MainProcess(self):
process = current_process()
name, process.name = process.name, "OtherProcess"
try:
worker = self._Worker()
handlers = self.psig(cd.install_worker_term_handler, worker)
handlers["SIGTERM"]("SIGTERM", object())
self.assertTrue(state.should_stop)
with patch("celery.apps.worker.active_thread_count") as c:
c.return_value = 3
worker = self._Worker()
handlers = self.psig(cd.install_worker_term_handler, worker)
handlers["SIGTERM"]("SIGTERM", object())
self.assertTrue(state.should_stop)
with patch("celery.apps.worker.active_thread_count") as c:
c.return_value = 1
worker = self._Worker()
handlers = self.psig(cd.install_worker_term_handler, worker)
with self.assertRaises(SystemExit):
handlers["SIGTERM"]("SIGTERM", object())
finally:
process.name = name
state.should_stop = False
Expand Down Expand Up @@ -521,11 +588,22 @@ def _execv(*args):
state.should_stop = False

@disable_stdouts
def test_worker_term_hard_handler(self):
worker = self._Worker()
handlers = self.psig(cd.install_worker_term_hard_handler, worker)
try:
handlers["SIGQUIT"]("SIGQUIT", object())
self.assertTrue(state.should_terminate)
finally:
state.should_terminate = False
def test_worker_term_hard_handler_when_threaded(self):
with patch("celery.apps.worker.active_thread_count") as c:
c.return_value = 3
worker = self._Worker()
handlers = self.psig(cd.install_worker_term_hard_handler, worker)
try:
handlers["SIGQUIT"]("SIGQUIT", object())
self.assertTrue(state.should_terminate)
finally:
state.should_terminate = False

@disable_stdouts
def test_worker_term_hard_handler_when_single_threaded(self):
with patch("celery.apps.worker.active_thread_count") as c:
c.return_value = 1
worker = self._Worker()
handlers = self.psig(cd.install_worker_term_hard_handler, worker)
with self.assertRaises(SystemTerminate):
handlers["SIGQUIT"]("SIGQUIT", object())
50 changes: 27 additions & 23 deletions celery/tests/utilities/test_timer2.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,12 +68,12 @@ class test_Timer(Case):
@skip_if_quick
def test_enter_after(self):
t = timer2.Timer()
done = [False]
try:
done = [False]

def set_done():
done[0] = True
def set_done():
done[0] = True

try:
t.apply_after(300, set_done)
while not done[0]:
time.sleep(0.1)
Expand All @@ -88,25 +88,29 @@ def test_exit_after(self):

def test_apply_interval(self):
t = timer2.Timer()
t.schedule.enter_after = Mock()

myfun = Mock()
t.apply_interval(30, myfun)

self.assertEqual(t.schedule.enter_after.call_count, 1)
args1, _ = t.schedule.enter_after.call_args_list[0]
msec1, tref1, _ = args1
self.assertEqual(msec1, 30)
tref1()

self.assertEqual(t.schedule.enter_after.call_count, 2)
args2, _ = t.schedule.enter_after.call_args_list[1]
msec2, tref2, _ = args2
self.assertEqual(msec2, 30)
tref2.cancelled = True
tref2()

self.assertEqual(t.schedule.enter_after.call_count, 2)
try:
t.schedule.enter_after = Mock()

myfun = Mock()
myfun.__name__ = "myfun"
t.apply_interval(30, myfun)

self.assertEqual(t.schedule.enter_after.call_count, 1)
args1, _ = t.schedule.enter_after.call_args_list[0]
msec1, tref1, _ = args1
self.assertEqual(msec1, 30)
tref1()

self.assertEqual(t.schedule.enter_after.call_count, 2)
args2, _ = t.schedule.enter_after.call_args_list[1]
msec2, tref2, _ = args2
self.assertEqual(msec2, 30)
tref2.cancelled = True
tref2()

self.assertEqual(t.schedule.enter_after.call_count, 2)
finally:
t.stop()

@patch("celery.utils.timer2.logger")
def test_apply_entry_error_handled(self, logger):
Expand Down
Loading

0 comments on commit 73ea354

Please sign in to comment.