Skip to content

Commit

Permalink
[FLINK-2673] [core] Add a comparator for Scala Option type
Browse files Browse the repository at this point in the history
This closes apache#2017.
  • Loading branch information
chiwanpark committed May 31, 2016
1 parent da23ee3 commit c60326f
Show file tree
Hide file tree
Showing 4 changed files with 214 additions and 5 deletions.
Original file line number Diff line number Diff line change
@@ -0,0 +1,157 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.flink.api.scala.typeutils

import org.apache.flink.annotation.Internal
import org.apache.flink.api.common.typeutils.TypeComparator
import org.apache.flink.core.memory.{DataInputView, DataOutputView, MemorySegment}

/**
* Comparator for [[Option]] values. Note that [[None]] is lesser than any [[Some]] values.
*/
@Internal
class OptionTypeComparator[A](
private val ascending: Boolean,
private val typeComparator: TypeComparator[A]
) extends TypeComparator[Option[A]] {
private var reference: Option[A] = _

override def hash(record: Option[A]) = record.hashCode()

override def compare(first: Option[A], second: Option[A]) = {
first match {
case Some(firstValue: A) =>
second match {
case Some(secondValue: A) => typeComparator.compare(firstValue, secondValue)
case None =>
if (ascending) {
1
} else {
-1
}
}
case None =>
second match {
case Some(secondValue) =>
if (ascending) {
-1
} else {
1
}
case None => 0
}
}
}

override def compareSerialized(firstSource: DataInputView, secondSource: DataInputView) = {
val firstSome = firstSource.readBoolean()
val secondSome = secondSource.readBoolean()

if (firstSome) {
if (secondSome) {
typeComparator.compareSerialized(firstSource, secondSource)
} else {
if (ascending) {
1
} else {
-1
}
}
} else {
if (secondSome) {
if (ascending) {
-1
} else {
1
}
} else {
0
}
}
}

override def extractKeys(record: AnyRef, target: Array[AnyRef], index: Int) = {
target(index) = record
1
}

override def setReference(toCompare: Option[A]) = {
reference = toCompare
}

override def equalToReference(candidate: Option[A]) = {
compare(reference, candidate) == 0
}

override def compareToReference(referencedComparator: TypeComparator[Option[A]]) = {
compare(referencedComparator.asInstanceOf[this.type].reference, reference)
}

override lazy val getFlatComparators = {
Array(this).asInstanceOf[Array[TypeComparator[_]]]
}

override def getNormalizeKeyLen = 1 + typeComparator.getNormalizeKeyLen

override def putNormalizedKey(
record: Option[A],
target: MemorySegment,
offset: Int,
numBytes: Int
) = {
if (numBytes >= 1) {
record match {
case Some(v: A) =>
target.put(offset, OptionTypeComparator.OneInByte)
typeComparator.putNormalizedKey(v, target, offset + 1, numBytes - 1)
case None =>
target.put(offset, OptionTypeComparator.ZeroInByte)
var i = 1
while (i < numBytes) {
target.put(offset + i, OptionTypeComparator.ZeroInByte)
i += 1
}
}
}
}

override def invertNormalizedKey() = !ascending

override def readWithKeyDenormalization(reuse: Option[A], source: DataInputView) = {
throw new UnsupportedOperationException
}

override def writeWithKeyNormalization(record: Option[A], target: DataOutputView) = {
throw new UnsupportedOperationException
}

override def isNormalizedKeyPrefixOnly(keyBytes: Int) = {
typeComparator.isNormalizedKeyPrefixOnly(keyBytes - 1)
}

override def supportsSerializationWithKeyNormalization() = false

override def supportsNormalizedKey() = typeComparator.supportsNormalizedKey()

override def duplicate() = new OptionTypeComparator[A](ascending, typeComparator)
}

object OptionTypeComparator {
val ZeroInByte = 0.asInstanceOf[Byte]
val OneInByte = 1.asInstanceOf[Byte]
}
Original file line number Diff line number Diff line change
Expand Up @@ -17,10 +17,10 @@
*/
package org.apache.flink.api.scala.typeutils

import org.apache.flink.annotation.{PublicEvolving, Public}
import org.apache.flink.annotation.{Public, PublicEvolving}
import org.apache.flink.api.common.ExecutionConfig
import org.apache.flink.api.common.typeinfo.TypeInformation
import org.apache.flink.api.common.typeutils.TypeSerializer
import org.apache.flink.api.common.typeinfo.{AtomicType, TypeInformation}
import org.apache.flink.api.common.typeutils.{TypeComparator, TypeSerializer}

import scala.collection.JavaConverters._

Expand All @@ -29,14 +29,14 @@ import scala.collection.JavaConverters._
*/
@Public
class OptionTypeInfo[A, T <: Option[A]](private val elemTypeInfo: TypeInformation[A])
extends TypeInformation[T] {
extends TypeInformation[T] with AtomicType[T] {

@PublicEvolving
override def isBasicType: Boolean = false
@PublicEvolving
override def isTupleType: Boolean = false
@PublicEvolving
override def isKeyType: Boolean = false
override def isKeyType: Boolean = elemTypeInfo.isKeyType
@PublicEvolving
override def getTotalFields: Int = 1
@PublicEvolving
Expand All @@ -46,6 +46,16 @@ class OptionTypeInfo[A, T <: Option[A]](private val elemTypeInfo: TypeInformatio
@PublicEvolving
override def getGenericParameters = List[TypeInformation[_]](elemTypeInfo).asJava

@PublicEvolving
override def createComparator(ascending: Boolean, executionConfig: ExecutionConfig) = {
if (isKeyType) {
val elemCompartor = elemTypeInfo.asInstanceOf[AtomicType[A]]
.createComparator(ascending, executionConfig)
new OptionTypeComparator[A](ascending, elemCompartor).asInstanceOf[TypeComparator[T]]
} else {
throw new UnsupportedOperationException("Element type that doesn't support ")
}
}

@PublicEvolving
def createSerializer(executionConfig: ExecutionConfig): TypeSerializer[T] = {
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,31 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.flink.api.scala.typeutils

import org.apache.flink.api.common.typeutils.ComparatorTestBase
import org.apache.flink.api.common.typeutils.base.{StringComparator, StringSerializer}

class OptionTypeComparatorTest extends ComparatorTestBase[Option[String]] {
override protected def createComparator(ascending: Boolean) = {
new OptionTypeComparator[String](ascending, new StringComparator(ascending))
}

override protected def createSerializer() = new OptionSerializer[String](new StringSerializer)

override protected def getSortedTestData = Array(None, Some("a"), Some("b"), Some("c"))
}
Original file line number Diff line number Diff line change
Expand Up @@ -406,4 +406,15 @@ class JoinITCase(mode: TestExecutionMode) extends MultipleProgramsTestBase(mode)
env.execute()
expected = "1,(1,1,Hi)\n2,(2,2,Hello)"
}

@Test
def testWithScalaOptionValues(): Unit = {
val env = ExecutionEnvironment.getExecutionEnvironment
val ds1 = env.fromElements(None, Some("a"), Some("b"))
val ds2 = env.fromElements(None, Some("a"))
val joinDs = ds1.join(ds2).where("_").equalTo("_")
joinDs.writeAsCsv(resultPath, writeMode = WriteMode.OVERWRITE)
env.execute()
expected = "None,None\nSome(a),Some(a)"
}
}

0 comments on commit c60326f

Please sign in to comment.