Skip to content

Commit

Permalink
Support for multiple training datasets (#38)
Browse files Browse the repository at this point in the history
* Multiple training data sets allowed

* Tests with multiple training data sets

* Updating docs according to #38

Documenting all models which predictions will be based on all
training data provided.

Some models already supported multiple training data sets.
  • Loading branch information
dmonllao authored and akondas committed Feb 1, 2017
1 parent 6281da2 commit c1b1a5d
Show file tree
Hide file tree
Showing 13 changed files with 61 additions and 32 deletions.
2 changes: 2 additions & 0 deletions docs/machine-learning/association/apriori.md
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,8 @@ $associator = new Apriori($support = 0.5, $confidence = 0.5);
$associator->train($samples, $labels);
```

You can train the associator using multiple data sets, predictions will be based on all the training data.

### Predict

To predict sample label use `predict` method. You can provide one sample or array of samples:
Expand Down
2 changes: 2 additions & 0 deletions docs/machine-learning/classification/k-nearest-neighbors.md
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,8 @@ $classifier = new KNearestNeighbors();
$classifier->train($samples, $labels);
```

You can train the classifier using multiple data sets, predictions will be based on all the training data.

## Predict

To predict sample label use `predict` method. You can provide one sample or array of samples:
Expand Down
2 changes: 2 additions & 0 deletions docs/machine-learning/classification/naive-bayes.md
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,8 @@ $classifier = new NaiveBayes();
$classifier->train($samples, $labels);
```

You can train the classifier using multiple data sets, predictions will be based on all the training data.

### Predict

To predict sample label use `predict` method. You can provide one sample or array of samples:
Expand Down
2 changes: 2 additions & 0 deletions docs/machine-learning/classification/svc.md
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,8 @@ $classifier = new SVC(Kernel::LINEAR, $cost = 1000);
$classifier->train($samples, $labels);
```

You can train the classifier using multiple data sets, predictions will be based on all the training data.

### Predict

To predict sample label use `predict` method. You can provide one sample or array of samples:
Expand Down
1 change: 1 addition & 0 deletions docs/machine-learning/neural-network/backpropagation.md
Original file line number Diff line number Diff line change
Expand Up @@ -27,3 +27,4 @@ $training->train(
$maxIteraions = 30000
);
```
You can train the neural network using multiple data sets, predictions will be based on all the training data.
2 changes: 2 additions & 0 deletions docs/machine-learning/regression/least-squares.md
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,8 @@ $regression = new LeastSquares();
$regression->train($samples, $targets);
```

You can train the model using multiple data sets, predictions will be based on all the training data.

### Predict

To predict sample target value use `predict` method with sample to check (as `array`). Example:
Expand Down
2 changes: 2 additions & 0 deletions docs/machine-learning/regression/svr.md
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,8 @@ $regression = new SVR(Kernel::LINEAR);
$regression->train($samples, $targets);
```

You can train the model using multiple data sets, predictions will be based on all the training data.

### Predict

To predict sample target value use `predict` method. You can provide one sample or array of samples:
Expand Down
13 changes: 7 additions & 6 deletions src/Phpml/Classification/DecisionTree.php
Original file line number Diff line number Diff line change
Expand Up @@ -64,12 +64,13 @@ public function __construct($maxDepth = 10)
*/
public function train(array $samples, array $targets)
{
$this->featureCount = count($samples[0]);
$this->columnTypes = $this->getColumnTypes($samples);
$this->samples = $samples;
$this->targets = $targets;
$this->labels = array_keys(array_count_values($targets));
$this->tree = $this->getSplitLeaf(range(0, count($samples) - 1));
$this->samples = array_merge($this->samples, $samples);
$this->targets = array_merge($this->targets, $targets);

$this->featureCount = count($this->samples[0]);
$this->columnTypes = $this->getColumnTypes($this->samples);
$this->labels = array_keys(array_count_values($this->targets));
$this->tree = $this->getSplitLeaf(range(0, count($this->samples) - 1));
}

protected function getColumnTypes(array $samples)
Expand Down
10 changes: 5 additions & 5 deletions src/Phpml/Classification/NaiveBayes.php
Original file line number Diff line number Diff line change
Expand Up @@ -63,12 +63,12 @@ class NaiveBayes implements Classifier
*/
public function train(array $samples, array $targets)
{
$this->samples = $samples;
$this->targets = $targets;
$this->sampleCount = count($samples);
$this->featureCount = count($samples[0]);
$this->samples = array_merge($this->samples, $samples);
$this->targets = array_merge($this->targets, $targets);
$this->sampleCount = count($this->samples);
$this->featureCount = count($this->samples[0]);

$labelCounts = array_count_values($targets);
$labelCounts = array_count_values($this->targets);
$this->labels = array_keys($labelCounts);
foreach ($this->labels as $label) {
$samples = $this->getSamplesByLabel($label);
Expand Down
8 changes: 4 additions & 4 deletions src/Phpml/Helper/Trainable.php
Original file line number Diff line number Diff line change
Expand Up @@ -9,20 +9,20 @@ trait Trainable
/**
* @var array
*/
private $samples;
private $samples = [];

/**
* @var array
*/
private $targets;
private $targets = [];

/**
* @param array $samples
* @param array $targets
*/
public function train(array $samples, array $targets)
{
$this->samples = $samples;
$this->targets = $targets;
$this->samples = array_merge($this->samples, $samples);
$this->targets = array_merge($this->targets, $targets);
}
}
8 changes: 4 additions & 4 deletions src/Phpml/Regression/LeastSquares.php
Original file line number Diff line number Diff line change
Expand Up @@ -13,12 +13,12 @@ class LeastSquares implements Regression
/**
* @var array
*/
private $samples;
private $samples = [];

/**
* @var array
*/
private $targets;
private $targets = [];

/**
* @var float
Expand All @@ -36,8 +36,8 @@ class LeastSquares implements Regression
*/
public function train(array $samples, array $targets)
{
$this->samples = $samples;
$this->targets = $targets;
$this->samples = array_merge($this->samples, $samples);
$this->targets = array_merge($this->targets, $targets);

$this->computeCoefficients();
}
Expand Down
31 changes: 18 additions & 13 deletions tests/Phpml/Classification/DecisionTreeTest.php
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@

class DecisionTreeTest extends \PHPUnit_Framework_TestCase
{
public $data = [
private $data = [
['sunny', 85, 85, 'false', 'Dont_play' ],
['sunny', 80, 90, 'true', 'Dont_play' ],
['overcast', 83, 78, 'false', 'Play' ],
Expand All @@ -25,34 +25,39 @@ class DecisionTreeTest extends \PHPUnit_Framework_TestCase
['rain', 71, 80, 'true', 'Dont_play' ]
];

public function getData()
private $extraData = [
['scorching', 90, 95, 'false', 'Dont_play'],
['scorching', 100, 93, 'true', 'Dont_play'],
];

private function getData($input)
{
static $data = null, $targets = null;
if ($data == null) {
$data = $this->data;
$targets = array_column($data, 4);
array_walk($data, function (&$v) {
array_splice($v, 4, 1);
});
}
return [$data, $targets];
$targets = array_column($input, 4);
array_walk($input, function (&$v) {
array_splice($v, 4, 1);
});
return [$input, $targets];
}

public function testPredictSingleSample()
{
list($data, $targets) = $this->getData();
list($data, $targets) = $this->getData($this->data);
$classifier = new DecisionTree(5);
$classifier->train($data, $targets);
$this->assertEquals('Dont_play', $classifier->predict(['sunny', 78, 72, 'false']));
$this->assertEquals('Play', $classifier->predict(['overcast', 60, 60, 'false']));
$this->assertEquals('Dont_play', $classifier->predict(['rain', 60, 60, 'true']));

list($data, $targets) = $this->getData($this->extraData);
$classifier->train($data, $targets);
$this->assertEquals('Dont_play', $classifier->predict(['scorching', 95, 90, 'true']));
$this->assertEquals('Play', $classifier->predict(['overcast', 60, 60, 'false']));
return $classifier;
}

public function testTreeDepth()
{
list($data, $targets) = $this->getData();
list($data, $targets) = $this->getData($this->data);
$classifier = new DecisionTree(5);
$classifier->train($data, $targets);
$this->assertTrue(5 >= $classifier->actualDepth);
Expand Down
10 changes: 10 additions & 0 deletions tests/Phpml/Classification/NaiveBayesTest.php
Original file line number Diff line number Diff line change
Expand Up @@ -34,5 +34,15 @@ public function testPredictArrayOfSamples()
$predicted = $classifier->predict($testSamples);

$this->assertEquals($testLabels, $predicted);

// Feed an extra set of training data.
$samples = [[1, 1, 6]];
$labels = ['d'];
$classifier->train($samples, $labels);

$testSamples = [[1, 1, 6], [5, 1, 1]];
$testLabels = ['d', 'a'];
$this->assertEquals($testLabels, $classifier->predict($testSamples));

}
}

0 comments on commit c1b1a5d

Please sign in to comment.