Skip to content

Commit

Permalink
- use ZRuntime executor instead of futures::executor
Browse files Browse the repository at this point in the history
- terminate more tasks
- make TaskController::terminate_all[_async] accept timeout duration
  • Loading branch information
DenisBiryukov91 committed Mar 26, 2024
1 parent 9d293a7 commit 9841bea
Show file tree
Hide file tree
Showing 16 changed files with 257 additions and 172 deletions.
6 changes: 3 additions & 3 deletions ci/valgrind-check/src/queryable_get/bin/z_queryable_get.rs
Original file line number Diff line number Diff line change
Expand Up @@ -34,9 +34,9 @@ async fn main() {
queryable_key_expr.clone(),
query.value().unwrap().clone(),
));
futures::executor::block_on(async move {
query.reply(reply).res().await.unwrap();
})
zenoh_runtime::ZRuntime::Application.block_in_place(
async move { query.reply(reply).res().await.unwrap(); }
);
})
.complete(true)
.res()
Expand Down
1 change: 1 addition & 0 deletions commons/zenoh-task/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -28,5 +28,6 @@ description = "Internal crate for zenoh."
tokio = { workspace = true, features = ["default", "sync"] }
futures = { workspace = true }
log = { workspace = true }
zenoh-core = { workspace = true }
zenoh-runtime = { workspace = true }
tokio-util = { workspace = true, features = ["rt"] }
89 changes: 71 additions & 18 deletions commons/zenoh-task/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@ use std::time::Duration;
use tokio::task::JoinHandle;
use tokio_util::sync::CancellationToken;
use tokio_util::task::TaskTracker;
use zenoh_core::{ResolveFuture, SyncResolve};
use zenoh_runtime::ZRuntime;

#[derive(Clone)]
Expand Down Expand Up @@ -102,37 +103,89 @@ impl TaskController {
self.tracker.spawn_on(future.map(|_f| ()), &rt)
}

/// Terminates all prevously spawned tasks
/// Attempts tp terminate all previously spawned tasks
/// The caller must ensure that all tasks spawned with [`TaskController::spawn()`]
/// or [`TaskController::spawn_with_rt()`] can yield in finite amount of time either because they will run to completion
/// or due to cancellation of token acquired via [`TaskController::get_cancellation_token()`].
/// Tasks spawned with [`TaskController::spawn_abortable()`] or [`TaskController::spawn_abortable_with_rt()`] will be aborted (i.e. terminated upon next await call).
pub fn terminate_all(&self) {
/// The call blocks until all tasks yield or timeout duration expires.
/// Returns 0 in case of success, number of non terminated tasks otherwise.
pub fn terminate_all(&self, timeout: Duration) -> usize {
ResolveFuture::new(async move { self.terminate_all_async(timeout).await }).res_sync()
}

/// Async version of [`TaskController::terminate_all()`].
pub async fn terminate_all_async(&self, timeout: Duration) -> usize {
self.tracker.close();
self.token.cancel();
if tokio::time::timeout(timeout, self.tracker.wait())
.await
.is_err()
{
log::error!("Failed to terminate {} tasks", self.tracker.len());
return self.tracker.len();
}
0
}
}

pub struct TerminatableTask {
handle: JoinHandle<()>,
token: CancellationToken,
}

impl TerminatableTask {
pub fn create_cancellation_token() -> CancellationToken {
CancellationToken::new()
}

/// Spawns a task that can be later terminated by [`TerminatableTask::terminate()`].
/// Prior to termination attempt the specified cancellation token will be cancelled.
pub fn spawn<F, T>(rt: ZRuntime, future: F, token: CancellationToken) -> TerminatableTask
where
F: Future<Output = T> + Send + 'static,
T: Send + 'static,
{
TerminatableTask {
handle: rt.spawn(future.map(|_f| ())),
token,
}
}

/// Spawns a task that can be later aborted by [`TerminatableTask::terminate()`].
pub fn spawn_abortable<F, T>(rt: ZRuntime, future: F) -> TerminatableTask
where
F: Future<Output = T> + Send + 'static,
T: Send + 'static,
{
let token = CancellationToken::new();
let token2 = token.clone();
let task = async move {
tokio::select! {
_ = tokio::time::sleep(Duration::from_secs(10)) => {
log::error!("Failed to terminate {} tasks", self.tracker.len());
}
_ = self.tracker.wait() => {}
_ = token2.cancelled() => {},
_ = future => {}
}
};
futures::executor::block_on(task);

TerminatableTask {
handle: rt.spawn(task),
token,
}
}

/// Async version of [`TaskController::terminate_all()`].
pub async fn terminate_all_async(&self) {
self.tracker.close();
/// Attempts to terminate the task.
/// Returns true if task completed / aborted within timeout duration, false otherwise.
pub fn terminate(self, timeout: Duration) -> bool {
ResolveFuture::new(async move { self.terminate_async(timeout).await }).res_sync()
}

/// Async version of [`TerminatableTask::terminate()`].
pub async fn terminate_async(self, timeout: Duration) -> bool {
self.token.cancel();
let task = async move {
tokio::select! {
_ = tokio::time::sleep(Duration::from_secs(10)) => {
log::error!("Failed to terminate {} tasks", self.tracker.len());
}
_ = self.tracker.wait() => {}
}
if tokio::time::timeout(timeout, self.handle).await.is_err() {
log::error!("Failed to terminate the task");
return false;
};
task.await;
true
}
}
4 changes: 3 additions & 1 deletion io/zenoh-transport/src/manager.rs
Original file line number Diff line number Diff line change
Expand Up @@ -397,7 +397,9 @@ impl TransportManager {

pub async fn close(&self) {
self.close_unicast().await;
self.task_controller.terminate_all();
self.task_controller
.terminate_all_async(Duration::from_secs(10))
.await;
}

/*************************************/
Expand Down
4 changes: 3 additions & 1 deletion io/zenoh-transport/src/multicast/transport.rs
Original file line number Diff line number Diff line change
Expand Up @@ -184,7 +184,9 @@ impl TransportMulticastInner {
cb.closed();
}

self.task_controller.terminate_all_async().await;
self.task_controller
.terminate_all_async(Duration::from_secs(10))
.await;

Ok(())
}
Expand Down
1 change: 1 addition & 0 deletions zenoh-ext/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,7 @@ zenoh-result = { workspace = true }
zenoh-sync = { workspace = true }
zenoh-util = { workspace = true }
zenoh-runtime = { workspace = true }
zenoh-task = { workspace = true }

[dev-dependencies]
clap = { workspace = true, features = ["derive"] }
Expand Down
123 changes: 63 additions & 60 deletions zenoh-ext/src/publication_cache.rs
Original file line number Diff line number Diff line change
Expand Up @@ -11,16 +11,17 @@
// Contributors:
// ZettaScale Zenoh Team, <[email protected]>
//
use flume::{bounded, Sender};
use std::collections::{HashMap, VecDeque};
use std::convert::TryInto;
use std::future::Ready;
use std::time::Duration;
use zenoh::prelude::r#async::*;
use zenoh::queryable::{Query, Queryable};
use zenoh::subscriber::FlumeSubscriber;
use zenoh::SessionRef;
use zenoh_core::{AsyncResolve, Resolvable, SyncResolve};
use zenoh_result::{bail, ZResult};
use zenoh_task::TerminatableTask;
use zenoh_util::core::ResolveFuture;

/// The builder of PublicationCache, allowing to configure it.
Expand Down Expand Up @@ -110,7 +111,7 @@ impl<'a> AsyncResolve for PublicationCacheBuilder<'a, '_, '_> {
pub struct PublicationCache<'a> {
local_sub: FlumeSubscriber<'a>,
_queryable: Queryable<'a, flume::Receiver<Query>>,
_stoptx: Sender<bool>,
task: TerminatableTask,
}

impl<'a> PublicationCache<'a> {
Expand Down Expand Up @@ -166,58 +167,46 @@ impl<'a> PublicationCache<'a> {
let history = conf.history;

// TODO(yuyuan): use CancellationToken to manage it
let (stoptx, stoprx) = bounded::<bool>(1);
zenoh_runtime::ZRuntime::TX.spawn(async move {
let mut cache: HashMap<OwnedKeyExpr, VecDeque<Sample>> =
HashMap::with_capacity(resources_limit.unwrap_or(32));
let limit = resources_limit.unwrap_or(usize::MAX);
let token = TerminatableTask::create_cancellation_token();
let token2 = token.clone();
let task = TerminatableTask::spawn(
zenoh_runtime::ZRuntime::TX,
async move {
let mut cache: HashMap<OwnedKeyExpr, VecDeque<Sample>> =
HashMap::with_capacity(resources_limit.unwrap_or(32));
let limit = resources_limit.unwrap_or(usize::MAX);
loop {
tokio::select! {
// on publication received by the local subscriber, store it
sample = sub_recv.recv_async() => {
if let Ok(sample) = sample {
let queryable_key_expr: KeyExpr<'_> = if let Some(prefix) = &queryable_prefix {
prefix.join(&sample.key_expr).unwrap().into()
} else {
sample.key_expr.clone()
};

loop {
tokio::select! {
// on publication received by the local subscriber, store it
sample = sub_recv.recv_async() => {
if let Ok(sample) = sample {
let queryable_key_expr: KeyExpr<'_> = if let Some(prefix) = &queryable_prefix {
prefix.join(&sample.key_expr).unwrap().into()
} else {
sample.key_expr.clone()
};

if let Some(queue) = cache.get_mut(queryable_key_expr.as_keyexpr()) {
if queue.len() >= history {
queue.pop_front();
if let Some(queue) = cache.get_mut(queryable_key_expr.as_keyexpr()) {
if queue.len() >= history {
queue.pop_front();
}
queue.push_back(sample);
} else if cache.len() >= limit {
log::error!("PublicationCache on {}: resource_limit exceeded - can't cache publication for a new resource",
pub_key_expr);
} else {
let mut queue: VecDeque<Sample> = VecDeque::new();
queue.push_back(sample);
cache.insert(queryable_key_expr.into(), queue);
}
queue.push_back(sample);
} else if cache.len() >= limit {
log::error!("PublicationCache on {}: resource_limit exceeded - can't cache publication for a new resource",
pub_key_expr);
} else {
let mut queue: VecDeque<Sample> = VecDeque::new();
queue.push_back(sample);
cache.insert(queryable_key_expr.into(), queue);
}
}
},
},

// on query, reply with cach content
query = quer_recv.recv_async() => {
if let Ok(query) = query {
if !query.selector().key_expr.as_str().contains('*') {
if let Some(queue) = cache.get(query.selector().key_expr.as_keyexpr()) {
for sample in queue {
if let (Ok(Some(time_range)), Some(timestamp)) = (query.selector().time_range(), sample.timestamp) {
if !time_range.contains(timestamp.get_time().to_system_time()){
continue;
}
}
if let Err(e) = query.reply(Ok(sample.clone())).res_async().await {
log::warn!("Error replying to query: {}", e);
}
}
}
} else {
for (key_expr, queue) in cache.iter() {
if query.selector().key_expr.intersects(unsafe{ keyexpr::from_str_unchecked(key_expr) }) {
// on query, reply with cach content
query = quer_recv.recv_async() => {
if let Ok(query) = query {
if !query.selector().key_expr.as_str().contains('*') {
if let Some(queue) = cache.get(query.selector().key_expr.as_keyexpr()) {
for sample in queue {
if let (Ok(Some(time_range)), Some(timestamp)) = (query.selector().time_range(), sample.timestamp) {
if !time_range.contains(timestamp.get_time().to_system_time()){
Expand All @@ -229,21 +218,35 @@ impl<'a> PublicationCache<'a> {
}
}
}
} else {
for (key_expr, queue) in cache.iter() {
if query.selector().key_expr.intersects(unsafe{ keyexpr::from_str_unchecked(key_expr) }) {
for sample in queue {
if let (Ok(Some(time_range)), Some(timestamp)) = (query.selector().time_range(), sample.timestamp) {
if !time_range.contains(timestamp.get_time().to_system_time()){
continue;
}
}
if let Err(e) = query.reply(Ok(sample.clone())).res_async().await {
log::warn!("Error replying to query: {}", e);
}
}
}
}
}
}
}
},

// When stoptx is dropped, stop the task
_ = stoprx.recv_async() => return
},
_ = token2.cancelled() => return
}
}
}
});
},
token,
);

Ok(PublicationCache {
local_sub,
_queryable: queryable,
_stoptx: stoptx,
task,
})
}

Expand All @@ -254,11 +257,11 @@ impl<'a> PublicationCache<'a> {
let PublicationCache {
_queryable,
local_sub,
_stoptx,
task,
} = self;
_queryable.undeclare().res_async().await?;
local_sub.undeclare().res_async().await?;
drop(_stoptx);
task.terminate(Duration::from_secs(10));
Ok(())
})
}
Expand Down
3 changes: 3 additions & 0 deletions zenoh/src/net/routing/dispatcher/face.rs
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@ use zenoh_protocol::{
network::{Mapping, Push, Request, RequestId, Response, ResponseFinal},
};
use zenoh_sync::get_mut_unchecked;
use zenoh_task::TaskController;
use zenoh_transport::multicast::TransportMulticast;
#[cfg(feature = "stats")]
use zenoh_transport::stats::TransportStats;
Expand All @@ -45,6 +46,7 @@ pub struct FaceState {
pub(crate) mcast_group: Option<TransportMulticast>,
pub(crate) in_interceptors: Option<Arc<InterceptorsChain>>,
pub(crate) hat: Box<dyn Any + Send + Sync>,
pub(crate) task_controller: TaskController,
}

impl FaceState {
Expand Down Expand Up @@ -73,6 +75,7 @@ impl FaceState {
mcast_group,
in_interceptors,
hat,
task_controller: TaskController::default(),
})
}

Expand Down
Loading

0 comments on commit 9841bea

Please sign in to comment.