Skip to content

Commit

Permalink
Pass Session to compiled JoinFilterFunction
Browse files Browse the repository at this point in the history
  • Loading branch information
losipiuk authored and dain committed Sep 16, 2016
1 parent 5dc213b commit cdb67cf
Show file tree
Hide file tree
Showing 3 changed files with 42 additions and 8 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -16,13 +16,15 @@
import com.facebook.presto.bytecode.BytecodeBlock;
import com.facebook.presto.bytecode.BytecodeNode;
import com.facebook.presto.bytecode.ClassDefinition;
import com.facebook.presto.bytecode.FieldDefinition;
import com.facebook.presto.bytecode.MethodDefinition;
import com.facebook.presto.bytecode.Parameter;
import com.facebook.presto.bytecode.Scope;
import com.facebook.presto.bytecode.Variable;
import com.facebook.presto.bytecode.control.IfStatement;
import com.facebook.presto.metadata.Metadata;
import com.facebook.presto.operator.JoinFilterFunction;
import com.facebook.presto.spi.ConnectorSession;
import com.facebook.presto.spi.block.Block;
import com.facebook.presto.spi.type.Type;
import com.facebook.presto.sql.relational.CallExpression;
Expand All @@ -42,9 +44,9 @@

import java.util.Map;
import java.util.Objects;
import java.util.function.Supplier;

import static com.facebook.presto.bytecode.Access.FINAL;
import static com.facebook.presto.bytecode.Access.PRIVATE;
import static com.facebook.presto.bytecode.Access.PUBLIC;
import static com.facebook.presto.bytecode.Access.a;
import static com.facebook.presto.bytecode.CompilerUtils.defineClass;
Expand Down Expand Up @@ -79,12 +81,12 @@ public Class<? extends JoinFilterFunction> load(JoinFilterCacheKey key)
}
});

public Supplier<JoinFilterFunction> compileJoinFilterFunction(RowExpression filter, int leftBlocksSize)
public JoinFilterFunctionFactory compileJoinFilterFunction(RowExpression filter, int leftBlocksSize)
{
Class<? extends JoinFilterFunction> joinFilterFunction = joinFilterFunctions.getUnchecked(new JoinFilterCacheKey(filter, leftBlocksSize));
return () -> {
return (session) -> {
try {
return joinFilterFunction.newInstance();
return joinFilterFunction.getConstructor(ConnectorSession.class).newInstance(session);
}
catch (ReflectiveOperationException e) {
throw Throwables.propagate(e);
Expand Down Expand Up @@ -121,11 +123,29 @@ private Class<? extends JoinFilterFunction> compileFilterFunctionInternal(RowExp
private void generateMethods(ClassDefinition classDefinition, CallSiteBinder callSiteBinder, RowExpression filter, int leftBlocksSize)
{
CachedInstanceBinder cachedInstanceBinder = new CachedInstanceBinder(classDefinition, callSiteBinder);
generateFilterMethod(classDefinition, callSiteBinder, cachedInstanceBinder, filter, leftBlocksSize);
classDefinition.declareDefaultConstructor(a(PUBLIC));

FieldDefinition sessionField = classDefinition.declareField(a(PRIVATE, FINAL), "session", ConnectorSession.class);
generateConstructor(classDefinition, sessionField);
generateFilterMethod(classDefinition, callSiteBinder, cachedInstanceBinder, filter, leftBlocksSize, sessionField);
}

private void generateFilterMethod(ClassDefinition classDefinition, CallSiteBinder callSiteBinder, CachedInstanceBinder cachedInstanceBinder, RowExpression filter, int leftBlocksSize)
private void generateConstructor(ClassDefinition classDefinition, FieldDefinition sessionField)
{
Parameter sessionParameter = arg("session", ConnectorSession.class);
MethodDefinition constructorDefinition = classDefinition.declareConstructor(a(PUBLIC), sessionParameter);

BytecodeBlock body = constructorDefinition.getBody();
Variable thisVariable = constructorDefinition.getThis();

body.comment("super();")
.append(thisVariable)
.invokeConstructor(Object.class);

body.append(thisVariable.setField(sessionField, sessionParameter));
body.ret();
}

private void generateFilterMethod(ClassDefinition classDefinition, CallSiteBinder callSiteBinder, CachedInstanceBinder cachedInstanceBinder, RowExpression filter, int leftBlocksSize, FieldDefinition sessionField)
{
// todo handle TRY expression
Map<CallExpression, MethodDefinition> tryMethodMap = ImmutableMap.of();
Expand All @@ -152,6 +172,7 @@ private void generateFilterMethod(ClassDefinition classDefinition, CallSiteBinde

Scope scope = method.getScope();
Variable wasNullVariable = scope.declareVariable("wasNull", body, constantFalse());
scope.declareVariable("session", body, method.getThis().getField(sessionField));

BytecodeExpressionVisitor visitor = new BytecodeExpressionVisitor(
callSiteBinder,
Expand Down Expand Up @@ -180,6 +201,11 @@ private static void generateToString(ClassDefinition classDefinition, CallSiteBi
.retObject();
}

@FunctionalInterface
public interface JoinFilterFunctionFactory {
JoinFilterFunction create(ConnectorSession session);
}

private static RowExpressionVisitor<Scope, BytecodeNode> fieldReferenceCompiler(
final CallSiteBinder callSiteBinder,
final Variable leftPosition,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1571,7 +1571,7 @@ private JoinFilterFunction compileJoinFilterFunction(
emptyList() /* parameters have already been replaced */);

RowExpression translatedFilter = toRowExpression(rewrittenFilter, expressionTypes);
return joinFilterFunctionCompiler.compileJoinFilterFunction(translatedFilter, buildLayout.size()).get();
return joinFilterFunctionCompiler.compileJoinFilterFunction(translatedFilter, buildLayout.size()).create(session.toConnectorSession());
}

private OperatorFactory createLookupJoin(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2111,6 +2111,14 @@ public void testNonEqualityLeftJoin()
"VALUES (1, 1, 2, 2)");
}

@Test
public void testNonEqalityJoinWithScalarRequiringSessionParameter()
throws Exception
{
assertQuery("SELECT * FROM (VALUES (1,1), (1,2)) t1(a,b) LEFT OUTER JOIN (VALUES (1,1), (1,2)) t2(c,d) ON a=c AND from_unixtime(b) > current_timestamp",
"VALUES (1, 1, NULL, NULL), (1, 2, NULL, NULL)");
}

@Test
public void testLeftJoinWithEmptyInnerTable()
throws Exception
Expand Down

0 comments on commit cdb67cf

Please sign in to comment.