Skip to content

Commit

Permalink
Add activationFunction parameter for Perceptron and Layer
Browse files Browse the repository at this point in the history
  • Loading branch information
akondas committed Aug 11, 2016
1 parent c506a84 commit 2412f15
Show file tree
Hide file tree
Showing 2 changed files with 29 additions and 10 deletions.
24 changes: 20 additions & 4 deletions src/Phpml/NeuralNetwork/Layer.php
Original file line number Diff line number Diff line change
Expand Up @@ -15,22 +15,38 @@ class Layer
private $nodes = [];

/**
* @param int $nodesNumber
* @param string $nodeClass
* @param int $nodesNumber
* @param string $nodeClass
* @param ActivationFunction|null $activationFunction
*
* @throws InvalidArgumentException
*/
public function __construct(int $nodesNumber = 0, string $nodeClass = Neuron::class)
public function __construct(int $nodesNumber = 0, string $nodeClass = Neuron::class, ActivationFunction $activationFunction = null)
{
if (!in_array(Node::class, class_implements($nodeClass))) {
throw InvalidArgumentException::invalidLayerNodeClass();
}

for ($i = 0; $i < $nodesNumber; ++$i) {
$this->nodes[] = new $nodeClass();
$this->nodes[] = $this->createNode($nodeClass, $activationFunction);
}
}

/**
* @param string $nodeClass
* @param ActivationFunction|null $activationFunction
*
* @return Neuron
*/
private function createNode(string $nodeClass, ActivationFunction $activationFunction = null)
{
if (Neuron::class == $nodeClass) {
return new Neuron($activationFunction);
}

return new $nodeClass();
}

/**
* @param Node $node
*/
Expand Down
15 changes: 9 additions & 6 deletions src/Phpml/NeuralNetwork/Network/MultilayerPerceptron.php
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
namespace Phpml\NeuralNetwork\Network;

use Phpml\Exception\InvalidArgumentException;
use Phpml\NeuralNetwork\ActivationFunction;
use Phpml\NeuralNetwork\Layer;
use Phpml\NeuralNetwork\Node\Bias;
use Phpml\NeuralNetwork\Node\Input;
Expand All @@ -14,18 +15,19 @@
class MultilayerPerceptron extends LayeredNetwork
{
/**
* @param array $layers
* @param array $layers
* @param ActivationFunction|null $activationFunction
*
* @throws InvalidArgumentException
*/
public function __construct(array $layers)
public function __construct(array $layers, ActivationFunction $activationFunction = null)
{
if (count($layers) < 2) {
throw InvalidArgumentException::invalidLayersNumber();
}

$this->addInputLayer(array_shift($layers));
$this->addNeuronLayers($layers);
$this->addNeuronLayers($layers, $activationFunction);
$this->addBiasNodes();
$this->generateSynapses();
}
Expand All @@ -39,12 +41,13 @@ private function addInputLayer(int $nodes)
}

/**
* @param array $layers
* @param array $layers
* @param ActivationFunction|null $activationFunction
*/
private function addNeuronLayers(array $layers)
private function addNeuronLayers(array $layers, ActivationFunction $activationFunction = null)
{
foreach ($layers as $neurons) {
$this->addLayer(new Layer($neurons, Neuron::class));
$this->addLayer(new Layer($neurons, Neuron::class, $activationFunction));
}
}

Expand Down

0 comments on commit 2412f15

Please sign in to comment.