Skip to content

Commit

Permalink
added results
Browse files Browse the repository at this point in the history
  • Loading branch information
dariosanfilippo committed May 31, 2021
1 parent 74d5724 commit ecdc84c
Show file tree
Hide file tree
Showing 612 changed files with 14,325 additions and 6,031 deletions.
13 changes: 7 additions & 6 deletions amp_est.c → amp_est_bi.c
Original file line number Diff line number Diff line change
Expand Up @@ -21,16 +21,16 @@ int main(void) {
kad_node_t* t;
kann_t* ann;
size_t inputs = 1024;
size_t outputs = 1;
size_t num_layers = 4;
size_t neurons = 64;
size_t outputs = 4;
size_t num_layers = 1;
size_t neurons = 8;
size_t SR = 192000;

/* Create the neural network */
t = kann_layer_input(inputs);
for (size_t i = 0; i < num_layers; i++) {
t = kann_layer_dense(t, neurons);
t = kad_relu(t);
t = kad_softmax(t);
}
t = kann_layer_cost(t, outputs, KANN_C_MSE);
ann = kann_new(t, 0);
Expand Down Expand Up @@ -71,9 +71,10 @@ int main(void) {
for (size_t j = 0; j < testsize; j++) {
output = kann_apply1(ann, x->vec_space[j]);
printf("Target: %.10f; prediction: %.10f; error factor: %.10f\n",
y->vec_space[j][0], *output, y->vec_space[j][0] / *output);
fprintf(csv, "%f, %f\n", *output, y->vec_space[j][0] / *output);
y->vec_space[j][0], *output, y->vec_space[j][0] - *output);
fprintf(csv, "%f, %f\n", y->vec_space[j][0], y->vec_space[j][0] - *output);
}

sig_free(x);
sig_free(y);
fclose(csv);
Expand Down
19 changes: 10 additions & 9 deletions est_plot.py → amp_est_plot.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@

for filename in glob.glob('*.csv'):

freq = []
amp = []
error = []
tmp = []

Expand All @@ -29,16 +29,15 @@
l = len(data)

for i in range(l):
freq = np.append(freq, tmp[i][0])
amp = np.append(amp, tmp[i][0])
error = np.append(error, tmp[i][1])

plt.xlabel('Frequency (Hz)')
plt.ylabel('Error (target-ANN ratio)')

plt.axhline(y = 1,linewidth = 1, color = 'r', label = "Best fit")
plt.scatter(freq, error, linewidth = .5, color = 'black', label = "ANN output")

plt.xscale('log')
plt.xlabel('Amplitude target')
plt.ylabel('Prediction error (target-ANN difference)')

plt.ylim(-1, 1)
plt.axhline(y = 0, linewidth = .25, color = 'r', label = "Best fit")
plt.scatter(amp, error, marker = "o", s = .1, linewidth = 1, color = 'black', label = "Prediction error")

plt.title("Amplitude estimation")

Expand All @@ -49,3 +48,5 @@
plt.grid(True)

plt.savefig(name)

plt.clf()

This file was deleted.

Binary file not shown.

This file was deleted.

Binary file not shown.

This file was deleted.

Loading

0 comments on commit ecdc84c

Please sign in to comment.