Skip to content

Commit

Permalink
added another function for setting training data
Browse files Browse the repository at this point in the history
  • Loading branch information
steffennissen committed Oct 28, 2012
1 parent d914171 commit 201ba5e
Show file tree
Hide file tree
Showing 10 changed files with 87 additions and 17 deletions.
Binary file modified bin/fanndouble.dll
Binary file not shown.
Binary file modified bin/fanndouble.lib
Binary file not shown.
Binary file modified bin/fannfixed.dll
Binary file not shown.
Binary file modified bin/fannfixed.lib
Binary file not shown.
Binary file modified bin/fannfloat.dll
Binary file not shown.
Binary file modified bin/fannfloat.lib
Binary file not shown.
31 changes: 20 additions & 11 deletions src/fann_train_data.c
Original file line number Diff line number Diff line change
Expand Up @@ -820,28 +820,37 @@ FANN_EXTERNAL struct fann_train_data * FANN_API fann_create_train(unsigned int n
return data;
}

FANN_EXTERNAL struct fann_train_data * FANN_API fann_create_train_array(unsigned int num_data, unsigned int num_input, fann_type **input, unsigned int num_output, fann_type **output)
FANN_EXTERNAL struct fann_train_data * FANN_API fann_create_train_pointer_array(unsigned int num_data, unsigned int num_input, fann_type **input, unsigned int num_output, fann_type **output)
{
unsigned int i, j;
struct fann_train_data *data;
fprintf(stderr, "***test***\n\n");
data = fann_create_train(num_data, num_input, num_output);
printf("Allocated train data at address %p\n", data);

if(data == NULL)
return NULL;

for (i = 0; i < num_data; ++i)
{
for (j = 0; j < num_input; ++j) {
printf("data->input[%d][%d] = %f\n", i, j, input[i][j]);
data->input[i][j] = input[i][j];
}
memcpy(data->input[i], input[i], num_input*sizeof(fann_type));
memcpy(data->output[i], output[i], num_output*sizeof(fann_type));
}

return data;
}

for (j = 0; j < num_output; ++j) {
printf("data->output[%d][%d] = %f\n", i, j, output[i][j]);
data->output[i][j] = output[i][j];
}
FANN_EXTERNAL struct fann_train_data * FANN_API fann_create_train_array(unsigned int num_data, unsigned int num_input, fann_type *input, unsigned int num_output, fann_type *output)
{
unsigned int i, j;
struct fann_train_data *data;
data = fann_create_train(num_data, num_input, num_output);

if(data == NULL)
return NULL;

for (i = 0; i < num_data; ++i)
{
memcpy(data->input[i], &input[i*num_input], num_input*sizeof(fann_type));
memcpy(data->output[i], &output[i*num_output], num_output*sizeof(fann_type));
}

return data;
Expand Down
25 changes: 25 additions & 0 deletions src/include/fann_cpp.h
Original file line number Diff line number Diff line change
Expand Up @@ -655,6 +655,31 @@ namespace FANN
void set_train_data(unsigned int num_data,
unsigned int num_input, fann_type **input,
unsigned int num_output, fann_type **output)
{
set_train_data(fann_create_train_pointer_array(num_data, num_input, input, num_output, output));
}

/* Method: set_train_data
Set the training data to the input and output data provided.
A copy of the data is made so there are no restrictions on the
allocation of the input/output data and the caller is responsible
for the deallocation of the data pointed to by input and output.
Parameters:
num_data - The number of training data
num_input - The number of inputs per training data
num_output - The number of ouputs per training data
input - The set of inputs (an array with the dimension num_data*num_input)
output - The set of desired outputs (an array with the dimension num_data*num_output)
See also:
<get_input>, <get_output>
*/
void set_train_data(unsigned int num_data,
unsigned int num_input, fann_type *input,
unsigned int num_output, fann_type *output)
{
set_train_data(fann_create_train_array(num_data, num_input, input, num_output, output));
}
Expand Down
29 changes: 25 additions & 4 deletions src/include/fann_train.h
Original file line number Diff line number Diff line change
Expand Up @@ -263,16 +263,37 @@ FANN_EXTERNAL struct fann_train_data *FANN_API fann_read_train_from_file(const c
*/
FANN_EXTERNAL struct fann_train_data * FANN_API fann_create_train(unsigned int num_data, unsigned int num_input, unsigned int num_output);

/* Function: fann_create_train_array
Creates an training data struct and fills it with data from provided arrays.
/* Function: fann_create_train_pointer_array
Creates an training data struct and fills it with data from provided arrays of pointer.
A copy of the data is made so there are no restrictions on the
allocation of the input/output data and the caller is responsible
for the deallocation of the data pointed to by input and output.
See also:
<fann_read_train_from_file>, <fann_train_on_data>, <fann_destroy_train>,
<fann_save_train>, <fann_create_train>, <fann_create_train_array>
This function appears in FANN >= 2.3.0
*/
FANN_EXTERNAL struct fann_train_data * FANN_API fann_create_train_pointer_array(unsigned int num_data, unsigned int num_input, fann_type **input, unsigned int num_output, fann_type **output);

/* Function: fann_create_train_array
Creates an training data struct and fills it with data from provided arrays, where the arrays must have the dimensions:
input[num_data*num_input]
output[num_data*num_output]
A copy of the data is made so there are no restrictions on the
allocation of the input/output data and the caller is responsible
for the deallocation of the data pointed to by input and output.
See also:
<fann_read_train_from_file>, <fann_train_on_data>, <fann_destroy_train>,
<fann_save_train>, <fann_create_train>
<fann_save_train>, <fann_create_train>, <fann_create_train_pointer_array>
This function appears in FANN >= 2.3.0
*/
FANN_EXTERNAL struct fann_train_data * FANN_API fann_create_train_array(unsigned int num_data, unsigned int num_input, fann_type **input, unsigned int num_output, fann_type **output);
FANN_EXTERNAL struct fann_train_data * FANN_API fann_create_train_array(unsigned int num_data, unsigned int num_input, fann_type *input, unsigned int num_output, fann_type *output);

/* Function: fann_create_train_from_callback
Creates the training data struct from a user supplied function.
Expand Down
19 changes: 17 additions & 2 deletions tests/fann_tests.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -127,7 +127,7 @@ TEST(Create, CreateShortcutArrayFourLayers) {
EXPECT_EQ(FANN_NETTYPE_SHORTCUT, net.get_network_type());
}

TEST(CreateTrain, CreateTrainDataFromArrays) {
TEST(CreateTrain, CreateTrainDataFromPointerArrays) {
FANN::training_data data;
fann_type **input = new fann_type*[2];
fann_type **output = new fann_type*[2];
Expand All @@ -140,7 +140,22 @@ TEST(CreateTrain, CreateTrainDataFromArrays) {
output[i][0] = 2.2f;
}

data.set_train_data(2, 3, (fann_type**)input, 1, (fann_type**)output);
data.set_train_data(2, 3, input, 1, output);

for(int i = 0; i < 2; i++) {
for(int j = 0; j < 3; j++) {
EXPECT_EQ(1.1f, data.get_input()[i][j]);
}
EXPECT_EQ(2.2f, data.get_output()[i][0]);
}
}

TEST(CreateTrain, CreateTrainDataFromArrays) {
FANN::training_data data;
fann_type input[] = {1.1f, 1.1f, 1.1f, 1.1f, 1.1f, 1.1f};
fann_type output[] = {2.2f, 2.2f};

data.set_train_data(2, 3, input, 1, output);

for(int i = 0; i < 2; i++) {
for(int j = 0; j < 3; j++) {
Expand Down

0 comments on commit 201ba5e

Please sign in to comment.