Skip to content

Commit

Permalink
[FLINK-22726][hive] Support grouping__id in hive prior to 2.3.0
Browse files Browse the repository at this point in the history
This closes apache#15983
  • Loading branch information
lirui-apache committed Jun 9, 2021
1 parent feedfb2 commit 6d8c02f
Show file tree
Hide file tree
Showing 7 changed files with 357 additions and 95 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
import org.apache.flink.table.catalog.hive.factories.HiveFunctionDefinitionFactory;
import org.apache.flink.table.functions.FunctionDefinition;
import org.apache.flink.table.module.Module;
import org.apache.flink.table.module.hive.udf.generic.GenericUDFLegacyGroupingID;
import org.apache.flink.table.module.hive.udf.generic.HiveGenericUDFGrouping;
import org.apache.flink.util.StringUtils;

Expand Down Expand Up @@ -101,6 +102,7 @@ public Set<String> listFunctions() {
functionNames = hiveShim.listBuiltInFunctions();
functionNames.removeAll(BUILT_IN_FUNC_BLACKLIST);
functionNames.add("grouping");
functionNames.add(GenericUDFLegacyGroupingID.NAME);
}
return functionNames;
}
Expand All @@ -117,6 +119,13 @@ public Optional<FunctionDefinition> getFunctionDefinition(String name) {
name, HiveGenericUDFGrouping.class.getName()));
}

// this function is used to generate legacy GROUPING__ID value for old hive versions
if (name.equalsIgnoreCase(GenericUDFLegacyGroupingID.NAME)) {
return Optional.of(
factory.createFunctionDefinitionFromHiveFunction(
name, GenericUDFLegacyGroupingID.class.getName()));
}

Optional<FunctionInfo> info = hiveShim.getBuiltInFunctionInfo(name);

return info.map(
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,110 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package org.apache.flink.table.module.hive.udf.generic;

import org.apache.flink.table.planner.delegation.hive.HiveParserUtils;

import org.apache.hadoop.hive.ql.exec.UDFArgumentException;
import org.apache.hadoop.hive.ql.exec.UDFArgumentLengthException;
import org.apache.hadoop.hive.ql.exec.UDFArgumentTypeException;
import org.apache.hadoop.hive.ql.metadata.HiveException;
import org.apache.hadoop.hive.ql.udf.generic.GenericUDF;
import org.apache.hadoop.hive.serde2.objectinspector.ConstantObjectInspector;
import org.apache.hadoop.hive.serde2.objectinspector.ObjectInspector;
import org.apache.hadoop.hive.serde2.objectinspector.PrimitiveObjectInspector;
import org.apache.hadoop.hive.serde2.objectinspector.primitive.PrimitiveObjectInspectorFactory;
import org.apache.hadoop.hive.serde2.objectinspector.primitive.PrimitiveObjectInspectorUtils;
import org.apache.hadoop.io.LongWritable;

import java.util.BitSet;

/**
* Hive's GROUPING__ID changed since 2.3.0. This function is to convert the new GROUPING__ID to the
* legacy value for older Hive versions. See https://issues.apache.org/jira/browse/HIVE-16102
*/
public class GenericUDFLegacyGroupingID extends GenericUDF {

public static final String NAME = "_legacy_grouping__id";

private transient PrimitiveObjectInspector groupingIdOI;
private int numExprs;
private final LongWritable legacyValue = new LongWritable();

@Override
public ObjectInspector initialize(ObjectInspector[] arguments) throws UDFArgumentException {
// we accept two arguments: the new GROUPING__ID and the number of GBY expressions
if (arguments.length != 2) {
throw new UDFArgumentLengthException(
"Expect 2 arguments but actually got " + arguments.length);
}
if (arguments[0].getCategory() != ObjectInspector.Category.PRIMITIVE) {
throw new UDFArgumentTypeException(0, "First argument should be primitive type");
}
if (arguments[1].getCategory() != ObjectInspector.Category.PRIMITIVE) {
throw new UDFArgumentTypeException(1, "Second argument should be primitive type");
}
groupingIdOI = (PrimitiveObjectInspector) arguments[0];
if (groupingIdOI.getPrimitiveCategory()
!= PrimitiveObjectInspector.PrimitiveCategory.LONG) {
throw new UDFArgumentTypeException(0, "First argument should be a long");
}
PrimitiveObjectInspector numExprOI = (PrimitiveObjectInspector) arguments[1];
if (numExprOI.getPrimitiveCategory() != PrimitiveObjectInspector.PrimitiveCategory.INT) {
throw new UDFArgumentTypeException(1, "Second argument should be an int");
}
if (!(numExprOI instanceof ConstantObjectInspector)) {
throw new UDFArgumentTypeException(1, "Second argument should be a constant");
}
numExprs =
PrimitiveObjectInspectorUtils.getInt(
((ConstantObjectInspector) numExprOI).getWritableConstantValue(),
numExprOI);
if (numExprs < 1 || numExprs > 64) {
throw new UDFArgumentException(
"Number of GROUP BY expressions out of range: " + numExprs);
}
return PrimitiveObjectInspectorFactory.writableLongObjectInspector;
}

@Override
public Object evaluate(DeferredObject[] arguments) throws HiveException {
long groupingId = PrimitiveObjectInspectorUtils.getLong(arguments[0].get(), groupingIdOI);
BitSet bitSet = BitSet.valueOf(new long[] {groupingId});
// flip each bit
bitSet.flip(0, numExprs);
// reverse bit order
int i = 0;
int j = numExprs - 1;
while (i < j) {
bitSet.set(i, bitSet.get(i) ^ bitSet.get(j));
bitSet.set(j, bitSet.get(i) ^ bitSet.get(j));
bitSet.set(i, bitSet.get(i) ^ bitSet.get(j));
i++;
j--;
}
long[] words = bitSet.toLongArray();
legacyValue.set(words.length == 0 ? 0L : words[0]);
return legacyValue;
}

@Override
public String getDisplayString(String[] children) {
return HiveParserUtils.getStandardDisplayString("grouping", children);
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -1244,33 +1244,6 @@ private RelNode genGBLogicalPlan(HiveParserQB qb, RelNode srcRel) throws Semanti
&& qbp.getDistinctFuncExprsForClause(detsClauseName).size() > 1) {
throw new SemanticException(ErrorMsg.UNSUPPORTED_MULTIPLE_DISTINCTS.getMsg());
}
if (cubeRollupGrpSetPresent) {
if (!HiveConf.getBoolVar(
semanticAnalyzer.getConf(), HiveConf.ConfVars.HIVEMAPSIDEAGGREGATE)) {
throw new SemanticException(ErrorMsg.HIVE_GROUPING_SETS_AGGR_NOMAPAGGR.getMsg());
}

if (semanticAnalyzer.getConf().getBoolVar(HiveConf.ConfVars.HIVEGROUPBYSKEW)) {
semanticAnalyzer.checkExpressionsForGroupingSet(
gbAstExprs,
qb.getParseInfo().getDistinctFuncExprsForClause(detsClauseName),
aggregationTrees,
this.relToRowResolver.get(srcRel));

if (qbp.getDestGroupingSets().size()
> semanticAnalyzer
.getConf()
.getIntVar(
HiveConf.ConfVars.HIVE_NEW_JOB_GROUPING_SET_CARDINALITY)) {
String errorMsg =
"The number of rows per input row due to grouping sets is "
+ qbp.getDestGroupingSets().size();
throw new SemanticException(
ErrorMsg.HIVE_GROUPING_SETS_THRESHOLD_NOT_ALLOWED_WITH_SKEW.getMsg(
errorMsg));
}
}
}

if (hasGrpByAstExprs || hasAggregationTrees) {
ArrayList<ExprNodeDesc> gbExprNodeDescs = new ArrayList<>();
Expand Down Expand Up @@ -2325,7 +2298,7 @@ private RelNode genSelectLogicalPlan(
tabAlias,
false);
colInfo.setSkewedCol(
(exprDesc instanceof ExprNodeColumnDesc)
exprDesc instanceof ExprNodeColumnDesc
&& ((ExprNodeColumnDesc) exprDesc).isSkewedCol());
// Hive errors out in case of duplication. We allow it and see what happens.
outRR.put(tabAlias, colAlias, colInfo);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@
import org.apache.flink.table.functions.FunctionKind;
import org.apache.flink.table.functions.hive.HiveGenericUDAF;
import org.apache.flink.table.functions.hive.HiveGenericUDTF;
import org.apache.flink.table.module.hive.udf.generic.GenericUDFLegacyGroupingID;
import org.apache.flink.table.planner.calcite.FlinkTypeFactory;
import org.apache.flink.table.planner.delegation.hive.copy.HiveASTParseDriver;
import org.apache.flink.table.planner.delegation.hive.copy.HiveASTParseUtils;
Expand Down Expand Up @@ -99,6 +100,7 @@
import org.apache.calcite.util.NlsString;
import org.apache.calcite.util.Pair;
import org.apache.commons.lang3.mutable.MutableBoolean;
import org.apache.hadoop.conf.Configuration;
import org.apache.hadoop.fs.FSDataOutputStream;
import org.apache.hadoop.hive.conf.HiveConf;
import org.apache.hadoop.hive.metastore.api.FieldSchema;
Expand Down Expand Up @@ -714,6 +716,7 @@ public static HiveParserASTNode rewriteGroupingFunctionAST(
throws SemanticException {
final MutableBoolean visited = new MutableBoolean(false);
final MutableBoolean found = new MutableBoolean(false);
final boolean legacyGrouping = legacyGrouping();

TreeVisitorAction action =
new TreeVisitorAction() {
Expand All @@ -725,58 +728,24 @@ public Object pre(Object t) {

@Override
public Object post(Object t) {
HiveParserASTNode root = (HiveParserASTNode) t;
if (root.getType() == HiveASTParser.TOK_FUNCTION
&& root.getChildCount() == 2) {
HiveParserASTNode func =
(HiveParserASTNode)
HiveASTParseDriver.ADAPTOR.getChild(root, 0);
HiveParserASTNode current = (HiveParserASTNode) t;
// rewrite grouping function
if (current.getType() == HiveASTParser.TOK_FUNCTION
&& current.getChildCount() >= 2) {
HiveParserASTNode func = (HiveParserASTNode) current.getChild(0);
if (func.getText().equals("grouping")) {
HiveParserASTNode c =
(HiveParserASTNode)
HiveASTParseDriver.ADAPTOR.getChild(root, 1);
visited.setValue(true);
for (int i = 0; i < grpByAstExprs.size(); i++) {
HiveParserASTNode grpByExpr = grpByAstExprs.get(i);
if (grpByExpr.toStringTree().equals(c.toStringTree())) {
HiveParserASTNode child1;
if (noneSet) {
// Query does not contain CUBE, ROLLUP, or GROUPING
// SETS, and thus, grouping should return 0
child1 =
(HiveParserASTNode)
HiveASTParseDriver.ADAPTOR.create(
HiveASTParser.IntegralLiteral,
String.valueOf(0));
} else {
// We refer to grouping_id column
child1 =
(HiveParserASTNode)
HiveASTParseDriver.ADAPTOR.create(
HiveASTParser.TOK_TABLE_OR_COL,
"TOK_TABLE_OR_COL");
HiveASTParseDriver.ADAPTOR.addChild(
child1,
HiveASTParseDriver.ADAPTOR.create(
HiveASTParser.Identifier,
VirtualColumn.GROUPINGID.getName()));
}
HiveParserASTNode child2 =
(HiveParserASTNode)
HiveASTParseDriver.ADAPTOR.create(
HiveASTParser.IntegralLiteral,
String.valueOf(
com.google.common.math
.IntMath.mod(
-i - 1,
grpByAstExprs
.size())));
root.setChild(1, child1);
root.addChild(child2);
found.setValue(true);
break;
}
}
convertGrouping(
current, grpByAstExprs, noneSet, legacyGrouping, found);
}
} else if (legacyGrouping
&& current.getType() == HiveASTParser.TOK_TABLE_OR_COL
&& current.getChildCount() == 1) {
// rewrite grouping__id
HiveParserASTNode child = (HiveParserASTNode) current.getChild(0);
if (child.getText()
.equalsIgnoreCase(VirtualColumn.GROUPINGID.getName())) {
return convertToLegacyGroupingId(current, grpByAstExprs.size());
}
}
return t;
Expand All @@ -791,6 +760,92 @@ public Object post(Object t) {
return newTargetNode;
}

private static HiveParserASTNode convertToLegacyGroupingId(
HiveParserASTNode groupingId, int numGBExprs) {
HiveParserASTNode converterFunc =
(HiveParserASTNode)
HiveASTParseDriver.ADAPTOR.create(
HiveASTParser.TOK_FUNCTION, "TOK_FUNCTION");
// function name
converterFunc.addChild(
(Tree)
HiveASTParseDriver.ADAPTOR.create(
HiveASTParser.StringLiteral, GenericUDFLegacyGroupingID.NAME));
// origin grouping__id
converterFunc.addChild(groupingId);
// num of group by expressions
converterFunc.addChild(
(Tree)
HiveASTParseDriver.ADAPTOR.create(
HiveASTParser.IntegralLiteral, String.valueOf(numGBExprs)));
return converterFunc;
}

private static void convertGrouping(
HiveParserASTNode function,
List<HiveParserASTNode> grpByAstExprs,
boolean noneSet,
boolean legacyGrouping,
MutableBoolean found) {
HiveParserASTNode col = (HiveParserASTNode) function.getChild(1);
for (int i = 0; i < grpByAstExprs.size(); i++) {
HiveParserASTNode grpByExpr = grpByAstExprs.get(i);
if (grpByExpr.toStringTree().equals(col.toStringTree())) {
HiveParserASTNode child1;
if (noneSet) {
// Query does not contain CUBE, ROLLUP, or GROUPING
// SETS, and thus, grouping should return 0
child1 =
(HiveParserASTNode)
HiveASTParseDriver.ADAPTOR.create(
HiveASTParser.IntegralLiteral, String.valueOf(0));
} else {
// We refer to grouping_id column
child1 =
(HiveParserASTNode)
HiveASTParseDriver.ADAPTOR.create(
HiveASTParser.TOK_TABLE_OR_COL, "TOK_TABLE_OR_COL");
HiveASTParseDriver.ADAPTOR.addChild(
child1,
HiveASTParseDriver.ADAPTOR.create(
HiveASTParser.Identifier, VirtualColumn.GROUPINGID.getName()));
if (legacyGrouping) {
child1 = convertToLegacyGroupingId(child1, grpByAstExprs.size());
}
}
HiveParserASTNode child2 =
(HiveParserASTNode)
HiveASTParseDriver.ADAPTOR.create(
HiveASTParser.IntegralLiteral,
String.valueOf(
nonNegativeMod(
legacyGrouping ? i : -i - 1,
grpByAstExprs.size())));
function.setChild(1, child1);
function.addChild(child2);
found.setValue(true);
break;
}
}
}

public static boolean legacyGrouping(Configuration conf) {
String hiveVersion = conf.get(HiveCatalogFactoryOptions.HIVE_VERSION.key());
return hiveVersion != null && hiveVersion.compareTo("2.3.0") < 0;
}

private static boolean legacyGrouping() {
return legacyGrouping(SessionState.get().getConf());
}

private static int nonNegativeMod(int x, int m) {
if (m <= 0) {
throw new ArithmeticException("Modulus " + m + " must be > 0");
}
int result = x % m;
return (result >= 0) ? result : result + m;
}

public static SqlOperator getAnySqlOperator(String funcName, SqlOperatorTable opTable) {
SqlOperator sqlOperator =
getSqlOperator(funcName, opTable, SqlFunctionCategory.USER_DEFINED_FUNCTION);
Expand Down
Loading

0 comments on commit 6d8c02f

Please sign in to comment.