Skip to content

Commit

Permalink
test refactoring
Browse files Browse the repository at this point in the history
  • Loading branch information
steffennissen committed Nov 1, 2015
1 parent 409c832 commit 6bc3459
Show file tree
Hide file tree
Showing 3 changed files with 64 additions and 35 deletions.
37 changes: 33 additions & 4 deletions tests/fann_test.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@ void FannTest::TearDown() {
}

void FannTest::AssertCreate(FANN::neural_net net, unsigned int num_layers, unsigned int *layers,
unsigned int neurons, unsigned int connections) {
unsigned int neurons, unsigned int connections) {
EXPECT_EQ(num_layers, net.get_num_layers());
EXPECT_EQ(layers[0], net.get_num_input());
EXPECT_EQ(layers[num_layers - 1], net.get_num_output());
Expand Down Expand Up @@ -44,14 +44,43 @@ void FannTest::AssertWeights(fann_type expected_min_weight, fann_type expected_m
fann_type max_weight = connections[0].weight;
fann_type total_weight = 0.0;
for (int i = 1; i < net.get_total_connections(); ++i) {
if(connections[i].weight < min_weight)
if (connections[i].weight < min_weight)
min_weight = connections[i].weight;
if(connections[i].weight > max_weight)
if (connections[i].weight > max_weight)
max_weight = connections[i].weight;
total_weight += connections[i].weight;
}

EXPECT_NEAR(expected_min_weight, min_weight, 0.01);
EXPECT_NEAR(expected_max_weight, max_weight, 0.01);
EXPECT_NEAR(expected_avg_weight, total_weight/(fann_type)net.get_total_connections(), 0.1);
EXPECT_NEAR(expected_avg_weight, total_weight / (fann_type) net.get_total_connections(), 0.1);
}


void FannTest::InitializeTrainDataStructure(unsigned int num_data,
unsigned int num_input,
unsigned int num_output,
float input_value, float output_value,
fann_type **input,
fann_type **output) {
for (unsigned int i = 0; i < num_data; i++) {
input[i] = new fann_type[num_input];
output[i] = new fann_type[num_output];
for (unsigned int j = 0; j < num_input; j++)
input[i][j] = input_value;
for (unsigned int j = 0; j < num_output; j++)
output[i][j] = output_value;
}
}


void FannTest::AssertTrainData(unsigned int num_data, unsigned int num_input, unsigned int num_output,
fann_type input_value, fann_type output_value) {
for (int i = 0; i < num_data; i++) {
for (int j = 0; j < num_input; j++)
EXPECT_EQ(input_value, this->data.get_input()[i][j]);
for (int j = 0; j < num_output; j++)
EXPECT_EQ(output_value, this->data.get_output()[i][j]);
}
}

6 changes: 6 additions & 0 deletions tests/fann_test.h
Original file line number Diff line number Diff line change
Expand Up @@ -22,10 +22,16 @@ class FannTest : public testing::Test {
void AssertWeights(fann_type expected_min_weight, fann_type expected_max_weight,
fann_type expected_avg_weight);

void AssertTrainData(unsigned int num_data, unsigned int num_input, unsigned int num_output, fann_type input_value,
fann_type output_value);

virtual void SetUp();

virtual void TearDown();

void InitializeTrainDataStructure(unsigned int num_data, unsigned int num_input, unsigned int num_output,
float input_value, float output_value, fann_type **input,
fann_type **output);
};

#endif //FANN_FANN_TESTFIXTURE_H
56 changes: 25 additions & 31 deletions tests/train_data.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -25,37 +25,31 @@ Foundation, Inc., 59 Temple Place, Suite 330, Boston, MA 02111-1307 USA
#include "fann_test.h"

TEST_F(FannTest, CreateTrainDataFromPointerArrays) {
fann_type **input = new fann_type *[2];
fann_type **output = new fann_type *[2];
for (int i = 0; i < 2; i++) {
input[i] = new fann_type[3];
output[i] = new fann_type[1];
for (int j = 0; j < 3; j++) {
input[i][j] = 1.1f;
}
output[i][0] = 2.2f;
}

data.set_train_data(2, 3, input, 1, output);

for (int i = 0; i < 2; i++) {
for (int j = 0; j < 3; j++) {
EXPECT_EQ(1.1f, data.get_input()[i][j]);
}
EXPECT_EQ(2.2f, data.get_output()[i][0]);
}
unsigned int num_data = 2;
unsigned int num_input = 3;
unsigned int num_output = 1;
float input_value = 1.1f;
float output_value = 2.2f;
fann_type **input = new fann_type *[num_data];
fann_type **output = new fann_type *[num_data];

InitializeTrainDataStructure(num_data, num_input, num_output, input_value, output_value, input, output);

data.set_train_data(num_data, num_input, input, num_output, output);

AssertTrainData(num_data, num_input, num_output, input_value, output_value);
}

TEST_F(FannTest, CreateTrainDataFromArrays) {
fann_type input[] = {1.1f, 1.1f, 1.1f, 1.1f, 1.1f, 1.1f};
fann_type output[] = {2.2f, 2.2f};

data.set_train_data(2, 3, input, 1, output);

for (int i = 0; i < 2; i++) {
for (int j = 0; j < 3; j++) {
EXPECT_EQ(1.1f, data.get_input()[i][j]);
}
EXPECT_EQ(2.2f, data.get_output()[i][0]);
}
}
unsigned int num_data = 2;
unsigned int num_input = 3;
unsigned int num_output = 1;
float input_value = 1.1f;
float output_value = 2.2f;

fann_type input[] = {input_value, input_value, input_value, input_value, input_value, input_value};
fann_type output[] = {output_value, output_value};
data.set_train_data(num_data, num_input, input, num_output, output);

AssertTrainData(num_data, num_input, num_output, input_value, output_value);
}

0 comments on commit 6bc3459

Please sign in to comment.