Skip to content

Commit

Permalink
[FLINK-3657] [dataSet] Change access of DataSetUtils.countElements() …
Browse files Browse the repository at this point in the history
…to 'public'

This closes apache#1829
  • Loading branch information
smarthi authored and fhueske committed Apr 15, 2016
1 parent d938c5f commit 5f993c6
Show file tree
Hide file tree
Showing 4 changed files with 49 additions and 9 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -60,15 +60,14 @@ public final class DataSetUtils {
* @param input the DataSet received as input
* @return a data set containing tuples of subtask index, number of elements mappings.
*/
private static <T> DataSet<Tuple2<Integer, Long>> countElements(DataSet<T> input) {
public static <T> DataSet<Tuple2<Integer, Long>> countElementsPerPartition(DataSet<T> input) {
return input.mapPartition(new RichMapPartitionFunction<T, Tuple2<Integer, Long>>() {
@Override
public void mapPartition(Iterable<T> values, Collector<Tuple2<Integer, Long>> out) throws Exception {
long counter = 0;
for (T value : values) {
counter++;
}

out.collect(new Tuple2<>(getRuntimeContext().getIndexOfThisSubtask(), counter));
}
});
Expand All @@ -83,7 +82,7 @@ public void mapPartition(Iterable<T> values, Collector<Tuple2<Integer, Long>> ou
*/
public static <T> DataSet<Tuple2<Long, T>> zipWithIndex(DataSet<T> input) {

DataSet<Tuple2<Integer, Long>> elementCount = countElements(input);
DataSet<Tuple2<Integer, Long>> elementCount = countElementsPerPartition(input);

return input.mapPartition(new RichMapPartitionFunction<T, Tuple2<Long, T>>() {

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,20 @@ package object utils {
@PublicEvolving
implicit class DataSetUtils[T: TypeInformation : ClassTag](val self: DataSet[T]) {

/**
* Method that goes over all the elements in each partition in order to retrieve
* the total number of elements.
*
* @return a data set of tuple2 consisting of (subtask index, number of elements mappings)
*/
def countElementsPerPartition: DataSet[(Int, Long)] = {
implicit val typeInfo = createTuple2TypeInformation[Int, Long](
BasicTypeInfo.INT_TYPE_INFO.asInstanceOf[TypeInformation[Int]],
BasicTypeInfo.LONG_TYPE_INFO.asInstanceOf[TypeInformation[Long]]
)
wrap(jutils.countElementsPerPartition(self.javaSet)).map { t => (t.f0.toInt, t.f1.toLong)}
}

/**
* Method that takes a set of subtask index, total number of elements mappings
* and assigns ids to all the elements from the input data set.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,8 +18,6 @@

package org.apache.flink.test.util;

import com.google.common.collect.Lists;
import com.google.common.collect.Sets;
import org.apache.flink.api.common.functions.MapFunction;
import org.apache.flink.api.java.DataSet;
import org.apache.flink.api.java.ExecutionEnvironment;
Expand All @@ -32,8 +30,10 @@
import org.junit.runner.RunWith;
import org.junit.runners.Parameterized;

import java.util.ArrayList;
import java.util.Collections;
import java.util.Comparator;
import java.util.HashSet;
import java.util.List;
import java.util.Set;

Expand All @@ -44,13 +44,25 @@ public DataSetUtilsITCase(TestExecutionMode mode) {
super(mode);
}

@Test
public void testCountElementsPerPartition() throws Exception {
ExecutionEnvironment env = ExecutionEnvironment.getExecutionEnvironment();
long expectedSize = 100L;
DataSet<Long> numbers = env.generateSequence(0, expectedSize - 1);

DataSet<Tuple2<Integer, Long>> ds = DataSetUtils.countElementsPerPartition(numbers);

Assert.assertEquals(env.getParallelism(), ds.count());
Assert.assertEquals(expectedSize, ds.sum(1).collect().get(0).f1.longValue());
}

@Test
public void testZipWithIndex() throws Exception {
ExecutionEnvironment env = ExecutionEnvironment.getExecutionEnvironment();
long expectedSize = 100L;
DataSet<Long> numbers = env.generateSequence(0, expectedSize - 1);

List<Tuple2<Long, Long>> result = Lists.newArrayList(DataSetUtils.zipWithIndex(numbers).collect());
List<Tuple2<Long, Long>> result = new ArrayList<>(DataSetUtils.zipWithIndex(numbers).collect());

Assert.assertEquals(expectedSize, result.size());
// sort result by created index
Expand Down Expand Up @@ -79,7 +91,7 @@ public Long map(Tuple2<Long, Long> value) throws Exception {
}
});

Set<Long> result = Sets.newHashSet(ids.collect());
Set<Long> result = new HashSet<>(ids.collect());

Assert.assertEquals(expectedSize, result.size());
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@ class DataSetUtilsITCase (
@Test
@throws(classOf[Exception])
def testZipWithIndex(): Unit = {
val env: ExecutionEnvironment = ExecutionEnvironment.getExecutionEnvironment
val env = ExecutionEnvironment.getExecutionEnvironment

val expectedSize = 100L

Expand All @@ -52,7 +52,7 @@ class DataSetUtilsITCase (
@Test
@throws(classOf[Exception])
def testZipWithUniqueId(): Unit = {
val env: ExecutionEnvironment = ExecutionEnvironment.getExecutionEnvironment
val env = ExecutionEnvironment.getExecutionEnvironment

val expectedSize = 100L

Expand All @@ -73,4 +73,19 @@ class DataSetUtilsITCase (
Assert.assertEquals(checksum.getCount, 15)
Assert.assertEquals(checksum.getChecksum, 55)
}

@Test
@throws(classOf[Exception])
def testCountElementsPerPartition(): Unit = {
val env = ExecutionEnvironment.getExecutionEnvironment

val expectedSize = 100L

val numbers = env.generateSequence(0, expectedSize - 1)

val ds = numbers.countElementsPerPartition

Assert.assertEquals(env.getParallelism, ds.collect().size)
Assert.assertEquals(expectedSize, ds.sum(1).collect().head._2)
}
}

0 comments on commit 5f993c6

Please sign in to comment.