Skip to content

Commit

Permalink
fix SketchMergeAggregatorFactory.finalizeResults, comparator and more…
Browse files Browse the repository at this point in the history
… UTs for timeseries, topN (apache#3613)
  • Loading branch information
himanshug authored and fjy committed Oct 28, 2016
1 parent 6a845e1 commit 23a8e22
Show file tree
Hide file tree
Showing 11 changed files with 624 additions and 147 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@

import com.fasterxml.jackson.annotation.JsonProperty;
import com.google.common.base.Preconditions;
import com.google.common.collect.Ordering;
import com.google.common.primitives.Doubles;
import com.google.common.primitives.Ints;
import com.yahoo.sketches.Family;
Expand All @@ -30,7 +31,6 @@
import com.yahoo.sketches.theta.Sketch;
import com.yahoo.sketches.theta.Sketches;
import com.yahoo.sketches.theta.Union;

import io.druid.java.util.common.IAE;
import io.druid.query.aggregation.Aggregator;
import io.druid.query.aggregation.AggregatorFactory;
Expand All @@ -52,14 +52,18 @@ public abstract class SketchAggregatorFactory extends AggregatorFactory
protected final int size;
private final byte cacheId;

public static final Comparator<Sketch> COMPARATOR = new Comparator<Sketch>()
{
@Override
public int compare(Sketch o, Sketch o1)
{
return Doubles.compare(o.getEstimate(), o1.getEstimate());
}
};
public static final Comparator<Object> COMPARATOR = Ordering.from(
new Comparator()
{
@Override
public int compare(Object o1, Object o2)
{
Sketch s1 = SketchAggregatorFactory.toSketch(o1);
Sketch s2 = SketchAggregatorFactory.toSketch(o2);
return Doubles.compare(s1.getEstimate(), s2.getEstimate());
}
}
).nullsFirst();

public SketchAggregatorFactory(String name, String fieldName, Integer size, byte cacheId)
{
Expand Down Expand Up @@ -103,7 +107,7 @@ public Object deserialize(Object object)
}

@Override
public Comparator<Sketch> getComparator()
public Comparator<Object> getComparator()
{
return COMPARATOR;
}
Expand Down Expand Up @@ -191,6 +195,17 @@ public byte[] getCacheKey()
.array();
}

public final static Sketch toSketch(Object obj)
{
if (obj instanceof Sketch) {
return (Sketch) obj;
} else if (obj instanceof Union) {
return ((Union) obj).getResult(true, null);
} else {
throw new IAE("Can't convert to Sketch object [%s]", obj.getClass());
}
}

@Override
public String toString()
{
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -82,7 +82,7 @@ public int compare(Object o1, Object o2)
@Override
public Object compute(Map<String, Object> combinedAggregators)
{
Sketch sketch = SketchSetPostAggregator.toSketch(field.compute(combinedAggregators));
Sketch sketch = SketchAggregatorFactory.toSketch(field.compute(combinedAggregators));
if (errorBoundsStdDev != null) {
SketchEstimateWithErrorBounds result = new SketchEstimateWithErrorBounds(
sketch.getEstimate(),
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -123,7 +123,7 @@ public Integer getErrorBoundsStdDev()
public Object finalizeComputation(Object object)
{
if (shouldFinalize) {
Sketch sketch = (Sketch) object;
Sketch sketch = SketchAggregatorFactory.toSketch(object);
if (errorBoundsStdDev != null) {
SketchEstimateWithErrorBounds result = new SketchEstimateWithErrorBounds(
sketch.getEstimate(),
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,6 @@
import com.yahoo.sketches.theta.Sketch;
import com.yahoo.sketches.theta.Sketches;
import com.yahoo.sketches.theta.Union;

import io.druid.java.util.common.IAE;
import io.druid.segment.data.ObjectStrategy;

Expand All @@ -41,13 +40,14 @@ public class SketchObjectStrategy implements ObjectStrategy
@Override
public int compare(Object s1, Object s2)
{
if (s1 instanceof Sketch) {
if (s2 instanceof Sketch) {
return SketchAggregatorFactory.COMPARATOR.compare((Sketch) s1, (Sketch) s2);
if (s1 instanceof Sketch || s1 instanceof Union) {
if (s2 instanceof Sketch || s2 instanceof Union) {
return SketchAggregatorFactory.COMPARATOR.compare(s1, s2);
} else {
return -1;
}
}

if (s1 instanceof Memory) {
if (s2 instanceof Memory) {
Memory s1Mem = (Memory) s1;
Expand All @@ -66,6 +66,7 @@ public int compare(Object s1, Object s2)
return 1;
}
}

throw new IAE("Unknwon class[%s], toString[%s]", s1.getClass(), s1);

}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,8 @@ public static Sketch deserialize(Object serializedSketch)
return deserializeFromByteArray((byte[]) serializedSketch);
} else if (serializedSketch instanceof Sketch) {
return (Sketch) serializedSketch;
} else if (serializedSketch instanceof Union) {
return ((Union) serializedSketch).getResult(true, null);
}

throw new IllegalStateException(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -24,8 +24,6 @@
import com.google.common.collect.Sets;
import com.yahoo.sketches.Util;
import com.yahoo.sketches.theta.Sketch;
import com.yahoo.sketches.theta.Union;

import io.druid.java.util.common.IAE;
import io.druid.java.util.common.logger.Logger;
import io.druid.query.aggregation.PostAggregator;
Expand Down Expand Up @@ -75,7 +73,7 @@ public Set<String> getDependentFields()
}

@Override
public Comparator<Sketch> getComparator()
public Comparator<Object> getComparator()
{
return SketchAggregatorFactory.COMPARATOR;
}
Expand All @@ -85,23 +83,12 @@ public Object compute(final Map<String, Object> combinedAggregators)
{
Sketch[] sketches = new Sketch[fields.size()];
for (int i = 0; i < sketches.length; i++) {
sketches[i] = toSketch(fields.get(i).compute(combinedAggregators));
sketches[i] = SketchAggregatorFactory.toSketch(fields.get(i).compute(combinedAggregators));
}

return SketchOperations.sketchSetOperation(func, maxSketchSize, sketches);
}

public final static Sketch toSketch(Object obj)
{
if (obj instanceof Sketch) {
return (Sketch) obj;
} else if (obj instanceof Union) {
return ((Union) obj).getResult(true, null);
} else {
throw new IAE("Can't convert to Sketch object [%s]", obj.getClass());
}
}

@Override
@JsonProperty
public String getName()
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -21,22 +21,22 @@

import com.google.common.collect.ImmutableList;
import com.google.common.collect.ImmutableMap;
import com.google.common.collect.Iterables;
import com.google.common.collect.Lists;
import com.google.common.io.Files;
import com.yahoo.sketches.Family;
import com.yahoo.sketches.theta.SetOperation;
import com.yahoo.sketches.theta.Sketch;
import com.yahoo.sketches.theta.Sketches;
import com.yahoo.sketches.theta.Union;
import io.druid.data.input.MapBasedRow;
import io.druid.data.input.Row;
import io.druid.granularity.QueryGranularities;
import io.druid.java.util.common.guava.Sequence;
import io.druid.java.util.common.guava.Sequences;
import io.druid.query.Result;
import io.druid.query.aggregation.AggregationTestHelper;
import io.druid.query.aggregation.AggregatorFactory;
import io.druid.query.aggregation.PostAggregator;
import io.druid.query.aggregation.post.FieldAccessPostAggregator;
import io.druid.query.select.SelectResultValue;
import org.joda.time.DateTime;
import org.junit.Assert;
import org.junit.Rule;
Expand All @@ -47,6 +47,7 @@
import java.io.IOException;
import java.nio.charset.Charset;
import java.util.Arrays;
import java.util.Comparator;
import java.util.List;

/**
Expand All @@ -65,119 +66,6 @@ public SketchAggregationTest()
helper = AggregationTestHelper.createGroupByQueryAggregationTestHelper(sm.getJacksonModules(), tempFolder);
}

@Test
public void testSimpleDataIngestAndGpByQuery() throws Exception
{
Sequence<Row> seq = helper.createIndexAndRunQueryOnSegment(
new File(this.getClass().getClassLoader().getResource("simple_test_data.tsv").getFile()),
readFileFromClasspathAsString("simple_test_data_record_parser.json"),
readFileFromClasspathAsString("simple_test_data_aggregators.json"),
0,
QueryGranularities.NONE,
5,
readFileFromClasspathAsString("simple_test_data_group_by_query.json")
);

List<Row> results = Sequences.toList(seq, Lists.<Row>newArrayList());
Assert.assertEquals(5, results.size());
Assert.assertEquals(
ImmutableList.of(
new MapBasedRow(
DateTime.parse("2014-10-19T00:00:00.000Z"),
ImmutableMap
.<String, Object>builder()
.put("product", "product_3")
.put("sketch_count", 38.0)
.put("sketchEstimatePostAgg", 38.0)
.put("sketchUnionPostAggEstimate", 38.0)
.put("sketchIntersectionPostAggEstimate", 38.0)
.put("sketchAnotBPostAggEstimate", 0.0)
.put("non_existing_col_validation", 0.0)
.build()
),
new MapBasedRow(
DateTime.parse("2014-10-19T00:00:00.000Z"),
ImmutableMap
.<String, Object>builder()
.put("product", "product_1")
.put("sketch_count", 42.0)
.put("sketchEstimatePostAgg", 42.0)
.put("sketchUnionPostAggEstimate", 42.0)
.put("sketchIntersectionPostAggEstimate", 42.0)
.put("sketchAnotBPostAggEstimate", 0.0)
.put("non_existing_col_validation", 0.0)
.build()
),
new MapBasedRow(
DateTime.parse("2014-10-19T00:00:00.000Z"),
ImmutableMap
.<String, Object>builder()
.put("product", "product_2")
.put("sketch_count", 42.0)
.put("sketchEstimatePostAgg", 42.0)
.put("sketchUnionPostAggEstimate", 42.0)
.put("sketchIntersectionPostAggEstimate", 42.0)
.put("sketchAnotBPostAggEstimate", 0.0)
.put("non_existing_col_validation", 0.0)
.build()
),
new MapBasedRow(
DateTime.parse("2014-10-19T00:00:00.000Z"),
ImmutableMap
.<String, Object>builder()
.put("product", "product_4")
.put("sketch_count", 42.0)
.put("sketchEstimatePostAgg", 42.0)
.put("sketchUnionPostAggEstimate", 42.0)
.put("sketchIntersectionPostAggEstimate", 42.0)
.put("sketchAnotBPostAggEstimate", 0.0)
.put("non_existing_col_validation", 0.0)
.build()
),
new MapBasedRow(
DateTime.parse("2014-10-19T00:00:00.000Z"),
ImmutableMap
.<String, Object>builder()
.put("product", "product_5")
.put("sketch_count", 42.0)
.put("sketchEstimatePostAgg", 42.0)
.put("sketchUnionPostAggEstimate", 42.0)
.put("sketchIntersectionPostAggEstimate", 42.0)
.put("sketchAnotBPostAggEstimate", 0.0)
.put("non_existing_col_validation", 0.0)
.build()
)
),
results
);
}

@Test
public void testSimpleDataIngestAndSelectQuery() throws Exception
{
SketchModule sm = new SketchModule();
sm.configure(null);
AggregationTestHelper selectQueryAggregationTestHelper = AggregationTestHelper.createSelectQueryAggregationTestHelper(
sm.getJacksonModules(),
tempFolder
);

Sequence seq = selectQueryAggregationTestHelper.createIndexAndRunQueryOnSegment(
new File(this.getClass().getClassLoader().getResource("simple_test_data.tsv").getFile()),
readFileFromClasspathAsString("simple_test_data_record_parser.json"),
readFileFromClasspathAsString("simple_test_data_aggregators.json"),
0,
QueryGranularities.NONE,
5000,
readFileFromClasspathAsString("select_query.json")
);

Result<SelectResultValue> result = (Result<SelectResultValue>) Iterables.getOnlyElement(Sequences.toList(seq, Lists.newArrayList()));
Assert.assertEquals(new DateTime("2014-10-20T00:00:00.000Z"), result.getTimestamp());
Assert.assertEquals(100, result.getValue().getEvents().size());
Assert.assertEquals("AgMDAAAazJMCAAAAAACAPzz9j7pWTMdROWGf15uY1nI=", result.getValue().getEvents().get(0).getEvent().get("pty_country"));
}

@Test
public void testSketchDataIngestAndGpByQuery() throws Exception
{
Expand Down Expand Up @@ -453,6 +341,34 @@ public void testRetentionDataIngestAndGpByQuery() throws Exception
);
}

@Test
public void testSketchAggregatorFactoryComparator()
{
Comparator<Object> comparator = SketchAggregatorFactory.COMPARATOR;
Assert.assertEquals(0, comparator.compare(null, null));

Union union1 = (Union) SetOperation.builder().build(1<<4, Family.UNION);
union1.update("a");
union1.update("b");
Sketch sketch1 = union1.getResult();

Assert.assertEquals(-1, comparator.compare(null, sketch1));
Assert.assertEquals(1, comparator.compare(sketch1, null));

Union union2 = (Union) SetOperation.builder().build(1<<4, Family.UNION);
union2.update("a");
union2.update("b");
union2.update("c");
Sketch sketch2 = union2.getResult();

Assert.assertEquals(-1, comparator.compare(sketch1, sketch2));
Assert.assertEquals(-1, comparator.compare(sketch1, union2));
Assert.assertEquals(1, comparator.compare(sketch2, sketch1));
Assert.assertEquals(1, comparator.compare(sketch2, union1));
Assert.assertEquals(1, comparator.compare(union2, union1));
Assert.assertEquals(1, comparator.compare(union2, sketch1));
}

private void assertPostAggregatorSerde(PostAggregator agg) throws Exception
{
Assert.assertEquals(
Expand Down
Loading

0 comments on commit 23a8e22

Please sign in to comment.