diff --git a/extensions-core/datasketches/src/main/java/io/druid/query/aggregation/datasketches/theta/SketchBufferAggregator.java b/extensions-core/datasketches/src/main/java/io/druid/query/aggregation/datasketches/theta/SketchBufferAggregator.java index c36f908d339a..4d02bea6fa60 100644 --- a/extensions-core/datasketches/src/main/java/io/druid/query/aggregation/datasketches/theta/SketchBufferAggregator.java +++ b/extensions-core/datasketches/src/main/java/io/druid/query/aggregation/datasketches/theta/SketchBufferAggregator.java @@ -28,20 +28,19 @@ import io.druid.query.aggregation.BufferAggregator; import io.druid.query.monomorphicprocessing.RuntimeShapeInspector; import io.druid.segment.ObjectColumnSelector; +import it.unimi.dsi.fastutil.ints.Int2ObjectMap; +import it.unimi.dsi.fastutil.ints.Int2ObjectOpenHashMap; import java.nio.ByteBuffer; -import java.util.HashMap; -import java.util.Map; +import java.util.IdentityHashMap; public class SketchBufferAggregator implements BufferAggregator { private final ObjectColumnSelector selector; private final int size; private final int maxIntermediateSize; - - private NativeMemory nm; - - private final Map unions = new HashMap<>(); //position in BB -> Union Object + private final IdentityHashMap> unions = new IdentityHashMap<>(); + private final IdentityHashMap nmCache = new IdentityHashMap<>(); public SketchBufferAggregator(ObjectColumnSelector selector, int size, int maxIntermediateSize) { @@ -53,12 +52,7 @@ public SketchBufferAggregator(ObjectColumnSelector selector, int size, int maxIn @Override public void init(ByteBuffer buf, int position) { - if (nm == null) { - nm = new NativeMemory(buf); - } - - Memory mem = new MemoryRegion(nm, position, maxIntermediateSize); - unions.put(position, (Union) SetOperation.builder().initMemory(mem).build(size, Family.UNION)); + createNewUnion(buf, position, false); } @Override @@ -87,12 +81,27 @@ public Object get(ByteBuffer buf, int position) //Note that this is not threadsafe and I don't think it needs to be private Union getUnion(ByteBuffer buf, int position) { - Union union = unions.get(position); - if (union == null) { - Memory mem = new MemoryRegion(nm, position, maxIntermediateSize); - union = (Union) SetOperation.wrap(mem); - unions.put(position, union); + Int2ObjectMap unionMap = unions.get(buf); + Union union = unionMap != null ? unionMap.get(position) : null; + if (union != null) { + return union; + } + return createNewUnion(buf, position, true); + } + + private Union createNewUnion(ByteBuffer buf, int position, boolean isWrapped) + { + NativeMemory nm = getNativeMemory(buf); + Memory mem = new MemoryRegion(nm, position, maxIntermediateSize); + Union union = isWrapped + ? (Union) SetOperation.wrap(mem) + : (Union) SetOperation.builder().initMemory(mem).build(size, Family.UNION); + Int2ObjectMap unionMap = unions.get(buf); + if (unionMap == null) { + unionMap = new Int2ObjectOpenHashMap<>(); + unions.put(buf, unionMap); } + unionMap.put(position, union); return union; } @@ -119,4 +128,29 @@ public void inspectRuntimeShape(RuntimeShapeInspector inspector) { inspector.visit("selector", selector); } + + @Override + public void relocate(int oldPosition, int newPosition, ByteBuffer oldBuffer, ByteBuffer newBuffer) + { + createNewUnion(newBuffer, newPosition, true); + Int2ObjectMap unionMap = unions.get(oldBuffer); + if (unionMap != null) { + unionMap.remove(oldPosition); + if (unionMap.isEmpty()) { + unions.remove(oldBuffer); + nmCache.remove(oldBuffer); + } + } + } + + private NativeMemory getNativeMemory(ByteBuffer buffer) + { + NativeMemory nm = nmCache.get(buffer); + if (nm == null) { + nm = new NativeMemory(buffer); + nmCache.put(buffer, nm); + } + return nm; + } + } diff --git a/extensions-core/datasketches/src/main/java/io/druid/query/aggregation/datasketches/theta/SketchHolder.java b/extensions-core/datasketches/src/main/java/io/druid/query/aggregation/datasketches/theta/SketchHolder.java index 54fa865d7970..b888335764fd 100644 --- a/extensions-core/datasketches/src/main/java/io/druid/query/aggregation/datasketches/theta/SketchHolder.java +++ b/extensions-core/datasketches/src/main/java/io/druid/query/aggregation/datasketches/theta/SketchHolder.java @@ -290,4 +290,16 @@ public static SketchHolder sketchSetOperation(Func func, int sketchSize, Object. throw new IllegalArgumentException("Unknown sketch operation " + func); } } + + @Override + public boolean equals(Object o) + { + if (this == o) { + return true; + } + if (o == null || getClass() != o.getClass()) { + return false; + } + return this.getSketch().equals(((SketchHolder) o).getSketch()); + } } diff --git a/extensions-core/datasketches/src/test/java/io/druid/query/aggregation/datasketches/theta/BufferGrouperUsingSketchMergeAggregatorFactoryTest.java b/extensions-core/datasketches/src/test/java/io/druid/query/aggregation/datasketches/theta/BufferGrouperUsingSketchMergeAggregatorFactoryTest.java new file mode 100644 index 000000000000..2d8f6e6ae7af --- /dev/null +++ b/extensions-core/datasketches/src/test/java/io/druid/query/aggregation/datasketches/theta/BufferGrouperUsingSketchMergeAggregatorFactoryTest.java @@ -0,0 +1,97 @@ +/* + * Licensed to Metamarkets Group Inc. (Metamarkets) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. Metamarkets licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package io.druid.query.aggregation.datasketches.theta; + +import com.google.common.base.Suppliers; +import com.google.common.collect.ImmutableMap; +import com.google.common.collect.Lists; +import com.yahoo.sketches.theta.Sketches; +import com.yahoo.sketches.theta.UpdateSketch; +import io.druid.data.input.MapBasedRow; +import io.druid.query.aggregation.AggregatorFactory; +import io.druid.query.aggregation.CountAggregatorFactory; +import io.druid.query.groupby.epinephelinae.BufferGrouper; +import io.druid.query.groupby.epinephelinae.Grouper; +import io.druid.query.groupby.epinephelinae.GrouperTestUtil; +import io.druid.query.groupby.epinephelinae.TestColumnSelectorFactory; +import org.junit.Assert; +import org.junit.Test; + +import java.nio.ByteBuffer; + +public class BufferGrouperUsingSketchMergeAggregatorFactoryTest +{ + private static BufferGrouper makeGrouper( + TestColumnSelectorFactory columnSelectorFactory, + int bufferSize, + int initialBuckets + ) + { + final BufferGrouper grouper = new BufferGrouper<>( + Suppliers.ofInstance(ByteBuffer.allocate(bufferSize)), + GrouperTestUtil.intKeySerde(), + columnSelectorFactory, + new AggregatorFactory[]{ + new SketchMergeAggregatorFactory("sketch", "sketch", 16, false, true, 2), + new CountAggregatorFactory("count") + }, + Integer.MAX_VALUE, + 0.75f, + initialBuckets + ); + grouper.init(); + return grouper; + } + + @Test + public void testGrowingBufferGrouper() + { + final TestColumnSelectorFactory columnSelectorFactory = GrouperTestUtil.newColumnSelectorFactory(); + final Grouper grouper = makeGrouper(columnSelectorFactory, 100000, 2); + try { + final int expectedMaxSize = 5; + + SketchHolder sketchHolder = SketchHolder.of(Sketches.updateSketchBuilder().build(16)); + UpdateSketch updateSketch = (UpdateSketch) sketchHolder.getSketch(); + updateSketch.update(1); + + columnSelectorFactory.setRow(new MapBasedRow(0, ImmutableMap.of("sketch", sketchHolder))); + + for (int i = 0; i < expectedMaxSize; i++) { + Assert.assertTrue(String.valueOf(i), grouper.aggregate(i)); + } + + updateSketch.update(3); + columnSelectorFactory.setRow(new MapBasedRow(0, ImmutableMap.of("sketch", sketchHolder))); + + for (int i = 0; i < expectedMaxSize; i++) { + Assert.assertTrue(String.valueOf(i), grouper.aggregate(i)); + } + + Object[] holders = Lists.newArrayList(grouper.iterator(true)).get(0).getValues(); + + Assert.assertEquals(2.0d, ((SketchHolder) holders[0]).getEstimate(), 0); + } + finally { + grouper.close(); + } + } + +} diff --git a/extensions-core/datasketches/src/test/java/io/druid/query/aggregation/datasketches/theta/SketchAggregationTest.java b/extensions-core/datasketches/src/test/java/io/druid/query/aggregation/datasketches/theta/SketchAggregationTest.java index e6167c151497..3dd699bec0f0 100644 --- a/extensions-core/datasketches/src/test/java/io/druid/query/aggregation/datasketches/theta/SketchAggregationTest.java +++ b/extensions-core/datasketches/src/test/java/io/druid/query/aggregation/datasketches/theta/SketchAggregationTest.java @@ -28,6 +28,7 @@ import com.yahoo.sketches.theta.Sketch; import com.yahoo.sketches.theta.Sketches; import com.yahoo.sketches.theta.Union; +import com.yahoo.sketches.theta.UpdateSketch; import io.druid.data.input.MapBasedRow; import io.druid.data.input.Row; import io.druid.java.util.common.granularity.Granularities; @@ -39,6 +40,8 @@ import io.druid.query.aggregation.post.FieldAccessPostAggregator; import io.druid.query.groupby.GroupByQueryConfig; import io.druid.query.groupby.GroupByQueryRunnerTest; +import io.druid.query.groupby.epinephelinae.GrouperTestUtil; +import io.druid.query.groupby.epinephelinae.TestColumnSelectorFactory; import org.joda.time.DateTime; import org.junit.Assert; import org.junit.Rule; @@ -389,6 +392,23 @@ public void testSketchAggregatorFactoryComparator() Assert.assertEquals(1, comparator.compare(SketchHolder.of(union2), SketchHolder.of(sketch1))); } + @Test + public void testRelocation() + { + final TestColumnSelectorFactory columnSelectorFactory = GrouperTestUtil.newColumnSelectorFactory(); + SketchHolder sketchHolder = SketchHolder.of(Sketches.updateSketchBuilder().build(16)); + UpdateSketch updateSketch = (UpdateSketch) sketchHolder.getSketch(); + updateSketch.update(1); + + columnSelectorFactory.setRow(new MapBasedRow(0, ImmutableMap.of("sketch", sketchHolder))); + SketchHolder[] holders = helper.runRelocateVerificationTest( + new SketchMergeAggregatorFactory("sketch", "sketch", 16, false, true, 2), + columnSelectorFactory, + SketchHolder.class + ); + Assert.assertEquals(holders[0].getEstimate(), holders[1].getEstimate(), 0); + } + private void assertPostAggregatorSerde(PostAggregator agg) throws Exception { Assert.assertEquals( diff --git a/extensions-core/datasketches/src/test/java/io/druid/query/aggregation/datasketches/theta/oldapi/OldApiSketchAggregationTest.java b/extensions-core/datasketches/src/test/java/io/druid/query/aggregation/datasketches/theta/oldapi/OldApiSketchAggregationTest.java index 7821876e0298..39f2812ab5c3 100644 --- a/extensions-core/datasketches/src/test/java/io/druid/query/aggregation/datasketches/theta/oldapi/OldApiSketchAggregationTest.java +++ b/extensions-core/datasketches/src/test/java/io/druid/query/aggregation/datasketches/theta/oldapi/OldApiSketchAggregationTest.java @@ -22,6 +22,8 @@ import com.google.common.collect.ImmutableMap; import com.google.common.collect.Lists; import com.google.common.io.Files; +import com.yahoo.sketches.theta.Sketches; +import com.yahoo.sketches.theta.UpdateSketch; import io.druid.data.input.MapBasedRow; import io.druid.java.util.common.granularity.Granularities; import io.druid.java.util.common.guava.Sequence; @@ -29,9 +31,12 @@ import io.druid.query.aggregation.AggregationTestHelper; import io.druid.query.aggregation.AggregatorFactory; import io.druid.query.aggregation.PostAggregator; +import io.druid.query.aggregation.datasketches.theta.SketchHolder; import io.druid.query.aggregation.post.FieldAccessPostAggregator; import io.druid.query.groupby.GroupByQueryConfig; import io.druid.query.groupby.GroupByQueryRunnerTest; +import io.druid.query.groupby.epinephelinae.GrouperTestUtil; +import io.druid.query.groupby.epinephelinae.TestColumnSelectorFactory; import org.joda.time.DateTime; import org.junit.Assert; import org.junit.Rule; @@ -194,6 +199,23 @@ public void testSketchSetPostAggregatorSerde() throws Exception ); } + @Test + public void testRelocation() + { + final TestColumnSelectorFactory columnSelectorFactory = GrouperTestUtil.newColumnSelectorFactory(); + SketchHolder sketchHolder = SketchHolder.of(Sketches.updateSketchBuilder().build(16)); + UpdateSketch updateSketch = (UpdateSketch) sketchHolder.getSketch(); + updateSketch.update(1); + + columnSelectorFactory.setRow(new MapBasedRow(0, ImmutableMap.of("sketch", sketchHolder))); + SketchHolder[] holders = helper.runRelocateVerificationTest( + new OldSketchMergeAggregatorFactory("sketch", "sketch", 16, false), + columnSelectorFactory, + SketchHolder.class + ); + Assert.assertEquals(holders[0].getEstimate(), holders[1].getEstimate(), 0); + } + private void assertPostAggregatorSerde(PostAggregator agg) throws Exception { Assert.assertEquals( diff --git a/processing/src/main/java/io/druid/query/aggregation/BufferAggregator.java b/processing/src/main/java/io/druid/query/aggregation/BufferAggregator.java index 3b565dcce2d2..1951a6df4ae8 100644 --- a/processing/src/main/java/io/druid/query/aggregation/BufferAggregator.java +++ b/processing/src/main/java/io/druid/query/aggregation/BufferAggregator.java @@ -126,4 +126,28 @@ public interface BufferAggregator extends HotLoopCallee default void inspectRuntimeShape(RuntimeShapeInspector inspector) { } + + /* + * Relocates any cached objects. + * If underlying ByteBuffer used for aggregation buffer relocates to a new ByteBuffer, positional caches(if any) + * built on top of old ByteBuffer can not be used for further {@link BufferAggregator#aggregate(ByteBuffer, int)} + * calls. This method tells the BufferAggregator that the cached objects at a certain location has been relocated to + * a different location. + * + * Only used if there is any positional caches/objects in the BufferAggregator implementation. + * + * If relocate happens to be across multiple new ByteBuffers (say n ByteBuffers), this method should be called + * multiple times(n times) given all the new positions/old positions should exist in newBuffer/OldBuffer. + * + * Implementations must not change the position, limit or mark of the given buffer + * + * @param oldPosition old position of a cached object before aggregation buffer relocates to a new ByteBuffer. + * @param newPosition new position of a cached object after aggregation buffer relocates to a new ByteBuffer. + * @param oldBuffer old aggregation buffer. + * @param newBuffer new aggregation buffer. + */ + default void relocate(int oldPosition, int newPosition, ByteBuffer oldBuffer, ByteBuffer newBuffer) + { + } + } diff --git a/processing/src/main/java/io/druid/query/groupby/epinephelinae/BufferGrouper.java b/processing/src/main/java/io/druid/query/groupby/epinephelinae/BufferGrouper.java index 4987d34b6b14..bef3ba226826 100644 --- a/processing/src/main/java/io/druid/query/groupby/epinephelinae/BufferGrouper.java +++ b/processing/src/main/java/io/druid/query/groupby/epinephelinae/BufferGrouper.java @@ -430,8 +430,9 @@ private void growIfPossible() for (int oldBucket = 0; oldBucket < buckets; oldBucket++) { if (isUsed(oldBucket)) { + int oldPosition = oldBucket * bucketSize; entryBuffer.limit((oldBucket + 1) * bucketSize); - entryBuffer.position(oldBucket * bucketSize); + entryBuffer.position(oldPosition); keyBuffer.limit(entryBuffer.position() + HASH_SIZE + keySize); keyBuffer.position(entryBuffer.position() + HASH_SIZE); @@ -442,9 +443,19 @@ private void growIfPossible() throw new ISE("WTF?! Couldn't find a bucket while resizing?!"); } - newTableBuffer.position(newBucket * bucketSize); + int newPosition = newBucket * bucketSize; + newTableBuffer.position(newPosition); newTableBuffer.put(entryBuffer); + for (int i = 0; i < aggregators.length; i++) { + aggregators[i].relocate( + oldPosition + aggregatorOffsets[i], + newPosition + aggregatorOffsets[i], + tableBuffer, + newTableBuffer + ); + } + buffer.putInt(tableArenaSize + newSize * Ints.BYTES, newBucket * bucketSize); newSize++; } diff --git a/processing/src/test/java/io/druid/query/aggregation/AggregationTestHelper.java b/processing/src/test/java/io/druid/query/aggregation/AggregationTestHelper.java index e3eac6e43bf4..df73615e9911 100644 --- a/processing/src/test/java/io/druid/query/aggregation/AggregationTestHelper.java +++ b/processing/src/test/java/io/druid/query/aggregation/AggregationTestHelper.java @@ -67,6 +67,7 @@ import io.druid.query.topn.TopNQueryConfig; import io.druid.query.topn.TopNQueryQueryToolChest; import io.druid.query.topn.TopNQueryRunnerFactory; +import io.druid.segment.ColumnSelectorFactory; import io.druid.segment.IndexIO; import io.druid.segment.IndexMerger; import io.druid.segment.IndexSpec; @@ -84,6 +85,7 @@ import java.io.FileInputStream; import java.io.IOException; import java.io.InputStream; +import java.lang.reflect.Array; import java.nio.ByteBuffer; import java.util.ArrayList; import java.util.Iterator; @@ -591,5 +593,30 @@ public ObjectMapper getObjectMapper() { return mapper; } + + public T[] runRelocateVerificationTest( + AggregatorFactory factory, + ColumnSelectorFactory selector, + Class clazz + ) + { + T[] results = (T[]) Array.newInstance(clazz, 2); + BufferAggregator agg = factory.factorizeBuffered(selector); + ByteBuffer myBuf = ByteBuffer.allocate(10040902); + agg.init(myBuf, 0); + agg.aggregate(myBuf, 0); + results[0] = (T) agg.get(myBuf, 0); + + byte[] theBytes = new byte[factory.getMaxIntermediateSize()]; + myBuf.get(theBytes); + + ByteBuffer newBuf = ByteBuffer.allocate(941209); + newBuf.position(7574); + newBuf.put(theBytes); + newBuf.position(0); + agg.relocate(0, 7574, myBuf, newBuf); + results[1] = (T) agg.get(newBuf, 7574); + return results; + } }