Skip to content

Commit

Permalink
TFLite RL RefAPP bug fix: disallow repeat move at inference time
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 387263830
  • Loading branch information
tensorflower-gardener authored and copybara-github committed Jul 28, 2021
1 parent 5c4d9ef commit 414e524
Show file tree
Hide file tree
Showing 6 changed files with 31 additions and 60 deletions.
14 changes: 7 additions & 7 deletions lite/examples/image_classification/raspberry_pi/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -41,14 +41,14 @@ Return here after you perform the `apt-get install` command.

First, clone this Git repo onto your Raspberry Pi like this:

```shell
```
git clone https://github.com/tensorflow/examples --depth 1
```

Then use our script to install a couple Python packages, and
download the MobileNet model and labels file:

```shell
```
cd examples/lite/examples/image_classification/raspberry_pi
# The script takes an argument specifying where you want to save the model files
Expand All @@ -58,7 +58,7 @@ bash download.sh /tmp

## Run the example

```shell
```
python3 classify_picamera.py \
--model /tmp/mobilenet_v1_1.0_224_quant.tflite \
--labels /tmp/labels_mobilenet_quant_v1_224.txt
Expand Down Expand Up @@ -90,20 +90,20 @@ delegate model execution to the Edge TPU processor:
2. Now open the `classify_picamera.py` file and add the following import at
the top:

```python
```
from tflite_runtime.interpreter import load_delegate
```
And then find the line that initializes the `Interpreter`, which looks like
this:
```python
```
interpreter = Interpreter(args.model)
```
And change it to specify the Edge TPU delegate:
```python
```
interpreter = Interpreter(args.model,
experimental_delegates=[load_delegate('libedgetpu.so.1.0')])
```
Expand All @@ -126,7 +126,7 @@ Now you're ready to execute the TensorFlow Lite model on the Edge TPU. Just run
`classify_picamera.py` again, but be sure you specify the model that's compiled
for the Edge TPU (it uses the same labels file as before):
```shell
```
python3 classify_picamera.py \
--model /tmp/mobilenet_v1_1.0_224_quant_edgetpu.tflite \
--labels /tmp/labels_mobilenet_quant_v1_224.txt
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -95,17 +95,17 @@ public void onItemClick(AdapterView<?> adapterView, View view, int position, lon
}

// Agent action
StrikePrediction agentStrikePosition = agent.predictNextMove(playerBoard);
if (agentStrikePosition == null) {
int agentStrikePosition = agent.predictNextMove(playerBoard);
if (agentStrikePosition == -1) {
Toast.makeText(
MainActivity.this,
"Something went wrong with the RL agent! Please restart the app.",
Toast.LENGTH_LONG)
.show();
return;
}
int agentStrikePositionX = agentStrikePosition.x;
int agentStrikePositionY = agentStrikePosition.y;
int agentStrikePositionX = agentStrikePosition / Constants.BOARD_SIZE;
int agentStrikePositionY = agentStrikePosition % Constants.BOARD_SIZE;

if (playerHiddenBoard[agentStrikePositionX][agentStrikePositionY]
== HiddenBoardCellStatus.OCCUPIED_BY_PLANE) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -31,15 +31,14 @@ public abstract class PlaneStrikeAgent {
protected Interpreter.Options tfliteOptions;

protected int agentStrikePosition;
protected boolean isPredictedByAgent;

public PlaneStrikeAgent(Activity activity) throws IOException {
tfliteOptions = new Interpreter.Options();
tflite = new Interpreter(loadModelFile(activity), tfliteOptions);
}

/** Predict the next move based on current board state. */
protected abstract StrikePrediction predictNextMove(BoardCellStatus[][] board);
protected abstract int predictNextMove(BoardCellStatus[][] board);

/** Run model inference on current board state. */
protected abstract void runInference();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -41,37 +41,36 @@ public RLAgent(Activity activity) throws IOException {

/** Predict the next move based on current board state. */
@Override
protected StrikePrediction predictNextMove(BoardCellStatus[][] board) {
protected int predictNextMove(BoardCellStatus[][] board) {

if (tflite == null) {
Log.e(
Constants.TAG, "Game agent failed to initialize. Please restart the app.");
return null;
Log.e(Constants.TAG, "Game agent failed to initialize. Please restart the app.");
return -1;
} else {
prepareModelInput(board);
runInference();
}

StrikePrediction strikePrediction = new StrikePrediction();
strikePrediction.x = agentStrikePosition / Constants.BOARD_SIZE;
strikePrediction.y = agentStrikePosition % Constants.BOARD_SIZE;
return strikePrediction;
// Post-processing (non-repeat argmax)
float[] probArray = outputProbArrays[0]; // batch size is 1 so we use [0] here
int agentStrikePosition = -1;
float maxProb = 0;
for (int i = 0; i < probArray.length; i++) {
int x = i / Constants.BOARD_SIZE;
int y = i % Constants.BOARD_SIZE;
if (board[x][y] == BoardCellStatus.UNTRIED && probArray[i] > maxProb) {
agentStrikePosition = i;
maxProb = probArray[i];
}
}
return agentStrikePosition;
}

/** Run model inference on current board state. */
@Override
protected void runInference() {
tflite.run(boardData, outputProbArrays);
boardData.rewind();

float[] probArray = outputProbArrays[0]; // batch size is 1 so we use [0] here
// Argmax
int maxIndex = 0;
for (int i = 0; i < probArray.length; i++) {
maxIndex = probArray[i] > probArray[maxIndex] ? i : maxIndex;
}
agentStrikePosition = maxIndex;
isPredictedByAgent = true;
}

@Override
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -36,21 +36,18 @@ public RLAgentFromTFAgents(Activity activity) throws IOException {

/** Predict the next move based on current board state. */
@Override
protected StrikePrediction predictNextMove(BoardCellStatus[][] board) {
protected int predictNextMove(BoardCellStatus[][] board) {

if (tflite == null) {
Log.e(
Constants.TAG, "Game agent failed to initialize. Please restart the app.");
return null;
return -1;
} else {
prepareModelInput(board);
runInference();
}

StrikePrediction strikePrediction = new StrikePrediction();
strikePrediction.x = agentStrikePosition / Constants.BOARD_SIZE;
strikePrediction.y = agentStrikePosition % Constants.BOARD_SIZE;
return strikePrediction;
return agentStrikePosition;
}

/** Run model inference on current board state. */
Expand All @@ -62,7 +59,6 @@ protected void runInference() {
output.put(0, prediction);
tflite.runForMultipleInputsOutputs(inputs, output);
agentStrikePosition = prediction[0][0];
isPredictedByAgent = true;
}

@Override
Expand Down

This file was deleted.

0 comments on commit 414e524

Please sign in to comment.