Skip to content

Commit

Permalink
ARROW-6567: [Rust] [DataFusion] Wrap aggregate in projection when needed
Browse files Browse the repository at this point in the history
This PR fixes a long standing bug where it was assumed that aggregate queries would always have a projection listing grouping expressions before aggregate expressions. The SQL query planner now wraps the aggregate query in a projection, when needed, to preserve the intended column order.

I also fixed a couple of non deterministic tests.

Closes apache#5639 from andygrove/ARROW-6567 and squashes the following commits:

46b9ee3 <Andy Grove> fix non deterministic tests
439f190 <Andy Grove> Wrap aggregate in projection if needed

Authored-by: Andy Grove <[email protected]>
Signed-off-by: Andy Grove <[email protected]>
  • Loading branch information
andygrove committed Oct 15, 2019
1 parent a75e1b7 commit 2e53c00
Show file tree
Hide file tree
Showing 2 changed files with 86 additions and 26 deletions.
82 changes: 66 additions & 16 deletions rust/datafusion/src/sql/planner.rs
Original file line number Diff line number Diff line change
Expand Up @@ -88,10 +88,7 @@ impl SqlToRel {
// collect aggregate expressions
let aggr_expr: Vec<Expr> = expr
.iter()
.filter(|e| match e {
Expr::AggregateFunction { .. } => true,
_ => false,
})
.filter(|e| is_aggregate_expr(e))
.map(|e| e.clone())
.collect();

Expand All @@ -117,28 +114,57 @@ impl SqlToRel {
input_schema,
)?);

//TODO: selection, projection, everything else
Ok(Arc::new(LogicalPlan::Aggregate {
let group_by_count = group_expr.len();
let aggr_count = aggr_expr.len();

let aggregate = Arc::new(LogicalPlan::Aggregate {
input: aggregate_input,
group_expr,
aggr_expr,
schema: Arc::new(aggr_schema),
}))
});

// wrap in projection to preserve final order of fields
let mut projected_fields =
Vec::with_capacity(group_by_count + aggr_count);
let mut group_expr_index = 0;
let mut aggr_expr_index = 0;
for i in 0..expr.len() {
if is_aggregate_expr(&expr[i]) {
projected_fields.push(group_by_count + aggr_expr_index);
aggr_expr_index += 1;
} else {
projected_fields.push(group_expr_index);
group_expr_index += 1;
}
}

// determine if projection is needed or not
// NOTE this would be better done later in a query optimizer rule
let mut projection_needed = false;
for i in 0..projected_fields.len() {
if projected_fields[i] != i {
projection_needed = true;
break;
}
}

if projection_needed {
let projection = create_projection(
projected_fields.iter().map(|i| Expr::Column(*i)).collect(),
aggregate,
)?;
Ok(Arc::new(projection))
} else {
Ok(aggregate)
}
} else {
let projection_input: Arc<LogicalPlan> = match selection_plan {
Some(s) => Arc::new(s),
_ => input.clone(),
};

let projection_schema = Arc::new(Schema::new(
utils::exprlist_to_fields(&expr, input_schema.as_ref())?,
));

let projection = LogicalPlan::Projection {
expr: expr,
input: projection_input,
schema: projection_schema.clone(),
};
let projection = create_projection(expr, projection_input)?;

if let &Some(_) = having {
return Err(ExecutionError::General(
Expand Down Expand Up @@ -377,6 +403,30 @@ impl SqlToRel {
}
}

/// Create a projection
fn create_projection(expr: Vec<Expr>, input: Arc<LogicalPlan>) -> Result<LogicalPlan> {
let input_schema = input.schema();

let schema = Arc::new(Schema::new(utils::exprlist_to_fields(
&expr,
input_schema.as_ref(),
)?));

Ok(LogicalPlan::Projection {
expr,
input,
schema,
})
}

/// Determine if an expression is an aggregate expression or not
fn is_aggregate_expr(e: &Expr) -> bool {
match e {
Expr::AggregateFunction { .. } => true,
_ => false,
}
}

/// Convert SQL data type to relational representation of data type
pub fn convert_data_type(sql: &SQLType) -> Result<DataType> {
match sql {
Expand Down
30 changes: 20 additions & 10 deletions rust/datafusion/tests/sql.rs
Original file line number Diff line number Diff line change
Expand Up @@ -164,6 +164,17 @@ fn csv_query_group_by_avg() {
assert_eq!(expected, actual.join("\n"));
}

#[test]
fn csv_query_group_by_avg_with_projection() {
let mut ctx = ExecutionContext::new();
register_aggregate_csv(&mut ctx);
let sql = "SELECT avg(c12), c1 FROM aggregate_test_100 GROUP BY c1";
let mut actual = execute(&mut ctx, sql);
actual.sort();
let expected = "0.41040709263815384\t\"b\"\n0.48600669271341534\t\"e\"\n0.48754517466109415\t\"a\"\n0.48855379387549824\t\"d\"\n0.6600456536439784\t\"c\"".to_string();
assert_eq!(expected, actual.join("\n"));
}

#[test]
fn csv_query_avg_multi_batch() {
let mut ctx = ExecutionContext::new();
Expand All @@ -187,7 +198,6 @@ fn csv_query_avg_multi_batch() {
fn csv_query_count() {
let mut ctx = ExecutionContext::new();
register_aggregate_csv(&mut ctx);
//TODO add ORDER BY once supported, to make this test determistic
let sql = "SELECT count(c12) FROM aggregate_test_100";
let actual = execute(&mut ctx, sql).join("\n");
let expected = "100".to_string();
Expand All @@ -198,23 +208,23 @@ fn csv_query_count() {
fn csv_query_group_by_int_count() {
let mut ctx = ExecutionContext::new();
register_aggregate_csv(&mut ctx);
//TODO add ORDER BY once supported, to make this test determistic
let sql = "SELECT count(c12) FROM aggregate_test_100 GROUP BY c1";
let actual = execute(&mut ctx, sql).join("\n");
let expected = "\"a\"\t21\n\"e\"\t21\n\"d\"\t18\n\"c\"\t21\n\"b\"\t19".to_string();
assert_eq!(expected, actual);
let sql = "SELECT c1, count(c12) FROM aggregate_test_100 GROUP BY c1";
let mut actual = execute(&mut ctx, sql);
actual.sort();
let expected = "\"a\"\t21\n\"b\"\t19\n\"c\"\t21\n\"d\"\t18\n\"e\"\t21".to_string();
assert_eq!(expected, actual.join("\n"));
}

#[test]
fn csv_query_group_by_string_min_max() {
let mut ctx = ExecutionContext::new();
register_aggregate_csv(&mut ctx);
//TODO add ORDER BY once supported, to make this test determistic
let sql = "SELECT c2, MIN(c12), MAX(c12) FROM aggregate_test_100 GROUP BY c1";
let actual = execute(&mut ctx, sql).join("\n");
let mut actual = execute(&mut ctx, sql);
actual.sort();
let expected =
"\"a\"\t0.02182578039211991\t0.9800193410444061\n\"e\"\t0.01479305307777301\t0.9965400387585364\n\"d\"\t0.061029375346466685\t0.9748360509016578\n\"c\"\t0.0494924465469434\t0.991517828651004\n\"b\"\t0.04893135681998029\t0.9185813970744787".to_string();
assert_eq!(expected, actual);
"\"a\"\t0.02182578039211991\t0.9800193410444061\n\"b\"\t0.04893135681998029\t0.9185813970744787\n\"c\"\t0.0494924465469434\t0.991517828651004\n\"d\"\t0.061029375346466685\t0.9748360509016578\n\"e\"\t0.01479305307777301\t0.9965400387585364".to_string();
assert_eq!(expected, actual.join("\n"));
}

#[test]
Expand Down

0 comments on commit 2e53c00

Please sign in to comment.