Skip to content

Commit

Permalink
[FLINK-13094][state-processor-api] Add registered*TimeTimers methods …
Browse files Browse the repository at this point in the history
…to KeyedStateReaderFunction#Context for querying the registered timers for a given key

This closes apache#9094.
  • Loading branch information
sjwiesman authored and tzulitai committed Jul 22, 2019
1 parent 93f1bcc commit f2300e0
Show file tree
Hide file tree
Showing 5 changed files with 222 additions and 22 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,8 @@
import org.apache.flink.configuration.Configuration;
import org.apache.flink.util.Collector;

import java.util.Set;

/**
* A function that processes keys from a restored operator
*
Expand Down Expand Up @@ -79,6 +81,16 @@ public abstract class KeyedStateReaderFunction<K, OUT> extends AbstractRichFunct
* afterwards!
*/
public interface Context {

/**
* @return All registered event time timers for the current key.
*/
Set<Long> registeredEventTimeTimers() throws Exception;

/**
* @return All registered processing time timers for the current key.
*/
Set<Long> registeredProcessingTimeTimers() throws Exception;
}
}

Original file line number Diff line number Diff line change
Expand Up @@ -23,9 +23,13 @@
import org.apache.flink.api.common.io.DefaultInputSplitAssigner;
import org.apache.flink.api.common.io.RichInputFormat;
import org.apache.flink.api.common.io.statistics.BaseStatistics;
import org.apache.flink.api.common.state.ListState;
import org.apache.flink.api.common.state.ListStateDescriptor;
import org.apache.flink.api.common.state.StateDescriptor;
import org.apache.flink.api.common.typeinfo.TypeInformation;
import org.apache.flink.api.common.typeinfo.Types;
import org.apache.flink.api.common.typeutils.TypeSerializer;
import org.apache.flink.api.common.typeutils.base.StringSerializer;
import org.apache.flink.configuration.Configuration;
import org.apache.flink.core.fs.CloseableRegistry;
import org.apache.flink.core.io.InputSplitAssigner;
Expand All @@ -37,24 +41,34 @@
import org.apache.flink.runtime.state.KeyGroupRange;
import org.apache.flink.runtime.state.KeyedStateHandle;
import org.apache.flink.runtime.state.StateBackend;
import org.apache.flink.runtime.state.VoidNamespace;
import org.apache.flink.runtime.state.VoidNamespaceSerializer;
import org.apache.flink.state.api.functions.KeyedStateReaderFunction;
import org.apache.flink.state.api.input.splits.KeyGroupRangeInputSplit;
import org.apache.flink.state.api.runtime.NeverFireProcessingTimeService;
import org.apache.flink.state.api.runtime.SavepointEnvironment;
import org.apache.flink.state.api.runtime.SavepointRuntimeContext;
import org.apache.flink.state.api.runtime.VoidTriggerable;
import org.apache.flink.streaming.api.operators.InternalTimeServiceManager;
import org.apache.flink.streaming.api.operators.InternalTimerService;
import org.apache.flink.streaming.api.operators.KeyContext;
import org.apache.flink.streaming.api.operators.StreamOperatorStateContext;
import org.apache.flink.streaming.api.operators.StreamTaskStateInitializer;
import org.apache.flink.streaming.api.operators.StreamTaskStateInitializerImpl;
import org.apache.flink.streaming.api.operators.TimerSerializer;
import org.apache.flink.util.CollectionUtil;
import org.apache.flink.util.Preconditions;

import javax.annotation.Nonnull;

import java.io.IOException;
import java.util.Collections;
import java.util.Comparator;
import java.util.Iterator;
import java.util.List;
import java.util.Set;
import java.util.stream.Collectors;
import java.util.stream.StreamSupport;

/**
* Input format for reading partitioned state.
Expand All @@ -67,6 +81,8 @@ public class KeyedStateInputFormat<K, OUT> extends RichInputFormat<OUT, KeyGroup

private static final long serialVersionUID = 8230460226049597182L;

private static final String USER_TIMERS_NAME = "user-timers";

private final OperatorState operatorState;

private final StateBackend stateBackend;
Expand Down Expand Up @@ -167,7 +183,20 @@ public void open(KeyGroupRangeInputSplit split) throws IOException {
FunctionUtils.setFunctionRuntimeContext(userFunction, ctx);

keys = getKeyIterator(ctx);
this.ctx = new Context();

final InternalTimerService<VoidNamespace> timerService = restoreTimerService(context);
try {
this.ctx = new Context(keyedStateBackend, timerService);
} catch (Exception e) {
throw new IOException("Failed to restore timer state", e);
}
}

@SuppressWarnings("unchecked")
private InternalTimerService<VoidNamespace> restoreTimerService(StreamOperatorStateContext context) {
InternalTimeServiceManager<K> timeServiceManager = (InternalTimeServiceManager<K>) context.internalTimerServiceManager();
TimerSerializer<K, VoidNamespace> timerSerializer = new TimerSerializer<>(keySerializer, VoidNamespaceSerializer.INSTANCE);
return timeServiceManager.getInternalTimerService(USER_TIMERS_NAME, timerSerializer, VoidTriggerable.instance());
}

@SuppressWarnings("unchecked")
Expand Down Expand Up @@ -266,5 +295,64 @@ private static List<KeyGroupRange> sortedKeyGroupRanges(int minNumSplits, int ma
return keyGroups;
}

private static class Context implements KeyedStateReaderFunction.Context {}
private static class Context<K> implements KeyedStateReaderFunction.Context {

private static final String EVENT_TIMER_STATE = "event-time-timers";

private static final String PROC_TIMER_STATE = "proc-time-timers";

ListState<Long> eventTimers;

ListState<Long> procTimers;

private Context(AbstractKeyedStateBackend<K> keyedStateBackend, InternalTimerService<VoidNamespace> timerService) throws Exception {
eventTimers = keyedStateBackend.getPartitionedState(
USER_TIMERS_NAME,
StringSerializer.INSTANCE,
new ListStateDescriptor<>(EVENT_TIMER_STATE, Types.LONG)
);

timerService.forEachEventTimeTimer((namespace, timer) -> {
if (namespace.equals(VoidNamespace.INSTANCE)) {
eventTimers.add(timer);
}
});

procTimers = keyedStateBackend.getPartitionedState(
USER_TIMERS_NAME,
StringSerializer.INSTANCE,
new ListStateDescriptor<>(PROC_TIMER_STATE, Types.LONG)
);

timerService.forEachProcessingTimeTimer((namespace, timer) -> {
if (namespace.equals(VoidNamespace.INSTANCE)) {
procTimers.add(timer);
}
});
}

@Override
public Set<Long> registeredEventTimeTimers() throws Exception {
Iterable<Long> timers = eventTimers.get();
if (timers == null) {
return Collections.emptySet();
}

return StreamSupport
.stream(timers.spliterator(), false)
.collect(Collectors.toSet());
}

@Override
public Set<Long> registeredProcessingTimeTimers() throws Exception {
Iterable<Long> timers = procTimers.get();
if (timers == null) {
return Collections.emptySet();
}

return StreamSupport
.stream(timers.spliterator(), false)
.collect(Collectors.toSet());
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -45,8 +45,11 @@

import java.time.Duration;
import java.util.Arrays;
import java.util.Comparator;
import java.util.Collections;
import java.util.HashSet;
import java.util.List;
import java.util.Objects;
import java.util.Set;
import java.util.concurrent.CompletableFuture;
import java.util.concurrent.TimeUnit;

Expand Down Expand Up @@ -77,7 +80,7 @@ private void runKeyedState(StateBackend backend) throws Exception {
streamEnv
.addSource(new SavepointSource())
.rebalance()
.keyBy(id -> id)
.keyBy(id -> id.key)
.process(new KeyedStatefulOperator())
.uid(uid)
.addSink(new DiscardingSink<>());
Expand All @@ -89,15 +92,13 @@ private void runKeyedState(StateBackend backend) throws Exception {
ExecutionEnvironment batchEnv = ExecutionEnvironment.getExecutionEnvironment();
ExistingSavepoint savepoint = Savepoint.load(batchEnv, path, backend);

List<Integer> results = savepoint
List<Pojo> results = savepoint
.readKeyedState(uid, new Reader())
.collect();

results.sort(Comparator.naturalOrder());
Set<Pojo> expected = SavepointSource.getElements();

List<Integer> expected = SavepointSource.getElements();

Assert.assertEquals("Unexpected results from keyed state", expected, results);
Assert.assertEquals("Unexpected results from keyed state", expected, new HashSet<>(results));
}

private String takeSavepoint(JobGraph jobGraph) throws Exception {
Expand Down Expand Up @@ -136,17 +137,20 @@ private String takeSavepoint(JobGraph jobGraph) throws Exception {
}
}

private static class SavepointSource implements SourceFunction<Integer> {
private static class SavepointSource implements SourceFunction<Pojo> {
private static volatile boolean finished;

private volatile boolean running = true;

private static final Integer[] elements = {1, 2, 3};
private static final Pojo[] elements = {
Pojo.of(1, 1),
Pojo.of(2, 2),
Pojo.of(3, 3)};

@Override
public void run(SourceContext<Integer> ctx) {
public void run(SourceContext<Pojo> ctx) {
synchronized (ctx.getCheckpointLock()) {
for (Integer element : elements) {
for (Pojo element : elements) {
ctx.collect(element);
}

Expand Down Expand Up @@ -175,12 +179,12 @@ private static boolean isFinished() {
return finished;
}

private static List<Integer> getElements() {
return Arrays.asList(elements);
private static Set<Pojo> getElements() {
return new HashSet<>(Arrays.asList(elements));
}
}

private static class KeyedStatefulOperator extends KeyedProcessFunction<Integer, Integer, Void> {
private static class KeyedStatefulOperator extends KeyedProcessFunction<Integer, Pojo, Void> {

private transient ValueState<Integer> state;

Expand All @@ -190,12 +194,15 @@ public void open(Configuration parameters) {
}

@Override
public void processElement(Integer value, Context ctx, Collector<Void> out) throws Exception {
state.update(value);
public void processElement(Pojo value, Context ctx, Collector<Void> out) throws Exception {
state.update(value.state);

value.eventTimeTimer.forEach(timer -> ctx.timerService().registerEventTimeTimer(timer));
value.processingTimeTimer.forEach(timer -> ctx.timerService().registerProcessingTimeTimer(timer));
}
}

private static class Reader extends KeyedStateReaderFunction<Integer, Integer> {
private static class Reader extends KeyedStateReaderFunction<Integer, Pojo> {

private transient ValueState<Integer> state;

Expand All @@ -205,8 +212,56 @@ public void open(Configuration parameters) {
}

@Override
public void readKey(Integer key, Context ctx, Collector<Integer> out) throws Exception {
out.collect(state.value());
public void readKey(Integer key, Context ctx, Collector<Pojo> out) throws Exception {
Pojo pojo = new Pojo();
pojo.key = key;
pojo.state = state.value();
pojo.eventTimeTimer = ctx.registeredEventTimeTimers();
pojo.processingTimeTimer = ctx.registeredProcessingTimeTimers();

out.collect(pojo);
}
}

/**
* A simple pojo type.
*/
public static class Pojo {
public static Pojo of(Integer key, Integer state) {
Pojo wrapper = new Pojo();
wrapper.key = key;
wrapper.state = state;
wrapper.eventTimeTimer = Collections.singleton(Long.MAX_VALUE - 1);
wrapper.processingTimeTimer = Collections.singleton(Long.MAX_VALUE - 2);

return wrapper;
}

Integer key;

Integer state;

Set<Long> eventTimeTimer;

Set<Long> processingTimeTimer;

@Override
public boolean equals(Object o) {
if (this == o) {
return true;
} else if (o == null || getClass() != o.getClass()) {
return false;
}
Pojo pojo = (Pojo) o;
return Objects.equals(key, pojo.key) &&
Objects.equals(state, pojo.state) &&
Objects.equals(eventTimeTimer, pojo.eventTimeTimer) &&
Objects.equals(processingTimeTimer, pojo.processingTimeTimer);
}

@Override
public int hashCode() {
return Objects.hash(key, state, eventTimeTimer, processingTimeTimer);
}
}
}
Loading

0 comments on commit f2300e0

Please sign in to comment.