Skip to content

Commit

Permalink
relocate method in BufferAggregator. (apache#4071)
Browse files Browse the repository at this point in the history
*  relocate method in BufferAggregator.

* Unused import.

* Detailed javadoc.

* using Int2ObjectMap.

* batch relocate.

* Revert batch relocate.

* Unused import.

* code comments.

* code comment.
  • Loading branch information
akashdw authored and gianm committed Mar 23, 2017
1 parent f68ba41 commit ff7f90b
Show file tree
Hide file tree
Showing 8 changed files with 266 additions and 19 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -28,20 +28,19 @@
import io.druid.query.aggregation.BufferAggregator;
import io.druid.query.monomorphicprocessing.RuntimeShapeInspector;
import io.druid.segment.ObjectColumnSelector;
import it.unimi.dsi.fastutil.ints.Int2ObjectMap;
import it.unimi.dsi.fastutil.ints.Int2ObjectOpenHashMap;

import java.nio.ByteBuffer;
import java.util.HashMap;
import java.util.Map;
import java.util.IdentityHashMap;

public class SketchBufferAggregator implements BufferAggregator
{
private final ObjectColumnSelector selector;
private final int size;
private final int maxIntermediateSize;

private NativeMemory nm;

private final Map<Integer, Union> unions = new HashMap<>(); //position in BB -> Union Object
private final IdentityHashMap<ByteBuffer, Int2ObjectMap<Union>> unions = new IdentityHashMap<>();
private final IdentityHashMap<ByteBuffer, NativeMemory> nmCache = new IdentityHashMap<>();

public SketchBufferAggregator(ObjectColumnSelector selector, int size, int maxIntermediateSize)
{
Expand All @@ -53,12 +52,7 @@ public SketchBufferAggregator(ObjectColumnSelector selector, int size, int maxIn
@Override
public void init(ByteBuffer buf, int position)
{
if (nm == null) {
nm = new NativeMemory(buf);
}

Memory mem = new MemoryRegion(nm, position, maxIntermediateSize);
unions.put(position, (Union) SetOperation.builder().initMemory(mem).build(size, Family.UNION));
createNewUnion(buf, position, false);
}

@Override
Expand Down Expand Up @@ -87,12 +81,27 @@ public Object get(ByteBuffer buf, int position)
//Note that this is not threadsafe and I don't think it needs to be
private Union getUnion(ByteBuffer buf, int position)
{
Union union = unions.get(position);
if (union == null) {
Memory mem = new MemoryRegion(nm, position, maxIntermediateSize);
union = (Union) SetOperation.wrap(mem);
unions.put(position, union);
Int2ObjectMap<Union> unionMap = unions.get(buf);
Union union = unionMap != null ? unionMap.get(position) : null;
if (union != null) {
return union;
}
return createNewUnion(buf, position, true);
}

private Union createNewUnion(ByteBuffer buf, int position, boolean isWrapped)
{
NativeMemory nm = getNativeMemory(buf);
Memory mem = new MemoryRegion(nm, position, maxIntermediateSize);
Union union = isWrapped
? (Union) SetOperation.wrap(mem)
: (Union) SetOperation.builder().initMemory(mem).build(size, Family.UNION);
Int2ObjectMap<Union> unionMap = unions.get(buf);
if (unionMap == null) {
unionMap = new Int2ObjectOpenHashMap<>();
unions.put(buf, unionMap);
}
unionMap.put(position, union);
return union;
}

Expand All @@ -119,4 +128,29 @@ public void inspectRuntimeShape(RuntimeShapeInspector inspector)
{
inspector.visit("selector", selector);
}

@Override
public void relocate(int oldPosition, int newPosition, ByteBuffer oldBuffer, ByteBuffer newBuffer)
{
createNewUnion(newBuffer, newPosition, true);
Int2ObjectMap<Union> unionMap = unions.get(oldBuffer);
if (unionMap != null) {
unionMap.remove(oldPosition);
if (unionMap.isEmpty()) {
unions.remove(oldBuffer);
nmCache.remove(oldBuffer);
}
}
}

private NativeMemory getNativeMemory(ByteBuffer buffer)
{
NativeMemory nm = nmCache.get(buffer);
if (nm == null) {
nm = new NativeMemory(buffer);
nmCache.put(buffer, nm);
}
return nm;
}

}
Original file line number Diff line number Diff line change
Expand Up @@ -290,4 +290,16 @@ public static SketchHolder sketchSetOperation(Func func, int sketchSize, Object.
throw new IllegalArgumentException("Unknown sketch operation " + func);
}
}

@Override
public boolean equals(Object o)
{
if (this == o) {
return true;
}
if (o == null || getClass() != o.getClass()) {
return false;
}
return this.getSketch().equals(((SketchHolder) o).getSketch());
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,97 @@
/*
* Licensed to Metamarkets Group Inc. (Metamarkets) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. Metamarkets licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/

package io.druid.query.aggregation.datasketches.theta;

import com.google.common.base.Suppliers;
import com.google.common.collect.ImmutableMap;
import com.google.common.collect.Lists;
import com.yahoo.sketches.theta.Sketches;
import com.yahoo.sketches.theta.UpdateSketch;
import io.druid.data.input.MapBasedRow;
import io.druid.query.aggregation.AggregatorFactory;
import io.druid.query.aggregation.CountAggregatorFactory;
import io.druid.query.groupby.epinephelinae.BufferGrouper;
import io.druid.query.groupby.epinephelinae.Grouper;
import io.druid.query.groupby.epinephelinae.GrouperTestUtil;
import io.druid.query.groupby.epinephelinae.TestColumnSelectorFactory;
import org.junit.Assert;
import org.junit.Test;

import java.nio.ByteBuffer;

public class BufferGrouperUsingSketchMergeAggregatorFactoryTest
{
private static BufferGrouper<Integer> makeGrouper(
TestColumnSelectorFactory columnSelectorFactory,
int bufferSize,
int initialBuckets
)
{
final BufferGrouper<Integer> grouper = new BufferGrouper<>(
Suppliers.ofInstance(ByteBuffer.allocate(bufferSize)),
GrouperTestUtil.intKeySerde(),
columnSelectorFactory,
new AggregatorFactory[]{
new SketchMergeAggregatorFactory("sketch", "sketch", 16, false, true, 2),
new CountAggregatorFactory("count")
},
Integer.MAX_VALUE,
0.75f,
initialBuckets
);
grouper.init();
return grouper;
}

@Test
public void testGrowingBufferGrouper()
{
final TestColumnSelectorFactory columnSelectorFactory = GrouperTestUtil.newColumnSelectorFactory();
final Grouper<Integer> grouper = makeGrouper(columnSelectorFactory, 100000, 2);
try {
final int expectedMaxSize = 5;

SketchHolder sketchHolder = SketchHolder.of(Sketches.updateSketchBuilder().build(16));
UpdateSketch updateSketch = (UpdateSketch) sketchHolder.getSketch();
updateSketch.update(1);

columnSelectorFactory.setRow(new MapBasedRow(0, ImmutableMap.<String, Object>of("sketch", sketchHolder)));

for (int i = 0; i < expectedMaxSize; i++) {
Assert.assertTrue(String.valueOf(i), grouper.aggregate(i));
}

updateSketch.update(3);
columnSelectorFactory.setRow(new MapBasedRow(0, ImmutableMap.<String, Object>of("sketch", sketchHolder)));

for (int i = 0; i < expectedMaxSize; i++) {
Assert.assertTrue(String.valueOf(i), grouper.aggregate(i));
}

Object[] holders = Lists.newArrayList(grouper.iterator(true)).get(0).getValues();

Assert.assertEquals(2.0d, ((SketchHolder) holders[0]).getEstimate(), 0);
}
finally {
grouper.close();
}
}

}
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@
import com.yahoo.sketches.theta.Sketch;
import com.yahoo.sketches.theta.Sketches;
import com.yahoo.sketches.theta.Union;
import com.yahoo.sketches.theta.UpdateSketch;
import io.druid.data.input.MapBasedRow;
import io.druid.data.input.Row;
import io.druid.java.util.common.granularity.Granularities;
Expand All @@ -39,6 +40,8 @@
import io.druid.query.aggregation.post.FieldAccessPostAggregator;
import io.druid.query.groupby.GroupByQueryConfig;
import io.druid.query.groupby.GroupByQueryRunnerTest;
import io.druid.query.groupby.epinephelinae.GrouperTestUtil;
import io.druid.query.groupby.epinephelinae.TestColumnSelectorFactory;
import org.joda.time.DateTime;
import org.junit.Assert;
import org.junit.Rule;
Expand Down Expand Up @@ -389,6 +392,23 @@ public void testSketchAggregatorFactoryComparator()
Assert.assertEquals(1, comparator.compare(SketchHolder.of(union2), SketchHolder.of(sketch1)));
}

@Test
public void testRelocation()
{
final TestColumnSelectorFactory columnSelectorFactory = GrouperTestUtil.newColumnSelectorFactory();
SketchHolder sketchHolder = SketchHolder.of(Sketches.updateSketchBuilder().build(16));
UpdateSketch updateSketch = (UpdateSketch) sketchHolder.getSketch();
updateSketch.update(1);

columnSelectorFactory.setRow(new MapBasedRow(0, ImmutableMap.<String, Object>of("sketch", sketchHolder)));
SketchHolder[] holders = helper.runRelocateVerificationTest(
new SketchMergeAggregatorFactory("sketch", "sketch", 16, false, true, 2),
columnSelectorFactory,
SketchHolder.class
);
Assert.assertEquals(holders[0].getEstimate(), holders[1].getEstimate(), 0);
}

private void assertPostAggregatorSerde(PostAggregator agg) throws Exception
{
Assert.assertEquals(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -22,16 +22,21 @@
import com.google.common.collect.ImmutableMap;
import com.google.common.collect.Lists;
import com.google.common.io.Files;
import com.yahoo.sketches.theta.Sketches;
import com.yahoo.sketches.theta.UpdateSketch;
import io.druid.data.input.MapBasedRow;
import io.druid.java.util.common.granularity.Granularities;
import io.druid.java.util.common.guava.Sequence;
import io.druid.java.util.common.guava.Sequences;
import io.druid.query.aggregation.AggregationTestHelper;
import io.druid.query.aggregation.AggregatorFactory;
import io.druid.query.aggregation.PostAggregator;
import io.druid.query.aggregation.datasketches.theta.SketchHolder;
import io.druid.query.aggregation.post.FieldAccessPostAggregator;
import io.druid.query.groupby.GroupByQueryConfig;
import io.druid.query.groupby.GroupByQueryRunnerTest;
import io.druid.query.groupby.epinephelinae.GrouperTestUtil;
import io.druid.query.groupby.epinephelinae.TestColumnSelectorFactory;
import org.joda.time.DateTime;
import org.junit.Assert;
import org.junit.Rule;
Expand Down Expand Up @@ -194,6 +199,23 @@ public void testSketchSetPostAggregatorSerde() throws Exception
);
}

@Test
public void testRelocation()
{
final TestColumnSelectorFactory columnSelectorFactory = GrouperTestUtil.newColumnSelectorFactory();
SketchHolder sketchHolder = SketchHolder.of(Sketches.updateSketchBuilder().build(16));
UpdateSketch updateSketch = (UpdateSketch) sketchHolder.getSketch();
updateSketch.update(1);

columnSelectorFactory.setRow(new MapBasedRow(0, ImmutableMap.<String, Object>of("sketch", sketchHolder)));
SketchHolder[] holders = helper.runRelocateVerificationTest(
new OldSketchMergeAggregatorFactory("sketch", "sketch", 16, false),
columnSelectorFactory,
SketchHolder.class
);
Assert.assertEquals(holders[0].getEstimate(), holders[1].getEstimate(), 0);
}

private void assertPostAggregatorSerde(PostAggregator agg) throws Exception
{
Assert.assertEquals(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -126,4 +126,28 @@ public interface BufferAggregator extends HotLoopCallee
default void inspectRuntimeShape(RuntimeShapeInspector inspector)
{
}

/*
* Relocates any cached objects.
* If underlying ByteBuffer used for aggregation buffer relocates to a new ByteBuffer, positional caches(if any)
* built on top of old ByteBuffer can not be used for further {@link BufferAggregator#aggregate(ByteBuffer, int)}
* calls. This method tells the BufferAggregator that the cached objects at a certain location has been relocated to
* a different location.
*
* Only used if there is any positional caches/objects in the BufferAggregator implementation.
*
* If relocate happens to be across multiple new ByteBuffers (say n ByteBuffers), this method should be called
* multiple times(n times) given all the new positions/old positions should exist in newBuffer/OldBuffer.
*
* <b>Implementations must not change the position, limit or mark of the given buffer</b>
*
* @param oldPosition old position of a cached object before aggregation buffer relocates to a new ByteBuffer.
* @param newPosition new position of a cached object after aggregation buffer relocates to a new ByteBuffer.
* @param oldBuffer old aggregation buffer.
* @param newBuffer new aggregation buffer.
*/
default void relocate(int oldPosition, int newPosition, ByteBuffer oldBuffer, ByteBuffer newBuffer)
{
}

}
Original file line number Diff line number Diff line change
Expand Up @@ -430,8 +430,9 @@ private void growIfPossible()

for (int oldBucket = 0; oldBucket < buckets; oldBucket++) {
if (isUsed(oldBucket)) {
int oldPosition = oldBucket * bucketSize;
entryBuffer.limit((oldBucket + 1) * bucketSize);
entryBuffer.position(oldBucket * bucketSize);
entryBuffer.position(oldPosition);
keyBuffer.limit(entryBuffer.position() + HASH_SIZE + keySize);
keyBuffer.position(entryBuffer.position() + HASH_SIZE);

Expand All @@ -442,9 +443,19 @@ private void growIfPossible()
throw new ISE("WTF?! Couldn't find a bucket while resizing?!");
}

newTableBuffer.position(newBucket * bucketSize);
int newPosition = newBucket * bucketSize;
newTableBuffer.position(newPosition);
newTableBuffer.put(entryBuffer);

for (int i = 0; i < aggregators.length; i++) {
aggregators[i].relocate(
oldPosition + aggregatorOffsets[i],
newPosition + aggregatorOffsets[i],
tableBuffer,
newTableBuffer
);
}

buffer.putInt(tableArenaSize + newSize * Ints.BYTES, newBucket * bucketSize);
newSize++;
}
Expand Down
Loading

0 comments on commit ff7f90b

Please sign in to comment.