Skip to content

Commit

Permalink
[MINOR] Clean up several build warnings, mostly due to internal use o…
Browse files Browse the repository at this point in the history
…f old accumulators

Another PR to clean up recent build warnings. This particularly cleans up several instances of the old accumulator API usage in tests that are straightforward to update. I think this qualifies as "minor".

Jenkins

Author: Sean Owen <[email protected]>

Closes apache#13642 from srowen/BuildWarnings.

(cherry picked from commit 6151d26)
Signed-off-by: Sean Owen <[email protected]>
  • Loading branch information
srowen committed Jun 14, 2016
1 parent e03c251 commit 2453922
Show file tree
Hide file tree
Showing 6 changed files with 31 additions and 136 deletions.
6 changes: 3 additions & 3 deletions core/pom.xml
Original file line number Diff line number Diff line change
Expand Up @@ -356,12 +356,12 @@
<phase>generate-resources</phase>
<configuration>
<!-- Execute the shell script to generate the spark build information. -->
<tasks>
<target>
<exec executable="${project.basedir}/../build/spark-build-info">
<arg value="${project.build.directory}/extra-resources"/>
<arg value="${pom.version}"/>
<arg value="${project.version}"/>
</exec>
</tasks>
</target>
</configuration>
<goals>
<goal>run</goal>
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1593,13 +1593,11 @@ class DAGSchedulerSuite extends SparkFunSuite with LocalSparkContext with Timeou
}

test("misbehaved accumulator should not crash DAGScheduler and SparkContext") {
val acc = new Accumulator[Int](0, new AccumulatorParam[Int] {
override def addAccumulator(t1: Int, t2: Int): Int = t1 + t2
override def zero(initialValue: Int): Int = 0
override def addInPlace(r1: Int, r2: Int): Int = {
throw new DAGSchedulerSuiteDummyException
}
})
val acc = new LongAccumulator {
override def add(v: java.lang.Long): Unit = throw new DAGSchedulerSuiteDummyException
override def add(v: Long): Unit = throw new DAGSchedulerSuiteDummyException
}
sc.register(acc)

// Run this on executors
sc.parallelize(1 to 10, 2).foreach { item => acc.add(1) }
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -146,14 +146,13 @@ class TaskContextSuite extends SparkFunSuite with BeforeAndAfter with LocalSpark
test("accumulators are updated on exception failures") {
// This means use 1 core and 4 max task failures
sc = new SparkContext("local[1,4]", "test")
val param = AccumulatorParam.LongAccumulatorParam
// Create 2 accumulators, one that counts failed values and another that doesn't
val acc1 = new Accumulator(0L, param, Some("x"), countFailedValues = true)
val acc2 = new Accumulator(0L, param, Some("y"), countFailedValues = false)
val acc1 = AccumulatorSuite.createLongAccum("x", true)
val acc2 = AccumulatorSuite.createLongAccum("y", false)
// Fail first 3 attempts of every task. This means each task should be run 4 times.
sc.parallelize(1 to 10, 10).map { i =>
acc1 += 1
acc2 += 1
acc1.add(1)
acc2.add(1)
if (TaskContext.get.attemptNumber() <= 2) {
throw new Exception("you did something wrong")
} else {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,6 @@ package org.apache.spark.sql.execution

import scala.collection.mutable.HashSet

import org.apache.spark.{Accumulator, AccumulatorParam}
import org.apache.spark.internal.Logging
import org.apache.spark.rdd.RDD
import org.apache.spark.sql._
Expand All @@ -28,7 +27,7 @@ import org.apache.spark.sql.catalyst.expressions.Attribute
import org.apache.spark.sql.catalyst.expressions.codegen.{CodeFormatter, CodegenContext, ExprCode}
import org.apache.spark.sql.catalyst.trees.TreeNodeRef
import org.apache.spark.sql.internal.SQLConf
import org.apache.spark.util.LongAccumulator
import org.apache.spark.util.{AccumulatorV2, LongAccumulator}

/**
* Contains methods for debugging query execution.
Expand Down Expand Up @@ -108,26 +107,27 @@ package object debug {
private[sql] case class DebugExec(child: SparkPlan) extends UnaryExecNode with CodegenSupport {
def output: Seq[Attribute] = child.output

implicit object SetAccumulatorParam extends AccumulatorParam[HashSet[String]] {
def zero(initialValue: HashSet[String]): HashSet[String] = {
initialValue.clear()
initialValue
}

def addInPlace(v1: HashSet[String], v2: HashSet[String]): HashSet[String] = {
v1 ++= v2
v1
class SetAccumulator[T] extends AccumulatorV2[T, HashSet[T]] {
private val _set = new HashSet[T]()
override def isZero: Boolean = _set.isEmpty
override def copy(): AccumulatorV2[T, HashSet[T]] = {
val newAcc = new SetAccumulator[T]()
newAcc._set ++= _set
newAcc
}
override def reset(): Unit = _set.clear()
override def add(v: T): Unit = _set += v
override def merge(other: AccumulatorV2[T, HashSet[T]]): Unit = _set ++= other.value
override def value: HashSet[T] = _set
}

/**
* A collection of metrics for each column of output.
*
* @param elementTypes the actual runtime types for the output. Useful when there are bugs
* causing the wrong data to be projected.
*/
case class ColumnMetrics(
elementTypes: Accumulator[HashSet[String]] = sparkContext.accumulator(HashSet.empty))
case class ColumnMetrics() {
val elementTypes = new SetAccumulator[String]
sparkContext.register(elementTypes)
}

val tupleCount: LongAccumulator = sparkContext.longAccumulator

Expand Down Expand Up @@ -155,7 +155,7 @@ package object debug {
while (i < numColumns) {
val value = currentRow.get(i, output(i).dataType)
if (value != null) {
columnStats(i).elementTypes += HashSet(value.getClass.getName)
columnStats(i).elementTypes.add(value.getClass.getName)
}
i += 1
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,48 +17,18 @@

package org.apache.spark.sql.execution.metric

import java.io.{ByteArrayInputStream, ByteArrayOutputStream}

import scala.collection.mutable

import org.apache.xbean.asm5._
import org.apache.xbean.asm5.Opcodes._

import org.apache.spark.SparkFunSuite
import org.apache.spark.sql._
import org.apache.spark.sql.execution.SparkPlanInfo
import org.apache.spark.sql.execution.ui.SparkPlanGraph
import org.apache.spark.sql.functions._
import org.apache.spark.sql.internal.SQLConf
import org.apache.spark.sql.test.SharedSQLContext
import org.apache.spark.util.{AccumulatorContext, JsonProtocol, Utils}

import org.apache.spark.util.{AccumulatorContext, JsonProtocol}

class SQLMetricsSuite extends SparkFunSuite with SharedSQLContext {
import testImplicits._

test("SQLMetric should not box Long") {
val l = SQLMetrics.createMetric(sparkContext, "long")
val f = () => {
l += 1L
l.add(1L)
}
val cl = BoxingFinder.getClassReader(f.getClass)
val boxingFinder = new BoxingFinder()
cl.accept(boxingFinder, 0)
assert(boxingFinder.boxingInvokes.isEmpty, s"Found boxing: ${boxingFinder.boxingInvokes}")
}

test("Normal accumulator should do boxing") {
// We need this test to make sure BoxingFinder works.
val l = sparkContext.accumulator(0L)
val f = () => { l += 1L }
val cl = BoxingFinder.getClassReader(f.getClass)
val boxingFinder = new BoxingFinder()
cl.accept(boxingFinder, 0)
assert(boxingFinder.boxingInvokes.nonEmpty, "Found find boxing in this test")
}

/**
* Call `df.collect()` and verify if the collected metrics are same as "expectedMetrics".
*
Expand Down Expand Up @@ -323,76 +293,3 @@ class SQLMetricsSuite extends SparkFunSuite with SharedSQLContext {
}

}

private case class MethodIdentifier[T](cls: Class[T], name: String, desc: String)

/**
* If `method` is null, search all methods of this class recursively to find if they do some boxing.
* If `method` is specified, only search this method of the class to speed up the searching.
*
* This method will skip the methods in `visitedMethods` to avoid potential infinite cycles.
*/
private class BoxingFinder(
method: MethodIdentifier[_] = null,
val boxingInvokes: mutable.Set[String] = mutable.Set.empty,
visitedMethods: mutable.Set[MethodIdentifier[_]] = mutable.Set.empty)
extends ClassVisitor(ASM5) {

private val primitiveBoxingClassName =
Set("java/lang/Long",
"java/lang/Double",
"java/lang/Integer",
"java/lang/Float",
"java/lang/Short",
"java/lang/Character",
"java/lang/Byte",
"java/lang/Boolean")

override def visitMethod(
access: Int, name: String, desc: String, sig: String, exceptions: Array[String]):
MethodVisitor = {
if (method != null && (method.name != name || method.desc != desc)) {
// If method is specified, skip other methods.
return new MethodVisitor(ASM5) {}
}

new MethodVisitor(ASM5) {
override def visitMethodInsn(
op: Int, owner: String, name: String, desc: String, itf: Boolean) {
if (op == INVOKESPECIAL && name == "<init>" || op == INVOKESTATIC && name == "valueOf") {
if (primitiveBoxingClassName.contains(owner)) {
// Find boxing methods, e.g, new java.lang.Long(l) or java.lang.Long.valueOf(l)
boxingInvokes.add(s"$owner.$name")
}
} else {
// scalastyle:off classforname
val classOfMethodOwner = Class.forName(owner.replace('/', '.'), false,
Thread.currentThread.getContextClassLoader)
// scalastyle:on classforname
val m = MethodIdentifier(classOfMethodOwner, name, desc)
if (!visitedMethods.contains(m)) {
// Keep track of visited methods to avoid potential infinite cycles
visitedMethods += m
val cl = BoxingFinder.getClassReader(classOfMethodOwner)
visitedMethods += m
cl.accept(new BoxingFinder(m, boxingInvokes, visitedMethods), 0)
}
}
}
}
}
}

private object BoxingFinder {

def getClassReader(cls: Class[_]): ClassReader = {
val className = cls.getName.replaceFirst("^.*\\.", "") + ".class"
val resourceStream = cls.getResourceAsStream(className)
val baos = new ByteArrayOutputStream(128)
// Copy data over, before delegating to ClassReader -
// else we can run out of open file handles.
Utils.copyStream(resourceStream, baos, true)
new ClassReader(new ByteArrayInputStream(baos.toByteArray))
}

}
Original file line number Diff line number Diff line change
Expand Up @@ -111,6 +111,7 @@ class YarnAllocatorSuite extends SparkFunSuite with Matchers with BeforeAndAfter
}

def createContainer(host: String): Container = {
// When YARN 2.6+ is required, avoid deprecation by using version with long second arg
val containerId = ContainerId.newInstance(appAttemptId, containerNum)
containerNum += 1
val nodeId = NodeId.newInstance(host, 1000)
Expand Down

0 comments on commit 2453922

Please sign in to comment.