Skip to content

Commit

Permalink
[FLINK-5265] Introduce state handle replication mode for CheckpointCo…
Browse files Browse the repository at this point in the history
…ordinator
  • Loading branch information
StefanRRichter authored and aljoscha committed Jan 13, 2017
1 parent d8d9d76 commit 29fbc49
Show file tree
Hide file tree
Showing 18 changed files with 787 additions and 223 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -30,8 +30,22 @@
public interface OperatorStateStore {

/**
* Creates a state descriptor of the given name that uses Java serialization to persist the
* state.
* Creates (or restores) a list state. Each state is registered under a unique name.
* The provided serializer is used to de/serialize the state in case of checkpointing (snapshot/restore).
*
* The items in the list are repartitionable by the system in case of changed operator parallelism.
*
* @param stateDescriptor The descriptor for this state, providing a name and serializer.
* @param <S> The generic type of the state
*
* @return A list for all state partitions.
* @throws Exception
*/
<S> ListState<S> getOperatorState(ListStateDescriptor<S> stateDescriptor) throws Exception;

/**
* Creates a state of the given name that uses Java serialization to persist the state. The items in the list
* are repartitionable by the system in case of changed operator parallelism.
*
* <p>This is a simple convenience method. For more flexibility on how state serialization
* should happen, use the {@link #getOperatorState(ListStateDescriptor)} method.
Expand All @@ -46,13 +60,28 @@ public interface OperatorStateStore {
* Creates (or restores) a list state. Each state is registered under a unique name.
* The provided serializer is used to de/serialize the state in case of checkpointing (snapshot/restore).
*
* On restore, all items in the list are broadcasted to all parallel operator instances.
*
* @param stateDescriptor The descriptor for this state, providing a name and serializer.
* @param <S> The generic type of the state
*
*
* @return A list for all state partitions.
* @throws Exception
*/
<S> ListState<S> getOperatorState(ListStateDescriptor<S> stateDescriptor) throws Exception;
<S> ListState<S> getBroadcastOperatorState(ListStateDescriptor<S> stateDescriptor) throws Exception;

/**
* Creates a state of the given name that uses Java serialization to persist the state. On restore, all items
* in the list are broadcasted to all parallel operator instances.
*
* <p>This is a simple convenience method. For more flexibility on how state serialization
* should happen, use the {@link #getBroadcastOperatorState(ListStateDescriptor)} method.
*
* @param stateName The name of state to create
* @return A list state using Java serialization to serialize state objects.
* @throws Exception
*/
<T extends Serializable> ListState<T> getBroadcastSerializableListState(String stateName) throws Exception;

/**
* Returns a set with the names of all currently registered states.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collection;
import java.util.EnumMap;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
Expand All @@ -47,8 +48,7 @@ public List<Collection<OperatorStateHandle>> repartitionState(
Preconditions.checkArgument(parallelism > 0);

// Reorganize: group by (State Name -> StreamStateHandle + Offsets)
Map<String, List<Tuple2<StreamStateHandle, long[]>>> nameToState =
groupByStateName(previousParallelSubtaskStates);
GroupByStateNameResults nameToStateByMode = groupByStateName(previousParallelSubtaskStates);

if (OPTIMIZE_MEMORY_USE) {
previousParallelSubtaskStates.clear(); // free for GC at to cost that old handles are no longer available
Expand All @@ -59,7 +59,7 @@ public List<Collection<OperatorStateHandle>> repartitionState(

// Do the actual repartitioning for all named states
List<Map<StreamStateHandle, OperatorStateHandle>> mergeMapList =
repartition(nameToState, parallelism);
repartition(nameToStateByMode, parallelism);

for (int i = 0; i < mergeMapList.size(); ++i) {
result.add(i, new ArrayList<>(mergeMapList.get(i).values()));
Expand All @@ -71,16 +71,33 @@ public List<Collection<OperatorStateHandle>> repartitionState(
/**
* Group by the different named states.
*/
private Map<String, List<Tuple2<StreamStateHandle, long[]>>> groupByStateName(
@SuppressWarnings("unchecked, rawtype")
private GroupByStateNameResults groupByStateName(
List<OperatorStateHandle> previousParallelSubtaskStates) {

//Reorganize: group by (State Name -> StreamStateHandle + Offsets)
Map<String, List<Tuple2<StreamStateHandle, long[]>>> nameToState = new HashMap<>();
//Reorganize: group by (State Name -> StreamStateHandle + StateMetaInfo)
EnumMap<OperatorStateHandle.Mode,
Map<String, List<Tuple2<StreamStateHandle, OperatorStateHandle.StateMetaInfo>>>> nameToStateByMode =
new EnumMap<>(OperatorStateHandle.Mode.class);

for (OperatorStateHandle.Mode mode : OperatorStateHandle.Mode.values()) {
Map<String, List<Tuple2<StreamStateHandle, OperatorStateHandle.StateMetaInfo>>> map = new HashMap<>();
nameToStateByMode.put(
mode,
new HashMap<String, List<Tuple2<StreamStateHandle, OperatorStateHandle.StateMetaInfo>>>());
}

for (OperatorStateHandle psh : previousParallelSubtaskStates) {

for (Map.Entry<String, long[]> e : psh.getStateNameToPartitionOffsets().entrySet()) {
for (Map.Entry<String, OperatorStateHandle.StateMetaInfo> e :
psh.getStateNameToPartitionOffsets().entrySet()) {
OperatorStateHandle.StateMetaInfo metaInfo = e.getValue();

Map<String, List<Tuple2<StreamStateHandle, OperatorStateHandle.StateMetaInfo>>> nameToState =
nameToStateByMode.get(metaInfo.getDistributionMode());

List<Tuple2<StreamStateHandle, long[]>> stateLocations = nameToState.get(e.getKey());
List<Tuple2<StreamStateHandle, OperatorStateHandle.StateMetaInfo>> stateLocations =
nameToState.get(e.getKey());

if (stateLocations == null) {
stateLocations = new ArrayList<>();
Expand All @@ -90,32 +107,40 @@ private Map<String, List<Tuple2<StreamStateHandle, long[]>>> groupByStateName(
stateLocations.add(new Tuple2<>(psh.getDelegateStateHandle(), e.getValue()));
}
}
return nameToState;

return new GroupByStateNameResults(nameToStateByMode);
}

/**
* Repartition all named states.
*/
private List<Map<StreamStateHandle, OperatorStateHandle>> repartition(
Map<String, List<Tuple2<StreamStateHandle, long[]>>> nameToState, int parallelism) {
GroupByStateNameResults nameToStateByMode,
int parallelism) {

// We will use this to merge w.r.t. StreamStateHandles for each parallel subtask inside the maps
List<Map<StreamStateHandle, OperatorStateHandle>> mergeMapList = new ArrayList<>(parallelism);

// Initialize
for (int i = 0; i < parallelism; ++i) {
mergeMapList.add(new HashMap<StreamStateHandle, OperatorStateHandle>());
}

int startParallelOP = 0;
// Start with the state handles we distribute round robin by splitting by offsets
Map<String, List<Tuple2<StreamStateHandle, OperatorStateHandle.StateMetaInfo>>> distributeNameToState =
nameToStateByMode.getByMode(OperatorStateHandle.Mode.SPLIT_DISTRIBUTE);

int startParallelOp = 0;
// Iterate all named states and repartition one named state at a time per iteration
for (Map.Entry<String, List<Tuple2<StreamStateHandle, long[]>>> e : nameToState.entrySet()) {
for (Map.Entry<String, List<Tuple2<StreamStateHandle, OperatorStateHandle.StateMetaInfo>>> e :
distributeNameToState.entrySet()) {

List<Tuple2<StreamStateHandle, long[]>> current = e.getValue();
List<Tuple2<StreamStateHandle, OperatorStateHandle.StateMetaInfo>> current = e.getValue();

// Determine actual number of partitions for this named state
int totalPartitions = 0;
for (Tuple2<StreamStateHandle, long[]> offsets : current) {
totalPartitions += offsets.f1.length;
for (Tuple2<StreamStateHandle, OperatorStateHandle.StateMetaInfo> offsets : current) {
totalPartitions += offsets.f1.getOffsets().length;
}

// Repartition the state across the parallel operator instances
Expand All @@ -124,12 +149,12 @@ private List<Map<StreamStateHandle, OperatorStateHandle>> repartition(
int baseFraction = totalPartitions / parallelism;
int remainder = totalPartitions % parallelism;

int newStartParallelOp = startParallelOP;
int newStartParallelOp = startParallelOp;

for (int i = 0; i < parallelism; ++i) {

// Preparation: calculate the actual index considering wrap around
int parallelOpIdx = (i + startParallelOP) % parallelism;
int parallelOpIdx = (i + startParallelOp) % parallelism;

// Now calculate the number of partitions we will assign to the parallel instance in this round ...
int numberOfPartitionsToAssign = baseFraction;
Expand All @@ -146,11 +171,14 @@ private List<Map<StreamStateHandle, OperatorStateHandle>> repartition(
}

// Now start collection the partitions for the parallel instance into this list
List<Tuple2<StreamStateHandle, long[]>> parallelOperatorState = new ArrayList<>();
List<Tuple2<StreamStateHandle, OperatorStateHandle.StateMetaInfo>> parallelOperatorState =
new ArrayList<>();

while (numberOfPartitionsToAssign > 0) {
Tuple2<StreamStateHandle, long[]> handleWithOffsets = current.get(lstIdx);
long[] offsets = handleWithOffsets.f1;
Tuple2<StreamStateHandle, OperatorStateHandle.StateMetaInfo> handleWithOffsets =
current.get(lstIdx);

long[] offsets = handleWithOffsets.f1.getOffsets();
int remaining = offsets.length - offsetIdx;
// Repartition offsets
long[] offs;
Expand All @@ -166,25 +194,74 @@ private List<Map<StreamStateHandle, OperatorStateHandle>> repartition(
++lstIdx;
}

parallelOperatorState.add(
new Tuple2<>(handleWithOffsets.f0, offs));
parallelOperatorState.add(new Tuple2<>(
handleWithOffsets.f0,
new OperatorStateHandle.StateMetaInfo(offs, OperatorStateHandle.Mode.SPLIT_DISTRIBUTE)));

numberOfPartitionsToAssign -= remaining;

// As a last step we merge partitions that use the same StreamStateHandle in a single
// OperatorStateHandle
Map<StreamStateHandle, OperatorStateHandle> mergeMap = mergeMapList.get(parallelOpIdx);
OperatorStateHandle psh = mergeMap.get(handleWithOffsets.f0);
if (psh == null) {
psh = new OperatorStateHandle(new HashMap<String, long[]>(), handleWithOffsets.f0);
mergeMap.put(handleWithOffsets.f0, psh);
OperatorStateHandle operatorStateHandle = mergeMap.get(handleWithOffsets.f0);
if (operatorStateHandle == null) {
operatorStateHandle = new OperatorStateHandle(
new HashMap<String, OperatorStateHandle.StateMetaInfo>(),
handleWithOffsets.f0);

mergeMap.put(handleWithOffsets.f0, operatorStateHandle);
}
psh.getStateNameToPartitionOffsets().put(e.getKey(), offs);
operatorStateHandle.getStateNameToPartitionOffsets().put(
e.getKey(),
new OperatorStateHandle.StateMetaInfo(offs, OperatorStateHandle.Mode.SPLIT_DISTRIBUTE));
}
}
startParallelOP = newStartParallelOp;
startParallelOp = newStartParallelOp;
e.setValue(null);
}

// Now we also add the state handles marked for broadcast to all parallel instances
Map<String, List<Tuple2<StreamStateHandle, OperatorStateHandle.StateMetaInfo>>> broadcastNameToState =
nameToStateByMode.getByMode(OperatorStateHandle.Mode.BROADCAST);

for (int i = 0; i < parallelism; ++i) {

Map<StreamStateHandle, OperatorStateHandle> mergeMap = mergeMapList.get(i);

for (Map.Entry<String, List<Tuple2<StreamStateHandle, OperatorStateHandle.StateMetaInfo>>> e :
broadcastNameToState.entrySet()) {

List<Tuple2<StreamStateHandle, OperatorStateHandle.StateMetaInfo>> current = e.getValue();

for (Tuple2<StreamStateHandle, OperatorStateHandle.StateMetaInfo> handleWithMetaInfo : current) {
OperatorStateHandle operatorStateHandle = mergeMap.get(handleWithMetaInfo.f0);
if (operatorStateHandle == null) {
operatorStateHandle = new OperatorStateHandle(
new HashMap<String, OperatorStateHandle.StateMetaInfo>(),
handleWithMetaInfo.f0);

mergeMap.put(handleWithMetaInfo.f0, operatorStateHandle);
}
operatorStateHandle.getStateNameToPartitionOffsets().put(e.getKey(), handleWithMetaInfo.f1);
}
}
}
return mergeMapList;
}

private static final class GroupByStateNameResults {
private final EnumMap<OperatorStateHandle.Mode,
Map<String, List<Tuple2<StreamStateHandle, OperatorStateHandle.StateMetaInfo>>>> byMode;

public GroupByStateNameResults(
EnumMap<OperatorStateHandle.Mode,
Map<String, List<Tuple2<StreamStateHandle, OperatorStateHandle.StateMetaInfo>>>> byMode) {
this.byMode = Preconditions.checkNotNull(byMode);
}

public Map<String, List<Tuple2<StreamStateHandle, OperatorStateHandle.StateMetaInfo>>> getByMode(
OperatorStateHandle.Mode mode) {
return byMode.get(mode);
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -338,9 +338,22 @@ private static List<Collection<OperatorStateHandle>> applyRepartitioner(
chainOpParallelStates,
newParallelism);
} else {

List<Collection<OperatorStateHandle>> repackStream = new ArrayList<>(newParallelism);
for (OperatorStateHandle operatorStateHandle : chainOpParallelStates) {

Map<String, OperatorStateHandle.StateMetaInfo> partitionOffsets =
operatorStateHandle.getStateNameToPartitionOffsets();

for (OperatorStateHandle.StateMetaInfo metaInfo : partitionOffsets.values()) {

// if we find any broadcast state, we cannot take the shortcut and need to go through repartitioning
if (OperatorStateHandle.Mode.BROADCAST.equals(metaInfo.getDistributionMode())) {
return opStateRepartitioner.repartitionState(
chainOpParallelStates,
newParallelism);
}
}

repackStream.add(Collections.singletonList(operatorStateHandle));
}
return repackStream;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -250,11 +250,18 @@ private static void serializeOperatorStateHandle(

if (stateHandle != null) {
dos.writeByte(PARTITIONABLE_OPERATOR_STATE_HANDLE);
Map<String, long[]> partitionOffsetsMap = stateHandle.getStateNameToPartitionOffsets();
Map<String, OperatorStateHandle.StateMetaInfo> partitionOffsetsMap =
stateHandle.getStateNameToPartitionOffsets();
dos.writeInt(partitionOffsetsMap.size());
for (Map.Entry<String, long[]> entry : partitionOffsetsMap.entrySet()) {
for (Map.Entry<String, OperatorStateHandle.StateMetaInfo> entry : partitionOffsetsMap.entrySet()) {
dos.writeUTF(entry.getKey());
long[] offsets = entry.getValue();

OperatorStateHandle.StateMetaInfo stateMetaInfo = entry.getValue();

int mode = stateMetaInfo.getDistributionMode().ordinal();
dos.writeByte(mode);

long[] offsets = stateMetaInfo.getOffsets();
dos.writeInt(offsets.length);
for (long offset : offsets) {
dos.writeLong(offset);
Expand All @@ -274,14 +281,21 @@ private static OperatorStateHandle deserializeOperatorStateHandle(
return null;
} else if (PARTITIONABLE_OPERATOR_STATE_HANDLE == type) {
int mapSize = dis.readInt();
Map<String, long[]> offsetsMap = new HashMap<>(mapSize);
Map<String, OperatorStateHandle.StateMetaInfo> offsetsMap = new HashMap<>(mapSize);
for (int i = 0; i < mapSize; ++i) {
String key = dis.readUTF();

int modeOrdinal = dis.readByte();
OperatorStateHandle.Mode mode = OperatorStateHandle.Mode.values()[modeOrdinal];

long[] offsets = new long[dis.readInt()];
for (int j = 0; j < offsets.length; ++j) {
offsets[j] = dis.readLong();
}
offsetsMap.put(key, offsets);

OperatorStateHandle.StateMetaInfo metaInfo =
new OperatorStateHandle.StateMetaInfo(offsets, mode);
offsetsMap.put(key, metaInfo);
}
StreamStateHandle stateHandle = deserializeStreamStateHandle(dis);
return new OperatorStateHandle(offsetsMap, stateHandle);
Expand Down
Loading

0 comments on commit 29fbc49

Please sign in to comment.