forked from jorgecasas/php-ml
-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Implement SelectKBest algo for feature selection
- Loading branch information
Showing
14 changed files
with
389 additions
and
18 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,10 @@ | ||
<?php | ||
|
||
declare(strict_types=1); | ||
|
||
namespace Phpml\FeatureSelection; | ||
|
||
interface ScoringFunction | ||
{ | ||
public function score(array $samples, array $targets): array; | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,21 @@ | ||
<?php | ||
|
||
declare(strict_types=1); | ||
|
||
namespace Phpml\FeatureSelection\ScoringFunction; | ||
|
||
use Phpml\FeatureSelection\ScoringFunction; | ||
use Phpml\Math\Statistic\ANOVA; | ||
|
||
final class ANOVAFValue implements ScoringFunction | ||
{ | ||
public function score(array $samples, array $targets): array | ||
{ | ||
$grouped = []; | ||
foreach ($samples as $index => $sample) { | ||
$grouped[$targets[$index]][] = $sample; | ||
} | ||
|
||
return ANOVA::oneWayF(array_values($grouped)); | ||
} | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,78 @@ | ||
<?php | ||
|
||
declare(strict_types=1); | ||
|
||
namespace Phpml\FeatureSelection; | ||
|
||
use Phpml\Exception\InvalidArgumentException; | ||
use Phpml\Exception\InvalidOperationException; | ||
use Phpml\FeatureSelection\ScoringFunction\ANOVAFValue; | ||
use Phpml\Transformer; | ||
|
||
final class SelectKBest implements Transformer | ||
{ | ||
/** | ||
* @var ScoringFunction | ||
*/ | ||
private $scoringFunction; | ||
|
||
/** | ||
* @var int | ||
*/ | ||
private $k; | ||
|
||
/** | ||
* @var array|null | ||
*/ | ||
private $scores = null; | ||
|
||
/** | ||
* @var array|null | ||
*/ | ||
private $keepColumns = null; | ||
|
||
public function __construct(?ScoringFunction $scoringFunction = null, int $k = 10) | ||
{ | ||
if ($scoringFunction === null) { | ||
$scoringFunction = new ANOVAFValue(); | ||
} | ||
|
||
$this->scoringFunction = $scoringFunction; | ||
$this->k = $k; | ||
} | ||
|
||
public function fit(array $samples, ?array $targets = null): void | ||
{ | ||
if ($targets === null || empty($targets)) { | ||
throw InvalidArgumentException::arrayCantBeEmpty(); | ||
} | ||
|
||
$this->scores = $sorted = $this->scoringFunction->score($samples, $targets); | ||
if ($this->k >= count($sorted)) { | ||
return; | ||
} | ||
|
||
arsort($sorted); | ||
$this->keepColumns = array_slice($sorted, 0, $this->k, true); | ||
} | ||
|
||
public function transform(array &$samples): void | ||
{ | ||
if ($this->keepColumns === null) { | ||
return; | ||
} | ||
|
||
foreach ($samples as &$sample) { | ||
$sample = array_values(array_intersect_key($sample, $this->keepColumns)); | ||
} | ||
} | ||
|
||
public function scores(): array | ||
{ | ||
if ($this->scores === null) { | ||
throw new InvalidOperationException('SelectKBest require to fit first to get scores'); | ||
} | ||
|
||
return $this->scores; | ||
} | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,137 @@ | ||
<?php | ||
|
||
declare(strict_types=1); | ||
|
||
namespace Phpml\Math\Statistic; | ||
|
||
use Phpml\Exception\InvalidArgumentException; | ||
|
||
/** | ||
* Analysis of variance | ||
* https://en.wikipedia.org/wiki/Analysis_of_variance | ||
*/ | ||
final class ANOVA | ||
{ | ||
/** | ||
* The one-way ANOVA tests the null hypothesis that 2 or more groups have | ||
* the same population mean. The test is applied to samples from two or | ||
* more groups, possibly with differing sizes. | ||
* | ||
* @param array|array[] $samples - each row is class samples | ||
* | ||
* @return array|float[] | ||
*/ | ||
public static function oneWayF(array $samples): array | ||
{ | ||
$classes = count($samples); | ||
if ($classes < 2) { | ||
throw InvalidArgumentException::arraySizeToSmall(2); | ||
} | ||
|
||
$samplesPerClass = array_map(function (array $class): int { | ||
return count($class); | ||
}, $samples); | ||
$allSamples = array_sum($samplesPerClass); | ||
$ssAllSamples = self::sumOfSquaresPerFeature($samples); | ||
$sumSamples = self::sumOfFeaturesPerClass($samples); | ||
$squareSumSamples = self::sumOfSquares($sumSamples); | ||
$sumSamplesSquare = self::squaresSum($sumSamples); | ||
$ssbn = self::calculateSsbn($samples, $sumSamplesSquare, $samplesPerClass, $squareSumSamples, $allSamples); | ||
$sswn = self::calculateSswn($ssbn, $ssAllSamples, $squareSumSamples, $allSamples); | ||
$dfbn = $classes - 1; | ||
$dfwn = $allSamples - $classes; | ||
|
||
$msb = array_map(function ($s) use ($dfbn) { | ||
return $s / $dfbn; | ||
}, $ssbn); | ||
$msw = array_map(function ($s) use ($dfwn) { | ||
return $s / $dfwn; | ||
}, $sswn); | ||
|
||
$f = []; | ||
foreach ($msb as $index => $msbValue) { | ||
$f[$index] = $msbValue / $msw[$index]; | ||
} | ||
|
||
return $f; | ||
} | ||
|
||
private static function sumOfSquaresPerFeature(array $samples): array | ||
{ | ||
$sum = array_fill(0, count($samples[0][0]), 0); | ||
foreach ($samples as $class) { | ||
foreach ($class as $sample) { | ||
foreach ($sample as $index => $feature) { | ||
$sum[$index] += $feature ** 2; | ||
} | ||
} | ||
} | ||
|
||
return $sum; | ||
} | ||
|
||
private static function sumOfFeaturesPerClass(array $samples): array | ||
{ | ||
return array_map(function (array $class) { | ||
$sum = array_fill(0, count($class[0]), 0); | ||
foreach ($class as $sample) { | ||
foreach ($sample as $index => $feature) { | ||
$sum[$index] += $feature; | ||
} | ||
} | ||
|
||
return $sum; | ||
}, $samples); | ||
} | ||
|
||
private static function sumOfSquares(array $sums): array | ||
{ | ||
$squares = array_fill(0, count($sums[0]), 0); | ||
foreach ($sums as $row) { | ||
foreach ($row as $index => $sum) { | ||
$squares[$index] += $sum; | ||
} | ||
} | ||
|
||
return array_map(function ($sum) { | ||
return $sum ** 2; | ||
}, $squares); | ||
} | ||
|
||
private static function squaresSum(array $sums): array | ||
{ | ||
foreach ($sums as &$row) { | ||
foreach ($row as &$sum) { | ||
$sum = $sum ** 2; | ||
} | ||
} | ||
|
||
return $sums; | ||
} | ||
|
||
private static function calculateSsbn(array $samples, array $sumSamplesSquare, array $samplesPerClass, array $squareSumSamples, int $allSamples): array | ||
{ | ||
$ssbn = array_fill(0, count($samples[0][0]), 0); | ||
foreach ($sumSamplesSquare as $classIndex => $class) { | ||
foreach ($class as $index => $feature) { | ||
$ssbn[$index] += $feature / $samplesPerClass[$classIndex]; | ||
} | ||
} | ||
|
||
foreach ($squareSumSamples as $index => $sum) { | ||
$ssbn[$index] -= $sum / $allSamples; | ||
} | ||
|
||
return $ssbn; | ||
} | ||
|
||
private static function calculateSswn(array $ssbn, array $ssAllSamples, array $squareSumSamples, int $allSamples): array | ||
{ | ||
$sswn = []; | ||
foreach ($ssAllSamples as $index => $ss) { | ||
$sswn[$index] = ($ss - $squareSumSamples[$index] / $allSamples) - $ssbn[$index]; | ||
} | ||
|
||
return $sswn; | ||
} | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
25 changes: 25 additions & 0 deletions
25
tests/FeatureSelection/ScoringFunction/ANOVAFValueTest.php
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,25 @@ | ||
<?php | ||
|
||
declare(strict_types=1); | ||
|
||
namespace Phpml\Tests\FeatureSelection\ScoringFunction; | ||
|
||
use Phpml\Dataset\Demo\IrisDataset; | ||
use Phpml\FeatureSelection\ScoringFunction\ANOVAFValue; | ||
use PHPUnit\Framework\TestCase; | ||
|
||
final class ANOVAFValueTest extends TestCase | ||
{ | ||
public function testScoreForANOVAFValue(): void | ||
{ | ||
$dataset = new IrisDataset(); | ||
$function = new ANOVAFValue(); | ||
|
||
self::assertEquals( | ||
[119.2645, 47.3644, 1179.0343, 959.3244], | ||
$function->score($dataset->getSamples(), $dataset->getTargets()), | ||
'', | ||
0.0001 | ||
); | ||
} | ||
} |
Oops, something went wrong.