Skip to content

Commit

Permalink
Merge pull request accord-net#1103 from FireDragonGameStudio/GH-1075_…
Browse files Browse the repository at this point in the history
…Add_QLearning_implementation_for_Animat_and_Unity_samples

accord-netGH-1075 Add custom QLearning implementation for Animat and add Unity Animat sample
  • Loading branch information
cesarsouza authored Dec 8, 2017
2 parents 6064bb1 + 3410ced commit 95bcfd5
Show file tree
Hide file tree
Showing 77 changed files with 376,274 additions and 145 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -90,6 +90,7 @@
</Compile>
<Compile Include="Program.cs" />
<Compile Include="Properties\AssemblyInfo.cs" />
<Compile Include="QLearning_Revisited\QLearning_FDGS.cs" />
<EmbeddedResource Include="CellWorld.resx">
<SubType>Designer</SubType>
<DependentUpon>CellWorld.cs</DependentUpon>
Expand Down
3 changes: 2 additions & 1 deletion Samples/MachineLearning/Animat/MainForm.Designer.cs

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

175 changes: 153 additions & 22 deletions Samples/MachineLearning/Animat/MainForm.cs
Original file line number Diff line number Diff line change
Expand Up @@ -7,16 +7,13 @@
//

using System;
using System.Collections.Generic;
using System.ComponentModel;
using System.Data;
using System.Drawing;
using System.Text;
using System.Windows.Forms;
using System.IO;
using System.Threading;

using Accord.MachineLearning;
using SampleApp.QLearning_Revisited;

namespace SampleApp
{
Expand All @@ -28,12 +25,20 @@ public partial class MainForm : Form
private int mapWidth;
private int mapHeight;

// agent' start and stop position
// agent's start and stop position
private int agentStartX;
private int agentStartY;
private int agentStopX;
private int agentStopY;

// agent's current position
private int agentCurrX;
private int agentCurrY;

// temp next state coordinates of the agent
private int agentNextX;
private int agentNextY;

// flag to stop background job
private volatile bool needToStop = false;

Expand All @@ -53,6 +58,8 @@ public partial class MainForm : Form
private QLearning qLearning = null;
// Sarsa algorithm
private Sarsa sarsa = null;
// self implemented Q-Learning
private QLearning_FDGS qLearning_FDGS = null;

// Form constructor
public MainForm()
Expand Down Expand Up @@ -321,19 +328,25 @@ private void startLearningButton_Click(object sender, EventArgs e)
// destroy algorithms
qLearning = null;
sarsa = null;
qLearning_FDGS = null;

if (algorithmCombo.SelectedIndex == 0)
{
// create new QLearning algorithm's instance
qLearning = new QLearning(256, 4, new TabuSearchExploration(4, new EpsilonGreedyExploration(explorationRate)));
workerThread = new Thread(new ThreadStart(QLearningThread));
}
else
else if (algorithmCombo.SelectedIndex == 1)
{
// create new Sarsa algorithm's instance
sarsa = new Sarsa(256, 4, new TabuSearchExploration(4, new EpsilonGreedyExploration(explorationRate)));
workerThread = new Thread(new ThreadStart(SarsaThread));
}
else
{
qLearning_FDGS = new QLearning_FDGS(4, agentStopX, agentStopY, map, new TabuSearchExploration(4, new EpsilonGreedyExploration(explorationRate)));
workerThread = new Thread(new ThreadStart(QLearningThread_FDGS));
}

// disable all settings controls except "Stop" button
EnableControls(false);
Expand All @@ -354,13 +367,18 @@ private void stopButton_Click(object sender, EventArgs e)
Application.DoEvents();
workerThread = null;
}

// reset learning class values
qLearning = null;
sarsa = null;
qLearning_FDGS = null;
}

// On "Show Solution" button
private void showSolutionButton_Click(object sender, EventArgs e)
{
// check if learning algorithm was run before
if ((qLearning == null) && (sarsa == null))
if ((qLearning == null) && (sarsa == null) && (qLearning_FDGS == null))
return;

// disable all settings controls except "Stop" button
Expand All @@ -376,7 +394,7 @@ private void showSolutionButton_Click(object sender, EventArgs e)
private void QLearningThread()
{
int iteration = 0;
// curent coordinates of the agent
// current coordinates of the agent
int agentCurrentX, agentCurrentY;
// exploration policy
TabuSearchExploration tabuPolicy = (TabuSearchExploration)qLearning.ExplorationPolicy;
Expand Down Expand Up @@ -500,6 +518,60 @@ private void SarsaThread()
EnableControls(true);
}

// self implemented Q-Learning thread
private void QLearningThread_FDGS()
{
int iteration = 0;

// exploration policy
TabuSearchExploration tabuPolicy = (TabuSearchExploration)qLearning_FDGS.ExplorationPolicy;
EpsilonGreedyExploration explorationPolicy = (EpsilonGreedyExploration)tabuPolicy.BasePolicy;

// loop
while ((!needToStop) && (iteration < learningIterations))
{
// set exploration rate for this iteration
explorationPolicy.Epsilon = explorationRate - ((double)iteration / learningIterations) * explorationRate;
// set learning rate for this iteration
qLearning_FDGS.LearningRate = learningRate - ((double)iteration / learningIterations) * learningRate;
// clear tabu list
tabuPolicy.ResetTabuList();

// reset agent's coordinates to the starting position
agentCurrX = agentStartX;
agentCurrY = agentStartY;

// steps performed by agent to get to the goal
int steps = 0;

while ((!needToStop) && ((agentCurrX != agentStopX) || (agentCurrY != agentStopY)))
{
steps++;
// get agent's current state
int currentState = qLearning_FDGS.GetStateFromCoordinates(agentCurrX, agentCurrY);
// get the action for this state
int action = qLearning_FDGS.GetAction(currentState);
// update agent and get next state
int nextState = UpdateAgentPosition(currentState, action);
// do learning of the agent - update his Q-function
qLearning_FDGS.LearnStep(currentState, action, nextState);

// set tabu action (prevent going back for the next iteration)
tabuPolicy.SetTabuAction((action + 2) % 4, 1);
}

System.Diagnostics.Debug.WriteLine(steps);

iteration++;

// show current iteration
SetText(iterationBox, iteration.ToString());
}

// enable settings controls
EnableControls(true);
}

// Show solution thread
private void ShowSolutionThread()
{
Expand All @@ -509,8 +581,10 @@ private void ShowSolutionThread()

if (qLearning != null)
tabuPolicy = (TabuSearchExploration)qLearning.ExplorationPolicy;
else
else if (sarsa != null)
tabuPolicy = (TabuSearchExploration)sarsa.ExplorationPolicy;
else
tabuPolicy = (TabuSearchExploration)qLearning_FDGS.ExplorationPolicy;

exploratioPolicy = (EpsilonGreedyExploration)tabuPolicy.BasePolicy;

Expand Down Expand Up @@ -549,12 +623,27 @@ private void ShowSolutionThread()
// remove agent from current position
mapToDisplay[agentCurrentY, agentCurrentX] = 0;

// get agent's current state
int currentState = GetStateNumber(agentCurrentX, agentCurrentY);
// get the action for this state
int action = (qLearning != null) ? qLearning.GetAction(currentState) : sarsa.GetAction(currentState);
// update agent's current position and get his reward
double reward = UpdateAgentPosition(ref agentCurrentX, ref agentCurrentY, action);
if ((qLearning != null) || (sarsa != null))
{
// get agent's current state
int currentState = GetStateNumber(agentCurrentX, agentCurrentY);
// get the action for this state
int action = (qLearning != null) ? qLearning.GetAction(currentState) : sarsa.GetAction(currentState);
// update agent's current position and get his reward
double reward = UpdateAgentPosition(ref agentCurrentX, ref agentCurrentY, action);
}
else
{
// get agent's current state
int currentState = qLearning_FDGS.GetStateFromCoordinates(agentCurrentX, agentCurrentY);
// get the action for this state
int action = qLearning_FDGS.GetLearnedAction(currentState);
// update agent's current position
UpdateAgentPosition(currentState, action);
// update current positions (due to current Animat implementation)
agentCurrentX = agentCurrX;
agentCurrentY = agentCurrY;
}

// put agent to the new position
mapToDisplay[agentCurrentY, agentCurrentX] = 2;
Expand All @@ -564,6 +653,48 @@ private void ShowSolutionThread()
EnableControls(true);
}

// Update agent position without reward calculation (will be done during learning step)
private int UpdateAgentPosition(int state, int action)
{
// moving direction
int dx = 0, dy = 0;

switch (action)
{
case 0: // go to north (up)
dy = -1;
break;
case 1: // go to east (right)
dx = 1;
break;
case 2: // go to south (down)
dy = 1;
break;
case 3: // go to west (left)
dx = -1;
break;
}

var currentCoordinates = qLearning_FDGS.GetCoordinatesFromState(state);
agentNextX = currentCoordinates.Item1 + dx; // calc new X
agentNextY = currentCoordinates.Item2 + dy; // calc new Y

// check new agent's coordinates and set if not hitting a wall
// or going out of bounds
if (!((map[agentNextY, agentNextX] != 0) ||
(agentNextX < 0) || (agentNextX >= mapWidth) ||
(agentNextY < 0) || (agentNextY >= mapHeight)))
{

agentCurrX = agentNextX;
agentCurrY = agentNextY;

return qLearning_FDGS.GetStateFromCoordinates(agentNextX, agentNextY);
}

return qLearning_FDGS.GetStateFromCoordinates(currentCoordinates.Item1, currentCoordinates.Item2);
}

// Update agent position and return reward for the move
private double UpdateAgentPosition(ref int currentX, ref int currentY, int action)
{
Expand All @@ -588,23 +719,23 @@ private double UpdateAgentPosition(ref int currentX, ref int currentY, int actio
break;
}

int newX = currentX + dx;
int newY = currentY + dy;
agentNextX = currentX + dx;
agentNextY = currentY + dy;

// check new agent's coordinates
if (
(map[newY, newX] != 0) ||
(newX < 0) || (newX >= mapWidth) ||
(newY < 0) || (newY >= mapHeight)
(map[agentNextY, agentNextX] != 0) ||
(agentNextX < 0) || (agentNextX >= mapWidth) ||
(agentNextY < 0) || (agentNextY >= mapHeight)
)
{
// we found a wall or got outside of the world
reward = wallReward;
}
else
{
currentX = newX;
currentY = newY;
currentX = agentNextX;
currentY = agentNextY;

// check if we found the goal
if ((currentX == agentStopX) && (currentY == agentStopY))
Expand Down
Loading

0 comments on commit 95bcfd5

Please sign in to comment.