Skip to content

Commit

Permalink
add ansi regr_count functions
Browse files Browse the repository at this point in the history
  • Loading branch information
yaooqinn committed May 7, 2021
1 parent 80546cb commit 4c96bc6
Show file tree
Hide file tree
Showing 11 changed files with 310 additions and 46 deletions.
2 changes: 1 addition & 1 deletion README.rst
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@ Config your spark applications with `spark.sql.extensions`, e.g. `spark.sql.exte
Databricks Installation
--------------

Create an [init script](https://docs.databricks.com/clusters/init-scripts.html) in DBFS:
Create an `init script <https://docs.databricks.com/clusters/init-scripts.html>`_ in DBFS:

dbutils.fs.mkdirs("dbfs:/databricks/scripts/")

Expand Down
31 changes: 31 additions & 0 deletions docs/functions/postgres.md
Original file line number Diff line number Diff line change
Expand Up @@ -189,6 +189,37 @@ org.apache.spark.sql.catalyst.expressions.postgresql.JustifyDays

- **Since**
0.1.0
## regr_count
- **Usage**
```scala

regr_count(expr1, expr2) - Returns the count of all rows in an expression pair. The function eliminates expression pairs where either expression in the pair is NULL.If no rows remain, the function returns 0.

```
- **Arguments**
```scala

expr1 The dependent DOUBLE PRECISION expression
expr2 The independent DOUBLE PRECISION expression

```
- **Examples**
```sql

> SELECT regr_count(1, 2);
1
> SELECT regr_count(1, null);
0

```
- **Class**
```scala
org.apache.spark.sql.catalyst.expressions.ansi.RegrCount
```
- **Note**

- **Since**
0.2.0
## scale
- **Usage**
```scala
Expand Down
31 changes: 31 additions & 0 deletions docs/functions/presto.md
Original file line number Diff line number Diff line change
Expand Up @@ -216,6 +216,37 @@ NaN

- **Since**

## regr_count
- **Usage**
```scala

regr_count(expr1, expr2) - Returns the count of all rows in an expression pair. The function eliminates expression pairs where either expression in the pair is NULL.If no rows remain, the function returns 0.

```
- **Arguments**
```scala

expr1 The dependent DOUBLE PRECISION expression
expr2 The independent DOUBLE PRECISION expression

```
- **Examples**
```sql

> SELECT regr_count(1, 2);
1
> SELECT regr_count(1, null);
0

```
- **Class**
```scala
org.apache.spark.sql.catalyst.expressions.ansi.RegrCount
```
- **Note**

- **Since**
0.2.0
## to_base
- **Usage**
```scala
Expand Down
48 changes: 48 additions & 0 deletions src/main/scala/org/apache/itachi/package.scala
Original file line number Diff line number Diff line change
@@ -0,0 +1,48 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package org.apache

import org.apache.spark.sql.SparkSession
import org.apache.spark.sql.catalyst.expressions.ansi.RegrCount
import org.apache.spark.sql.catalyst.expressions.postgresql.{Age, ArrayAppend, ArrayLength, IntervalJustifyLike, Scale, SplitPart, StringToArray, UnNest}
import org.apache.spark.sql.extra.{FunctionAliases, FunctionDescription}

package object itachi {

private def registerFunction(function: FunctionDescription): Unit = {
SparkSession.active.sessionState
.functionRegistry
.registerFunction(function._1, function._2, function._3)
}

def registerPostgresFunctions: Unit = {
registerFunction(Age.fd)
registerFunction(ArrayAppend.fd)
registerFunction(ArrayLength.fd)
registerFunction(IntervalJustifyLike.justifyDays)
registerFunction(IntervalJustifyLike.justifyHours)
registerFunction(IntervalJustifyLike.justifyInterval)

registerFunction(Scale.fd)
registerFunction(SplitPart.fd)
registerFunction(StringToArray.fd)
registerFunction(UnNest.fd)

registerFunction(RegrCount.fd)
}
}
57 changes: 57 additions & 0 deletions src/main/scala/org/apache/itachi/postgres/functions.scala
Original file line number Diff line number Diff line change
@@ -0,0 +1,57 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package org.apache.itachi.postgres

import org.apache.spark.sql.Column
import org.apache.spark.sql.catalyst.expressions.aggregate.AggregateFunction
import org.apache.spark.sql.catalyst.expressions.ansi.RegrCount

// scalastyle:off
object functions {
// scalastyle:on

private def withAggregateFunction(
func: AggregateFunction,
isDistinct: Boolean = false): Column = {
new Column(func.toAggregateExpression(isDistinct))
}

/**
* Returns the count of all rows in an expression pair.
* The function eliminates expression pairs where either expression in the pair is NULL.
* If no rows remain, the function returns 0.
*
* @group agg_funcs
* @since 0.2.0
*/
def regr_count(y: Column, x: Column): Column = withAggregateFunction {
RegrCount(y.expr, x.expr)
}

/**
* Returns the count of all rows in an expression pair.
* The function eliminates expression pairs where either expression in the pair is NULL.
* If no rows remain, the function returns 0.
*
* @group agg_funcs
* @since 0.2.0
*/
def regr_count(y: String, x: String): Column = {
regr_count(new Column(y), new Column(x))
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,84 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package org.apache.spark.sql.catalyst.expressions.ansi

import org.apache.spark.sql.catalyst.FunctionIdentifier
import org.apache.spark.sql.catalyst.dsl.expressions._
import org.apache.spark.sql.catalyst.expressions.{AttributeReference, Expression, ExpressionDescription, If, ImplicitCastInputTypes, IsNull, Literal, Or}
import org.apache.spark.sql.catalyst.expressions.aggregate.DeclarativeAggregate
import org.apache.spark.sql.extra.{ExpressionUtils, FunctionDescription}
import org.apache.spark.sql.types.{AbstractDataType, DataType, DoubleType, LongType}

// scalastyle:off line.size.limit
@ExpressionDescription(
usage = """
_FUNC_(expr1, expr2) - Returns the count of all rows in an expression pair. The function eliminates expression pairs where either expression in the pair is NULL.If no rows remain, the function returns 0.
""",
arguments =
"""
expr1 The dependent DOUBLE PRECISION expression
expr2 The independent DOUBLE PRECISION expression
""",
examples = """
> SELECT _FUNC_(1, 2);
1
> SELECT _FUNC_(1, null);
0
""",
since = "0.2.0",
group = "agg_funcs")
// scalastyle:on line.size.limit
case class RegrCount(y: Expression, x: Expression)
extends DeclarativeAggregate with ImplicitCastInputTypes {
override def prettyName: String = "regr_count"
private lazy val regrCount = AttributeReference(prettyName, LongType, nullable = false)()

override lazy val initialValues = Seq(Literal(0L))
override lazy val updateExpressions: Seq[Expression] = {
val nullableChildren = children.filter(_.nullable)
if (nullableChildren.isEmpty) {
Seq(regrCount + 1L)
} else {
Seq(If(nullableChildren.map(IsNull).reduce(Or), regrCount, regrCount + 1L))
}
}

override lazy val mergeExpressions = Seq(regrCount.left + regrCount.right)

override lazy val evaluateExpression = regrCount

override lazy val aggBufferAttributes: Seq[AttributeReference] = regrCount :: Nil

override def inputTypes: Seq[AbstractDataType] = Seq(DoubleType, DoubleType)

override def nullable: Boolean = false

override def dataType: DataType = LongType

override def children: Seq[Expression] = Seq(y, x)
}


object RegrCount {

val fd: FunctionDescription = (
new FunctionIdentifier("regr_count"),
ExpressionUtils.getExpressionInfo(classOf[RegrCount], "regr_count"),
(children: Seq[Expression]) => RegrCount(children.head, children.last))

}
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
package org.apache.spark.sql.extra

import org.apache.spark.sql.SparkSessionExtensions
import org.apache.spark.sql.catalyst.expressions.ansi.RegrCount
import org.apache.spark.sql.catalyst.expressions.postgresql.{Age, ArrayAppend, ArrayLength, IntervalJustifyLike, Scale, SplitPart, StringToArray, UnNest}

class PostgreSQLExtensions extends Extensions {
Expand All @@ -29,6 +30,7 @@ class PostgreSQLExtensions extends Extensions {
ext.injectFunction(IntervalJustifyLike.justifyDays)
ext.injectFunction(IntervalJustifyLike.justifyHours)
ext.injectFunction(IntervalJustifyLike.justifyInterval)
ext.injectFunction(RegrCount.fd)
ext.injectFunction(Scale.fd)
ext.injectFunction(SplitPart.fd)
ext.injectFunction(StringToArray.fd)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
package org.apache.spark.sql.extra

import org.apache.spark.sql.SparkSessionExtensions
import org.apache.spark.sql.catalyst.expressions.ansi.RegrCount
import org.apache.spark.sql.catalyst.expressions.teradata._

class TeradataExtensions extends Extensions {
Expand All @@ -32,6 +33,7 @@ class TeradataExtensions extends Extensions {
extensions.injectFunction(IsFinite.fd)
extensions.injectFunction(IsInfinite.fd)
extensions.injectFunction(NaN.fd)
extensions.injectFunction(RegrCount.fd)
extensions.injectFunction(TryExpression.fd)
}
}
Expand Down
24 changes: 0 additions & 24 deletions src/main/scala/yaooqinn/itachi/package.scala

This file was deleted.

54 changes: 54 additions & 0 deletions src/test/scala/org/apache/itachi/ItachiTest.scala
Original file line number Diff line number Diff line change
@@ -0,0 +1,54 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package org.apache.itachi

import org.apache.spark.sql.{DataFrame, Row}
import org.apache.spark.sql.extra.SparkSessionHelper
import org.apache.spark.unsafe.types.CalendarInterval

class ItachiTest extends SparkSessionHelper {

override def beforeAll(): Unit = {
super.beforeAll()
org.apache.itachi.registerPostgresFunctions
}

def checkAnswer(df: DataFrame, expect: Seq[Row]): Unit = {
assert(df.collect() === expect)
}

test("age") {
checkAnswer(
spark.sql("select age(timestamp '2001-04-10', timestamp '1957-06-13')"),
Seq(Row(new CalendarInterval(525, 28, 0)))
)
}


test("regr_count") {
val query = "select k, count(*), regr_count(v, v2)" +
" from values(1, 10, null), (2, 10, 11), (2, 20, 22), (2, 25,null), (2, 30, 35) t(k, v, v2)" +
" group by k"
checkAnswer(sql(query), Seq(Row(1, 1, 0), Row(2, 4, 3)))

checkAnswer(sql("SELECT REGR_COUNT(1, 2)"), Seq(Row(1)))
checkAnswer(sql("SELECT REGR_COUNT(1, null)"), Seq(Row(0)))

}

}
Loading

0 comments on commit 4c96bc6

Please sign in to comment.