Skip to content

Commit

Permalink
Use 'spawn_blocking' to drop sync database pools.
Browse files Browse the repository at this point in the history
This was already done for the connections, but pools might also do
synchronous/blocking work on Drop.

Fixes rwf2#1466.
  • Loading branch information
jebrosen authored and SergioBenitez committed Nov 5, 2020
1 parent 2f98299 commit c6298b9
Show file tree
Hide file tree
Showing 2 changed files with 66 additions and 3 deletions.
17 changes: 14 additions & 3 deletions contrib/lib/src/databases.rs
Original file line number Diff line number Diff line change
Expand Up @@ -707,7 +707,8 @@ impl Poolable for memcache::Client {
#[doc(hidden)]
pub struct ConnectionPool<K, C: Poolable> {
config: Config,
pool: r2d2::Pool<C::Manager>,
// This is an 'Option' so that we can drop the pool in a 'spawn_blocking'.
pool: Option<r2d2::Pool<C::Manager>>,
semaphore: Arc<Semaphore>,
_marker: PhantomData<fn() -> K>,
}
Expand Down Expand Up @@ -766,7 +767,8 @@ impl<K: 'static, C: Poolable> ConnectionPool<K, C> {
let pool_size = config.pool_size;
match C::pool(db, &rocket) {
Ok(pool) => Ok(rocket.manage(ConnectionPool::<K, C> {
pool, config,
config,
pool: Some(pool),
semaphore: Arc::new(Semaphore::new(pool_size as usize)),
_marker: PhantomData,
})),
Expand All @@ -787,7 +789,9 @@ impl<K: 'static, C: Poolable> ConnectionPool<K, C> {
}
};

let pool = self.pool.clone();
let pool = self.pool.as_ref().cloned()
.expect("internal invariant broken: self.pool is Some");

match run_blocking(move || pool.get_timeout(duration)).await {
Ok(c) => Ok(Connection {
connection: Arc::new(Mutex::new(Some(c))),
Expand Down Expand Up @@ -849,6 +853,13 @@ impl<K, C: Poolable> Drop for Connection<K, C> {
}
}

impl<K, C: Poolable> Drop for ConnectionPool<K, C> {
fn drop(&mut self) {
let pool = self.pool.take();
tokio::task::spawn_blocking(move || drop(pool));
}
}

#[rocket::async_trait]
impl<'a, 'r, K: 'static, C: Poolable> FromRequest<'a, 'r> for Connection<K, C> {
type Error = ();
Expand Down
52 changes: 52 additions & 0 deletions contrib/lib/tests/databases.rs
Original file line number Diff line number Diff line change
Expand Up @@ -52,3 +52,55 @@ mod rusqlite_integration_test {
}).await;
}
}

#[cfg(feature = "databases")]
#[cfg(test)]
mod drop_runtime_test {
use r2d2::{ManageConnection, Pool};
use rocket_contrib::databases::{database, Poolable, PoolResult};
use tokio::runtime::Runtime;

struct ContainsRuntime(Runtime);
struct TestConnection;

impl ManageConnection for ContainsRuntime {
type Connection = TestConnection;
type Error = std::convert::Infallible;

fn connect(&self) -> Result<Self::Connection, Self::Error> {
Ok(TestConnection)
}

fn is_valid(&self, _conn: &mut Self::Connection) -> Result<(), Self::Error> {
Ok(())
}

fn has_broken(&self, _conn: &mut Self::Connection) -> bool {
false
}
}

impl Poolable for TestConnection {
type Manager = ContainsRuntime;
type Error = ();

fn pool(_db_name: &str, _rocket: &rocket::Rocket) -> PoolResult<Self> {
let manager = ContainsRuntime(tokio::runtime::Runtime::new().unwrap());
Ok(Pool::builder().build(manager)?)
}
}

#[database("test_db")]
struct TestDb(TestConnection);

#[rocket::async_test]
async fn test_drop_runtime() {
use rocket::figment::{Figment, util::map};

let config = Figment::from(rocket::Config::default())
.merge(("databases", map!["test_db" => map!["url" => ""]]));

let rocket = rocket::custom(config).attach(TestDb::fairing());
drop(rocket);
}
}

0 comments on commit c6298b9

Please sign in to comment.