Skip to content

Commit

Permalink
[FLINK-11884][table] JoinTableOperation construction & tranformation …
Browse files Browse the repository at this point in the history
…to RelNodes

This closes apache#8062
  • Loading branch information
dawidwys committed Apr 24, 2019
1 parent 665e3f4 commit df59135
Show file tree
Hide file tree
Showing 9 changed files with 212 additions and 158 deletions.
Original file line number Diff line number Diff line change
@@ -0,0 +1,117 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package org.apache.flink.table.operations;

import org.apache.flink.annotation.Internal;
import org.apache.flink.api.common.typeinfo.TypeInformation;
import org.apache.flink.table.api.TableSchema;
import org.apache.flink.table.expressions.Expression;

import java.util.Arrays;
import java.util.List;

/**
* Table operation that joins two relational operations based on given condition.
*/
@Internal
public class JoinTableOperation implements TableOperation {
private final TableOperation left;
private final TableOperation right;
private final JoinType joinType;
private final Expression condition;
private final boolean correlated;
private final TableSchema tableSchema;

/**
* Specifies how the two Tables should be joined.
*/
public enum JoinType {
INNER,
LEFT_OUTER,
RIGHT_OUTER,
FULL_OUTER
}

public JoinTableOperation(
TableOperation left,
TableOperation right,
JoinType joinType,
Expression condition,
boolean correlated) {
this.left = left;
this.right = right;
this.joinType = joinType;
this.condition = condition;
this.correlated = correlated;

this.tableSchema = calculateResultingSchema(left, right);
}

private TableSchema calculateResultingSchema(TableOperation left, TableOperation right) {
TableSchema leftSchema = left.getTableSchema();
TableSchema rightSchema = right.getTableSchema();
int resultingSchemaSize = leftSchema.getFieldCount() + rightSchema.getFieldCount();
String[] newFieldNames = new String[resultingSchemaSize];
System.arraycopy(leftSchema.getFieldNames(), 0, newFieldNames, 0, leftSchema.getFieldCount());
System.arraycopy(
rightSchema.getFieldNames(),
0,
newFieldNames,
leftSchema.getFieldCount(),
rightSchema.getFieldCount());

TypeInformation[] newFieldTypes = new TypeInformation[resultingSchemaSize];

System.arraycopy(leftSchema.getFieldTypes(), 0, newFieldTypes, 0, leftSchema.getFieldCount());
System.arraycopy(
rightSchema.getFieldTypes(),
0,
newFieldTypes,
leftSchema.getFieldCount(),
rightSchema.getFieldCount());
return new TableSchema(newFieldNames, newFieldTypes);
}

public JoinType getJoinType() {
return joinType;
}

public Expression getCondition() {
return condition;
}

public boolean isCorrelated() {
return correlated;
}

@Override
public TableSchema getTableSchema() {
return tableSchema;
}

@Override
public List<TableOperation> getChildren() {
return Arrays.asList(left, right);
}

@Override
public <T> T accept(TableOperationVisitor<T> visitor) {
return visitor.visitJoin(this);
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,11 @@ public T visitWindowAggregate(WindowAggregateTableOperation windowAggregate) {
return defaultMethod(windowAggregate);
}

@Override
public T visitJoin(JoinTableOperation join) {
return defaultMethod(join);
}

@Override
public T visitSetOperation(SetTableOperation setOperation) {
return defaultMethod(setOperation);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,8 @@ public interface TableOperationVisitor<T> {

T visitWindowAggregate(WindowAggregateTableOperation windowAggregate);

T visitJoin(JoinTableOperation join);

T visitSetOperation(SetTableOperation setOperation);

T visitFilter(FilterTableOperation filter);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -30,16 +30,15 @@
import org.apache.flink.table.expressions.ExpressionUtils;
import org.apache.flink.table.expressions.FieldReferenceExpression;
import org.apache.flink.table.expressions.PlannerExpression;
import org.apache.flink.table.plan.logical.Join;
import org.apache.flink.table.operations.JoinTableOperation.JoinType;

import java.util.HashSet;
import java.util.Optional;
import java.util.Set;

import static java.util.Arrays.asList;

/**
* Utility class for creating a valid {@link Join} operation.
* Utility class for creating a valid {@link JoinTableOperation} operation.
*/
@Internal
public class JoinOperationFactory {
Expand All @@ -52,17 +51,7 @@ public JoinOperationFactory(ExpressionBridge<PlannerExpression> expressionBridge
}

/**
* Specifies how the two Tables should be joined.
*/
public enum JoinType {
INNER,
LEFT_OUTER,
RIGHT_OUTER,
FULL_OUTER
}

/**
* Creates a valid {@link Join} operation.
* Creates a valid {@link JoinTableOperation} operation.
*
* <p>It performs validations such as:
* <ul>
Expand All @@ -88,9 +77,7 @@ public TableOperation create(
verifyConditionType(condition);
validateNamesAmbiguity(left, right);
validateCondition(right, joinType, condition, correlated);

PlannerExpression plannerExpression = expressionBridge.bridge(condition);
return new Join(left, right, joinType, Optional.of(plannerExpression), correlated);
return new JoinTableOperation(left, right, joinType, condition, correlated);
}

private void validateCondition(TableOperation right, JoinType joinType, Expression condition, boolean correlated) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,9 @@
import org.apache.flink.table.expressions.Expression;
import org.apache.flink.table.expressions.ExpressionBridge;
import org.apache.flink.table.expressions.ExpressionDefaultVisitor;
import org.apache.flink.table.expressions.FieldReferenceExpression;
import org.apache.flink.table.expressions.PlannerExpression;
import org.apache.flink.table.expressions.RexPlannerExpression;
import org.apache.flink.table.expressions.WindowReference;
import org.apache.flink.table.functions.TableFunction;
import org.apache.flink.table.functions.utils.TableSqlFunction;
Expand All @@ -38,6 +40,8 @@
import org.apache.flink.table.operations.CatalogTableOperation;
import org.apache.flink.table.operations.DistinctTableOperation;
import org.apache.flink.table.operations.FilterTableOperation;
import org.apache.flink.table.operations.JoinTableOperation;
import org.apache.flink.table.operations.JoinTableOperation.JoinType;
import org.apache.flink.table.operations.PlannerTableOperation;
import org.apache.flink.table.operations.ProjectTableOperation;
import org.apache.flink.table.operations.SetTableOperation;
Expand All @@ -47,14 +51,15 @@
import org.apache.flink.table.operations.TableOperationVisitor;
import org.apache.flink.table.operations.WindowAggregateTableOperation;
import org.apache.flink.table.operations.WindowAggregateTableOperation.ResolvedGroupWindow;
import org.apache.flink.table.plan.logical.LogicalNode;
import org.apache.flink.table.plan.logical.LogicalWindow;
import org.apache.flink.table.plan.logical.SessionGroupWindow;
import org.apache.flink.table.plan.logical.SlidingGroupWindow;
import org.apache.flink.table.plan.logical.TumblingGroupWindow;
import org.apache.flink.table.plan.schema.FlinkTableFunctionImpl;

import org.apache.calcite.rel.RelNode;
import org.apache.calcite.rel.core.CorrelationId;
import org.apache.calcite.rel.core.JoinRelType;
import org.apache.calcite.rel.logical.LogicalTableFunctionScan;
import org.apache.calcite.rex.RexNode;
import org.apache.calcite.tools.RelBuilder;
Expand All @@ -63,6 +68,7 @@

import java.util.Collections;
import java.util.List;
import java.util.Set;
import java.util.stream.IntStream;

import scala.Some;
Expand Down Expand Up @@ -101,6 +107,7 @@ public TableOperationConverter get(RelBuilder relBuilder) {
private final SingleRelVisitor singleRelVisitor = new SingleRelVisitor();
private final ExpressionBridge<PlannerExpression> expressionBridge;
private final AggregateVisitor aggregateVisitor = new AggregateVisitor();
private final JoinExpressionVisitor joinExpressionVisitor = new JoinExpressionVisitor();

public TableOperationConverter(
RelBuilder relBuilder,
Expand Down Expand Up @@ -155,6 +162,22 @@ public RelNode visitWindowAggregate(WindowAggregateTableOperation windowAggregat
return flinkRelBuilder.aggregate(logicalWindow, groupKey, windowProperties, aggregations).build();
}

@Override
public RelNode visitJoin(JoinTableOperation join) {
final Set<CorrelationId> corSet;
if (join.isCorrelated()) {
corSet = Collections.singleton(relBuilder.peek().getCluster().createCorrel());
} else {
corSet = Collections.emptySet();
}

return relBuilder.join(
convertJoinType(join.getJoinType()),
join.getCondition().accept(joinExpressionVisitor),
corSet)
.build();
}

@Override
public RelNode visitSetOperation(SetTableOperation setOperation) {
switch (setOperation.getType()) {
Expand Down Expand Up @@ -228,9 +251,7 @@ public RelNode visitCatalogTable(CatalogTableOperation catalogTable) {

@Override
public RelNode visitOther(TableOperation other) {
if (other instanceof LogicalNode) {
return ((LogicalNode) other).toRelNode(relBuilder);
} else if (other instanceof PlannerTableOperation) {
if (other instanceof PlannerTableOperation) {
return ((PlannerTableOperation) other).getCalciteTree();
}

Expand Down Expand Up @@ -276,6 +297,47 @@ private LogicalWindow toLogicalWindow(ResolvedGroupWindow window) {
throw new TableException("Unknown window type");
}
}

private JoinRelType convertJoinType(JoinType joinType) {
switch (joinType) {
case INNER:
return JoinRelType.INNER;
case LEFT_OUTER:
return JoinRelType.LEFT;
case RIGHT_OUTER:
return JoinRelType.RIGHT;
case FULL_OUTER:
return JoinRelType.FULL;
default:
throw new TableException("Unknown join type: " + joinType);
}
}
}

private class JoinExpressionVisitor extends ExpressionDefaultVisitor<RexNode> {

private static final int numberOfJoinInputs = 2;

@Override
public RexNode visitCall(CallExpression call) {
List<Expression> newChildren = call.getChildren().stream().map(expr -> {
RexNode convertedNode = expr.accept(this);
return (Expression) new RexPlannerExpression(convertedNode);
}).collect(toList());

CallExpression newCall = new CallExpression(call.getFunctionDefinition(), newChildren);
return expressionBridge.bridge(newCall).toRexNode(relBuilder);
}

@Override
public RexNode visitFieldReference(FieldReferenceExpression fieldReference) {
return relBuilder.field(numberOfJoinInputs, fieldReference.getInputIndex(), fieldReference.getFieldIndex());
}

@Override
protected RexNode defaultMethod(Expression expression) {
return expressionBridge.bridge(expression).toRexNode(relBuilder);
}
}

private class AggregateVisitor extends ExpressionDefaultVisitor<AggCall> {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ import _root_.java.util.function.Supplier
import org.apache.calcite.rel.RelNode
import org.apache.flink.table.expressions.{Expression, ExpressionParser, LookupCallResolver}
import org.apache.flink.table.functions.{TemporalTableFunction, TemporalTableFunctionImpl}
import org.apache.flink.table.operations.JoinOperationFactory.JoinType
import org.apache.flink.table.operations.JoinTableOperation.JoinType
import org.apache.flink.table.operations.OperationExpressionsUtils.extractAggregationsAndProperties
import org.apache.flink.table.operations.{OperationTreeBuilder, TableOperation}
import org.apache.flink.table.util.JavaScalaConversionUtil.toJava
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@ import org.apache.calcite.tools.RelBuilder
import org.apache.flink.api.common.typeinfo.TypeInformation
import org.apache.flink.table.api._
import org.apache.flink.table.calcite.FlinkRelBuilder.NamedWindowProperty
import org.apache.flink.table.calcite.FlinkTypeFactory
import org.apache.flink.table.calcite.FlinkTypeFactory._
import org.apache.flink.table.functions.sql.StreamRecordTimestampSqlFunction
import org.apache.flink.table.operations.TableOperation
Expand All @@ -39,6 +40,21 @@ abstract class Attribute extends LeafExpression with NamedExpression {
private[flink] def withName(newName: String): Attribute
}

/**
* Dummy wrapper for expressions that were converted to RexNode in a different way.
*/
case class RexPlannerExpression(
private[flink] val rexNode: RexNode)
extends LeafExpression {

override private[flink] def toRexNode(implicit relBuilder: RelBuilder): RexNode = {
rexNode
}

override private[flink] def resultType: TypeInformation[_] =
FlinkTypeFactory.toTypeInfo(rexNode.getType)
}

case class UnresolvedFieldReference(name: String) extends Attribute {

override def toString = s"'$name"
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@ import org.apache.flink.table.expressions.catalog.FunctionDefinitionCatalog
import org.apache.flink.table.expressions.lookups.TableReferenceLookup
import org.apache.flink.table.functions.utils.UserDefinedFunctionUtils
import org.apache.flink.table.operations.AliasOperationUtils.createAliasList
import org.apache.flink.table.operations.JoinOperationFactory.JoinType
import org.apache.flink.table.operations.JoinTableOperation.JoinType
import org.apache.flink.table.operations.SetTableOperation.SetTableOperationType._
import org.apache.flink.table.util.JavaScalaConversionUtil
import org.apache.flink.table.util.JavaScalaConversionUtil.toScala
Expand Down
Loading

0 comments on commit df59135

Please sign in to comment.