Skip to content

Commit

Permalink
Added support for training binary relation detectors
Browse files Browse the repository at this point in the history
  • Loading branch information
ankane committed Sep 2, 2022
1 parent 03436d5 commit d473a2f
Show file tree
Hide file tree
Showing 5 changed files with 260 additions and 5 deletions.
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
## 0.1.3 (unreleased)

- Added support for training NER models
- Added support for training binary relation detectors

## 0.1.2 (2022-08-30)

Expand Down
28 changes: 28 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -144,6 +144,34 @@ This returns
[['first' => 'Shopify', 'second' => 'Ottawa', 'score' => 0.17649169745814464]]
```

### Training [unreleased]

Load an NER model into a trainer

```php
$trainer = new Mitie\BinaryRelationTrainer($model);
```

Add positive and negative examples to the trainer

```php
$tokens = ['Shopify', 'was', 'founded', 'in', 'Ottawa'];
$trainer->addPositiveBinaryRelation($tokens, [0, 0], [4, 4]);
$trainer->addNegativeBinaryRelation($tokens, [4, 4], [0, 0]);
```

Train the detector

```php
$detector = $trainer->train();
```

Save the detector

```php
$detector->saveToDisk('binary_relation_detector.svm');
```

## Text Categorization

Load a model into a trainer
Expand Down
15 changes: 10 additions & 5 deletions src/BinaryRelationDetector.php
Original file line number Diff line number Diff line change
Expand Up @@ -4,15 +4,20 @@

class BinaryRelationDetector
{
public function __construct($path)
public function __construct($path = null, $pointer = null)
{
$this->ffi = FFI::instance();

if (!file_exists($path)) {
throw new \InvalidArgumentException('File does not exist');
if (!is_null($path)) {
if (!file_exists($path)) {
throw new \InvalidArgumentException('File does not exist');
}
$this->pointer = $this->ffi->mitie_load_binary_relation_detector($path);
} elseif (!is_null($pointer)) {
$this->pointer = $pointer;
} else {
throw new \InvalidArgumentException('Must pass either a path or a pointer');
}

$this->pointer = $this->ffi->mitie_load_binary_relation_detector($path);
}

public function __destruct()
Expand Down
104 changes: 104 additions & 0 deletions src/BinaryRelationTrainer.php
Original file line number Diff line number Diff line change
@@ -0,0 +1,104 @@
<?php

namespace Mitie;

class BinaryRelationTrainer
{
public function __construct($ner, $name = '')
{
$this->ffi = FFI::instance();

$this->pointer = $this->ffi->mitie_create_binary_relation_trainer($name, $ner->pointer);
}

public function __destruct()
{
FFI::mitie_free($this->pointer);
}

public function addPositiveBinaryRelation($tokens, $range1, $range2)
{
$this->checkAdd($tokens, $range1, $range2);

$tokensPointer = Utils::arrayToPointer($tokens);
$status = $this->ffi->mitie_add_positive_binary_relation($this->pointer, $tokensPointer, $range1[0], $range1[1] - $range1[0] + 1, $range2[0], $range2[1] - $range2[0] + 1);
if ($status != 0) {
throw new Exception('Unable to add binary relation');
}
}

public function addNegativeBinaryRelation($tokens, $range1, $range2)
{
$this->checkAdd($tokens, $range1, $range2);

$tokensPointer = Utils::arrayToPointer($tokens);
$status = $this->ffi->mitie_add_negative_binary_relation($this->pointer, $tokensPointer, $range1[0], $range1[1] - $range1[0] + 1, $range2[0], $range2[1] - $range2[0] + 1);
if ($status != 0) {
throw new Exception('Unable to add binary relation');
}
}

public function beta()
{
return $this->ffi->mitie_binary_relation_trainer_get_beta($this->pointer);
}

public function setBeta($value)
{
if ($value < 0) {
throw new \InvalidArgumentException('beta must be greater than or equal to zero');
}

$this->ffi->mitie_binary_relation_trainer_set_beta($this->pointer, $value);
}

public function numThreads()
{
return $this->ffi->mitie_binary_relation_trainer_get_num_threads($this->pointer);
}

public function setNumThreads($value)
{
return $this->ffi->mitie_binary_relation_trainer_set_num_threads($this->pointer, $value);
}

public function numPositiveExamples()
{
return $this->ffi->mitie_binary_relation_trainer_num_positive_examples($this->pointer);
}

public function numNegativeExamples()
{
return $this->ffi->mitie_binary_relation_trainer_num_negative_examples($this->pointer);
}

public function train()
{
if ($this->numPositiveExamples() + $this->numNegativeExamples() == 0) {
throw new Exception("You can't call train() on an empty trainer");
}

$detector = $this->ffi->mitie_train_binary_relation_detector($this->pointer);

if (is_null($detector)) {
throw new Exception('Unable to create binary relation detector. Probably ran out of RAM.');
}

return new BinaryRelationDetector(pointer: $detector);
}

private function checkAdd($tokens, $range1, $range2)
{
Utils::checkRange($range1[0], $range1[1], count($tokens));
Utils::checkRange($range2[0], $range2[1], count($tokens));

if ($this->entitiesOverlap($range1, $range2)) {
throw new \InvalidArgumentException('Entities overlap');
}
}

private function entitiesOverlap($range1, $range2)
{
return $this->ffi->mitie_entities_overlap($range1[0], $range1[1] - $range1[0] + 1, $range2[0], $range2[1] - $range2[0] + 1) == 1;
}
}
117 changes: 117 additions & 0 deletions tests/BinaryRelationTrainerTest.php
Original file line number Diff line number Diff line change
@@ -0,0 +1,117 @@
<?php

use Tests\TestCase;

final class BinaryRelationTrainerTest extends TestCase
{
public function testWorks()
{
$trainer = new Mitie\BinaryRelationTrainer($this->model());
$trainer->addPositiveBinaryRelation($this->tokens(), [0, 0], [4, 4]);
$trainer->addNegativeBinaryRelation($this->tokens(), [4, 4], [0, 0]);
$this->assertEquals(1, $trainer->numPositiveExamples());
$this->assertEquals(1, $trainer->numNegativeExamples());
$detector = $trainer->train();
$this->assertEquals('', $detector->name());

$path = tempnam(sys_get_temp_dir(), 'detector');
$detector->saveToDisk($path);
$this->assertFileExists($path);

$detector = new Mitie\BinaryRelationDetector($path);
$doc = $this->model()->doc('Shopify was founded in Ottawa');

$relations = $detector->relations($doc);
$this->assertCount(1, $relations);

$relation = $relations[0];
$this->assertEquals('Shopify', $relation['first']);
$this->assertEquals('Ottawa', $relation['second']);
}

public function testAddPositiveBinaryRelationInvalidRange1()
{
$this->expectException(InvalidArgumentException::class);
$this->expectExceptionMessage('Invalid range');

$trainer = new Mitie\BinaryRelationTrainer($this->model());
$trainer->addPositiveBinaryRelation($this->tokens(), [0, -1], [4, 4]);
}

public function testAddPositiveBinaryRelationInvalidRange2()
{
$this->expectException(InvalidArgumentException::class);
$this->expectExceptionMessage('Invalid range');

$trainer = new Mitie\BinaryRelationTrainer($this->model());
$trainer->addPositiveBinaryRelation($this->tokens(), [0, 0], [4, 3]);
}

public function testAddPositiveBinaryRelationInvalidRange3()
{
$this->expectException(InvalidArgumentException::class);
$this->expectExceptionMessage('Invalid range');

$trainer = new Mitie\BinaryRelationTrainer($this->model());
$trainer->addPositiveBinaryRelation($this->tokens(), [0, 0], [4, 5]);
}

public function testAddNegativeBinaryRelationInvalidRange1()
{
$this->expectException(InvalidArgumentException::class);
$this->expectExceptionMessage('Invalid range');

$trainer = new Mitie\BinaryRelationTrainer($this->model());
$trainer->addNegativeBinaryRelation($this->tokens(), [0, -1], [4, 4]);
}

public function testAddNegativeBinaryRelationInvalidRange2()
{
$this->expectException(InvalidArgumentException::class);
$this->expectExceptionMessage('Invalid range');

$trainer = new Mitie\BinaryRelationTrainer($this->model());
$trainer->addNegativeBinaryRelation($this->tokens(), [0, 0], [4, 3]);
}

public function testAddNegativeBinaryRelationInvalidRange3()
{
$this->expectException(InvalidArgumentException::class);
$this->expectExceptionMessage('Invalid range');

$trainer = new Mitie\BinaryRelationTrainer($this->model());
$trainer->addNegativeBinaryRelation($this->tokens(), [0, 0], [4, 5]);
}

public function testAddPositiveBinaryRelationEntitiesOverlap()
{
$this->expectException(InvalidArgumentException::class);
$this->expectExceptionMessage('Entities overlap');

$trainer = new Mitie\BinaryRelationTrainer($this->model());
$trainer->addPositiveBinaryRelation($this->tokens(), [0, 1], [1, 2]);
}

public function testAddNegativeBinaryRelationEntitiesOverlap()
{
$this->expectException(InvalidArgumentException::class);
$this->expectExceptionMessage('Entities overlap');

$trainer = new Mitie\BinaryRelationTrainer($this->model());
$trainer->addNegativeBinaryRelation($this->tokens(), [0, 1], [1, 2]);
}

public function testEmptyTrainer()
{
$this->expectException(Mitie\Exception::class);
$this->expectExceptionMessage("You can't call train() on an empty trainer");

$trainer = new Mitie\BinaryRelationTrainer($this->model());
$trainer->train();
}

private function tokens()
{
return ['Shopify', 'was', 'founded', 'in', 'Ottawa'];
}
}

0 comments on commit d473a2f

Please sign in to comment.