-
Notifications
You must be signed in to change notification settings - Fork 5
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
Showing
11 changed files
with
310 additions
and
46 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,48 @@ | ||
/* | ||
* Licensed to the Apache Software Foundation (ASF) under one or more | ||
* contributor license agreements. See the NOTICE file distributed with | ||
* this work for additional information regarding copyright ownership. | ||
* The ASF licenses this file to You under the Apache License, Version 2.0 | ||
* (the "License"); you may not use this file except in compliance with | ||
* the License. You may obtain a copy of the License at | ||
* | ||
* http://www.apache.org/licenses/LICENSE-2.0 | ||
* | ||
* Unless required by applicable law or agreed to in writing, software | ||
* distributed under the License is distributed on an "AS IS" BASIS, | ||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
* See the License for the specific language governing permissions and | ||
* limitations under the License. | ||
*/ | ||
|
||
package org.apache | ||
|
||
import org.apache.spark.sql.SparkSession | ||
import org.apache.spark.sql.catalyst.expressions.ansi.RegrCount | ||
import org.apache.spark.sql.catalyst.expressions.postgresql.{Age, ArrayAppend, ArrayLength, IntervalJustifyLike, Scale, SplitPart, StringToArray, UnNest} | ||
import org.apache.spark.sql.extra.{FunctionAliases, FunctionDescription} | ||
|
||
package object itachi { | ||
|
||
private def registerFunction(function: FunctionDescription): Unit = { | ||
SparkSession.active.sessionState | ||
.functionRegistry | ||
.registerFunction(function._1, function._2, function._3) | ||
} | ||
|
||
def registerPostgresFunctions: Unit = { | ||
registerFunction(Age.fd) | ||
registerFunction(ArrayAppend.fd) | ||
registerFunction(ArrayLength.fd) | ||
registerFunction(IntervalJustifyLike.justifyDays) | ||
registerFunction(IntervalJustifyLike.justifyHours) | ||
registerFunction(IntervalJustifyLike.justifyInterval) | ||
|
||
registerFunction(Scale.fd) | ||
registerFunction(SplitPart.fd) | ||
registerFunction(StringToArray.fd) | ||
registerFunction(UnNest.fd) | ||
|
||
registerFunction(RegrCount.fd) | ||
} | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,57 @@ | ||
/* | ||
* Licensed to the Apache Software Foundation (ASF) under one or more | ||
* contributor license agreements. See the NOTICE file distributed with | ||
* this work for additional information regarding copyright ownership. | ||
* The ASF licenses this file to You under the Apache License, Version 2.0 | ||
* (the "License"); you may not use this file except in compliance with | ||
* the License. You may obtain a copy of the License at | ||
* | ||
* http://www.apache.org/licenses/LICENSE-2.0 | ||
* | ||
* Unless required by applicable law or agreed to in writing, software | ||
* distributed under the License is distributed on an "AS IS" BASIS, | ||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
* See the License for the specific language governing permissions and | ||
* limitations under the License. | ||
*/ | ||
|
||
package org.apache.itachi.postgres | ||
|
||
import org.apache.spark.sql.Column | ||
import org.apache.spark.sql.catalyst.expressions.aggregate.AggregateFunction | ||
import org.apache.spark.sql.catalyst.expressions.ansi.RegrCount | ||
|
||
// scalastyle:off | ||
object functions { | ||
// scalastyle:on | ||
|
||
private def withAggregateFunction( | ||
func: AggregateFunction, | ||
isDistinct: Boolean = false): Column = { | ||
new Column(func.toAggregateExpression(isDistinct)) | ||
} | ||
|
||
/** | ||
* Returns the count of all rows in an expression pair. | ||
* The function eliminates expression pairs where either expression in the pair is NULL. | ||
* If no rows remain, the function returns 0. | ||
* | ||
* @group agg_funcs | ||
* @since 0.2.0 | ||
*/ | ||
def regr_count(y: Column, x: Column): Column = withAggregateFunction { | ||
RegrCount(y.expr, x.expr) | ||
} | ||
|
||
/** | ||
* Returns the count of all rows in an expression pair. | ||
* The function eliminates expression pairs where either expression in the pair is NULL. | ||
* If no rows remain, the function returns 0. | ||
* | ||
* @group agg_funcs | ||
* @since 0.2.0 | ||
*/ | ||
def regr_count(y: String, x: String): Column = { | ||
regr_count(new Column(y), new Column(x)) | ||
} | ||
} |
84 changes: 84 additions & 0 deletions
84
src/main/scala/org/apache/spark/sql/catalyst/expressions/ansi/RegrCount.scala
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,84 @@ | ||
/* | ||
* Licensed to the Apache Software Foundation (ASF) under one or more | ||
* contributor license agreements. See the NOTICE file distributed with | ||
* this work for additional information regarding copyright ownership. | ||
* The ASF licenses this file to You under the Apache License, Version 2.0 | ||
* (the "License"); you may not use this file except in compliance with | ||
* the License. You may obtain a copy of the License at | ||
* | ||
* http://www.apache.org/licenses/LICENSE-2.0 | ||
* | ||
* Unless required by applicable law or agreed to in writing, software | ||
* distributed under the License is distributed on an "AS IS" BASIS, | ||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
* See the License for the specific language governing permissions and | ||
* limitations under the License. | ||
*/ | ||
|
||
package org.apache.spark.sql.catalyst.expressions.ansi | ||
|
||
import org.apache.spark.sql.catalyst.FunctionIdentifier | ||
import org.apache.spark.sql.catalyst.dsl.expressions._ | ||
import org.apache.spark.sql.catalyst.expressions.{AttributeReference, Expression, ExpressionDescription, If, ImplicitCastInputTypes, IsNull, Literal, Or} | ||
import org.apache.spark.sql.catalyst.expressions.aggregate.DeclarativeAggregate | ||
import org.apache.spark.sql.extra.{ExpressionUtils, FunctionDescription} | ||
import org.apache.spark.sql.types.{AbstractDataType, DataType, DoubleType, LongType} | ||
|
||
// scalastyle:off line.size.limit | ||
@ExpressionDescription( | ||
usage = """ | ||
_FUNC_(expr1, expr2) - Returns the count of all rows in an expression pair. The function eliminates expression pairs where either expression in the pair is NULL.If no rows remain, the function returns 0. | ||
""", | ||
arguments = | ||
""" | ||
expr1 The dependent DOUBLE PRECISION expression | ||
expr2 The independent DOUBLE PRECISION expression | ||
""", | ||
examples = """ | ||
> SELECT _FUNC_(1, 2); | ||
1 | ||
> SELECT _FUNC_(1, null); | ||
0 | ||
""", | ||
since = "0.2.0", | ||
group = "agg_funcs") | ||
// scalastyle:on line.size.limit | ||
case class RegrCount(y: Expression, x: Expression) | ||
extends DeclarativeAggregate with ImplicitCastInputTypes { | ||
override def prettyName: String = "regr_count" | ||
private lazy val regrCount = AttributeReference(prettyName, LongType, nullable = false)() | ||
|
||
override lazy val initialValues = Seq(Literal(0L)) | ||
override lazy val updateExpressions: Seq[Expression] = { | ||
val nullableChildren = children.filter(_.nullable) | ||
if (nullableChildren.isEmpty) { | ||
Seq(regrCount + 1L) | ||
} else { | ||
Seq(If(nullableChildren.map(IsNull).reduce(Or), regrCount, regrCount + 1L)) | ||
} | ||
} | ||
|
||
override lazy val mergeExpressions = Seq(regrCount.left + regrCount.right) | ||
|
||
override lazy val evaluateExpression = regrCount | ||
|
||
override lazy val aggBufferAttributes: Seq[AttributeReference] = regrCount :: Nil | ||
|
||
override def inputTypes: Seq[AbstractDataType] = Seq(DoubleType, DoubleType) | ||
|
||
override def nullable: Boolean = false | ||
|
||
override def dataType: DataType = LongType | ||
|
||
override def children: Seq[Expression] = Seq(y, x) | ||
} | ||
|
||
|
||
object RegrCount { | ||
|
||
val fd: FunctionDescription = ( | ||
new FunctionIdentifier("regr_count"), | ||
ExpressionUtils.getExpressionInfo(classOf[RegrCount], "regr_count"), | ||
(children: Seq[Expression]) => RegrCount(children.head, children.last)) | ||
|
||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file was deleted.
Oops, something went wrong.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,54 @@ | ||
/* | ||
* Licensed to the Apache Software Foundation (ASF) under one or more | ||
* contributor license agreements. See the NOTICE file distributed with | ||
* this work for additional information regarding copyright ownership. | ||
* The ASF licenses this file to You under the Apache License, Version 2.0 | ||
* (the "License"); you may not use this file except in compliance with | ||
* the License. You may obtain a copy of the License at | ||
* | ||
* http://www.apache.org/licenses/LICENSE-2.0 | ||
* | ||
* Unless required by applicable law or agreed to in writing, software | ||
* distributed under the License is distributed on an "AS IS" BASIS, | ||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
* See the License for the specific language governing permissions and | ||
* limitations under the License. | ||
*/ | ||
|
||
package org.apache.itachi | ||
|
||
import org.apache.spark.sql.{DataFrame, Row} | ||
import org.apache.spark.sql.extra.SparkSessionHelper | ||
import org.apache.spark.unsafe.types.CalendarInterval | ||
|
||
class ItachiTest extends SparkSessionHelper { | ||
|
||
override def beforeAll(): Unit = { | ||
super.beforeAll() | ||
org.apache.itachi.registerPostgresFunctions | ||
} | ||
|
||
def checkAnswer(df: DataFrame, expect: Seq[Row]): Unit = { | ||
assert(df.collect() === expect) | ||
} | ||
|
||
test("age") { | ||
checkAnswer( | ||
spark.sql("select age(timestamp '2001-04-10', timestamp '1957-06-13')"), | ||
Seq(Row(new CalendarInterval(525, 28, 0))) | ||
) | ||
} | ||
|
||
|
||
test("regr_count") { | ||
val query = "select k, count(*), regr_count(v, v2)" + | ||
" from values(1, 10, null), (2, 10, 11), (2, 20, 22), (2, 25,null), (2, 30, 35) t(k, v, v2)" + | ||
" group by k" | ||
checkAnswer(sql(query), Seq(Row(1, 1, 0), Row(2, 4, 3))) | ||
|
||
checkAnswer(sql("SELECT REGR_COUNT(1, 2)"), Seq(Row(1))) | ||
checkAnswer(sql("SELECT REGR_COUNT(1, null)"), Seq(Row(0))) | ||
|
||
} | ||
|
||
} |
Oops, something went wrong.