Skip to content

Commit

Permalink
[FLINK-6242] [table] Add code generation for DataSet Aggregates
Browse files Browse the repository at this point in the history
This closes apache#3735.
  • Loading branch information
shaoxuan-wang authored and fhueske committed Apr 21, 2017
1 parent 4024aff commit 3b4542b
Show file tree
Hide file tree
Showing 19 changed files with 878 additions and 1,106 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -250,57 +250,88 @@ class CodeGenerator(
* @param aggregates All aggregate functions
* @param aggFields Indexes of the input fields for all aggregate functions
* @param aggMapping The mapping of aggregates to output fields
* @param partialResults A flag defining whether final or partial results (accumulators) are set
* to the output row.
* @param fwdMapping The mapping of input fields to output fields
* @param mergeMapping An optional mapping to specify the accumulators to merge. If not set, we
* assume that both rows have the accumulators at the same position.
* @param constantFlags An optional parameter to define where to set constant boolean flags in
* the output row.
* @param outputArity The number of fields in the output row.
*
* @return A GeneratedAggregationsFunction
*/
def generateAggregations(
name: String,
generator: CodeGenerator,
inputType: RelDataType,
aggregates: Array[AggregateFunction[_ <: Any]],
aggFields: Array[Array[Int]],
aggMapping: Array[Int],
fwdMapping: Array[(Int, Int)],
outputArity: Int)
name: String,
generator: CodeGenerator,
inputType: RelDataType,
aggregates: Array[AggregateFunction[_ <: Any]],
aggFields: Array[Array[Int]],
aggMapping: Array[Int],
partialResults: Boolean,
fwdMapping: Array[Int],
mergeMapping: Option[Array[Int]],
constantFlags: Option[Array[(Int, Boolean)]],
outputArity: Int)
: GeneratedAggregationsFunction = {

def genSetAggregationResults(
accTypes: Array[String],
aggs: Array[String],
aggMapping: Array[Int]): String = {
// get unique function name
val funcName = newName(name)
// register UDAGGs
val aggs = aggregates.map(a => generator.addReusableFunction(a))
// get java types of accumulators
val accTypes = aggregates.map { a =>
a.getClass.getMethod("createAccumulator").getReturnType.getCanonicalName
}

// get java types of input fields
val javaTypes = inputType.getFieldList
.map(f => FlinkTypeFactory.toTypeInfo(f.getType))
.map(t => t.getTypeClass.getCanonicalName)
// get parameter lists for aggregation functions
val parameters = aggFields.map {inFields =>
val fields = for (f <- inFields) yield s"(${javaTypes(f)}) input.getField($f)"
fields.mkString(", ")
}

def genSetAggregationResults: String = {

val sig: String =
j"""
| public void setAggregationResults(
| org.apache.flink.types.Row accs,
| org.apache.flink.types.Row output)""".stripMargin
| public final void setAggregationResults(
| org.apache.flink.types.Row accs,
| org.apache.flink.types.Row output)""".stripMargin

val setAggs: String = {
for (i <- aggs.indices) yield
j"""
| org.apache.flink.table.functions.AggregateFunction baseClass$i =
| (org.apache.flink.table.functions.AggregateFunction) ${aggs(i)};
|
| output.setField(
| ${aggMapping(i)},
| baseClass$i.getValue((${accTypes(i)}) accs.getField($i)));""".stripMargin

if (partialResults) {
j"""
| output.setField(
| ${aggMapping(i)},
| (${accTypes(i)}) accs.getField($i));""".stripMargin
} else {
j"""
| org.apache.flink.table.functions.AggregateFunction baseClass$i =
| (org.apache.flink.table.functions.AggregateFunction) ${aggs(i)};
|
| output.setField(
| ${aggMapping(i)},
| baseClass$i.getValue((${accTypes(i)}) accs.getField($i)));""".stripMargin
}
}.mkString("\n")

j"""$sig {
j"""
|$sig {
|$setAggs
| }""".stripMargin
}

def genAccumulate(
accTypes: Array[String],
aggs: Array[String],
parameters: Array[String]): String = {
def genAccumulate: String = {

val sig: String =
j"""
| public void accumulate(
| public final void accumulate(
| org.apache.flink.types.Row accs,
| org.apache.flink.types.Row input)""".stripMargin

Expand All @@ -317,14 +348,11 @@ class CodeGenerator(
| }""".stripMargin
}

def genRetract(
accTypes: Array[String],
aggs: Array[String],
parameters: Array[String]): String = {
def genRetract: String = {

val sig: String =
j"""
| public void retract(
| public final void retract(
| org.apache.flink.types.Row accs,
| org.apache.flink.types.Row input)""".stripMargin

Expand All @@ -341,12 +369,11 @@ class CodeGenerator(
| }""".stripMargin
}

def genCreateAccumulators(
aggs: Array[String]): String = {
def genCreateAccumulators: String = {

val sig: String =
j"""
| public org.apache.flink.types.Row createAccumulators()
| public final org.apache.flink.types.Row createAccumulators()
| """.stripMargin
val init: String =
j"""
Expand All @@ -373,51 +400,77 @@ class CodeGenerator(
| }""".stripMargin
}

def genSetForwardedFields(
forwardMapping: Array[(Int, Int)]): String = {
def genSetForwardedFields: String = {

val sig: String =
j"""
| public void setForwardedFields(
| public final void setForwardedFields(
| org.apache.flink.types.Row input,
| org.apache.flink.types.Row output)
| """.stripMargin

val forward: String = {
for (i <- forwardMapping.indices) yield
j"""
| output.setField(
| ${forwardMapping(i)._1},
| input.getField(${forwardMapping(i)._2}));"""
.stripMargin
for (i <- fwdMapping.indices if fwdMapping(i) >= 0) yield
{
j"""
| output.setField(
| $i,
| input.getField(${fwdMapping(i)}));"""
.stripMargin
}
}.mkString("\n")

j"""$sig {
|$forward
| }""".stripMargin
}

def genCreateOutputRow(outputArity: Int): String = {
def genSetConstantFlags: String = {

val sig: String =
j"""
| public final void setConstantFlags(org.apache.flink.types.Row output)
| """.stripMargin

val setFlags: String = if (constantFlags.isDefined) {
{
for (cf <- constantFlags.get) yield {
j"""
| output.setField(${cf._1}, ${if (cf._2) "true" else "false"});"""
.stripMargin
}
}.mkString("\n")
} else {
""
}

j"""$sig {
|$setFlags
| }""".stripMargin
}

def genCreateOutputRow: String = {
j"""
| public org.apache.flink.types.Row createOutputRow() {
| public final org.apache.flink.types.Row createOutputRow() {
| return new org.apache.flink.types.Row($outputArity);
| }""".stripMargin
}

def genMergeAccumulatorsPair(
accTypes: Array[String],
aggs: Array[String]): String = {
def genMergeAccumulatorsPair: String = {

val mapping = mergeMapping.getOrElse(aggs.indices.toArray)

val sig: String =
j"""
| public org.apache.flink.types.Row mergeAccumulatorsPair(
| public final org.apache.flink.types.Row mergeAccumulatorsPair(
| org.apache.flink.types.Row a,
| org.apache.flink.types.Row b)
""".stripMargin
val merge: String = {
for (i <- aggs.indices) yield
j"""
| ${accTypes(i)} aAcc$i = (${accTypes(i)}) a.getField($i);
| ${accTypes(i)} bAcc$i = (${accTypes(i)}) b.getField($i);
| ${accTypes(i)} bAcc$i = (${accTypes(i)}) b.getField(${mapping(i)});
| accList$i.set(0, aAcc$i);
| accList$i.set(1, bAcc$i);
| a.setField(
Expand All @@ -430,75 +483,76 @@ class CodeGenerator(
| return a;
""".stripMargin

j"""$sig {
j"""
|$sig {
|$merge
|$ret
| }""".stripMargin
}

def genMergeList(accTypes: Array[String]): String = {
def genMergeList: String = {
{
for (i <- accTypes.indices) yield
j"""
| java.util.ArrayList<${accTypes(i)}> accList$i;
| private final java.util.ArrayList<${accTypes(i)}> accList$i =
| new java.util.ArrayList<${accTypes(i)}>(2);
""".stripMargin
}.mkString("\n")
}

def initMergeList(
accTypes: Array[String],
aggs: Array[String]): String = {
def initMergeList: String = {
{
for (i <- accTypes.indices) yield
j"""
| accList$i = new java.util.ArrayList<${accTypes(i)}>(2);
| accList$i.add(${aggs(i)}.createAccumulator());
| accList$i.add(${aggs(i)}.createAccumulator());
""".stripMargin
}.mkString("\n")
}

// get unique function name
val funcName = newName(name)
// register UDAGGs
val aggs = aggregates.map(a => generator.addReusableFunction(a))
// get java types of accumulators
val accTypes = aggregates.map { a =>
a.getClass.getMethod("createAccumulator").getReturnType.getCanonicalName
}
def genResetAccumulator: String = {

// get java types of input fields
val javaTypes = inputType.getFieldList
.map(f => FlinkTypeFactory.toTypeInfo(f.getType))
.map(t => t.getTypeClass.getCanonicalName)
// get parameter lists for aggregation functions
val parameters = aggFields.map {inFields =>
val fields = for (f <- inFields) yield s"(${javaTypes(f)}) input.getField($f)"
fields.mkString(", ")
val sig: String =
j"""
| public final void resetAccumulator(
| org.apache.flink.types.Row accs)""".stripMargin

val reset: String = {
for (i <- aggs.indices) yield
j"""
| ${aggs(i)}.resetAccumulator(
| ((${accTypes(i)}) accs.getField($i)));""".stripMargin
}.mkString("\n")

j"""$sig {
|$reset
| }""".stripMargin
}

var funcCode =
j"""
|public class $funcName
|public final class $funcName
| extends org.apache.flink.table.runtime.aggregate.GeneratedAggregations {
|
| ${reuseMemberCode()}
| ${genMergeList(accTypes)}
| $genMergeList
| public $funcName() throws Exception {
| ${reuseInitCode()}
| ${initMergeList(accTypes, aggs)}
| $initMergeList
| }
| ${reuseConstructorCode(funcName)}
|
""".stripMargin

funcCode += genSetAggregationResults(accTypes, aggs, aggMapping) + "\n"
funcCode += genAccumulate(accTypes, aggs, parameters) + "\n"
funcCode += genRetract(accTypes, aggs, parameters) + "\n"
funcCode += genCreateAccumulators(aggs) + "\n"
funcCode += genSetForwardedFields(fwdMapping) + "\n"
funcCode += genCreateOutputRow(outputArity) + "\n"
funcCode += genMergeAccumulatorsPair(accTypes, aggs) + "\n"
funcCode += genSetAggregationResults + "\n"
funcCode += genAccumulate + "\n"
funcCode += genRetract + "\n"
funcCode += genCreateAccumulators + "\n"
funcCode += genSetForwardedFields + "\n"
funcCode += genSetConstantFlags + "\n"
funcCode += genCreateOutputRow + "\n"
funcCode += genMergeAccumulatorsPair + "\n"
funcCode += genResetAccumulator + "\n"
funcCode += "}"

GeneratedAggregationsFunction(funcName, funcCode)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@ import org.apache.flink.api.java.DataSet
import org.apache.flink.api.java.typeutils.RowTypeInfo
import org.apache.flink.table.api.BatchTableEnvironment
import org.apache.flink.table.calcite.FlinkTypeFactory
import org.apache.flink.table.codegen.CodeGenerator
import org.apache.flink.table.plan.nodes.CommonAggregate
import org.apache.flink.table.runtime.aggregate.{AggregateUtil, DataSetPreAggFunction}
import org.apache.flink.table.runtime.aggregate.AggregateUtil.CalcitePair
Expand Down Expand Up @@ -89,19 +90,25 @@ class DataSetAggregate(

override def translateToPlan(tableEnv: BatchTableEnvironment): DataSet[Row] = {

val inputDS = getInput.asInstanceOf[DataSetRel].translateToPlan(tableEnv)

val generator = new CodeGenerator(
tableEnv.getConfig,
false,
inputDS.getType)

val (
preAgg: Option[DataSetPreAggFunction],
preAggType: Option[TypeInformation[Row]],
finalAgg: GroupReduceFunction[Row, Row]
) = AggregateUtil.createDataSetAggregateFunctions(
generator,
namedAggregates,
inputType,
rowRelDataType,
grouping,
inGroupingSet)

val inputDS = getInput.asInstanceOf[DataSetRel].translateToPlan(tableEnv)

val aggString = aggregationToString(inputType, grouping, getRowType, namedAggregates, Nil)

val rowTypeInfo = FlinkTypeFactory.toInternalRowTypeInfo(getRowType).asInstanceOf[RowTypeInfo]
Expand Down
Loading

0 comments on commit 3b4542b

Please sign in to comment.