Skip to content

Commit

Permalink
Encode all join conditions in a single expression field (apache#7612)
Browse files Browse the repository at this point in the history
* Encode all join conditions in a single expression field

* Removed all references to post_join_filter

* Simplify from_substrait_rel()

* Clippy fix

* Added test to ensure that Substrait plans produced from DF do not contain a post_join_filter

---------

Co-authored-by: Andrew Lamb <[email protected]>
  • Loading branch information
nseekhao and alamb authored Oct 13, 2023
1 parent e0fa75f commit f5a6d01
Show file tree
Hide file tree
Showing 3 changed files with 205 additions and 54 deletions.
108 changes: 66 additions & 42 deletions datafusion/substrait/src/logical_plan/consumer.rs
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
use async_recursion::async_recursion;
use datafusion::arrow::datatypes::{DataType, Field, TimeUnit};
use datafusion::common::{not_impl_err, DFField, DFSchema, DFSchemaRef};

use datafusion::logical_expr::{
aggregate_function, window_function::find_df_window_func, BinaryExpr,
BuiltinScalarFunction, Case, Expr, LogicalPlan, Operator,
Expand Down Expand Up @@ -129,6 +130,51 @@ fn scalar_function_type_from_str(name: &str) -> Result<ScalarFunctionType> {
}
}

fn split_eq_and_noneq_join_predicate_with_nulls_equality(
filter: &Expr,
) -> (Vec<(Column, Column)>, bool, Option<Expr>) {
let exprs = split_conjunction(filter);

let mut accum_join_keys: Vec<(Column, Column)> = vec![];
let mut accum_filters: Vec<Expr> = vec![];
let mut nulls_equal_nulls = false;

for expr in exprs {
match expr {
Expr::BinaryExpr(binary_expr) => match binary_expr {
x @ (BinaryExpr {
left,
op: Operator::Eq,
right,
}
| BinaryExpr {
left,
op: Operator::IsNotDistinctFrom,
right,
}) => {
nulls_equal_nulls = match x.op {
Operator::Eq => false,
Operator::IsNotDistinctFrom => true,
_ => unreachable!(),
};

match (left.as_ref(), right.as_ref()) {
(Expr::Column(l), Expr::Column(r)) => {
accum_join_keys.push((l.clone(), r.clone()));
}
_ => accum_filters.push(expr.clone()),
}
}
_ => accum_filters.push(expr.clone()),
},
_ => accum_filters.push(expr.clone()),
}
}

let join_filter = accum_filters.into_iter().reduce(Expr::and);
(accum_join_keys, nulls_equal_nulls, join_filter)
}

/// Convert Substrait Plan to DataFusion DataFrame
pub async fn from_substrait_plan(
ctx: &mut SessionContext,
Expand Down Expand Up @@ -336,7 +382,13 @@ pub async fn from_substrait_rel(
}
}
Some(RelType::Join(join)) => {
let left = LogicalPlanBuilder::from(
if join.post_join_filter.is_some() {
return not_impl_err!(
"JoinRel with post_join_filter is not yet supported"
);
}

let left: LogicalPlanBuilder = LogicalPlanBuilder::from(
from_substrait_rel(ctx, join.left.as_ref().unwrap(), extensions).await?,
);
let right = LogicalPlanBuilder::from(
Expand All @@ -346,60 +398,32 @@ pub async fn from_substrait_rel(
// The join condition expression needs full input schema and not the output schema from join since we lose columns from
// certain join types such as semi and anti joins
let in_join_schema = left.schema().join(right.schema())?;
// Parse post join filter if exists
let join_filter = match &join.post_join_filter {
Some(filter) => {
let parsed_filter =
from_substrait_rex(filter, &in_join_schema, extensions).await?;
Some(parsed_filter.as_ref().clone())
}
None => None,
};

// If join expression exists, parse the `on` condition expression, build join and return
// Otherwise, build join with koin filter, without join keys
// Otherwise, build join with only the filter, without join keys
match &join.expression.as_ref() {
Some(expr) => {
let on =
from_substrait_rex(expr, &in_join_schema, extensions).await?;
let predicates = split_conjunction(&on);
// TODO: collect only one null_eq_null
let join_exprs: Vec<(Column, Column, bool)> = predicates
.iter()
.map(|p| match p {
Expr::BinaryExpr(BinaryExpr { left, op, right }) => {
match (left.as_ref(), right.as_ref()) {
(Expr::Column(l), Expr::Column(r)) => match op {
Operator::Eq => Ok((l.clone(), r.clone(), false)),
Operator::IsNotDistinctFrom => {
Ok((l.clone(), r.clone(), true))
}
_ => plan_err!("invalid join condition op"),
},
_ => plan_err!("invalid join condition expression"),
}
}
_ => plan_err!(
"Non-binary expression is not supported in join condition"
),
})
.collect::<Result<Vec<_>>>()?;
let (left_cols, right_cols, null_eq_nulls): (Vec<_>, Vec<_>, Vec<_>) =
itertools::multiunzip(join_exprs);
// The join expression can contain both equal and non-equal ops.
// As of datafusion 31.0.0, the equal and non equal join conditions are in separate fields.
// So we extract each part as follows:
// - If an Eq or IsNotDistinctFrom op is encountered, add the left column, right column and is_null_equal_nulls to `join_ons` vector
// - Otherwise we add the expression to join_filter (use conjunction if filter already exists)
let (join_ons, nulls_equal_nulls, join_filter) =
split_eq_and_noneq_join_predicate_with_nulls_equality(&on);
let (left_cols, right_cols): (Vec<_>, Vec<_>) =
itertools::multiunzip(join_ons);
left.join_detailed(
right.build()?,
join_type,
(left_cols, right_cols),
join_filter,
null_eq_nulls[0],
nulls_equal_nulls,
)?
.build()
}
None => match &join_filter {
Some(_) => left
.join_on(right.build()?, join_type, join_filter)?
.build(),
None => plan_err!("Join without join keys require a valid filter"),
},
None => plan_err!("JoinRel without join condition is not allowed"),
}
}
Some(RelType::Read(read)) => match &read.as_ref().read_type {
Expand Down
33 changes: 25 additions & 8 deletions datafusion/substrait/src/logical_plan/producer.rs
Original file line number Diff line number Diff line change
Expand Up @@ -278,39 +278,56 @@ pub fn to_substrait_rel(
// parse filter if exists
let in_join_schema = join.left.schema().join(join.right.schema())?;
let join_filter = match &join.filter {
Some(filter) => Some(Box::new(to_substrait_rex(
Some(filter) => Some(to_substrait_rex(
filter,
&Arc::new(in_join_schema),
0,
extension_info,
)?)),
)?),
None => None,
};

// map the left and right columns to binary expressions in the form `l = r`
// build a single expression for the ON condition, such as `l.a = r.a AND l.b = r.b`
let eq_op = if join.null_equals_null {
Operator::IsNotDistinctFrom
} else {
Operator::Eq
};

let join_expr = to_substrait_join_expr(
let join_on = to_substrait_join_expr(
&join.on,
eq_op,
join.left.schema(),
join.right.schema(),
extension_info,
)?
.map(Box::new);
)?;

// create conjunction between `join_on` and `join_filter` to embed all join conditions,
// whether equal or non-equal in a single expression
let join_expr = match &join_on {
Some(on_expr) => match &join_filter {
Some(filter) => Some(Box::new(make_binary_op_scalar_func(
on_expr,
filter,
Operator::And,
extension_info,
))),
None => join_on.map(Box::new), // the join expression will only contain `join_on` if filter doesn't exist
},
None => match &join_filter {
Some(_) => join_filter.map(Box::new), // the join expression will only contain `join_filter` if the `on` condition doesn't exist
None => None,
},
};

Ok(Box::new(Rel {
rel_type: Some(RelType::Join(Box::new(JoinRel {
common: None,
left: Some(left),
right: Some(right),
r#type: join_type as i32,
expression: join_expr,
post_join_filter: join_filter,
expression: join_expr.clone(),
post_join_filter: None,
advanced_extension: None,
}))),
}))
Expand Down
118 changes: 114 additions & 4 deletions datafusion/substrait/tests/cases/roundtrip_logical_plan.rs
Original file line number Diff line number Diff line change
Expand Up @@ -23,15 +23,18 @@ use std::hash::Hash;
use std::sync::Arc;

use datafusion::arrow::datatypes::{DataType, Field, Schema, TimeUnit};
use datafusion::common::{DFSchema, DFSchemaRef};
use datafusion::error::Result;
use datafusion::common::{not_impl_err, plan_err, DFSchema, DFSchemaRef};
use datafusion::error::{DataFusionError, Result};
use datafusion::execution::context::SessionState;
use datafusion::execution::registry::SerializerRegistry;
use datafusion::execution::runtime_env::RuntimeEnv;
use datafusion::logical_expr::{Extension, LogicalPlan, UserDefinedLogicalNode};
use datafusion::optimizer::simplify_expressions::expr_simplifier::THRESHOLD_INLINE_INLIST;
use datafusion::prelude::*;

use substrait::proto::extensions::simple_extension_declaration::MappingType;
use substrait::proto::rel::RelType;
use substrait::proto::{plan_rel, Plan, Rel};

struct MockSerializerRegistry;

Expand Down Expand Up @@ -383,12 +386,15 @@ async fn roundtrip_inner_join() -> Result<()> {

#[tokio::test]
async fn roundtrip_non_equi_inner_join() -> Result<()> {
roundtrip("SELECT data.a FROM data JOIN data2 ON data.a <> data2.a").await
roundtrip_verify_post_join_filter(
"SELECT data.a FROM data JOIN data2 ON data.a <> data2.a",
)
.await
}

#[tokio::test]
async fn roundtrip_non_equi_join() -> Result<()> {
roundtrip(
roundtrip_verify_post_join_filter(
"SELECT data.a FROM data, data2 WHERE data.a = data2.a AND data.e > data2.a",
)
.await
Expand Down Expand Up @@ -620,6 +626,91 @@ async fn extension_logical_plan() -> Result<()> {
Ok(())
}

fn check_post_join_filters(rel: &Rel) -> Result<()> {
// search for target_rel and field value in proto
match &rel.rel_type {
Some(RelType::Join(join)) => {
// check if join filter is None
if join.post_join_filter.is_some() {
plan_err!(
"DataFusion generated Susbtrait plan cannot have post_join_filter in JoinRel"
)
} else {
// recursively check JoinRels
match check_post_join_filters(join.left.as_ref().unwrap().as_ref()) {
Err(e) => Err(e),
Ok(_) => {
check_post_join_filters(join.right.as_ref().unwrap().as_ref())
}
}
}
}
Some(RelType::Project(p)) => {
check_post_join_filters(p.input.as_ref().unwrap().as_ref())
}
Some(RelType::Filter(filter)) => {
check_post_join_filters(filter.input.as_ref().unwrap().as_ref())
}
Some(RelType::Fetch(fetch)) => {
check_post_join_filters(fetch.input.as_ref().unwrap().as_ref())
}
Some(RelType::Sort(sort)) => {
check_post_join_filters(sort.input.as_ref().unwrap().as_ref())
}
Some(RelType::Aggregate(agg)) => {
check_post_join_filters(agg.input.as_ref().unwrap().as_ref())
}
Some(RelType::Set(set)) => {
for input in &set.inputs {
match check_post_join_filters(input) {
Err(e) => return Err(e),
Ok(_) => continue,
}
}
Ok(())
}
Some(RelType::ExtensionSingle(ext)) => {
check_post_join_filters(ext.input.as_ref().unwrap().as_ref())
}
Some(RelType::ExtensionMulti(ext)) => {
for input in &ext.inputs {
match check_post_join_filters(input) {
Err(e) => return Err(e),
Ok(_) => continue,
}
}
Ok(())
}
Some(RelType::ExtensionLeaf(_)) | Some(RelType::Read(_)) => Ok(()),
_ => not_impl_err!(
"Unsupported RelType: {:?} in post join filter check",
rel.rel_type
),
}
}

async fn verify_post_join_filter_value(proto: Box<Plan>) -> Result<()> {
for relation in &proto.relations {
match relation.rel_type.as_ref() {
Some(rt) => match rt {
plan_rel::RelType::Rel(rel) => match check_post_join_filters(rel) {
Err(e) => return Err(e),
Ok(_) => continue,
},
plan_rel::RelType::Root(root) => {
match check_post_join_filters(root.input.as_ref().unwrap()) {
Err(e) => return Err(e),
Ok(_) => continue,
}
}
},
None => return plan_err!("Cannot parse plan relation: None"),
}
}

Ok(())
}

async fn assert_expected_plan(sql: &str, expected_plan_str: &str) -> Result<()> {
let mut ctx = create_context().await?;
let df = ctx.sql(sql).await?;
Expand Down Expand Up @@ -688,6 +779,25 @@ async fn roundtrip(sql: &str) -> Result<()> {
Ok(())
}

async fn roundtrip_verify_post_join_filter(sql: &str) -> Result<()> {
let mut ctx = create_context().await?;
let df = ctx.sql(sql).await?;
let plan = df.into_optimized_plan()?;
let proto = to_substrait_plan(&plan, &ctx)?;
let plan2 = from_substrait_plan(&mut ctx, &proto).await?;
let plan2 = ctx.state().optimize(&plan2)?;

println!("{plan:#?}");
println!("{plan2:#?}");

let plan1str = format!("{plan:?}");
let plan2str = format!("{plan2:?}");
assert_eq!(plan1str, plan2str);

// verify that the join filters are None
verify_post_join_filter_value(proto).await
}

async fn roundtrip_all_types(sql: &str) -> Result<()> {
let mut ctx = create_all_type_context().await?;
let df = ctx.sql(sql).await?;
Expand Down

0 comments on commit f5a6d01

Please sign in to comment.