Skip to content
This repository has been archived by the owner on May 15, 2023. It is now read-only.

Commit

Permalink
data_mat
Browse files Browse the repository at this point in the history
  • Loading branch information
fwy423 committed Apr 4, 2017
1 parent 3af941d commit 0fee63a
Show file tree
Hide file tree
Showing 2 changed files with 42 additions and 12 deletions.
2 changes: 1 addition & 1 deletion image_load_save.py
Original file line number Diff line number Diff line change
Expand Up @@ -133,7 +133,7 @@ def batch_recover(batch_input, image_length=16, rows_in_single_image=64, recover


if __name__ == '__main__':
parse_data(400, 100, 100)
parse_data(40, 10, 10)
rval = load_data()
train_input, train_output = rval[0]
valid_input, valid_output = rval[1]
Expand Down
52 changes: 41 additions & 11 deletions tensorflow_recover_CNN.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import tensorflow as tf
import numpy as np
from image_load_save import load_data
from image_load_save import load_data, check_path, batch_recover


def conv_layer(input_x, in_size, out_size, kernal_shape, seed, index=""):
Expand Down Expand Up @@ -117,6 +117,7 @@ def __init__(self, input_size, seed, feature_map_size=[64, 128], learning_rate=1

with tf.name_scope("prediction"):
y_pred = tf.reshape(conv_4_1, [-1, input_size * input_size], name="y_pred")
self.y_pred = y_pred

with tf.name_scope("loss"):
# the loss of prediction result
Expand All @@ -137,6 +138,9 @@ def training(training_input, training_truth,
############
# Training #
############
check_path("my_params")
check_path("log/")

image_size = int(np.sqrt(training_input.shape[1]))

with tf.name_scope('CNN'):
Expand All @@ -149,6 +153,7 @@ def training(training_input, training_truth,
best_valid = np.infty
with tf.Session() as sess:
merged = tf.summary.merge_all()

writer = tf.summary.FileWriter("log/", sess.graph)

sess.run(tf.global_variables_initializer())
Expand All @@ -166,19 +171,44 @@ def training(training_input, training_truth,
sess.run(my_cnn.train_step, feed_dict={my_cnn.xs: training_batch,
my_cnn.ys: training_batch_truth})

train_loss = sess.run(my_cnn.loss, feed_dict={my_cnn.xs: training_batch,
my_cnn.ys: training_batch_truth})

if iter_total % 100 == 1:
print("iter: %d, training RMES: %.4f" % (iter_total, train_loss))
merged_result = sess.run(merged, feed_dict={my_cnn.xs: training_batch,
my_cnn.ys: training_batch_truth})
valid_loss = sess.run(my_cnn.loss, feed_dict={my_cnn.xs: validate_input,
my_cnn.ys: validate_truth})
print("iter: %d, valid RMES: %4f" % (iter_total, valid_loss))
merged_result = sess.run(merged, feed_dict={my_cnn.xs: validate_input,
my_cnn.ys: validate_truth})
writer.add_summary(merged_result, iter_total)

valid_loss = sess.run(my_cnn.loss, feed_dict={my_cnn.xs: validate_input,
my_cnn.ys: validate_truth})
if valid_loss < best_valid:
best_valid = valid_loss
if valid_loss < best_valid:
best_valid = valid_loss
test_loss = sess.run(my_cnn.loss, feed_dict={my_cnn.xs: test_input,
my_cnn.ys: test_truth})

print("======Best validation, test RMES: %4f" % test_loss)

saver = tf.train.Saver()

saver.save(sess, "my_params/save_net.ckpt")


def load_CNN(pred_input, param_path="my_params/save_net.ckpt"):
image_size = int(np.sqrt(pred_input.shape[1]))

my_cnn = My_CNN(
input_size=image_size,
seed=seed,
learning_rate=learning_rate
)

with tf.Session() as sess:
sess.run(tf.global_variables_initializer())

saver = tf.train.Saver()
saver.restore(sess, "my_params/save_net.ckpt")
result = sess.run(my_cnn.y_pred, feed_dict={my_cnn.xs: pred_input,
my_cnn.ys: pred_input})

batch_recover(result)


if __name__ == '__main__':
Expand Down

0 comments on commit 0fee63a

Please sign in to comment.