Skip to content

Commit

Permalink
util: add spawn_pinned (tokio-rs#3370)
Browse files Browse the repository at this point in the history
  • Loading branch information
AzureMarker authored Jan 27, 2022
1 parent 5af9e0d commit 257053e
Show file tree
Hide file tree
Showing 5 changed files with 506 additions and 1 deletion.
2 changes: 1 addition & 1 deletion tokio-util/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@ codec = []
time = ["tokio/time","slab"]
io = []
io-util = ["io", "tokio/rt", "tokio/io-util"]
rt = ["tokio/rt"]
rt = ["tokio/rt", "tokio/sync", "futures-util"]

__docs_rs = ["futures-util"]

Expand Down
1 change: 1 addition & 0 deletions tokio-util/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,7 @@ cfg_io! {

cfg_rt! {
pub mod context;
pub mod task;
}

cfg_time! {
Expand Down
4 changes: 4 additions & 0 deletions tokio-util/src/task/mod.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
//! Extra utilities for spawning tasks
mod spawn_pinned;
pub use spawn_pinned::LocalPoolHandle;
307 changes: 307 additions & 0 deletions tokio-util/src/task/spawn_pinned.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,307 @@
use futures_util::future::{AbortHandle, Abortable};
use std::fmt;
use std::fmt::{Debug, Formatter};
use std::future::Future;
use std::sync::atomic::{AtomicUsize, Ordering};
use std::sync::Arc;
use tokio::runtime::Builder;
use tokio::sync::mpsc::{unbounded_channel, UnboundedReceiver, UnboundedSender};
use tokio::sync::oneshot;
use tokio::task::{spawn_local, JoinHandle, LocalSet};

/// A handle to a local pool, used for spawning `!Send` tasks.
#[derive(Clone)]
pub struct LocalPoolHandle {
pool: Arc<LocalPool>,
}

impl LocalPoolHandle {
/// Create a new pool of threads to handle `!Send` tasks. Spawn tasks onto this
/// pool via [`LocalPoolHandle::spawn_pinned`].
///
/// # Panics
/// Panics if the pool size is less than one.
pub fn new(pool_size: usize) -> LocalPoolHandle {
assert!(pool_size > 0);

let workers = (0..pool_size)
.map(|_| LocalWorkerHandle::new_worker())
.collect();

let pool = Arc::new(LocalPool { workers });

LocalPoolHandle { pool }
}

/// Spawn a task onto a worker thread and pin it there so it can't be moved
/// off of the thread. Note that the future is not [`Send`], but the
/// [`FnOnce`] which creates it is.
///
/// # Examples
/// ```
/// use std::rc::Rc;
/// use tokio_util::task::LocalPoolHandle;
///
/// #[tokio::main]
/// async fn main() {
/// // Create the local pool
/// let pool = LocalPoolHandle::new(1);
///
/// // Spawn a !Send future onto the pool and await it
/// let output = pool
/// .spawn_pinned(|| {
/// // Rc is !Send + !Sync
/// let local_data = Rc::new("test");
///
/// // This future holds an Rc, so it is !Send
/// async move { local_data.to_string() }
/// })
/// .await
/// .unwrap();
///
/// assert_eq!(output, "test");
/// }
/// ```
pub fn spawn_pinned<F, Fut>(&self, create_task: F) -> JoinHandle<Fut::Output>
where
F: FnOnce() -> Fut,
F: Send + 'static,
Fut: Future + 'static,
Fut::Output: Send + 'static,
{
self.pool.spawn_pinned(create_task)
}
}

impl Debug for LocalPoolHandle {
fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result {
f.write_str("LocalPoolHandle")
}
}

struct LocalPool {
workers: Vec<LocalWorkerHandle>,
}

impl LocalPool {
/// Spawn a `?Send` future onto a worker
fn spawn_pinned<F, Fut>(&self, create_task: F) -> JoinHandle<Fut::Output>
where
F: FnOnce() -> Fut,
F: Send + 'static,
Fut: Future + 'static,
Fut::Output: Send + 'static,
{
let (sender, receiver) = oneshot::channel();

let (worker, job_guard) = self.find_and_incr_least_burdened_worker();
let worker_spawner = worker.spawner.clone();

// Spawn a future onto the worker's runtime so we can immediately return
// a join handle.
worker.runtime_handle.spawn(async move {
// Move the job guard into the task
let _job_guard = job_guard;

// Propagate aborts via Abortable/AbortHandle
let (abort_handle, abort_registration) = AbortHandle::new_pair();
let _abort_guard = AbortGuard(abort_handle);

// Inside the future we can't run spawn_local yet because we're not
// in the context of a LocalSet. We need to send create_task to the
// LocalSet task for spawning.
let spawn_task = Box::new(move || {
// Once we're in the LocalSet context we can call spawn_local
let join_handle =
spawn_local(
async move { Abortable::new(create_task(), abort_registration).await },
);

// Send the join handle back to the spawner. If sending fails,
// we assume the parent task was canceled, so cancel this task
// as well.
if let Err(join_handle) = sender.send(join_handle) {
join_handle.abort()
}
});

// Send the callback to the LocalSet task
if let Err(e) = worker_spawner.send(spawn_task) {
// Propagate the error as a panic in the join handle.
panic!("Failed to send job to worker: {}", e);
}

// Wait for the task's join handle
let join_handle = match receiver.await {
Ok(handle) => handle,
Err(e) => {
// We sent the task successfully, but failed to get its
// join handle... We assume something happened to the worker
// and the task was not spawned. Propagate the error as a
// panic in the join handle.
panic!("Worker failed to send join handle: {}", e);
}
};

// Wait for the task to complete
let join_result = join_handle.await;

match join_result {
Ok(Ok(output)) => output,
Ok(Err(_)) => {
// Pinned task was aborted. But that only happens if this
// task is aborted. So this is an impossible branch.
unreachable!(
"Reaching this branch means this task was previously \
aborted but it continued running anyways"
)
}
Err(e) => {
if e.is_panic() {
std::panic::resume_unwind(e.into_panic());
} else if e.is_cancelled() {
// No one else should have the join handle, so this is
// unexpected. Forward this error as a panic in the join
// handle.
panic!("spawn_pinned task was canceled: {}", e);
} else {
// Something unknown happened (not a panic or
// cancellation). Forward this error as a panic in the
// join handle.
panic!("spawn_pinned task failed: {}", e);
}
}
}
})
}

/// Find the worker with the least number of tasks, increment its task
/// count, and return its handle. Make sure to actually spawn a task on
/// the worker so the task count is kept consistent with load.
///
/// A job count guard is also returned to ensure the task count gets
/// decremented when the job is done.
fn find_and_incr_least_burdened_worker(&self) -> (&LocalWorkerHandle, JobCountGuard) {
loop {
let (worker, task_count) = self
.workers
.iter()
.map(|worker| (worker, worker.task_count.load(Ordering::SeqCst)))
.min_by_key(|&(_, count)| count)
.expect("There must be more than one worker");

// Make sure the task count hasn't changed since when we choose this
// worker. Otherwise, restart the search.
if worker
.task_count
.compare_exchange(
task_count,
task_count + 1,
Ordering::SeqCst,
Ordering::Relaxed,
)
.is_ok()
{
return (worker, JobCountGuard(Arc::clone(&worker.task_count)));
}
}
}
}

/// Automatically decrements a worker's job count when a job finishes (when
/// this gets dropped).
struct JobCountGuard(Arc<AtomicUsize>);

impl Drop for JobCountGuard {
fn drop(&mut self) {
// Decrement the job count
let previous_value = self.0.fetch_sub(1, Ordering::SeqCst);
debug_assert!(previous_value >= 1);
}
}

/// Calls abort on the handle when dropped.
struct AbortGuard(AbortHandle);

impl Drop for AbortGuard {
fn drop(&mut self) {
self.0.abort();
}
}

type PinnedFutureSpawner = Box<dyn FnOnce() + Send + 'static>;

struct LocalWorkerHandle {
runtime_handle: tokio::runtime::Handle,
spawner: UnboundedSender<PinnedFutureSpawner>,
task_count: Arc<AtomicUsize>,
}

impl LocalWorkerHandle {
/// Create a new worker for executing pinned tasks
fn new_worker() -> LocalWorkerHandle {
let (sender, receiver) = unbounded_channel();
let runtime = Builder::new_current_thread()
.enable_all()
.build()
.expect("Failed to start a pinned worker thread runtime");
let runtime_handle = runtime.handle().clone();
let task_count = Arc::new(AtomicUsize::new(0));
let task_count_clone = Arc::clone(&task_count);

std::thread::spawn(|| Self::run(runtime, receiver, task_count_clone));

LocalWorkerHandle {
runtime_handle,
spawner: sender,
task_count,
}
}

fn run(
runtime: tokio::runtime::Runtime,
mut task_receiver: UnboundedReceiver<PinnedFutureSpawner>,
task_count: Arc<AtomicUsize>,
) {
let local_set = LocalSet::new();
local_set.block_on(&runtime, async {
while let Some(spawn_task) = task_receiver.recv().await {
// Calls spawn_local(future)
(spawn_task)();
}
});

// If there are any tasks on the runtime associated with a LocalSet task
// that has already completed, but whose output has not yet been
// reported, let that task complete.
//
// Since the task_count is decremented when the runtime task exits,
// reading that counter lets us know if any such tasks completed during
// the call to `block_on`.
//
// Tasks on the LocalSet can't complete during this loop since they're
// stored on the LocalSet and we aren't accessing it.
let mut previous_task_count = task_count.load(Ordering::SeqCst);
loop {
// This call will also run tasks spawned on the runtime.
runtime.block_on(tokio::task::yield_now());
let new_task_count = task_count.load(Ordering::SeqCst);
if new_task_count == previous_task_count {
break;
} else {
previous_task_count = new_task_count;
}
}

// It's now no longer possible for a task on the runtime to be
// associated with a LocalSet task that has completed. Drop both the
// LocalSet and runtime to let tasks on the runtime be cancelled if and
// only if they are still on the LocalSet.
//
// Drop the LocalSet task first so that anyone awaiting the runtime
// JoinHandle will see the cancelled error after the LocalSet task
// destructor has completed.
drop(local_set);
drop(runtime);
}
}
Loading

0 comments on commit 257053e

Please sign in to comment.