Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[WIP] Supporting CTAS queries for Hive to Spark query translations #324

Open
wants to merge 11 commits into
base: master
Choose a base branch
from
Prev Previous commit
Next Next commit
fixing ide lint unnncessary changes
  • Loading branch information
nimesh1601 committed Nov 3, 2022
commit a5a207954012ebcd352ab65a319d12a8d13de3c8
Original file line number Diff line number Diff line change
Expand Up @@ -5,17 +5,43 @@
*/
package com.linkedin.coral.hive.hive2rel.parsetree;

import java.util.*;
import java.util.ArrayList;
import java.util.Collections;
import java.util.List;
import java.util.Optional;
import java.util.Arrays;
import java.util.stream.Collectors;

import javax.annotation.Nullable;

import com.linkedin.coral.common.calcite.sql.SqlCreateTable;
import org.apache.calcite.avatica.util.TimeUnit;
import org.apache.calcite.sql.*;
import org.apache.calcite.sql.JoinConditionType;
import org.apache.calcite.sql.JoinType;
import org.apache.calcite.sql.SqlAsOperator;
import org.apache.calcite.sql.SqlBasicCall;
import org.apache.calcite.sql.SqlBasicTypeNameSpec;
import org.apache.calcite.sql.SqlCall;
import org.apache.calcite.sql.SqlDataTypeSpec;
import org.apache.calcite.sql.SqlIdentifier;
import org.apache.calcite.sql.SqlIntervalQualifier;
import org.apache.calcite.sql.SqlJoin;
import org.apache.calcite.sql.SqlKind;
import org.apache.calcite.sql.SqlLateralOperator;
import org.apache.calcite.sql.SqlLiteral;
import org.apache.calcite.sql.SqlNode;
import org.apache.calcite.sql.SqlNodeList;
import org.apache.calcite.sql.SqlOperator;
import org.apache.calcite.sql.SqlSelect;
import org.apache.calcite.sql.SqlSelectKeyword;
import org.apache.calcite.sql.SqlTypeNameSpec;
import org.apache.calcite.sql.SqlWindow;
import org.apache.calcite.sql.SqlWith;
import org.apache.calcite.sql.SqlWithItem;
import org.apache.calcite.sql.fun.SqlStdOperatorTable;
import org.apache.calcite.sql.parser.SqlParserPos;
import org.apache.calcite.sql.type.SqlTypeName;
import org.apache.calcite.sql.SqlCharStringLiteral;
import org.apache.hadoop.hive.metastore.api.Table;

import com.linkedin.coral.com.google.common.collect.ImmutableList;
Expand Down Expand Up @@ -160,10 +186,10 @@ private SqlNode visitLateralViewInternal(ASTNode node, ParseContext ctx, boolean
*/
private SqlNode visitLateralViewUDTF(List<SqlNode> sqlNodes, List<SqlNode> aliasOperands, SqlCall tableFunctionCall) {
SqlNode lateralCall = SqlStdOperatorTable.LATERAL.createCall(ZERO,
new SqlLateralOperator(SqlKind.COLLECTION_TABLE).createCall(ZERO, tableFunctionCall));
new SqlLateralOperator(SqlKind.COLLECTION_TABLE).createCall(ZERO, tableFunctionCall));
final String functionName = tableFunctionCall.getOperator().getName();
ImmutableList<String> fieldNames =
StaticHiveFunctionRegistry.UDTF_RETURN_FIELD_NAME_MAP.getOrDefault(functionName, null);
StaticHiveFunctionRegistry.UDTF_RETURN_FIELD_NAME_MAP.getOrDefault(functionName, null);
if (fieldNames == null) {
throw new RuntimeException("User defined table function " + functionName + " is not registered.");
}
Expand All @@ -173,11 +199,11 @@ private SqlNode visitLateralViewUDTF(List<SqlNode> sqlNodes, List<SqlNode> alias
fieldNames.forEach(name -> asOperands.add(new SqlIdentifier(name, ZERO)));
SqlCall aliasCall = SqlStdOperatorTable.AS.createCall(ZERO, asOperands);
return new SqlJoin(ZERO, sqlNodes.get(1), SqlLiteral.createBoolean(false, ZERO), JoinType.COMMA.symbol(ZERO),
aliasCall/*lateralCall*/, JoinConditionType.NONE.symbol(ZERO), null);
aliasCall/*lateralCall*/, JoinConditionType.NONE.symbol(ZERO), null);
}

private SqlNode visitLateralViewExplode(List<SqlNode> sqlNodes, List<SqlNode> aliasOperands,
SqlCall tableFunctionCall, boolean isOuter) {
SqlCall tableFunctionCall, boolean isOuter) {
final int operandCount = aliasOperands.size();
// explode array if operandCount == 3: LATERAL VIEW EXPLODE(op0) op1 AS op2
// explode map if operandCount == 4: LATERAL VIEW EXPLODE(op0) op1 AS op2, op3
Expand All @@ -189,33 +215,33 @@ private SqlNode visitLateralViewExplode(List<SqlNode> sqlNodes, List<SqlNode> al
// Note that `operandCount == 2 && isOuter` is not supported yet due to the lack of type information needed
// to derive the correct IF function parameters.
checkState(operandCount == 2 || operandCount == 3 || operandCount == 4,
format("Unsupported LATERAL VIEW EXPLODE operand number: %d", operandCount));
format("Unsupported LATERAL VIEW EXPLODE operand number: %d", operandCount));
// TODO The code below assumes LATERAL VIEW is used with UNNEST EXPLODE/POSEXPLODE only. It should be made more generic.
SqlCall unnestCall = tableFunctionCall;
SqlNode unnestOperand = unnestCall.operand(0);
final SqlOperator operator = unnestCall.getOperator();

if (isOuter) {
checkState(operandCount > 2,
"LATERAL VIEW OUTER EXPLODE without column aliases is not supported. Add 'AS col' or 'AS key, value' to fix it");
"LATERAL VIEW OUTER EXPLODE without column aliases is not supported. Add 'AS col' or 'AS key, value' to fix it");
// transforms unnest(b) to unnest( if(b is null or cardinality(b) = 0, ARRAY(null)/MAP(null, null), b))
SqlNode operandIsNull = SqlStdOperatorTable.IS_NOT_NULL.createCall(ZERO, unnestOperand);
SqlNode emptyArray = SqlStdOperatorTable.GREATER_THAN.createCall(ZERO,
SqlStdOperatorTable.CARDINALITY.createCall(ZERO, unnestOperand), SqlLiteral.createExactNumeric("0", ZERO));
SqlStdOperatorTable.CARDINALITY.createCall(ZERO, unnestOperand), SqlLiteral.createExactNumeric("0", ZERO));
SqlNode ifCondition = SqlStdOperatorTable.AND.createCall(ZERO, operandIsNull, emptyArray);
// array of [null] or map of (null, null) should be 3rd param to if function. With our type inference, calcite acts
// smart and for unnest(array[null]) or unnest(map(null, null)) determines return type to be null
SqlNode arrayOrMapOfNull;
if (operandCount == 3
|| (operator instanceof CoralSqlUnnestOperator && ((CoralSqlUnnestOperator) operator).withOrdinality)) {
|| (operator instanceof CoralSqlUnnestOperator && ((CoralSqlUnnestOperator) operator).withOrdinality)) {
arrayOrMapOfNull = SqlStdOperatorTable.ARRAY_VALUE_CONSTRUCTOR.createCall(ZERO, SqlLiteral.createNull(ZERO));
} else {
arrayOrMapOfNull = SqlStdOperatorTable.MAP_VALUE_CONSTRUCTOR.createCall(ZERO, SqlLiteral.createNull(ZERO),
SqlLiteral.createNull(ZERO));
SqlLiteral.createNull(ZERO));
}
Function hiveIfFunction = functionResolver.tryResolve("if", null, 1);
unnestOperand = hiveIfFunction.createCall(SqlLiteral.createCharString("if", ZERO),
ImmutableList.of(ifCondition, unnestOperand, arrayOrMapOfNull), null);
ImmutableList.of(ifCondition, unnestOperand, arrayOrMapOfNull), null);
}
unnestCall = operator.createCall(ZERO, unnestOperand);

Expand All @@ -241,7 +267,7 @@ private SqlNode visitLateralViewExplode(List<SqlNode> sqlNodes, List<SqlNode> al
SqlNode as = SqlStdOperatorTable.AS.createCall(ZERO, asOperands);

return new SqlJoin(ZERO, sqlNodes.get(1), SqlLiteral.createBoolean(false, ZERO), JoinType.COMMA.symbol(ZERO), as,
JoinConditionType.NONE.symbol(ZERO), null);
JoinConditionType.NONE.symbol(ZERO), null);
}

private SqlNode visitLateralViewJsonTuple(List<SqlNode> sqlNodes, List<SqlNode> aliasOperands, SqlCall sqlCall) {
Expand Down Expand Up @@ -270,22 +296,22 @@ LATERAL VIEW json_tuple(json, p1, p2) jt AS a, b

// '$["jsonKey"]'
SqlCall jsonPath = SqlStdOperatorTable.CONCAT.createCall(ZERO,
SqlStdOperatorTable.CONCAT.createCall(ZERO, SqlLiteral.createCharString("$[\"", ZERO), jsonKey),
SqlLiteral.createCharString("\"]", ZERO));
SqlStdOperatorTable.CONCAT.createCall(ZERO, SqlLiteral.createCharString("$[\"", ZERO), jsonKey),
SqlLiteral.createCharString("\"]", ZERO));

SqlCall getJsonObjectCall =
getJsonObjectFunction.createCall(SqlLiteral.createCharString(getJsonObjectFunction.getFunctionName(), ZERO),
ImmutableList.of(jsonInput, jsonPath), null);
getJsonObjectFunction.createCall(SqlLiteral.createCharString(getJsonObjectFunction.getFunctionName(), ZERO),
ImmutableList.of(jsonInput, jsonPath), null);
// TODO Hive get_json_object returns a string, but currently is mapped in Trino to json_extract which returns a json. Once fixed, remove the CAST
SqlCall castToString = SqlStdOperatorTable.CAST.createCall(ZERO, getJsonObjectCall,
// TODO This results in CAST to VARCHAR(65535), which may be too short, but there seems to be no way to avoid that.
// even `new SqlDataTypeSpec(new SqlBasicTypeNameSpec(SqlTypeName.VARCHAR, Integer.MAX_VALUE - 1, ZERO), ZERO)` results in a limited VARCHAR precision.
createBasicTypeSpec(SqlTypeName.VARCHAR));
// TODO support jsonKey containing a quotation mark (") or backslash (\)
SqlCall ifCondition =
HiveRLikeOperator.RLIKE.createCall(ZERO, jsonKey, SqlLiteral.createCharString("^[^\\\"]*$", ZERO));
HiveRLikeOperator.RLIKE.createCall(ZERO, jsonKey, SqlLiteral.createCharString("^[^\\\"]*$", ZERO));
SqlCall ifFunctionCall = ifFunction.createCall(SqlLiteral.createCharString(ifFunction.getFunctionName(), ZERO),
ImmutableList.of(ifCondition, castToString, SqlLiteral.createNull(ZERO)), null);
ImmutableList.of(ifCondition, castToString, SqlLiteral.createNull(ZERO)), null);
SqlNode projection = ifFunctionCall;
// Currently only explicit aliasing is supported. Implicit alias would be c0, c1, etc.
projections.add(SqlStdOperatorTable.AS.createCall(ZERO, projection, keyAlias));
Expand All @@ -295,9 +321,9 @@ LATERAL VIEW json_tuple(json, p1, p2) jt AS a, b
new SqlSelect(ZERO, null, new SqlNodeList(projections, ZERO), null, null, null, null, null, null, null, null);
SqlNode lateral = SqlStdOperatorTable.LATERAL.createCall(ZERO, select);
SqlCall lateralAlias = SqlStdOperatorTable.AS.createCall(ZERO,
ImmutableList.<SqlNode> builder().add(lateral).addAll(aliasOperands.subList(1, aliasOperands.size())).build());
ImmutableList.<SqlNode> builder().add(lateral).addAll(aliasOperands.subList(1, aliasOperands.size())).build());
SqlNode joinNode = new SqlJoin(ZERO, sqlNodes.get(1), SqlLiteral.createBoolean(false, ZERO),
JoinType.COMMA.symbol(ZERO), lateralAlias, JoinConditionType.NONE.symbol(ZERO), null);
JoinType.COMMA.symbol(ZERO), lateralAlias, JoinConditionType.NONE.symbol(ZERO), null);
return joinNode;
}

Expand Down Expand Up @@ -344,7 +370,7 @@ private SqlNode processJoin(ASTNode node, ParseContext ctx, JoinType joinType) {
}

return new SqlJoin(ZERO, children.get(0), SqlLiteral.createBoolean(false, ZERO), joinType.symbol(ZERO),
children.get(1), conditionType.symbol(ZERO), condition);
children.get(1), conditionType.symbol(ZERO), condition);
}

@Override
Expand Down Expand Up @@ -458,7 +484,7 @@ protected SqlNode visitOperator(ASTNode node, ParseContext ctx) {
return visitBinaryOperator(node, ctx);
} else {
throw new RuntimeException(
String.format("Unhandled AST operator: %s with > 2 children, tree: %s", node.getText(), node.dump()));
String.format("Unhandled AST operator: %s with > 2 children, tree: %s", node.getText(), node.dump()));
}
}

Expand Down Expand Up @@ -500,7 +526,7 @@ protected SqlNode visitLParen(ASTNode node, ParseContext ctx) {
protected SqlNode visitFunctionStar(ASTNode node, ParseContext ctx) {
ASTNode functionNode = (ASTNode) node.getChildren().get(0);
List<SqlOperator> functions = SqlStdOperatorTable.instance().getOperatorList().stream()
.filter(f -> functionNode.getText().equalsIgnoreCase(f.getName())).collect(Collectors.toList());
.filter(f -> functionNode.getText().equalsIgnoreCase(f.getName())).collect(Collectors.toList());
checkState(functions.size() == 1);
return new SqlBasicCall(functions.get(0), new SqlNode[] { new SqlIdentifier("", ZERO) }, ZERO);
}
Expand All @@ -522,8 +548,8 @@ private SqlNode visitFunctionInternal(ASTNode node, ParseContext ctx, SqlLiteral
String functionName = functionNode.getText();
List<SqlNode> sqlOperands = visitChildren(children, ctx);
Function hiveFunction = functionResolver.tryResolve(functionName, ctx.hiveTable.orElse(null),
// The first element of sqlOperands is the operator itself. The actual # of operands is sqlOperands.size() - 1
sqlOperands.size() - 1);
// The first element of sqlOperands is the operator itself. The actual # of operands is sqlOperands.size() - 1
sqlOperands.size() - 1);

// Special treatment for Window Function
SqlNode lastSqlOperand = sqlOperands.get(sqlOperands.size() - 1);
Expand All @@ -536,7 +562,7 @@ private SqlNode visitFunctionInternal(ASTNode node, ParseContext ctx, SqlLiteral
// SqlBasicCall("OVER") will have 2 children: "func" and SqlWindow
/** See {@link #visitWindowSpec(ASTNode, ParseContext)} for SQL, AST Tree and SqlNode Tree examples */
SqlNode func =
hiveFunction.createCall(sqlOperands.get(0), sqlOperands.subList(1, sqlOperands.size() - 1), quantifier);
hiveFunction.createCall(sqlOperands.get(0), sqlOperands.subList(1, sqlOperands.size() - 1), quantifier);
SqlNode window = lastSqlOperand;
return new SqlBasicCall(SqlStdOperatorTable.OVER, new SqlNode[] { func, window }, ZERO);
}
Expand Down Expand Up @@ -693,7 +719,7 @@ protected SqlNode visitTabRefNode(ASTNode node, ParseContext ctx) {
protected SqlNode visitTabnameNode(ASTNode node, ParseContext ctx) {
List<SqlNode> sqlNodes = visitChildren(node, ctx);
List<String> names =
sqlNodes.stream().map(s -> ((SqlIdentifier) s).names).flatMap(List::stream).collect(Collectors.toList());
sqlNodes.stream().map(s -> ((SqlIdentifier) s).names).flatMap(List::stream).collect(Collectors.toList());

return new SqlIdentifier(names, ZERO);
}
Expand Down Expand Up @@ -797,7 +823,7 @@ protected SqlNode visitQueryNode(ASTNode node, ParseContext ctx) {
}
}
SqlSelect select = new SqlSelect(ZERO, qc.keywords, qc.selects, qc.from, qc.where, qc.grpBy, qc.having, null,
qc.orderBy, null, qc.fetch);
qc.orderBy, null, qc.fetch);
if (cte != null) {
// Calcite uses "SqlWith(SqlNodeList of SqlWithItem, SqlSelect)" to represent queries with WITH
/** See {@link #visitCTE(ASTNode, ParseContext) visitCTE} for details */
Expand Down Expand Up @@ -871,8 +897,8 @@ protected SqlNode visitDecimal(ASTNode node, ParseContext ctx) {
if (node.getChildCount() == 2) {
try {
final SqlTypeNameSpec typeNameSpec = new SqlBasicTypeNameSpec(SqlTypeName.DECIMAL,
Integer.parseInt(((ASTNode) node.getChildren().get(0)).getText()),
Integer.parseInt(((ASTNode) node.getChildren().get(1)).getText()), ZERO);
Integer.parseInt(((ASTNode) node.getChildren().get(0)).getText()),
Integer.parseInt(((ASTNode) node.getChildren().get(1)).getText()), ZERO);
return new SqlDataTypeSpec(typeNameSpec, ZERO);
} catch (NumberFormatException e) {
return createBasicTypeSpec(SqlTypeName.DECIMAL);
Expand Down Expand Up @@ -990,17 +1016,17 @@ protected SqlNode visitWindowSpec(ASTNode node, ParseContext ctx) {
SqlWindow window = windowRange != null ? windowRange : windowValues;

return new SqlWindow(ZERO, null, null, partitionSpec == null ? SqlNodeList.EMPTY : partitionSpec.getPartitionList(),
partitionSpec == null ? SqlNodeList.EMPTY : partitionSpec.getOrderList(),
SqlLiteral.createBoolean(windowRange != null, ZERO), window == null ? null : window.getLowerBound(),
window == null ? null : window.getUpperBound(), null);
partitionSpec == null ? SqlNodeList.EMPTY : partitionSpec.getOrderList(),
SqlLiteral.createBoolean(windowRange != null, ZERO), window == null ? null : window.getLowerBound(),
window == null ? null : window.getUpperBound(), null);
}

@Override
protected SqlNode visitPartitioningSpec(ASTNode node, ParseContext ctx) {
SqlNode partitionList = visitOptionalChildByType(node, ctx, HiveParser.TOK_DISTRIBUTEBY);
SqlNode orderList = visitOptionalChildByType(node, ctx, HiveParser.TOK_ORDERBY);
return new SqlWindow(ZERO, null, null, partitionList != null ? (SqlNodeList) partitionList : SqlNodeList.EMPTY,
orderList != null ? (SqlNodeList) orderList : SqlNodeList.EMPTY, null, null, null, null);
orderList != null ? (SqlNodeList) orderList : SqlNodeList.EMPTY, null, null, null, null);
}

@Override
Expand Down