Skip to content

Commit

Permalink
Return labels in MultilayerPerceptron output (#315)
Browse files Browse the repository at this point in the history
  • Loading branch information
marmichalski authored and akondas committed Oct 15, 2018
1 parent e255369 commit d29c590
Show file tree
Hide file tree
Showing 7 changed files with 61 additions and 12 deletions.
5 changes: 2 additions & 3 deletions src/Classification/MLPClassifier.php
Original file line number Diff line number Diff line change
Expand Up @@ -41,17 +41,16 @@ protected function predictSample(array $sample)
}
}

return $this->classes[$predictedClass];
return $predictedClass;
}

/**
* @param mixed $target
*/
protected function trainSample(array $sample, $target): void
{

// Feed-forward.
$this->setInput($sample)->getOutput();
$this->setInput($sample);

// Back-propagate.
$this->backpropagation->backpropagate($this->getLayers(), $this->getTargetClass($target));
Expand Down
5 changes: 1 addition & 4 deletions src/NeuralNetwork/Layer.php
Original file line number Diff line number Diff line change
Expand Up @@ -41,12 +41,9 @@ public function getNodes(): array
return $this->nodes;
}

/**
* @return Neuron
*/
private function createNode(string $nodeClass, ?ActivationFunction $activationFunction = null): Node
{
if ($nodeClass == Neuron::class) {
if ($nodeClass === Neuron::class) {
return new Neuron($activationFunction);
}

Expand Down
2 changes: 0 additions & 2 deletions src/NeuralNetwork/Network/LayeredNetwork.php
Original file line number Diff line number Diff line change
Expand Up @@ -51,8 +51,6 @@ public function getOutput(): array

/**
* @param mixed $input
*
* @return $this
*/
public function setInput($input): Network
{
Expand Down
14 changes: 14 additions & 0 deletions src/NeuralNetwork/Network/MultilayerPerceptron.php
Original file line number Diff line number Diff line change
Expand Up @@ -69,6 +69,10 @@ public function __construct(int $inputLayerFeatures, array $hiddenLayers, array
throw new InvalidArgumentException('Provide at least 2 different classes');
}

if (count($classes) !== count(array_unique($classes))) {
throw new InvalidArgumentException('Classes must be unique');
}

$this->classes = array_values($classes);
$this->iterations = $iterations;
$this->inputLayerFeatures = $inputLayerFeatures;
Expand Down Expand Up @@ -109,6 +113,16 @@ public function setLearningRate(float $learningRate): void
$this->backpropagation->setLearningRate($this->learningRate);
}

public function getOutput(): array
{
$result = [];
foreach ($this->getOutputLayer()->getNodes() as $i => $neuron) {
$result[$this->classes[$i]] = $neuron->getOutput();
}

return $result;
}

/**
* @param mixed $target
*/
Expand Down
2 changes: 1 addition & 1 deletion src/NeuralNetwork/Node/Neuron.php
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,7 @@ public function addSynapse(Synapse $synapse): void
/**
* @return Synapse[]
*/
public function getSynapses()
public function getSynapses(): array
{
return $this->synapses;
}
Expand Down
11 changes: 9 additions & 2 deletions tests/Classification/MLPClassifierTest.php
Original file line number Diff line number Diff line change
Expand Up @@ -183,7 +183,7 @@ public function testSaveAndRestore(): void
$testSamples = [[0, 0], [1, 0], [0, 1], [1, 1]];
$predicted = $classifier->predict($testSamples);

$filename = 'perceptron-test-'.random_int(100, 999).'-'.uniqid();
$filename = 'perceptron-test-'.random_int(100, 999).'-'.uniqid('', false);
$filepath = tempnam(sys_get_temp_dir(), $filename);
$modelManager = new ModelManager();
$modelManager->saveToFile($classifier, $filepath);
Expand All @@ -204,7 +204,7 @@ public function testSaveAndRestoreWithPartialTraining(): void
$this->assertEquals('a', $network->predict([1, 0]));
$this->assertEquals('b', $network->predict([0, 1]));

$filename = 'perceptron-test-'.random_int(100, 999).'-'.uniqid();
$filename = 'perceptron-test-'.random_int(100, 999).'-'.uniqid('', false);
$filepath = tempnam(sys_get_temp_dir(), $filename);
$modelManager = new ModelManager();
$modelManager->saveToFile($network, $filepath);
Expand Down Expand Up @@ -245,6 +245,13 @@ public function testThrowExceptionOnInvalidClassesNumber(): void
new MLPClassifier(2, [2], [0]);
}

public function testOutputWithLabels(): void
{
$output = (new MLPClassifier(2, [2, 2], ['T', 'F']))->getOutput();

$this->assertEquals(['T', 'F'], array_keys($output));
}

private function getSynapsesNodes(array $synapses): array
{
$nodes = [];
Expand Down
34 changes: 34 additions & 0 deletions tests/NeuralNetwork/Network/MultilayerPerceptronTest.php
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@

namespace Phpml\Tests\NeuralNetwork\Network;

use Phpml\Exception\InvalidArgumentException;
use Phpml\NeuralNetwork\ActivationFunction;
use Phpml\NeuralNetwork\Layer;
use Phpml\NeuralNetwork\Network\MultilayerPerceptron;
Expand All @@ -13,6 +14,39 @@

class MultilayerPerceptronTest extends TestCase
{
public function testThrowExceptionWhenHiddenLayersAreEmpty(): void
{
$this->expectException(InvalidArgumentException::class);
$this->expectExceptionMessage('Provide at least 1 hidden layer');

$this->getMockForAbstractClass(
MultilayerPerceptron::class,
[5, [], [0, 1], 1000, null, 0.42]
);
}

public function testThrowExceptionWhenThereIsOnlyOneClass(): void
{
$this->expectException(InvalidArgumentException::class);
$this->expectExceptionMessage('Provide at least 2 different classes');

$this->getMockForAbstractClass(
MultilayerPerceptron::class,
[5, [3], [0], 1000, null, 0.42]
);
}

public function testThrowExceptionWhenClassesAreNotUnique(): void
{
$this->expectException(InvalidArgumentException::class);
$this->expectExceptionMessage('Classes must be unique');

$this->getMockForAbstractClass(
MultilayerPerceptron::class,
[5, [3], [0, 1, 2, 3, 1], 1000, null, 0.42]
);
}

public function testLearningRateSetter(): void
{
/** @var MultilayerPerceptron $mlp */
Expand Down

0 comments on commit d29c590

Please sign in to comment.