Skip to content

Commit

Permalink
[FLINK-5163] Port the MessageAcknowledgingSourceBase to the new state…
Browse files Browse the repository at this point in the history
… abstractions.
  • Loading branch information
kl0u authored and aljoscha committed Dec 13, 2016
1 parent d24833d commit 956ffa6
Show file tree
Hide file tree
Showing 4 changed files with 120 additions and 40 deletions.
16 changes: 16 additions & 0 deletions flink-connectors/flink-connector-rabbitmq/pom.xml
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,22 @@ under the License.
<version>${rabbitmq.version}</version>
</dependency>

<dependency>
<groupId>org.apache.flink</groupId>
<artifactId>flink-streaming-java_2.10</artifactId>
<version>${project.version}</version>
<scope>test</scope>
<type>test-jar</type>
</dependency>

<dependency>
<groupId>org.apache.flink</groupId>
<artifactId>flink-runtime_2.10</artifactId>
<version>${project.version}</version>
<type>test-jar</type>
<scope>test</scope>
</dependency>

</dependencies>

</project>
Original file line number Diff line number Diff line change
Expand Up @@ -23,16 +23,19 @@
import com.rabbitmq.client.Envelope;
import com.rabbitmq.client.QueueingConsumer;
import org.apache.flink.api.common.functions.RuntimeContext;
import org.apache.flink.api.common.state.OperatorStateStore;
import org.apache.flink.api.common.typeinfo.TypeInformation;
import org.apache.flink.api.common.typeutils.base.StringSerializer;
import org.apache.flink.api.java.tuple.Tuple2;
import org.apache.flink.api.java.typeutils.TypeExtractor;
import org.apache.flink.configuration.Configuration;
import org.apache.flink.runtime.state.SerializedCheckpointData;
import org.apache.flink.runtime.state.FunctionInitializationContext;
import org.apache.flink.streaming.api.functions.source.SourceFunction;
import org.apache.flink.streaming.api.operators.StreamSource;
import org.apache.flink.streaming.api.operators.StreamingRuntimeContext;
import org.apache.flink.streaming.api.watermark.Watermark;
import org.apache.flink.streaming.connectors.rabbitmq.common.RMQConnectionConfig;
import org.apache.flink.streaming.runtime.tasks.OperatorStateHandles;
import org.apache.flink.streaming.util.AbstractStreamOperatorTestHarness;
import org.apache.flink.streaming.util.serialization.DeserializationSchema;
import org.junit.After;
import org.junit.Before;
Expand All @@ -53,6 +56,7 @@
import static org.junit.Assert.assertNotNull;
import static org.junit.Assert.assertTrue;
import static org.junit.Assert.fail;
import static org.mockito.Matchers.any;


/**
Expand Down Expand Up @@ -83,7 +87,13 @@ public class RMQSourceTest {
@Before
public void beforeTest() throws Exception {

OperatorStateStore mockStore = Mockito.mock(OperatorStateStore.class);
FunctionInitializationContext mockContext = Mockito.mock(FunctionInitializationContext.class);
Mockito.when(mockContext.getOperatorStateStore()).thenReturn(mockStore);
Mockito.when(mockStore.getSerializableListState(any(String.class))).thenReturn(null);

source = new RMQTestSource();
source.initializeState(mockContext);
source.open(config);

messageId = 0;
Expand Down Expand Up @@ -128,6 +138,12 @@ public void throwExceptionIfConnectionFactoryReturnNull() throws Exception {
@Test
public void testCheckpointing() throws Exception {
source.autoAck = false;

StreamSource<String, RMQSource<String>> src = new StreamSource<>(source);
AbstractStreamOperatorTestHarness<String> testHarness =
new AbstractStreamOperatorTestHarness<>(src, 1, 1, 0);
testHarness.open();

sourceThread.start();

Thread.sleep(5);
Expand All @@ -141,10 +157,10 @@ public void testCheckpointing() throws Exception {

for (int i=0; i < numSnapshots; i++) {
long snapshotId = random.nextLong();
SerializedCheckpointData[] data;
OperatorStateHandles data;

synchronized (DummySourceContext.lock) {
data = source.snapshotState(snapshotId, System.currentTimeMillis());
data = testHarness.snapshot(snapshotId, System.currentTimeMillis());
previousSnapshotId = lastSnapshotId;
lastSnapshotId = messageId;
}
Expand All @@ -153,15 +169,25 @@ public void testCheckpointing() throws Exception {

// check if the correct number of messages have been snapshotted
final long numIds = lastSnapshotId - previousSnapshotId;
assertEquals(numIds, data[0].getNumIds());
// deserialize and check if the last id equals the last snapshotted id
ArrayDeque<Tuple2<Long, List<String>>> deque = SerializedCheckpointData.toDeque(data, new StringSerializer());

RMQTestSource sourceCopy = new RMQTestSource();
StreamSource<String, RMQTestSource> srcCopy = new StreamSource<>(sourceCopy);
AbstractStreamOperatorTestHarness<String> testHarnessCopy =
new AbstractStreamOperatorTestHarness<>(srcCopy, 1, 1, 0);

testHarnessCopy.setup();
testHarnessCopy.initializeState(data);
testHarnessCopy.open();

ArrayDeque<Tuple2<Long, List<String>>> deque = sourceCopy.getRestoredState();
List<String> messageIds = deque.getLast().f1;

assertEquals(numIds, messageIds.size());
if (messageIds.size() > 0) {
assertEquals(lastSnapshotId, (long) Long.valueOf(messageIds.get(messageIds.size() - 1)));
}

// check if the messages are being acknowledged and the transaction comitted
// check if the messages are being acknowledged and the transaction committed
synchronized (DummySourceContext.lock) {
source.notifyCheckpointComplete(snapshotId);
}
Expand Down Expand Up @@ -313,12 +339,24 @@ public TypeInformation<String> getProducedType() {

private class RMQTestSource extends RMQSource<String> {

private ArrayDeque<Tuple2<Long, List<String>>> restoredState;

public RMQTestSource() {
super(new RMQConnectionConfig.Builder().setHost("hostTest")
.setPort(999).setUserName("userTest").setPassword("passTest").setVirtualHost("/").build()
, "queueDummy", true, new StringDeserializationScheme());
}

@Override
public void initializeState(FunctionInitializationContext context) throws Exception {
super.initializeState(context);
this.restoredState = this.pendingCheckpoints;
}

public ArrayDeque<Tuple2<Long, List<String>>> getRestoredState() {
return this.restoredState;
}

@Override
public void open(Configuration config) throws Exception {
super.open(config);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -27,14 +27,17 @@

import org.apache.flink.annotation.PublicEvolving;
import org.apache.flink.api.common.ExecutionConfig;
import org.apache.flink.api.common.state.ListState;
import org.apache.flink.api.common.typeinfo.TypeInformation;
import org.apache.flink.api.common.typeutils.TypeSerializer;
import org.apache.flink.api.java.tuple.Tuple2;
import org.apache.flink.api.java.typeutils.TypeExtractor;
import org.apache.flink.configuration.Configuration;
import org.apache.flink.runtime.state.CheckpointListener;
import org.apache.flink.streaming.api.checkpoint.Checkpointed;
import org.apache.flink.runtime.state.FunctionInitializationContext;
import org.apache.flink.runtime.state.FunctionSnapshotContext;
import org.apache.flink.runtime.state.SerializedCheckpointData;
import org.apache.flink.streaming.api.checkpoint.CheckpointedFunction;
import org.apache.flink.util.Preconditions;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

Expand Down Expand Up @@ -73,14 +76,16 @@
* }
* }
* }</pre>
*
*
* <b>NOTE:</b> This source has a parallelism of {@code 1}.
*
* @param <Type> The type of the messages created by the source.
* @param <UId> The type of unique IDs which may be used to acknowledge elements.
*/
@PublicEvolving
public abstract class MessageAcknowledgingSourceBase<Type, UId>
extends RichSourceFunction<Type>
implements Checkpointed<SerializedCheckpointData[]>, CheckpointListener {
implements CheckpointedFunction, CheckpointListener {

private static final long serialVersionUID = -8689291992192955579L;

Expand All @@ -93,7 +98,7 @@ public abstract class MessageAcknowledgingSourceBase<Type, UId>
private transient List<UId> idsForCurrentCheckpoint;

/** The list with IDs from checkpoints that were triggered, but not yet completed or notified of completion */
private transient ArrayDeque<Tuple2<Long, List<UId>>> pendingCheckpoints;
protected transient ArrayDeque<Tuple2<Long, List<UId>>> pendingCheckpoints;

/**
* Set which contain all processed ids. Ids are acknowledged after checkpoints. When restoring
Expand All @@ -102,6 +107,8 @@ public abstract class MessageAcknowledgingSourceBase<Type, UId>
*/
private transient Set<UId> idsProcessedButNotAcknowledged;

private transient ListState<SerializedCheckpointData[]> checkpointedState;

// ------------------------------------------------------------------------

/**
Expand All @@ -123,13 +130,38 @@ protected MessageAcknowledgingSourceBase(TypeInformation<UId> idTypeInfo) {
}

@Override
public void open(Configuration parameters) throws Exception {
idsForCurrentCheckpoint = new ArrayList<>(64);
if (pendingCheckpoints == null) {
pendingCheckpoints = new ArrayDeque<>();
}
if (idsProcessedButNotAcknowledged == null) {
idsProcessedButNotAcknowledged = new HashSet<>();
public void initializeState(FunctionInitializationContext context) throws Exception {
Preconditions.checkState(this.checkpointedState == null,
"The " + getClass().getSimpleName() + " has already been initialized.");

this.checkpointedState = context
.getOperatorStateStore()
.getSerializableListState("message-acknowledging-source-state");

this.idsForCurrentCheckpoint = new ArrayList<>(64);
this.pendingCheckpoints = new ArrayDeque<>();
this.idsProcessedButNotAcknowledged = new HashSet<>();

if (context.isRestored()) {
LOG.info("Restoring state for the {}.", getClass().getSimpleName());

List<SerializedCheckpointData[]> retrievedStates = new ArrayList<>();
for (SerializedCheckpointData[] entry : this.checkpointedState.get()) {
retrievedStates.add(entry);
}

// given that the parallelism of the function is 1, we can only have at most 1 state
Preconditions.checkArgument(retrievedStates.size() == 1,
getClass().getSimpleName() + " retrieved invalid state.");

pendingCheckpoints = SerializedCheckpointData.toDeque(retrievedStates.get(0), idSerializer);
// build a set which contains all processed ids. It may be used to check if we have
// already processed an incoming message.
for (Tuple2<Long, List<UId>> checkpoint : pendingCheckpoints) {
idsProcessedButNotAcknowledged.addAll(checkpoint.f1);
}
} else {
LOG.info("No state to restore for the {}.", getClass().getSimpleName());
}
}

Expand Down Expand Up @@ -166,26 +198,20 @@ protected boolean addId(UId uid) {
// ------------------------------------------------------------------------

@Override
public SerializedCheckpointData[] snapshotState(long checkpointId, long checkpointTimestamp) throws Exception {
LOG.debug("Snapshotting state. Messages: {}, checkpoint id: {}, timestamp: {}",
idsForCurrentCheckpoint, checkpointId, checkpointTimestamp);
public void snapshotState(FunctionSnapshotContext context) throws Exception {
Preconditions.checkState(this.checkpointedState != null,
"The " + getClass().getSimpleName() + " has not been properly initialized.");

pendingCheckpoints.addLast(new Tuple2<>(checkpointId, idsForCurrentCheckpoint));
if (LOG.isDebugEnabled()) {
LOG.debug("{} checkpointing: Messages: {}, checkpoint id: {}, timestamp: {}",
idsForCurrentCheckpoint, context.getCheckpointId(), context.getCheckpointTimestamp());
}

pendingCheckpoints.addLast(new Tuple2<>(context.getCheckpointId(), idsForCurrentCheckpoint));
idsForCurrentCheckpoint = new ArrayList<>(64);

return SerializedCheckpointData.fromDeque(pendingCheckpoints, idSerializer);
}

@Override
public void restoreState(SerializedCheckpointData[] state) throws Exception {
idsProcessedButNotAcknowledged = new HashSet<>();
pendingCheckpoints = SerializedCheckpointData.toDeque(state, idSerializer);
// build a set which contains all processed ids. It may be used to check if we have
// already processed an incoming message.
for (Tuple2<Long, List<UId>> checkpoint : pendingCheckpoints) {
idsProcessedButNotAcknowledged.addAll(checkpoint.f1);
}
this.checkpointedState.clear();
this.checkpointedState.add(SerializedCheckpointData.fromDeque(pendingCheckpoints, idSerializer));
}

@Override
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@
import org.apache.flink.api.common.typeinfo.TypeInformation;
import org.apache.flink.api.java.tuple.Tuple2;
import org.apache.flink.configuration.Configuration;
import org.apache.flink.runtime.state.SerializedCheckpointData;
import org.apache.flink.runtime.state.FunctionSnapshotContext;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

Expand Down Expand Up @@ -133,9 +133,9 @@ protected final void acknowledgeIDs(long checkpointId, List<UId> uniqueIds) {


@Override
public SerializedCheckpointData[] snapshotState(long checkpointId, long checkpointTimestamp) throws Exception {
sessionIdsPerSnapshot.add(new Tuple2<>(checkpointId, sessionIds));
public void snapshotState(FunctionSnapshotContext context) throws Exception {
sessionIdsPerSnapshot.add(new Tuple2<>(context.getCheckpointId(), sessionIds));
sessionIds = new ArrayList<>(64);
return super.snapshotState(checkpointId, checkpointTimestamp);
super.snapshotState(context);
}
}

0 comments on commit 956ffa6

Please sign in to comment.