Skip to content

Commit

Permalink
DataFrame owned SessionState (apache#4617)
Browse files Browse the repository at this point in the history
  • Loading branch information
tustvold committed Dec 14, 2022
1 parent e38e76d commit d2091d9
Show file tree
Hide file tree
Showing 3 changed files with 25 additions and 51 deletions.
2 changes: 1 addition & 1 deletion datafusion-examples/examples/custom_datasource.rs
Original file line number Diff line number Diff line change
Expand Up @@ -69,7 +69,7 @@ async fn search_accounts(
)?
.build()?;

let mut dataframe = DataFrame::new(ctx.state, logical_plan)
let mut dataframe = DataFrame::new(ctx.state(), logical_plan)
.select_columns(&["id", "bank_account"])?;

if let Some(f) = filter {
Expand Down
58 changes: 16 additions & 42 deletions datafusion/core/src/dataframe.rs
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,6 @@ use std::any::Any;
use std::sync::Arc;

use async_trait::async_trait;
use parking_lot::RwLock;
use parquet::file::properties::WriterProperties;

use datafusion_common::{Column, DFSchema};
Expand Down Expand Up @@ -74,40 +73,22 @@ use crate::prelude::SessionContext;
/// ```
#[derive(Debug, Clone)]
pub struct DataFrame {
session_state: Arc<RwLock<SessionState>>,
session_state: SessionState,
plan: LogicalPlan,
}

impl DataFrame {
/// Create a new Table based on an existing logical plan
pub fn new(session_state: Arc<RwLock<SessionState>>, plan: LogicalPlan) -> Self {
pub fn new(session_state: SessionState, plan: LogicalPlan) -> Self {
Self {
session_state,
plan,
}
}

/// Create a physical plan
pub async fn create_physical_plan(self) -> Result<Arc<dyn ExecutionPlan>> {
// this function is copied from SessionContext function of the
// same name
let state_cloned = {
let mut state = self.session_state.write();
state.execution_props.start_execution();

// We need to clone `state` to release the lock that is not `Send`. We could
// make the lock `Send` by using `tokio::sync::Mutex`, but that would require to
// propagate async even to the `LogicalPlan` building methods.
// Cloning `state` here is fine as we then pass it as immutable `&state`, which
// means that we avoid write consistency issues as the cloned version will not
// be written to. As for eventual modifications that would be applied to the
// original state after it has been cloned, they will not be picked up by the
// clone but that is okay, as it is equivalent to postponing the state update
// by keeping the lock until the end of the function scope.
state.clone()
};

state_cloned.create_physical_plan(&self.plan).await
pub async fn create_physical_plan(&self) -> Result<Arc<dyn ExecutionPlan>> {
self.session_state.create_physical_plan(&self.plan).await
}

/// Filter the DataFrame by column. Returns a new DataFrame only containing the
Expand Down Expand Up @@ -437,8 +418,7 @@ impl DataFrame {
}

fn task_ctx(&self) -> TaskContext {
let lock = self.session_state.read();
TaskContext::from(&*lock)
TaskContext::from(&self.session_state)
}

/// Executes this DataFrame and returns a stream over a single partition
Expand Down Expand Up @@ -527,8 +507,7 @@ impl DataFrame {
/// Return the optimized logical plan represented by this DataFrame.
pub fn to_logical_plan(self) -> Result<LogicalPlan> {
// Optimize the plan first for better UX
let state = self.session_state.read().clone();
state.optimize(&self.plan)
self.session_state.optimize(&self.plan)
}

/// Return a DataFrame with the explanation of its plan so far.
Expand Down Expand Up @@ -567,9 +546,8 @@ impl DataFrame {
/// # Ok(())
/// # }
/// ```
pub fn registry(&self) -> Arc<dyn FunctionRegistry> {
let registry = self.session_state.read().clone();
Arc::new(registry)
pub fn registry(&self) -> &dyn FunctionRegistry {
&self.session_state
}

/// Calculate the intersection of two [`DataFrame`]s. The two [`DataFrame`]s must have exactly the same schema
Expand Down Expand Up @@ -621,9 +599,8 @@ impl DataFrame {

/// Write a `DataFrame` to a CSV file.
pub async fn write_csv(self, path: &str) -> Result<()> {
let state = self.session_state.read().clone();
let plan = self.create_physical_plan().await?;
plan_to_csv(&state, plan, path).await
plan_to_csv(&self.session_state, plan, path).await
}

/// Write a `DataFrame` to a Parquet file.
Expand All @@ -632,16 +609,14 @@ impl DataFrame {
path: &str,
writer_properties: Option<WriterProperties>,
) -> Result<()> {
let state = self.session_state.read().clone();
let plan = self.create_physical_plan().await?;
plan_to_parquet(&state, plan, path, writer_properties).await
plan_to_parquet(&self.session_state, plan, path, writer_properties).await
}

/// Executes a query and writes the results to a partitioned JSON file.
pub async fn write_json(self, path: impl AsRef<str>) -> Result<()> {
let state = self.session_state.read().clone();
let plan = self.create_physical_plan().await?;
plan_to_json(&state, plan, path).await
plan_to_json(&self.session_state, plan, path).await
}

/// Add an additional column to the DataFrame.
Expand Down Expand Up @@ -747,7 +722,7 @@ impl DataFrame {
/// # }
/// ```
pub async fn cache(self) -> Result<DataFrame> {
let context = SessionContext::with_state(self.session_state.read().clone());
let context = SessionContext::with_state(self.session_state.clone());
let mem_table = MemTable::try_new(
SchemaRef::from(self.schema().clone()),
self.collect_partitioned().await?,
Expand Down Expand Up @@ -1029,9 +1004,8 @@ mod tests {
// build query with a UDF using DataFrame API
let df = ctx.table("aggregate_test_100")?;

let f = df.registry();

let df = df.select(vec![f.udf("my_fn")?.call(vec![col("c12")])])?;
let expr = df.registry().udf("my_fn")?.call(vec![col("c12")]);
let df = df.select(vec![expr])?;

// build query using SQL
let sql_plan =
Expand Down Expand Up @@ -1088,7 +1062,7 @@ mod tests {
async fn register_table() -> Result<()> {
let df = test_table().await?.select_columns(&["c1", "c12"])?;
let ctx = SessionContext::new();
let df_impl = DataFrame::new(ctx.state.clone(), df.plan.clone());
let df_impl = DataFrame::new(ctx.state(), df.plan.clone());

// register a dataframe as a table
ctx.register_table("test_table", Arc::new(df_impl.clone()))?;
Expand Down Expand Up @@ -1180,7 +1154,7 @@ mod tests {
async fn with_column() -> Result<()> {
let df = test_table().await?.select_columns(&["c1", "c2", "c3"])?;
let ctx = SessionContext::new();
let df_impl = DataFrame::new(ctx.state.clone(), df.plan.clone());
let df_impl = DataFrame::new(ctx.state(), df.plan.clone());

let df = df_impl
.filter(col("c2").eq(lit(3)).and(col("c1").eq(lit("a"))))?
Expand Down
16 changes: 8 additions & 8 deletions datafusion/core/src/execution/context.rs
Original file line number Diff line number Diff line change
Expand Up @@ -273,7 +273,7 @@ impl SessionContext {
(false, true, Ok(_)) => {
self.deregister_table(&name)?;
let schema = Arc::new(input.schema().as_ref().into());
let physical = DataFrame::new(self.state.clone(), input);
let physical = DataFrame::new(self.state(), input);

let batches: Vec<_> = physical.collect_partitioned().await?;
let table = Arc::new(MemTable::try_new(schema, batches)?);
Expand All @@ -286,7 +286,7 @@ impl SessionContext {
)),
(_, _, Err(_)) => {
let schema = Arc::new(input.schema().as_ref().into());
let physical = DataFrame::new(self.state.clone(), input);
let physical = DataFrame::new(self.state(), input);

let batches: Vec<_> = physical.collect_partitioned().await?;
let table = Arc::new(MemTable::try_new(schema, batches)?);
Expand Down Expand Up @@ -475,14 +475,14 @@ impl SessionContext {
}
}

plan => Ok(DataFrame::new(self.state.clone(), plan)),
plan => Ok(DataFrame::new(self.state(), plan)),
}
}

// return an empty dataframe
fn return_empty_dataframe(&self) -> Result<DataFrame> {
let plan = LogicalPlanBuilder::empty(false).build()?;
Ok(DataFrame::new(self.state.clone(), plan))
Ok(DataFrame::new(self.state(), plan))
}

async fn create_external_table(
Expand Down Expand Up @@ -661,7 +661,7 @@ impl SessionContext {
/// Creates an empty DataFrame.
pub fn read_empty(&self) -> Result<DataFrame> {
Ok(DataFrame::new(
self.state.clone(),
self.state(),
LogicalPlanBuilder::empty(true).build()?,
))
}
Expand Down Expand Up @@ -716,7 +716,7 @@ impl SessionContext {
/// Creates a [`DataFrame`] for reading a custom [`TableProvider`].
pub fn read_table(&self, provider: Arc<dyn TableProvider>) -> Result<DataFrame> {
Ok(DataFrame::new(
self.state.clone(),
self.state(),
LogicalPlanBuilder::scan(UNNAMED_TABLE, provider_as_source(provider), None)?
.build()?,
))
Expand All @@ -726,7 +726,7 @@ impl SessionContext {
pub fn read_batch(&self, batch: RecordBatch) -> Result<DataFrame> {
let provider = MemTable::try_new(batch.schema(), vec![vec![batch]])?;
Ok(DataFrame::new(
self.state.clone(),
self.state(),
LogicalPlanBuilder::scan(
UNNAMED_TABLE,
provider_as_source(Arc::new(provider)),
Expand Down Expand Up @@ -946,7 +946,7 @@ impl SessionContext {
None,
)?
.build()?;
Ok(DataFrame::new(self.state.clone(), plan))
Ok(DataFrame::new(self.state(), plan))
}

/// Return a [`TabelProvider`] for the specified table.
Expand Down

0 comments on commit d2091d9

Please sign in to comment.