Skip to content

Commit

Permalink
Add cast expression with bool, integers and decimal128 support (apach…
Browse files Browse the repository at this point in the history
…e#5137)

Address comment on unwrap
  • Loading branch information
nseekhao authored Mar 21, 2023
1 parent 9d60e14 commit 63acd57
Show file tree
Hide file tree
Showing 3 changed files with 119 additions and 2 deletions.
42 changes: 41 additions & 1 deletion datafusion/substrait/src/logical_plan/consumer.rs
Original file line number Diff line number Diff line change
Expand Up @@ -16,12 +16,13 @@
// under the License.

use async_recursion::async_recursion;
use datafusion::arrow::datatypes::DataType;
use datafusion::common::{DFField, DFSchema, DFSchemaRef};
use datafusion::logical_expr::expr;
use datafusion::logical_expr::{
aggregate_function, BinaryExpr, Case, Expr, LogicalPlan, Operator,
};
use datafusion::logical_expr::{build_join_schema, LogicalPlanBuilder};
use datafusion::logical_expr::{expr, Cast};
use datafusion::prelude::JoinType;
use datafusion::sql::TableReference;
use datafusion::{
Expand Down Expand Up @@ -721,12 +722,51 @@ pub async fn from_substrait_rex(
))),
}
}
Some(RexType::Cast(cast)) => match cast.as_ref().r#type.as_ref() {
Some(output_type) => Ok(Arc::new(Expr::Cast(Cast::new(
Box::new(
from_substrait_rex(
cast.as_ref().input.as_ref().unwrap().as_ref(),
input_schema,
extensions,
)
.await?
.as_ref()
.clone(),
),
from_substrait_type(output_type)?,
)))),
None => Err(DataFusionError::Substrait(
"Cast experssion without output type is not allowed".to_string(),
)),
},
_ => Err(DataFusionError::NotImplemented(
"unsupported rex_type".to_string(),
)),
}
}

fn from_substrait_type(dt: &substrait::proto::Type) -> Result<DataType> {
match &dt.kind {
Some(s_kind) => match s_kind {
r#type::Kind::Bool(_) => Ok(DataType::Boolean),
r#type::Kind::I8(_) => Ok(DataType::Int8),
r#type::Kind::I16(_) => Ok(DataType::Int16),
r#type::Kind::I32(_) => Ok(DataType::Int32),
r#type::Kind::I64(_) => Ok(DataType::Int64),
r#type::Kind::Decimal(d) => {
Ok(DataType::Decimal128(d.precision as u8, d.scale as i8))
}
_ => Err(DataFusionError::NotImplemented(format!(
"Unsupported Substrait type: {s_kind:?}"
))),
},
_ => Err(DataFusionError::NotImplemented(
"`None` Substrait kind is not supported".to_string(),
)),
}
}

fn from_substrait_null(null_type: &Type) -> Result<ScalarValue> {
if let Some(kind) = &null_type.kind {
match kind {
Expand Down
69 changes: 68 additions & 1 deletion datafusion/substrait/src/logical_plan/producer.rs
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
use std::{collections::HashMap, mem, sync::Arc};

use datafusion::{
arrow::datatypes::DataType,
error::{DataFusionError, Result},
prelude::JoinType,
scalar::ScalarValue,
Expand All @@ -26,7 +27,7 @@ use datafusion::{
use datafusion::common::DFSchemaRef;
#[allow(unused_imports)]
use datafusion::logical_expr::aggregate_function;
use datafusion::logical_expr::expr::{BinaryExpr, Case, Sort};
use datafusion::logical_expr::expr::{BinaryExpr, Case, Cast, Sort};
use datafusion::logical_expr::{expr, Between, JoinConstraint, LogicalPlan, Operator};
use datafusion::prelude::{binary_expr, Expr};
use substrait::proto::{
Expand Down Expand Up @@ -577,6 +578,21 @@ pub fn to_substrait_rex(
rex_type: Some(RexType::IfThen(Box::new(IfThen { ifs, r#else }))),
})
}
Expr::Cast(Cast { expr, data_type }) => {
Ok(Expression {
rex_type: Some(RexType::Cast(Box::new(
substrait::proto::expression::Cast {
r#type: Some(to_substrait_type(data_type)?),
input: Some(Box::new(to_substrait_rex(
expr,
schema,
extension_info,
)?)),
failure_behavior: 0, // FAILURE_BEHAVIOR_UNSPECIFIED
},
))),
})
}
Expr::Literal(value) => {
let literal_type = match value {
ScalarValue::Int8(Some(n)) => Some(LiteralType::I8(*n as i32)),
Expand Down Expand Up @@ -626,6 +642,57 @@ pub fn to_substrait_rex(
}
}

fn to_substrait_type(dt: &DataType) -> Result<substrait::proto::Type> {
let default_type_ref = 0;
let default_nullability = r#type::Nullability::Required as i32;
match dt {
DataType::Null => Err(DataFusionError::Internal(
"Null cast is not valid".to_string(),
)),
DataType::Boolean => Ok(substrait::proto::Type {
kind: Some(r#type::Kind::Bool(r#type::Boolean {
type_variation_reference: default_type_ref,
nullability: default_nullability,
})),
}),
DataType::Int8 => Ok(substrait::proto::Type {
kind: Some(r#type::Kind::I8(r#type::I8 {
type_variation_reference: default_type_ref,
nullability: default_nullability,
})),
}),
DataType::Int16 => Ok(substrait::proto::Type {
kind: Some(r#type::Kind::I16(r#type::I16 {
type_variation_reference: default_type_ref,
nullability: default_nullability,
})),
}),
DataType::Int32 => Ok(substrait::proto::Type {
kind: Some(r#type::Kind::I32(r#type::I32 {
type_variation_reference: default_type_ref,
nullability: default_nullability,
})),
}),
DataType::Int64 => Ok(substrait::proto::Type {
kind: Some(r#type::Kind::I64(r#type::I64 {
type_variation_reference: default_type_ref,
nullability: default_nullability,
})),
}),
DataType::Decimal128(p, s) => Ok(substrait::proto::Type {
kind: Some(r#type::Kind::Decimal(r#type::Decimal {
type_variation_reference: default_type_ref,
nullability: default_nullability,
scale: *s as i32,
precision: *p as i32,
})),
}),
_ => Err(DataFusionError::NotImplemented(format!(
"Unsupported cast type: {dt:?}"
))),
}
}

fn try_to_substrait_null(v: &ScalarValue) -> Result<LiteralType> {
let default_type_ref = 0;
let default_nullability = r#type::Nullability::Nullable as i32;
Expand Down
10 changes: 10 additions & 0 deletions datafusion/substrait/tests/roundtrip_logical_plan.rs
Original file line number Diff line number Diff line change
Expand Up @@ -179,6 +179,16 @@ mod tests {
.await
}

#[tokio::test]
async fn cast_decimal_to_int() -> Result<()> {
roundtrip("SELECT * FROM data WHERE a = CAST(2.5 AS int)").await
}

#[tokio::test]
async fn implicit_cast() -> Result<()> {
roundtrip("SELECT * FROM data WHERE a = b").await
}

#[tokio::test]
async fn aggregate_case() -> Result<()> {
assert_expected_plan(
Expand Down

0 comments on commit 63acd57

Please sign in to comment.