Skip to content

Commit

Permalink
[GSCOLLECT-1648] Use the Kahan summation algorithm to handle double a…
Browse files Browse the repository at this point in the history
…nd float precision in sumBy methods.

git-svn-id: svn+ssh://gscollections.svn.services.gs.com/svnroot/gscollections-svn/trunk@903 d5c9223b-1aff-41ac-aadd-f810b4a99ac4
  • Loading branch information
itohro committed Oct 2, 2015
1 parent 1934119 commit aa0f835
Show file tree
Hide file tree
Showing 7 changed files with 364 additions and 24 deletions.
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
/*
* Copyright 2014 Goldman Sachs.
* Copyright 2015 Goldman Sachs.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
Expand All @@ -26,6 +26,8 @@
import com.gs.collections.api.block.function.primitive.IntFunction;
import com.gs.collections.api.block.function.primitive.LongFunction;
import com.gs.collections.api.block.function.primitive.ShortFunction;
import com.gs.collections.api.map.primitive.MutableObjectDoubleMap;
import com.gs.collections.impl.factory.primitive.ObjectDoubleMaps;
import com.gs.collections.impl.map.mutable.primitive.ObjectDoubleHashMap;
import com.gs.collections.impl.map.mutable.primitive.ObjectLongHashMap;

Expand Down Expand Up @@ -125,6 +127,7 @@ public static <T, V> Function2<ObjectLongHashMap<V>, T, ObjectLongHashMap<V>> su
{
return new Function2<ObjectLongHashMap<V>, T, ObjectLongHashMap<V>>()
{
private static final long serialVersionUID = 1L;
public ObjectLongHashMap<V> value(ObjectLongHashMap<V> map, T each)
{
map.addToValue(groupBy.valueOf(each), function.intValueOf(each));
Expand All @@ -137,9 +140,17 @@ public static <T, V> Function2<ObjectDoubleHashMap<V>, T, ObjectDoubleHashMap<V>
{
return new Function2<ObjectDoubleHashMap<V>, T, ObjectDoubleHashMap<V>>()
{
private static final long serialVersionUID = 1L;
private final MutableObjectDoubleMap<V> compensation = ObjectDoubleMaps.mutable.of();

public ObjectDoubleHashMap<V> value(ObjectDoubleHashMap<V> map, T each)
{
map.addToValue(groupBy.valueOf(each), function.floatValueOf(each));
V groupKey = groupBy.valueOf(each);
double compensation = this.compensation.getIfAbsent(groupKey, 0.0d);
double adjustedValue = function.floatValueOf(each) - compensation;
double nextSum = map.get(groupKey) + adjustedValue;
this.compensation.put(groupKey, nextSum - map.get(groupKey) - adjustedValue);
map.put(groupKey, nextSum);
return map;
}
};
Expand All @@ -149,6 +160,7 @@ public static <T, V> Function2<ObjectLongHashMap<V>, T, ObjectLongHashMap<V>> su
{
return new Function2<ObjectLongHashMap<V>, T, ObjectLongHashMap<V>>()
{
private static final long serialVersionUID = 1L;
public ObjectLongHashMap<V> value(ObjectLongHashMap<V> map, T each)
{
map.addToValue(groupBy.valueOf(each), function.longValueOf(each));
Expand All @@ -161,9 +173,17 @@ public static <T, V> Function2<ObjectDoubleHashMap<V>, T, ObjectDoubleHashMap<V>
{
return new Function2<ObjectDoubleHashMap<V>, T, ObjectDoubleHashMap<V>>()
{
private static final long serialVersionUID = 1L;
private final MutableObjectDoubleMap<V> compensation = ObjectDoubleMaps.mutable.of();

public ObjectDoubleHashMap<V> value(ObjectDoubleHashMap<V> map, T each)
{
map.addToValue(groupBy.valueOf(each), function.doubleValueOf(each));
V groupKey = groupBy.valueOf(each);
double compensation = this.compensation.getIfAbsent(groupKey, 0.0d);
double adjustedValue = function.doubleValueOf(each) - compensation;
double nextSum = map.get(groupKey) + adjustedValue;
this.compensation.put(groupKey, nextSum - map.get(groupKey) - adjustedValue);
map.put(groupKey, nextSum);
return map;
}
};
Expand Down
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
/*
* Copyright 2014 Goldman Sachs.
* Copyright 2015 Goldman Sachs.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
Expand Down Expand Up @@ -31,33 +31,36 @@
import com.gs.collections.api.block.function.Function0;
import com.gs.collections.api.block.function.Function2;
import com.gs.collections.api.block.function.primitive.DoubleFunction;
import com.gs.collections.api.block.function.primitive.DoubleFunction0;
import com.gs.collections.api.block.function.primitive.FloatFunction;
import com.gs.collections.api.block.function.primitive.IntFunction;
import com.gs.collections.api.block.function.primitive.LongFunction;
import com.gs.collections.api.block.predicate.Predicate;
import com.gs.collections.api.block.procedure.Procedure;
import com.gs.collections.api.block.procedure.Procedure2;
import com.gs.collections.api.block.procedure.primitive.ObjectDoubleProcedure;
import com.gs.collections.api.block.procedure.primitive.ObjectIntProcedure;
import com.gs.collections.api.block.procedure.primitive.ObjectLongProcedure;
import com.gs.collections.api.list.ListIterable;
import com.gs.collections.api.map.MutableMap;
import com.gs.collections.api.map.primitive.ObjectDoubleMap;
import com.gs.collections.api.map.primitive.ObjectLongMap;
import com.gs.collections.api.multimap.MutableMultimap;
import com.gs.collections.api.tuple.primitive.DoubleDoublePair;
import com.gs.collections.impl.block.factory.Functions0;
import com.gs.collections.impl.block.procedure.MultimapPutProcedure;
import com.gs.collections.impl.block.procedure.MutatingAggregationProcedure;
import com.gs.collections.impl.block.procedure.NonMutatingAggregationProcedure;
import com.gs.collections.impl.factory.Maps;
import com.gs.collections.impl.list.fixed.ArrayAdapter;
import com.gs.collections.impl.map.mutable.ConcurrentHashMap;
import com.gs.collections.impl.map.mutable.UnifiedMap;
import com.gs.collections.impl.map.mutable.primitive.ObjectDoubleHashMap;
import com.gs.collections.impl.map.mutable.primitive.ObjectLongHashMap;
import com.gs.collections.impl.multimap.list.SynchronizedPutFastListMultimap;
import com.gs.collections.impl.tuple.primitive.PrimitiveTuples;
import com.gs.collections.impl.utility.Iterate;

import static com.gs.collections.impl.factory.Iterables.*;
import static com.gs.collections.impl.factory.Iterables.iList;

/**
* The ParallelIterate class contains several parallel algorithms that work with Collections. All of the higher
Expand Down Expand Up @@ -1377,7 +1380,7 @@ public static int getTaskRatio()

private static final class SumByDoubleProcedure<T, V> implements Procedure<T>, ProcedureFactory<SumByDoubleProcedure<T, V>>
{
private final ObjectDoubleHashMap<V> map = ObjectDoubleHashMap.newMap();
private final MutableMap<V, DoubleDoublePair> map = Maps.mutable.of();
private final Function<T, V> groupBy;
private final DoubleFunction<? super T> function;

Expand All @@ -1389,10 +1392,24 @@ private SumByDoubleProcedure(Function<T, V> groupBy, DoubleFunction<? super T> f

public void value(T each)
{
this.map.addToValue(this.groupBy.valueOf(each), this.function.doubleValueOf(each));
V groupKey = this.groupBy.valueOf(each);
DoubleDoublePair sumCompensation = this.map.getIfAbsentPut(groupKey, new Function0<DoubleDoublePair>()
{
public DoubleDoublePair value()
{
return PrimitiveTuples.pair(0.0d, 0.0d);
}
});
double sum = sumCompensation.getOne();
double compensation = sumCompensation.getTwo();
double adjustedValue = this.function.doubleValueOf(each) - compensation;
double nextSum = sum + adjustedValue;
compensation = nextSum - sum - adjustedValue;
sum = nextSum;
this.map.put(groupKey, PrimitiveTuples.pair(sum, compensation));
}

public ObjectDoubleHashMap<V> getResult()
public MutableMap<V, DoubleDoublePair> getResult()
{
return this.map;
}
Expand All @@ -1406,6 +1423,7 @@ public SumByDoubleProcedure<T, V> create()
private static final class SumByDoubleCombiner<T, V> extends AbstractProcedureCombiner<SumByDoubleProcedure<T, V>>
{
private final ObjectDoubleHashMap<V> result;
private final ObjectDoubleHashMap<V> compensation = ObjectDoubleHashMap.newMap();

private SumByDoubleCombiner(ObjectDoubleHashMap<V> result)
{
Expand All @@ -1417,15 +1435,34 @@ public void combineOne(SumByDoubleProcedure<T, V> thingToCombine)
{
if (this.result.isEmpty())
{
this.result.putAll(thingToCombine.getResult());
thingToCombine.getResult().forEachKeyValue(new Procedure2<V, DoubleDoublePair>()
{
public void value(V each, DoubleDoublePair sumCompensation)
{
SumByDoubleCombiner.this.result.put(each, sumCompensation.getOne());
SumByDoubleCombiner.this.compensation.put(each, sumCompensation.getTwo());
}
});
}
else
{
thingToCombine.getResult().forEachKeyValue(new ObjectDoubleProcedure<V>()
thingToCombine.getResult().forEachKeyValue(new Procedure2<V, DoubleDoublePair>()
{
public void value(V each, double value)
public void value(V each, DoubleDoublePair sumCompensation)
{
SumByDoubleCombiner.this.result.addToValue(each, value);
double sum = SumByDoubleCombiner.this.result.get(each);
double currentCompensation = SumByDoubleCombiner.this.compensation.getIfAbsentPut(each, new DoubleFunction0()
{
public double value()
{
return 0.0d;
}
}) + sumCompensation.getTwo();

double adjustedValue = sumCompensation.getOne() - currentCompensation;
double nextSum = sum + adjustedValue;
SumByDoubleCombiner.this.compensation.put(each, nextSum - sum - adjustedValue);
SumByDoubleCombiner.this.result.put(each, nextSum);
}
});
}
Expand All @@ -1434,7 +1471,7 @@ public void value(V each, double value)

private static final class SumByFloatProcedure<T, V> implements Procedure<T>, ProcedureFactory<SumByFloatProcedure<T, V>>
{
private final ObjectDoubleHashMap<V> map = ObjectDoubleHashMap.newMap();
private final MutableMap<V, DoubleDoublePair> map = Maps.mutable.of();
private final Function<T, V> groupBy;
private final FloatFunction<? super T> function;

Expand All @@ -1446,10 +1483,24 @@ private SumByFloatProcedure(Function<T, V> groupBy, FloatFunction<? super T> fun

public void value(T each)
{
this.map.addToValue(this.groupBy.valueOf(each), (double) this.function.floatValueOf(each));
V groupKey = this.groupBy.valueOf(each);
DoubleDoublePair sumCompensation = this.map.getIfAbsentPut(groupKey, new Function0<DoubleDoublePair>()
{
public DoubleDoublePair value()
{
return PrimitiveTuples.pair(0.0d, 0.0d);
}
});
double sum = sumCompensation.getOne();
double compensation = sumCompensation.getTwo();
double adjustedValue = this.function.floatValueOf(each) - compensation;
double nextSum = sum + adjustedValue;
compensation = nextSum - sum - adjustedValue;
sum = nextSum;
this.map.put(groupKey, PrimitiveTuples.pair(sum, compensation));
}

public ObjectDoubleHashMap<V> getResult()
public MutableMap<V, DoubleDoublePair> getResult()
{
return this.map;
}
Expand All @@ -1463,6 +1514,7 @@ public SumByFloatProcedure<T, V> create()
private static final class SumByFloatCombiner<T, V> extends AbstractProcedureCombiner<SumByFloatProcedure<T, V>>
{
private final ObjectDoubleHashMap<V> result;
private final ObjectDoubleHashMap<V> compensation = ObjectDoubleHashMap.newMap();

private SumByFloatCombiner(ObjectDoubleHashMap<V> result)
{
Expand All @@ -1474,15 +1526,34 @@ public void combineOne(SumByFloatProcedure<T, V> thingToCombine)
{
if (this.result.isEmpty())
{
this.result.putAll(thingToCombine.getResult());
thingToCombine.getResult().forEachKeyValue(new Procedure2<V, DoubleDoublePair>()
{
public void value(V each, DoubleDoublePair sumCompensation)
{
SumByFloatCombiner.this.result.put(each, sumCompensation.getOne());
SumByFloatCombiner.this.compensation.put(each, sumCompensation.getTwo());
}
});
}
else
{
thingToCombine.getResult().forEachKeyValue(new ObjectDoubleProcedure<V>()
thingToCombine.getResult().forEachKeyValue(new Procedure2<V, DoubleDoublePair>()
{
public void value(V each, double value)
public void value(V each, DoubleDoublePair sumCompensation)
{
SumByFloatCombiner.this.result.addToValue(each, value);
double sum = SumByFloatCombiner.this.result.get(each);
double currentCompensation = SumByFloatCombiner.this.compensation.getIfAbsentPut(each, new DoubleFunction0()
{
public double value()
{
return 0.0d;
}
}) + sumCompensation.getTwo();

double adjustedValue = sumCompensation.getOne() - currentCompensation;
double nextSum = sum + adjustedValue;
SumByFloatCombiner.this.compensation.put(each, nextSum - sum - adjustedValue);
SumByFloatCombiner.this.result.put(each, nextSum);
}
});
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,6 @@
import com.gs.collections.api.set.MutableSet;
import com.gs.collections.api.tuple.Twin;
import com.gs.collections.impl.block.factory.Comparators;
import com.gs.collections.impl.block.factory.HashingStrategies;
import com.gs.collections.impl.block.procedure.CountProcedure;
import com.gs.collections.impl.block.procedure.FastListCollectIfProcedure;
import com.gs.collections.impl.block.procedure.FastListCollectProcedure;
Expand Down Expand Up @@ -956,21 +955,33 @@ public static <V, T> ObjectLongMap<V> sumByLong(T[] array, int size, Function<T,
public static <V, T> ObjectDoubleMap<V> sumByFloat(T[] array, int size, Function<T, V> groupBy, FloatFunction<? super T> function)
{
ObjectDoubleHashMap<V> result = ObjectDoubleHashMap.newMap();
ObjectDoubleHashMap<V> groupKeyToCompensation = ObjectDoubleHashMap.newMap();
for (int i = 0; i < size; i++)
{
T item = array[i];
result.addToValue(groupBy.valueOf(item), function.floatValueOf(item));
V groupByKey = groupBy.valueOf(item);
double compensation = groupKeyToCompensation.getIfAbsentPut(groupByKey, 0.0d);
double adjustedValue = function.floatValueOf(item) - compensation;
double nextSum = result.get(groupByKey) + adjustedValue;
groupKeyToCompensation.put(groupByKey, nextSum - result.get(groupByKey) - adjustedValue);
result.put(groupByKey, nextSum);
}
return result;
}

public static <V, T> ObjectDoubleMap<V> sumByDouble(T[] array, int size, Function<T, V> groupBy, DoubleFunction<? super T> function)
{
ObjectDoubleHashMap<V> result = ObjectDoubleHashMap.newMap();
ObjectDoubleHashMap<V> groupKeyToCompensation = ObjectDoubleHashMap.newMap();
for (int i = 0; i < size; i++)
{
T item = array[i];
result.addToValue(groupBy.valueOf(item), function.doubleValueOf(item));
V groupByKey = groupBy.valueOf(item);
double compensation = groupKeyToCompensation.getIfAbsentPut(groupByKey, 0.0d);
double adjustedValue = function.doubleValueOf(item) - compensation;
double nextSum = result.get(groupByKey) + adjustedValue;
groupKeyToCompensation.put(groupByKey, nextSum - result.get(groupByKey) - adjustedValue);
result.put(groupByKey, nextSum);
}
return result;
}
Expand Down
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
/*
* Copyright 2011 Goldman Sachs.
* Copyright 2015 Goldman Sachs.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
Expand Down Expand Up @@ -161,4 +161,58 @@ public void unboxFloatToFloat()
+ "bmN0aW9ucyRVbmJveEZsb2F0VG9GbG9hdAAAAAAAAAABAgAAeHA=",
PrimitiveFunctions.unboxFloatToFloat());
}

@Test
public void sumByInt()
{
Verify.assertSerializedForm(
1L,
"rO0ABXNyADpjb20uZ3MuY29sbGVjdGlvbnMuaW1wbC5ibG9jay5mYWN0b3J5LlByaW1pdGl2ZUZ1\n"
+ "bmN0aW9ucyQxAAAAAAAAAAECAAJMAAx2YWwkZnVuY3Rpb250AD1MY29tL2dzL2NvbGxlY3Rpb25z\n"
+ "L2FwaS9ibG9jay9mdW5jdGlvbi9wcmltaXRpdmUvSW50RnVuY3Rpb247TAALdmFsJGdyb3VwQnl0\n"
+ "ADBMY29tL2dzL2NvbGxlY3Rpb25zL2FwaS9ibG9jay9mdW5jdGlvbi9GdW5jdGlvbjt4cHBw",
PrimitiveFunctions.sumByIntFunction(null, null));
}

@Test
public void sumByLong()
{
Verify.assertSerializedForm(
1L,
"rO0ABXNyADpjb20uZ3MuY29sbGVjdGlvbnMuaW1wbC5ibG9jay5mYWN0b3J5LlByaW1pdGl2ZUZ1\n"
+ "bmN0aW9ucyQzAAAAAAAAAAECAAJMAAx2YWwkZnVuY3Rpb250AD5MY29tL2dzL2NvbGxlY3Rpb25z\n"
+ "L2FwaS9ibG9jay9mdW5jdGlvbi9wcmltaXRpdmUvTG9uZ0Z1bmN0aW9uO0wAC3ZhbCRncm91cEJ5\n"
+ "dAAwTGNvbS9ncy9jb2xsZWN0aW9ucy9hcGkvYmxvY2svZnVuY3Rpb24vRnVuY3Rpb247eHBwcA==\n",
PrimitiveFunctions.sumByLongFunction(null, null));
}

@Test
public void sumByFloat()
{
Verify.assertSerializedForm(
1L,
"rO0ABXNyADpjb20uZ3MuY29sbGVjdGlvbnMuaW1wbC5ibG9jay5mYWN0b3J5LlByaW1pdGl2ZUZ1\n"
+ "bmN0aW9ucyQyAAAAAAAAAAECAANMAAxjb21wZW5zYXRpb250AD1MY29tL2dzL2NvbGxlY3Rpb25z\n"
+ "L2FwaS9tYXAvcHJpbWl0aXZlL011dGFibGVPYmplY3REb3VibGVNYXA7TAAMdmFsJGZ1bmN0aW9u\n"
+ "dAA/TGNvbS9ncy9jb2xsZWN0aW9ucy9hcGkvYmxvY2svZnVuY3Rpb24vcHJpbWl0aXZlL0Zsb2F0\n"
+ "RnVuY3Rpb247TAALdmFsJGdyb3VwQnl0ADBMY29tL2dzL2NvbGxlY3Rpb25zL2FwaS9ibG9jay9m\n"
+ "dW5jdGlvbi9GdW5jdGlvbjt4cHNyAEFjb20uZ3MuY29sbGVjdGlvbnMuaW1wbC5tYXAubXV0YWJs\n"
+ "ZS5wcmltaXRpdmUuT2JqZWN0RG91YmxlSGFzaE1hcAAAAAAAAAABDAAAeHB3BAAAAAB4cHA=",
PrimitiveFunctions.sumByFloatFunction(null, null));
}

@Test
public void sumByDouble()
{
Verify.assertSerializedForm(
1L,
"rO0ABXNyADpjb20uZ3MuY29sbGVjdGlvbnMuaW1wbC5ibG9jay5mYWN0b3J5LlByaW1pdGl2ZUZ1\n"
+ "bmN0aW9ucyQ0AAAAAAAAAAECAANMAAxjb21wZW5zYXRpb250AD1MY29tL2dzL2NvbGxlY3Rpb25z\n"
+ "L2FwaS9tYXAvcHJpbWl0aXZlL011dGFibGVPYmplY3REb3VibGVNYXA7TAAMdmFsJGZ1bmN0aW9u\n"
+ "dABATGNvbS9ncy9jb2xsZWN0aW9ucy9hcGkvYmxvY2svZnVuY3Rpb24vcHJpbWl0aXZlL0RvdWJs\n"
+ "ZUZ1bmN0aW9uO0wAC3ZhbCRncm91cEJ5dAAwTGNvbS9ncy9jb2xsZWN0aW9ucy9hcGkvYmxvY2sv\n"
+ "ZnVuY3Rpb24vRnVuY3Rpb247eHBzcgBBY29tLmdzLmNvbGxlY3Rpb25zLmltcGwubWFwLm11dGFi\n"
+ "bGUucHJpbWl0aXZlLk9iamVjdERvdWJsZUhhc2hNYXAAAAAAAAAAAQwAAHhwdwQAAAAAeHBw",
PrimitiveFunctions.sumByDoubleFunction(null, null));
}
}
Loading

0 comments on commit aa0f835

Please sign in to comment.