Skip to content

Commit

Permalink
[FLINK-5874] Restrict key types in the DataStream API.
Browse files Browse the repository at this point in the history
Reject a type from being a key in keyBy() if it is:
1. it is a POJO type but does not override the hashCode() and
   relies on the Object.hashCode() implementation.
2. it is an array of any type.
  • Loading branch information
kl0u committed Mar 10, 2017
1 parent 70e78a6 commit f15a7d2
Show file tree
Hide file tree
Showing 5 changed files with 327 additions and 7 deletions.
9 changes: 9 additions & 0 deletions docs/dev/datastream_api.md
Original file line number Diff line number Diff line change
Expand Up @@ -216,6 +216,15 @@ dataStream.filter(new FilterFunction<Integer>() {
dataStream.keyBy("someKey") // Key by field "someKey"
dataStream.keyBy(0) // Key by the first element of a Tuple
{% endhighlight %}
<p>
<span class="label label-danger">Attention</span>
A type <strong>cannot be a key</strong> if:
<ol>
<li> it is a POJO type but does not override the <em>hashCode()</em> method and
relies on the <em>Object.hashCode()</em> implementation.</li>
<li> it is an array of any type.</li>
</ol>
</p>
</td>
</tr>
<tr>
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -282,9 +282,9 @@ public KeyedStream<T, Tuple> keyBy(int... fields) {
}

/**
* Partitions the operator state of a {@link DataStream}using field expressions.
* Partitions the operator state of a {@link DataStream} using field expressions.
* A field expression is either the name of a public field or a getter method with parentheses
* of the {@link DataStream}S underlying type. A dot can be used to drill
* of the {@link DataStream}'s underlying type. A dot can be used to drill
* down into objects, as in {@code "field1.getInnerField2()" }.
*
* @param fields
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,18 +17,25 @@

package org.apache.flink.streaming.api.datastream;

import org.apache.commons.lang3.StringUtils;
import org.apache.flink.annotation.Internal;
import org.apache.flink.annotation.Public;
import org.apache.flink.annotation.PublicEvolving;
import org.apache.flink.api.common.InvalidProgramException;
import org.apache.flink.api.common.functions.FlatMapFunction;
import org.apache.flink.api.common.functions.FoldFunction;
import org.apache.flink.api.common.functions.ReduceFunction;
import org.apache.flink.api.common.state.FoldingStateDescriptor;
import org.apache.flink.api.common.state.ReducingStateDescriptor;
import org.apache.flink.api.common.state.ValueStateDescriptor;
import org.apache.flink.api.common.typeinfo.BasicArrayTypeInfo;
import org.apache.flink.api.common.typeinfo.PrimitiveArrayTypeInfo;
import org.apache.flink.api.common.typeinfo.TypeInformation;
import org.apache.flink.api.java.Utils;
import org.apache.flink.api.java.functions.KeySelector;
import org.apache.flink.api.java.typeutils.ObjectArrayTypeInfo;
import org.apache.flink.api.java.typeutils.PojoTypeInfo;
import org.apache.flink.api.java.typeutils.TupleTypeInfoBase;
import org.apache.flink.api.java.typeutils.TypeExtractor;
import org.apache.flink.streaming.api.TimeCharacteristic;
import org.apache.flink.streaming.api.functions.ProcessFunction;
Expand Down Expand Up @@ -61,6 +68,9 @@
import org.apache.flink.streaming.runtime.partitioner.KeyGroupStreamPartitioner;
import org.apache.flink.streaming.runtime.partitioner.StreamPartitioner;

import java.util.ArrayList;
import java.util.List;
import java.util.Stack;
import java.util.UUID;

/**
Expand Down Expand Up @@ -114,9 +124,72 @@ public KeyedStream(DataStream<T> dataStream, KeySelector<T, KEY> keySelector, Ty
dataStream.getTransformation(),
new KeyGroupStreamPartitioner<>(keySelector, StreamGraphGenerator.DEFAULT_LOWER_BOUND_MAX_PARALLELISM)));
this.keySelector = keySelector;
this.keyType = keyType;
this.keyType = validateKeyType(keyType);
}


/**
* Validates that a given type of element (as encoded by the provided {@link TypeInformation}) can be
* used as a key in the {@code DataStream.keyBy()} operation. This is done by searching depth-first the
* key type and checking if each of the composite types satisfies the required conditions
* (see {@link #validateKeyTypeIsHashable(TypeInformation)}).
*
* @param keyType The {@link TypeInformation} of the key.
*/
private TypeInformation<KEY> validateKeyType(TypeInformation<KEY> keyType) {
Stack<TypeInformation<?>> stack = new Stack<>();
stack.push(keyType);

List<TypeInformation<?>> unsupportedTypes = new ArrayList<>();

while (!stack.isEmpty()) {
TypeInformation<?> typeInfo = stack.pop();

if (!validateKeyTypeIsHashable(typeInfo)) {
unsupportedTypes.add(typeInfo);
}

if (typeInfo instanceof TupleTypeInfoBase) {
for (int i = 0; i < typeInfo.getArity(); i++) {
stack.push(((TupleTypeInfoBase) typeInfo).getTypeAt(i));
}
}
}

if (!unsupportedTypes.isEmpty()) {
throw new InvalidProgramException("Type " + keyType + " cannot be used as key. Contained " +
"UNSUPPORTED key types: " + StringUtils.join(unsupportedTypes, ", ") + ". Look " +
"at the keyBy() documentation for the conditions a type has to satisfy in order to be " +
"eligible for a key.");
}

return keyType;
}

/**
* Validates that a given type of element (as encoded by the provided {@link TypeInformation}) can be
* used as a key in the {@code DataStream.keyBy()} operation.
*
* @param type The {@link TypeInformation} of the type to check.
* @return {@code false} if:
* <ol>
* <li>it is a POJO type but does not override the {@link #hashCode()} method and relies on
* the {@link Object#hashCode()} implementation.</li>
* <li>it is an array of any type (see {@link PrimitiveArrayTypeInfo}, {@link BasicArrayTypeInfo},
* {@link ObjectArrayTypeInfo}).</li>
* </ol>,
* {@code true} otherwise.
*/
private boolean validateKeyTypeIsHashable(TypeInformation<?> type) {
try {
return (type instanceof PojoTypeInfo)
? !type.getTypeClass().getMethod("hashCode").getDeclaringClass().equals(Object.class)
: !(type instanceof PrimitiveArrayTypeInfo || type instanceof BasicArrayTypeInfo || type instanceof ObjectArrayTypeInfo);
} catch (NoSuchMethodException ignored) {
// this should never happen as we are just searching for the hashCode() method.
}
return false;
}

// ------------------------------------------------------------------------
// properties
// ------------------------------------------------------------------------
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -163,7 +163,7 @@ private Collection<Integer> transform(StreamTransformation<?> transform) {

Collection<Integer> transformedIds;
if (transform instanceof OneInputTransformation<?, ?>) {
transformedIds = transformOnInputTransform((OneInputTransformation<?, ?>) transform);
transformedIds = transformOneInputTransform((OneInputTransformation<?, ?>) transform);
} else if (transform instanceof TwoInputTransformation<?, ?, ?>) {
transformedIds = transformTwoInputTransform((TwoInputTransformation<?, ?, ?>) transform);
} else if (transform instanceof SourceTransformation<?>) {
Expand Down Expand Up @@ -496,10 +496,10 @@ private <T> Collection<Integer> transformSink(SinkTransformation<T> sink) {
* Transforms a {@code OneInputTransformation}.
*
* <p>
* This recusively transforms the inputs, creates a new {@code StreamNode} in the graph and
* This recursively transforms the inputs, creates a new {@code StreamNode} in the graph and
* wired the inputs to this new node.
*/
private <IN, OUT> Collection<Integer> transformOnInputTransform(OneInputTransformation<IN, OUT> transform) {
private <IN, OUT> Collection<Integer> transformOneInputTransform(OneInputTransformation<IN, OUT> transform) {

Collection<Integer> inputIds = transform(transform.getInput());

Expand Down
Loading

0 comments on commit f15a7d2

Please sign in to comment.