Skip to content

Commit

Permalink
Reduce number of equality checks in join hash
Browse files Browse the repository at this point in the history
Save one byte per entry of the raw hash to reduce the total number of equality
checks in the join hash.
  • Loading branch information
pnowojski authored and dain committed Jun 20, 2016
1 parent ae50750 commit 0255e72
Showing 1 changed file with 64 additions and 29 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
import com.facebook.presto.spi.block.Block;
import com.google.common.primitives.Ints;
import io.airlift.slice.XxHash64;
import io.airlift.units.DataSize;
import it.unimi.dsi.fastutil.HashCommon;
import it.unimi.dsi.fastutil.longs.LongArrayList;

Expand All @@ -26,14 +27,14 @@
import static com.facebook.presto.operator.SyntheticAddress.decodePosition;
import static com.facebook.presto.operator.SyntheticAddress.decodeSliceIndex;
import static io.airlift.slice.SizeOf.sizeOf;
import static io.airlift.slice.SizeOf.sizeOfBooleanArray;
import static io.airlift.slice.SizeOf.sizeOfIntArray;
import static io.airlift.units.DataSize.Unit.KILOBYTE;
import static java.util.Objects.requireNonNull;

// This implementation assumes arrays used in the hash are always a power of 2
public final class InMemoryJoinHash
implements LookupSource
{
private static final DataSize CACHE_SIZE = new DataSize(128, KILOBYTE);
private final LongArrayList addresses;
private final PagesHashStrategy pagesHashStrategy;

Expand All @@ -44,6 +45,11 @@ public final class InMemoryJoinHash
private final long size;
private final boolean filterFunctionPresent;

// Native array of hashes for faster collisions resolution compared
// to accessing values in blocks. We use bytes to reduce memory foot print
// and there is no performance gain from storing full hashes
private final byte[] positionToHashes;

public InMemoryJoinHash(LongArrayList addresses, PagesHashStrategy pagesHashStrategy)
{
this.addresses = requireNonNull(addresses, "addresses is null");
Expand All @@ -53,8 +59,6 @@ public InMemoryJoinHash(LongArrayList addresses, PagesHashStrategy pagesHashStra

// reserve memory for the arrays
int hashSize = HashCommon.arraySize(addresses.size(), 0.75f);
size = sizeOfIntArray(hashSize) + sizeOfBooleanArray(hashSize) + sizeOfIntArray(addresses.size())
+ sizeOf(addresses.elements()) + pagesHashStrategy.getSizeInBytes();

mask = hashSize - 1;
key = new int[hashSize];
Expand All @@ -63,31 +67,58 @@ public InMemoryJoinHash(LongArrayList addresses, PagesHashStrategy pagesHashStra
this.positionLinks = new int[addresses.size()];
Arrays.fill(positionLinks, -1);

// index pages
for (int position = 0; position < addresses.size(); position++) {
if (isPositionNull(position)) {
continue;
positionToHashes = new byte[addresses.size()];

// We will process addresses in batches, to save memory on array of hashes.
int positionsInStep = Math.min(addresses.size() + 1, (int) CACHE_SIZE.toBytes() / Integer.SIZE);
long[] positionToFullHashes = new long[positionsInStep];

for (int step = 0; step * positionsInStep <= addresses.size(); step++) {
int stepBeginPosition = step * positionsInStep;
int stepEndPosition = Math.min((step + 1) * positionsInStep, addresses.size());
int stepSize = stepEndPosition - stepBeginPosition;

// First extract all hashes from blocks to native array.
// Somehow having this as a separate loop is much faster compared
// to extracting hashes on the fly in the loop below.
for (int position = 0; position < stepSize; position++) {
int realPosition = position + stepBeginPosition;
long hash = readHashPosition(realPosition);
positionToFullHashes[position] = hash;
positionToHashes[realPosition] = (byte) hash;
}

int pos = (int) getHashPosition(hashPosition(position), mask);

// look for an empty slot or a slot containing this key
while (key[pos] != -1) {
int currentKey = key[pos];
if (positionEqualsPositionIgnoreNulls(currentKey, position)) {
// found a slot for this key
// link the new key position to the current key position
positionLinks[position] = currentKey;
// index pages
for (int position = 0; position < stepSize; position++) {
int realPosition = position + stepBeginPosition;
if (isPositionNull(realPosition)) {
continue;
}

// key[pos] updated outside of this loop
break;
long hash = positionToFullHashes[position];
int pos = getHashPosition(hash, mask);

// look for an empty slot or a slot containing this key
while (key[pos] != -1) {
int currentKey = key[pos];
if (((byte) hash) == positionToHashes[currentKey] && positionEqualsPositionIgnoreNulls(currentKey, realPosition)) {
// found a slot for this key
// link the new key position to the current key position
positionLinks[realPosition] = currentKey;

// key[pos] updated outside of this loop
break;
}
// increment position and mask to handler wrap around
pos = (pos + 1) & mask;
}
// increment position and mask to handler wrap around
pos = (pos + 1) & mask;
}

key[pos] = position;
key[pos] = realPosition;
}
}

size = sizeOf(addresses.elements()) + pagesHashStrategy.getSizeInBytes() +
sizeOf(key) + sizeOf(positionLinks) + sizeOf(positionToHashes);
}

@Override
Expand Down Expand Up @@ -117,10 +148,10 @@ public long getJoinPosition(int position, Page hashChannelsPage, Page allChannel
@Override
public long getJoinPosition(int rightPosition, Page hashChannelsPage, Page allChannelsPage, long rawHash)
{
int pos = (int) getHashPosition(rawHash, mask);
int pos = getHashPosition(rawHash, mask);

while (key[pos] != -1) {
if (positionEqualsCurrentRowIgnoreNulls(key[pos], rightPosition, hashChannelsPage)) {
if (positionEqualsCurrentRowIgnoreNulls(key[pos], (byte) rawHash, rightPosition, hashChannelsPage)) {
return getNextJoinPositionFrom(key[pos], rightPosition, allChannelsPage);
}
// increment position and mask to handler wrap around
Expand Down Expand Up @@ -168,7 +199,7 @@ private boolean isPositionNull(int position)
return pagesHashStrategy.isPositionNull(blockIndex, blockPosition);
}

private long hashPosition(int position)
private long readHashPosition(int position)
{
long pageAddress = addresses.getLong(position);
int blockIndex = decodeSliceIndex(pageAddress);
Expand All @@ -177,8 +208,12 @@ private long hashPosition(int position)
return pagesHashStrategy.hashPosition(blockIndex, blockPosition);
}

private boolean positionEqualsCurrentRowIgnoreNulls(int leftPosition, int rightPosition, Page rightPage)
private boolean positionEqualsCurrentRowIgnoreNulls(int leftPosition, byte rawHash, int rightPosition, Page rightPage)
{
if (positionToHashes[leftPosition] != rawHash) {
return false;
}

long pageAddress = addresses.getLong(leftPosition);
int blockIndex = decodeSliceIndex(pageAddress);
int blockPosition = decodePosition(pageAddress);
Expand Down Expand Up @@ -212,8 +247,8 @@ private boolean positionEqualsPositionIgnoreNulls(int leftPosition, int rightPos
return pagesHashStrategy.positionEqualsPositionIgnoreNulls(leftBlockIndex, leftBlockPosition, rightBlockIndex, rightBlockPosition);
}

private static long getHashPosition(long rawHash, long mask)
private static int getHashPosition(long rawHash, long mask)
{
return (XxHash64.hash(rawHash)) & mask;
return (int) ((XxHash64.hash(rawHash)) & mask);
}
}

0 comments on commit 0255e72

Please sign in to comment.