Skip to content

Commit

Permalink
SPARK-1063 Add .sortBy(f) method on RDD
Browse files Browse the repository at this point in the history
This never got merged from the apache/incubator-spark repo (which is now deleted) but there had been several rounds of code review on this PR there.

I think this is ready for merging.

Author: Andrew Ash <[email protected]>

This patch had conflicts when merged, resolved by
Committer: Reynold Xin <[email protected]>

Closes apache#369 from ash211/sortby and squashes the following commits:

d09147a [Andrew Ash] Fix Ordering import
43d0a53 [Andrew Ash] Fix missing .collect()
29a54ed [Andrew Ash] Re-enable test by converting to a closure
5a95348 [Andrew Ash] Add license for RDDSuiteUtils
64ed6e3 [Andrew Ash] Remove leaked diff
d4de69a [Andrew Ash] Remove scar tissue
63638b5 [Andrew Ash] Add Python version of .sortBy()
45e0fde [Andrew Ash] Add Java version of .sortBy()
adf84c5 [Andrew Ash] Re-indent to keep line lengths under 100 chars
9d9b9d8 [Andrew Ash] Use parentheses on .collect() calls
0457b69 [Andrew Ash] Ignore failing test
99f0baf [Andrew Ash] Merge branch 'master' into sortby
222ae97 [Andrew Ash] Try moving Ordering objects out to a different class
3fd0dd3 [Andrew Ash] Add (failing) test for sortByKey with explicit Ordering
b8b5bbc [Andrew Ash] Align remove extra spaces that were used to align ='s in test code
8c53298 [Andrew Ash] Actually use ascending and numPartitions parameters
381eef2 [Andrew Ash] Correct silly typo
7db3e84 [Andrew Ash] Support ascending and numPartitions params in sortBy()
0f685fd [Andrew Ash] Merge remote-tracking branch 'origin/master' into sortby
ca4490d [Andrew Ash] Add .sortBy(f) method on RDD
  • Loading branch information
ash211 authored and rxin committed Jun 17, 2014
1 parent e243c5f commit b92d16b
Show file tree
Hide file tree
Showing 6 changed files with 159 additions and 4 deletions.
16 changes: 16 additions & 0 deletions core/src/main/scala/org/apache/spark/api/java/JavaRDD.scala
Original file line number Diff line number Diff line change
Expand Up @@ -17,10 +17,13 @@

package org.apache.spark.api.java

import java.util.Comparator

import scala.language.implicitConversions
import scala.reflect.ClassTag

import org.apache.spark._
import org.apache.spark.api.java.JavaSparkContext.fakeClassTag
import org.apache.spark.api.java.function.{Function => JFunction}
import org.apache.spark.rdd.RDD
import org.apache.spark.storage.StorageLevel
Expand Down Expand Up @@ -172,6 +175,19 @@ class JavaRDD[T](val rdd: RDD[T])(implicit val classTag: ClassTag[T])
rdd.setName(name)
this
}

/**
* Return this RDD sorted by the given key function.
*/
def sortBy[S](f: JFunction[T, S], ascending: Boolean, numPartitions: Int): JavaRDD[T] = {
import scala.collection.JavaConverters._
def fn = (x: T) => f.call(x)
import com.google.common.collect.Ordering // shadows scala.math.Ordering
implicit val ordering = Ordering.natural().asInstanceOf[Ordering[S]]
implicit val ctag: ClassTag[S] = fakeClassTag
wrapRDD(rdd.sortBy(fn, ascending, numPartitions))
}

}

object JavaRDD {
Expand Down
12 changes: 12 additions & 0 deletions core/src/main/scala/org/apache/spark/rdd/RDD.scala
Original file line number Diff line number Diff line change
Expand Up @@ -442,6 +442,18 @@ abstract class RDD[T: ClassTag](
*/
def ++(other: RDD[T]): RDD[T] = this.union(other)

/**
* Return this RDD sorted by the given key function.
*/
def sortBy[K](
f: (T) K,
ascending: Boolean = true,
numPartitions: Int = this.partitions.size)
(implicit ord: Ordering[K], ctag: ClassTag[K]): RDD[T] =
this.keyBy[K](f)
.sortByKey(ascending, numPartitions)
.values

/**
* Return the intersection of this RDD and another one. The output will not contain any duplicate
* elements, even if the input RDDs did.
Expand Down
33 changes: 33 additions & 0 deletions core/src/test/java/org/apache/spark/JavaAPISuite.java
Original file line number Diff line number Diff line change
Expand Up @@ -180,6 +180,39 @@ public void sortByKey() {
Assert.assertEquals(new Tuple2<Integer, Integer>(3, 2), sortedPairs.get(2));
}

@Test
public void sortBy() {
List<Tuple2<Integer, Integer>> pairs = new ArrayList<Tuple2<Integer, Integer>>();
pairs.add(new Tuple2<Integer, Integer>(0, 4));
pairs.add(new Tuple2<Integer, Integer>(3, 2));
pairs.add(new Tuple2<Integer, Integer>(-1, 1));

JavaRDD<Tuple2<Integer, Integer>> rdd = sc.parallelize(pairs);

// compare on first value
JavaRDD<Tuple2<Integer, Integer>> sortedRDD = rdd.sortBy(new Function<Tuple2<Integer, Integer>, Integer>() {
public Integer call(Tuple2<Integer, Integer> t) throws Exception {
return t._1();
}
}, true, 2);

Assert.assertEquals(new Tuple2<Integer, Integer>(-1, 1), sortedRDD.first());
List<Tuple2<Integer, Integer>> sortedPairs = sortedRDD.collect();
Assert.assertEquals(new Tuple2<Integer, Integer>(0, 4), sortedPairs.get(1));
Assert.assertEquals(new Tuple2<Integer, Integer>(3, 2), sortedPairs.get(2));

// compare on second value
sortedRDD = rdd.sortBy(new Function<Tuple2<Integer, Integer>, Integer>() {
public Integer call(Tuple2<Integer, Integer> t) throws Exception {
return t._2();
}
}, true, 2);
Assert.assertEquals(new Tuple2<Integer, Integer>(-1, 1), sortedRDD.first());
sortedPairs = sortedRDD.collect();
Assert.assertEquals(new Tuple2<Integer, Integer>(3, 2), sortedPairs.get(1));
Assert.assertEquals(new Tuple2<Integer, Integer>(0, 4), sortedPairs.get(2));
}

@Test
public void foreach() {
final Accumulator<Integer> accum = sc.accumulator(0);
Expand Down
59 changes: 55 additions & 4 deletions core/src/test/scala/org/apache/spark/rdd/RDDSuite.scala
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,8 @@ import org.apache.spark._
import org.apache.spark.SparkContext._
import org.apache.spark.util.Utils

import org.apache.spark.rdd.RDDSuiteUtils._

class RDDSuite extends FunSuite with SharedSparkContext {

test("basic operations") {
Expand Down Expand Up @@ -585,23 +587,72 @@ class RDDSuite extends FunSuite with SharedSparkContext {
}
}

test("sortByKey") {
val data = sc.parallelize(Seq("5|50|A","4|60|C", "6|40|B"))

val col1 = Array("4|60|C", "5|50|A", "6|40|B")
val col2 = Array("6|40|B", "5|50|A", "4|60|C")
val col3 = Array("5|50|A", "6|40|B", "4|60|C")

assert(data.sortBy(_.split("\\|")(0)).collect() === col1)
assert(data.sortBy(_.split("\\|")(1)).collect() === col2)
assert(data.sortBy(_.split("\\|")(2)).collect() === col3)
}

test("sortByKey ascending parameter") {
val data = sc.parallelize(Seq("5|50|A","4|60|C", "6|40|B"))

val asc = Array("4|60|C", "5|50|A", "6|40|B")
val desc = Array("6|40|B", "5|50|A", "4|60|C")

assert(data.sortBy(_.split("\\|")(0), true).collect() === asc)
assert(data.sortBy(_.split("\\|")(0), false).collect() === desc)
}

test("sortByKey with explicit ordering") {
val data = sc.parallelize(Seq("Bob|Smith|50",
"Jane|Smith|40",
"Thomas|Williams|30",
"Karen|Williams|60"))

val ageOrdered = Array("Thomas|Williams|30",
"Jane|Smith|40",
"Bob|Smith|50",
"Karen|Williams|60")

// last name, then first name
val nameOrdered = Array("Bob|Smith|50",
"Jane|Smith|40",
"Karen|Williams|60",
"Thomas|Williams|30")

val parse = (s: String) => {
val split = s.split("\\|")
Person(split(0), split(1), split(2).toInt)
}

import scala.reflect.classTag
assert(data.sortBy(parse, true, 2)(AgeOrdering, classTag[Person]).collect() === ageOrdered)
assert(data.sortBy(parse, true, 2)(NameOrdering, classTag[Person]).collect() === nameOrdered)
}

test("intersection") {
val all = sc.parallelize(1 to 10)
val evens = sc.parallelize(2 to 10 by 2)
val intersection = Array(2, 4, 6, 8, 10)

// intersection is commutative
assert(all.intersection(evens).collect.sorted === intersection)
assert(evens.intersection(all).collect.sorted === intersection)
assert(all.intersection(evens).collect().sorted === intersection)
assert(evens.intersection(all).collect().sorted === intersection)
}

test("intersection strips duplicates in an input") {
val a = sc.parallelize(Seq(1,2,3,3))
val b = sc.parallelize(Seq(1,1,2,3))
val intersection = Array(1,2,3)

assert(a.intersection(b).collect.sorted === intersection)
assert(b.intersection(a).collect.sorted === intersection)
assert(a.intersection(b).collect().sorted === intersection)
assert(b.intersection(a).collect().sorted === intersection)
}

test("zipWithIndex") {
Expand Down
31 changes: 31 additions & 0 deletions core/src/test/scala/org/apache/spark/rdd/RDDSuiteUtils.scala
Original file line number Diff line number Diff line change
@@ -0,0 +1,31 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package org.apache.spark.rdd

object RDDSuiteUtils {
case class Person(first: String, last: String, age: Int)

object AgeOrdering extends Ordering[Person] {
def compare(a:Person, b:Person) = a.age compare b.age
}

object NameOrdering extends Ordering[Person] {
def compare(a:Person, b:Person) =
implicitly[Ordering[Tuple2[String,String]]].compare((a.last, a.first), (b.last, b.first))
}
}
12 changes: 12 additions & 0 deletions python/pyspark/rdd.py
Original file line number Diff line number Diff line change
Expand Up @@ -549,6 +549,18 @@ def mapFunc(iterator):
.mapPartitions(mapFunc,preservesPartitioning=True)
.flatMap(lambda x: x, preservesPartitioning=True))

def sortBy(self, keyfunc, ascending=True, numPartitions=None):
"""
Sorts this RDD by the given keyfunc
>>> tmp = [('a', 1), ('b', 2), ('1', 3), ('d', 4), ('2', 5)]
>>> sc.parallelize(tmp).sortBy(lambda x: x[0]).collect()
[('1', 3), ('2', 5), ('a', 1), ('b', 2), ('d', 4)]
>>> sc.parallelize(tmp).sortBy(lambda x: x[1]).collect()
[('a', 1), ('b', 2), ('1', 3), ('d', 4), ('2', 5)]
"""
return self.keyBy(keyfunc).sortByKey(ascending, numPartitions).values()

def glom(self):
"""
Return an RDD created by coalescing all elements within each partition
Expand Down

0 comments on commit b92d16b

Please sign in to comment.