diff --git a/celery/app/task.py b/celery/app/task.py index c019be2d330..b33d9a87125 100644 --- a/celery/app/task.py +++ b/celery/app/task.py @@ -778,6 +778,21 @@ def update_state(self, task_id=None, state=None, meta=None): task_id = self.request.id self.backend.store_result(task_id, meta, state) + def on_success(self, retval, task_id, args, kwargs): + """Success handler. + + Run by the worker if the task executes successfully. + + :param retval: The return value of the task. + :param task_id: Unique id of the executed task. + :param args: Original arguments for the executed task. + :param kwargs: Original keyword arguments for the executed task. + + The return value of this handler is ignored. + + """ + pass + def on_retry(self, exc, task_id, args, kwargs, einfo): """Retry handler. @@ -815,6 +830,24 @@ def on_failure(self, exc, task_id, args, kwargs, einfo): """ pass + def after_return(self, status, retval, task_id, args, kwargs, einfo): + """Handler called after the task returns. + + :param status: Current task state. + :param retval: Task return value/exception. + :param task_id: Unique id of the task. + :param args: Original arguments for the task that failed. + :param kwargs: Original keyword arguments for the task + that failed. + + :keyword einfo: :class:`~celery.datastructures.ExceptionInfo` + instance, containing the traceback (if any). + + The return value of this handler is ignored. + + """ + pass + def send_error_email(self, context, exc, **kwargs): if self.send_error_emails and not self.disable_error_emails: self.ErrorMail(self, **kwargs).send(context, exc) diff --git a/celery/task/trace.py b/celery/task/trace.py index aa3c0b867e0..356d3feb556 100644 --- a/celery/task/trace.py +++ b/celery/task/trace.py @@ -65,10 +65,10 @@ def mro_lookup(cls, attr, stop=()): return node -def defines_custom_call(task): +def task_has_custom(task, attr): """Returns true if the task or one of its bases - defines __call__ (excluding the one in BaseTask).""" - return mro_lookup(task.__class__, "__call__", stop=(BaseTask, object)) + defines ``attr`` (excluding the one in BaseTask).""" + return mro_lookup(task.__class__, attr, stop=(BaseTask, object)) class TraceInfo(object): @@ -157,7 +157,7 @@ def build_tracer(name, task, loader=None, hostname=None, store_errors=True, # If the task doesn't define a custom __call__ method # we optimize it away by simply calling the run method directly, # saving the extra method call and a line less in the stack trace. - fun = task if defines_custom_call(task) else task.run + fun = task if task_has_custom(task, "__call__") else task.run loader = loader or current_app.loader backend = task.backend @@ -170,8 +170,12 @@ def build_tracer(name, task, loader=None, hostname=None, store_errors=True, loader_task_init = loader.on_task_init loader_cleanup = loader.on_process_cleanup - task_on_success = getattr(task, "on_success", None) - task_after_return = getattr(task, "after_return", None) + task_on_success = None + task_after_return = None + if task_has_custom(task, "on_success"): + task_on_success = task.on_success + if task_has_custom(task, "after_return"): + task_after_return = task.after_return store_result = backend.store_result backend_cleanup = backend.process_cleanup diff --git a/celery/tests/bin/test_celeryd.py b/celery/tests/bin/test_celeryd.py index 46a802fa4a3..d3740cdbf92 100644 --- a/celery/tests/bin/test_celeryd.py +++ b/celery/tests/bin/test_celeryd.py @@ -19,7 +19,7 @@ from celery import current_app from celery.apps import worker as cd from celery.bin.celeryd import WorkerCommand, main as celeryd_main -from celery.exceptions import ImproperlyConfigured +from celery.exceptions import ImproperlyConfigured, SystemTerminate from celery.utils.log import ensure_process_aware_logger from celery.worker import state @@ -32,12 +32,17 @@ def disable_stdouts(fun): @wraps(fun) def disable(*args, **kwargs): - sys.stdout, sys.stderr = WhateverIO(), WhateverIO() + prev_out, prev_err = sys.stdout, sys.stderr + prev_rout, prev_rerr = sys.__stdout__, sys.__stderr__ + sys.stdout = sys.__stdout__ = WhateverIO() + sys.stderr = sys.__stderr__ = WhateverIO() try: return fun(*args, **kwargs) finally: - sys.stdout = sys.__stdout__ - sys.stderr = sys.__stderr__ + sys.stdout = prev_out + sys.stderr = prev_err + sys.__stdout__ = prev_rout + sys.__stderr__ = prev_rerr return disable @@ -58,6 +63,9 @@ class Worker(cd.Worker): class test_Worker(AppCase): Worker = Worker + def teardown(self): + self.app.conf.CELERY_INCLUDE = () + @disable_stdouts def test_queues_string(self): celery = Celery(set_as_current=False) @@ -402,19 +410,33 @@ class Signals(platforms.Signals): def __setitem__(self, sig, handler): next_handlers[sig] = handler - p, platforms.signals = platforms.signals, Signals() - try: - handlers["SIGINT"]("SIGINT", object()) - self.assertTrue(state.should_stop) - finally: - platforms.signals = p - state.should_stop = False + with patch("celery.apps.worker.active_thread_count") as c: + c.return_value = 3 + p, platforms.signals = platforms.signals, Signals() + try: + handlers["SIGINT"]("SIGINT", object()) + self.assertTrue(state.should_stop) + finally: + platforms.signals = p + state.should_stop = False - try: - next_handlers["SIGINT"]("SIGINT", object()) - self.assertTrue(state.should_terminate) - finally: - state.should_terminate = False + try: + next_handlers["SIGINT"]("SIGINT", object()) + self.assertTrue(state.should_terminate) + finally: + state.should_terminate = False + + with patch("celery.apps.worker.active_thread_count") as c: + c.return_value = 1 + p, platforms.signals = platforms.signals, Signals() + try: + with self.assertRaises(SystemExit): + handlers["SIGINT"]("SIGINT", object()) + finally: + platforms.signals = p + + with self.assertRaises(SystemTerminate): + next_handlers["SIGINT"]("SIGINT", object()) @disable_stdouts def test_worker_int_handler_only_stop_MainProcess(self): @@ -424,14 +446,27 @@ def test_worker_int_handler_only_stop_MainProcess(self): raise SkipTest("only relevant for multiprocessing") process = current_process() name, process.name = process.name, "OtherProcess" - try: - worker = self._Worker() - handlers = self.psig(cd.install_worker_int_handler, worker) - handlers["SIGINT"]("SIGINT", object()) - self.assertTrue(state.should_stop) - finally: - process.name = name - state.should_stop = False + with patch("celery.apps.worker.active_thread_count") as c: + c.return_value = 3 + try: + worker = self._Worker() + handlers = self.psig(cd.install_worker_int_handler, worker) + handlers["SIGINT"]("SIGINT", object()) + self.assertTrue(state.should_stop) + finally: + process.name = name + state.should_stop = False + + with patch("celery.apps.worker.active_thread_count") as c: + c.return_value = 1 + try: + worker = self._Worker() + handlers = self.psig(cd.install_worker_int_handler, worker) + with self.assertRaises(SystemExit): + handlers["SIGINT"]("SIGINT", object()) + finally: + process.name = name + state.should_stop = False @disable_stdouts def test_install_HUP_not_supported_handler(self): @@ -448,25 +483,49 @@ def test_worker_term_hard_handler_only_stop_MainProcess(self): process = current_process() name, process.name = process.name, "OtherProcess" try: + with patch("celery.apps.worker.active_thread_count") as c: + c.return_value = 3 + worker = self._Worker() + handlers = self.psig( + cd.install_worker_term_hard_handler, worker) + try: + handlers["SIGQUIT"]("SIGQUIT", object()) + self.assertTrue(state.should_terminate) + finally: + state.should_terminate = False + with patch("celery.apps.worker.active_thread_count") as c: + c.return_value = 1 + worker = self._Worker() + handlers = self.psig( + cd.install_worker_term_hard_handler, worker) + with self.assertRaises(SystemTerminate): + handlers["SIGQUIT"]("SIGQUIT", object()) + finally: + process.name = name + + @disable_stdouts + def test_worker_term_handler_when_threads(self): + with patch("celery.apps.worker.active_thread_count") as c: + c.return_value = 3 worker = self._Worker() - handlers = self.psig(cd.install_worker_term_hard_handler, worker) + handlers = self.psig(cd.install_worker_term_handler, worker) try: - handlers["SIGQUIT"]("SIGQUIT", object()) - self.assertTrue(state.should_terminate) + handlers["SIGTERM"]("SIGTERM", object()) + self.assertTrue(state.should_stop) finally: - state.should_terminate = False - finally: - process.name = name + state.should_stop = False @disable_stdouts - def test_worker_term_handler(self): - worker = self._Worker() - handlers = self.psig(cd.install_worker_term_handler, worker) - try: - handlers["SIGTERM"]("SIGTERM", object()) - self.assertTrue(state.should_stop) - finally: - state.should_stop = False + def test_worker_term_handler_when_single_thread(self): + with patch("celery.apps.worker.active_thread_count") as c: + c.return_value = 1 + worker = self._Worker() + handlers = self.psig(cd.install_worker_term_handler, worker) + try: + with self.assertRaises(SystemExit): + handlers["SIGTERM"]("SIGTERM", object()) + finally: + state.should_stop = False @patch("sys.__stderr__") def test_worker_cry_handler(self, stderr): @@ -490,10 +549,18 @@ def test_worker_term_handler_only_stop_MainProcess(self): process = current_process() name, process.name = process.name, "OtherProcess" try: - worker = self._Worker() - handlers = self.psig(cd.install_worker_term_handler, worker) - handlers["SIGTERM"]("SIGTERM", object()) - self.assertTrue(state.should_stop) + with patch("celery.apps.worker.active_thread_count") as c: + c.return_value = 3 + worker = self._Worker() + handlers = self.psig(cd.install_worker_term_handler, worker) + handlers["SIGTERM"]("SIGTERM", object()) + self.assertTrue(state.should_stop) + with patch("celery.apps.worker.active_thread_count") as c: + c.return_value = 1 + worker = self._Worker() + handlers = self.psig(cd.install_worker_term_handler, worker) + with self.assertRaises(SystemExit): + handlers["SIGTERM"]("SIGTERM", object()) finally: process.name = name state.should_stop = False @@ -521,11 +588,22 @@ def _execv(*args): state.should_stop = False @disable_stdouts - def test_worker_term_hard_handler(self): - worker = self._Worker() - handlers = self.psig(cd.install_worker_term_hard_handler, worker) - try: - handlers["SIGQUIT"]("SIGQUIT", object()) - self.assertTrue(state.should_terminate) - finally: - state.should_terminate = False + def test_worker_term_hard_handler_when_threaded(self): + with patch("celery.apps.worker.active_thread_count") as c: + c.return_value = 3 + worker = self._Worker() + handlers = self.psig(cd.install_worker_term_hard_handler, worker) + try: + handlers["SIGQUIT"]("SIGQUIT", object()) + self.assertTrue(state.should_terminate) + finally: + state.should_terminate = False + + @disable_stdouts + def test_worker_term_hard_handler_when_single_threaded(self): + with patch("celery.apps.worker.active_thread_count") as c: + c.return_value = 1 + worker = self._Worker() + handlers = self.psig(cd.install_worker_term_hard_handler, worker) + with self.assertRaises(SystemTerminate): + handlers["SIGQUIT"]("SIGQUIT", object()) diff --git a/celery/tests/utilities/test_timer2.py b/celery/tests/utilities/test_timer2.py index 01487a8063e..0e64e2cba2c 100644 --- a/celery/tests/utilities/test_timer2.py +++ b/celery/tests/utilities/test_timer2.py @@ -68,12 +68,12 @@ class test_Timer(Case): @skip_if_quick def test_enter_after(self): t = timer2.Timer() - done = [False] + try: + done = [False] - def set_done(): - done[0] = True + def set_done(): + done[0] = True - try: t.apply_after(300, set_done) while not done[0]: time.sleep(0.1) @@ -88,25 +88,29 @@ def test_exit_after(self): def test_apply_interval(self): t = timer2.Timer() - t.schedule.enter_after = Mock() - - myfun = Mock() - t.apply_interval(30, myfun) - - self.assertEqual(t.schedule.enter_after.call_count, 1) - args1, _ = t.schedule.enter_after.call_args_list[0] - msec1, tref1, _ = args1 - self.assertEqual(msec1, 30) - tref1() - - self.assertEqual(t.schedule.enter_after.call_count, 2) - args2, _ = t.schedule.enter_after.call_args_list[1] - msec2, tref2, _ = args2 - self.assertEqual(msec2, 30) - tref2.cancelled = True - tref2() - - self.assertEqual(t.schedule.enter_after.call_count, 2) + try: + t.schedule.enter_after = Mock() + + myfun = Mock() + myfun.__name__ = "myfun" + t.apply_interval(30, myfun) + + self.assertEqual(t.schedule.enter_after.call_count, 1) + args1, _ = t.schedule.enter_after.call_args_list[0] + msec1, tref1, _ = args1 + self.assertEqual(msec1, 30) + tref1() + + self.assertEqual(t.schedule.enter_after.call_count, 2) + args2, _ = t.schedule.enter_after.call_args_list[1] + msec2, tref2, _ = args2 + self.assertEqual(msec2, 30) + tref2.cancelled = True + tref2() + + self.assertEqual(t.schedule.enter_after.call_count, 2) + finally: + t.stop() @patch("celery.utils.timer2.logger") def test_apply_entry_error_handled(self, logger): diff --git a/celery/utils/timer2.py b/celery/utils/timer2.py index f206d8c3b3d..d8bcc7b2cfc 100644 --- a/celery/utils/timer2.py +++ b/celery/utils/timer2.py @@ -33,6 +33,7 @@ __docformat__ = "restructuredtext" DEFAULT_MAX_INTERVAL = 2 +TIMER_DEBUG = os.environ.get("TIMER_DEBUG") logger = get_logger("timer2") @@ -215,6 +216,13 @@ class Timer(Thread): on_tick = None _timer_count = count(1).next + if TIMER_DEBUG: + def start(self, *args, **kwargs): + import traceback + print("TIMER START") + traceback.print_stack() + super(Timer, self).start(*args, **kwargs) + def __init__(self, schedule=None, on_error=None, on_tick=None, max_interval=None, **kwargs): self.schedule = schedule or self.Schedule(on_error=on_error, diff --git a/celery/worker/job.py b/celery/worker/job.py index 01c845c804d..d4ad9a9f39f 100644 --- a/celery/worker/job.py +++ b/celery/worker/job.py @@ -419,7 +419,7 @@ def _log_error(self, einfo): "hostname": self.hostname, "internal": internal}}) - self.task.send_error_email(context, exc_info.exception) + self.task.send_error_email(context, einfo.exception) def acknowledge(self): """Acknowledge task."""