Skip to content

Commit

Permalink
add_callback now takes *args, **kwargs.
Browse files Browse the repository at this point in the history
This reduces the need for functools.partial or lambda wrappers, and
works better with stack_context in some cases since binding the
arguments within IOLoop lets it see whether the function is already
wrapped.
  • Loading branch information
bdarnell committed Dec 8, 2012
1 parent 6cf1fa1 commit ea79e8a
Show file tree
Hide file tree
Showing 5 changed files with 77 additions and 12 deletions.
16 changes: 9 additions & 7 deletions tornado/ioloop.py
Original file line number Diff line number Diff line change
Expand Up @@ -320,7 +320,7 @@ def remove_timeout(self, timeout):
"""
raise NotImplementedError()

def add_callback(self, callback):
def add_callback(self, callback, *args, **kwargs):
"""Calls the given callback on the next I/O loop iteration.
It is safe to call this method from any thread at any time,
Expand All @@ -335,7 +335,7 @@ def add_callback(self, callback):
"""
raise NotImplementedError()

def add_callback_from_signal(self, callback):
def add_callback_from_signal(self, callback, *args, **kwargs):
"""Calls the given callback on the next I/O loop iteration.
Safe for use from a Python signal handler; should not be used
Expand Down Expand Up @@ -609,12 +609,13 @@ def remove_timeout(self, timeout):
# collection pass whenever there are too many dead timeouts.
timeout.callback = None

def add_callback(self, callback):
def add_callback(self, callback, *args, **kwargs):
with self._callback_lock:
if self._closing:
raise RuntimeError("IOLoop is closing")
list_empty = not self._callbacks
self._callbacks.append(stack_context.wrap(callback))
self._callbacks.append(functools.partial(
stack_context.wrap(callback), *args, **kwargs))
if list_empty and thread.get_ident() != self._thread_ident:
# If we're in the IOLoop's thread, we know it's not currently
# polling. If we're not, and we added the first callback to an
Expand All @@ -624,12 +625,12 @@ def add_callback(self, callback):
# avoid it when we can.
self._waker.wake()

def add_callback_from_signal(self, callback):
def add_callback_from_signal(self, callback, *args, **kwargs):
with stack_context.NullContext():
if thread.get_ident() != self._thread_ident:
# if the signal is handled on another thread, we can add
# it normally (modulo the NullContext)
self.add_callback(callback)
self.add_callback(callback, *args, **kwargs)
else:
# If we're on the IOLoop's thread, we cannot use
# the regular add_callback because it may deadlock on
Expand All @@ -639,7 +640,8 @@ def add_callback_from_signal(self, callback):
# _callback_lock block in IOLoop.start, we may modify
# either the old or new version of self._callbacks,
# but either way will work.
self._callbacks.append(stack_context.wrap(callback))
self._callbacks.append(functools.partial(
stack_context.wrap(callback), *args, **kwargs))


class _Timeout(object):
Expand Down
9 changes: 5 additions & 4 deletions tornado/platform/twisted.py
Original file line number Diff line number Diff line change
Expand Up @@ -458,8 +458,9 @@ def add_timeout(self, deadline, callback):
def remove_timeout(self, timeout):
timeout.cancel()

def add_callback(self, callback):
self.reactor.callFromThread(wrap(callback))
def add_callback(self, callback, *args, **kwargs):
self.reactor.callFromThread(functools.partial(wrap(callback),
*args, **kwargs))

def add_callback_from_signal(self, callback):
self.add_callback(callback)
def add_callback_from_signal(self, callback, *args, **kwargs):
self.add_callback(callback, *args, **kwargs)
61 changes: 60 additions & 1 deletion tornado/test/ioloop_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,12 +2,14 @@


from __future__ import absolute_import, division, with_statement
import contextlib
import datetime
import functools
import threading
import time

from tornado.ioloop import IOLoop
from tornado.stack_context import ExceptionStackContext
from tornado.stack_context import ExceptionStackContext, StackContext, wrap
from tornado.testing import AsyncTestCase, bind_unused_port
from tornado.test.util import unittest

Expand Down Expand Up @@ -111,6 +113,63 @@ def target():
self.assertEqual("IOLoop is closing", str(e))
break


class TestIOLoopAddCallback(AsyncTestCase):
def setUp(self):
super(TestIOLoopAddCallback, self).setUp()
self.active_contexts = []

def add_callback(self, callback, *args, **kwargs):
self.io_loop.add_callback(callback, *args, **kwargs)

@contextlib.contextmanager
def context(self, name):
self.active_contexts.append(name)
yield
self.assertEqual(self.active_contexts.pop(), name)

def test_pre_wrap(self):
# A pre-wrapped callback is run in the context in which it was
# wrapped, not when it was added to the IOLoop.
def f1():
self.assertIn('c1', self.active_contexts)
self.assertNotIn('c2', self.active_contexts)
self.stop()

with StackContext(functools.partial(self.context, 'c1')):
wrapped = wrap(f1)

with StackContext(functools.partial(self.context, 'c2')):
self.add_callback(wrapped)

self.wait()

def test_pre_wrap_with_args(self):
# Same as test_pre_wrap, but the function takes arguments.
# Implementation note: The function must not be wrapped in a
# functools.partial until after it has been passed through
# stack_context.wrap
def f1(foo, bar):
self.assertIn('c1', self.active_contexts)
self.assertNotIn('c2', self.active_contexts)
self.stop((foo, bar))

with StackContext(functools.partial(self.context, 'c1')):
wrapped = wrap(f1)

with StackContext(functools.partial(self.context, 'c2')):
self.add_callback(wrapped, 1, bar=2)

result = self.wait()
self.assertEqual(result, (1, 2))


class TestIOLoopAddCallbackFromSignal(TestIOLoopAddCallback):
# Repeat the add_callback tests using add_callback_from_signal
def add_callback(self, callback, *args, **kwargs):
self.io_loop.add_callback_from_signal(callback, *args, **kwargs)


class TestIOLoopFutures(AsyncTestCase):
def test_add_future_threads(self):
with futures.ThreadPoolExecutor(1) as pool:
Expand Down
1 change: 1 addition & 0 deletions tornado/test/stack_context_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -169,5 +169,6 @@ def f3():
self.io_loop.add_callback(f1)
self.wait()


if __name__ == '__main__':
unittest.main()
2 changes: 2 additions & 0 deletions website/sphinx/releases/next.rst
Original file line number Diff line number Diff line change
Expand Up @@ -194,3 +194,5 @@ In progress
that are passed to the ``get``/``post``/etc method. These attributes
are set before those methods are called, so they are available during
``prepare()``
* `IOLoop.add_callback` and `add_callback_from_signal` now take
``*args, **kwargs`` to pass along to the callback.

0 comments on commit ea79e8a

Please sign in to comment.