From c3b734b4c907d41bf9ead03aef5f3ff9f0824469 Mon Sep 17 00:00:00 2001 From: Memo Akten Date: Tue, 3 Jul 2018 19:14:37 +0100 Subject: [PATCH] output management: - set root output directory with --out_dir. - set output name with --out_name (or leave blank for it to be generated automatically from hyperparams) - checkpoints, samples and log saved under out_dir/out_name - environment variables expanded - save json for FLAGS for reference - auto set output size to input size if None --- main.py | 52 ++++++++++++++++++++++++++++++++++++++-------------- model.py | 5 +++-- utils.py | 24 ++++++++++++++++++------ 3 files changed, 59 insertions(+), 22 deletions(-) diff --git a/main.py b/main.py index b499124a6..dcfd5acaf 100644 --- a/main.py +++ b/main.py @@ -1,9 +1,10 @@ import os import scipy.misc import numpy as np +import json from model import DCGAN -from utils import pp, visualize, to_json, show_all_variables +from utils import pp, visualize, to_json, show_all_variables, expand_path, timestamp import tensorflow as tf @@ -19,9 +20,11 @@ flags.DEFINE_integer("output_width", None, "The size of the output images to produce. If None, same value as output_height [None]") flags.DEFINE_string("dataset", "celebA", "The name of dataset [celebA, mnist, lsun]") flags.DEFINE_string("input_fname_pattern", "*.jpg", "Glob pattern of filename of input images [*]") -flags.DEFINE_string("checkpoint_dir", "checkpoint", "Directory name to save the checkpoints [checkpoint]") -flags.DEFINE_string("data_dir", "./data", "Root directory of dataset [data]") -flags.DEFINE_string("sample_dir", "samples", "Directory name to save the image samples [samples]") +flags.DEFINE_string("data_dir", "$HOME/data", "path to datasets [$HOME/data]") +flags.DEFINE_string("out_dir", "$HOME/out", "Root directory for outputs [$HOME/out]") +flags.DEFINE_string("out_name", "", "Folder (under out_root_dir) for all outputs. Generated automatically if left blank []") +flags.DEFINE_string("checkpoint_dir", "checkpoint", "Folder (under out_root_dir/out_name) to save checkpoints [checkpoint]") +flags.DEFINE_string("sample_dir", "samples", "Folder (under out_root_dir/out_name) to save samples [samples]") flags.DEFINE_boolean("train", False, "True for training, False for testing [False]") flags.DEFINE_boolean("crop", False, "True for training, False for testing [False]") flags.DEFINE_boolean("visualize", False, "True for visualizing, False for nothing [False]") @@ -37,16 +40,35 @@ def main(_): pp.pprint(flags.FLAGS.__flags) - - if FLAGS.input_width is None: - FLAGS.input_width = FLAGS.input_height - if FLAGS.output_width is None: - FLAGS.output_width = FLAGS.output_height - - if not os.path.exists(FLAGS.checkpoint_dir): - os.makedirs(FLAGS.checkpoint_dir) - if not os.path.exists(FLAGS.sample_dir): - os.makedirs(FLAGS.sample_dir) + + # expand user name and environment variables + FLAGS.data_dir = expand_path(FLAGS.data_dir) + FLAGS.out_dir = expand_path(FLAGS.out_dir) + FLAGS.out_name = expand_path(FLAGS.out_name) + FLAGS.checkpoint_dir = expand_path(FLAGS.checkpoint_dir) + FLAGS.sample_dir = expand_path(FLAGS.sample_dir) + + if FLAGS.output_height is None: FLAGS.output_height = FLAGS.input_height + if FLAGS.input_width is None: FLAGS.input_width = FLAGS.input_height + if FLAGS.output_width is None: FLAGS.output_width = FLAGS.output_height + + # output folders + if FLAGS.out_name == "": + FLAGS.out_name = '{} - {} - {}'.format(timestamp(), FLAGS.data_dir.split('/')[-1], FLAGS.dataset) # penultimate folder of path + if FLAGS.train: + FLAGS.out_name += ' - x{}.z{}.{}.y{}.b{}'.format(FLAGS.input_width, FLAGS.z_dim, FLAGS.z_dist, FLAGS.output_width, FLAGS.batch_size) + + FLAGS.out_dir = os.path.join(FLAGS.out_dir, FLAGS.out_name) + FLAGS.checkpoint_dir = os.path.join(FLAGS.out_dir, FLAGS.checkpoint_dir) + FLAGS.sample_dir = os.path.join(FLAGS.out_dir, FLAGS.sample_dir) + + if not os.path.exists(FLAGS.checkpoint_dir): os.makedirs(FLAGS.checkpoint_dir) + if not os.path.exists(FLAGS.sample_dir): os.makedirs(FLAGS.sample_dir) + + with open(os.path.join(FLAGS.out_dir, 'FLAGS.json'), 'w') as f: + flags_dict = {k:FLAGS[k].value for k in FLAGS} + json.dump(flags_dict, f, indent=4, sort_keys=True, ensure_ascii=False) + #gpu_options = tf.GPUOptions(per_process_gpu_memory_fraction=0.333) run_config = tf.ConfigProto() @@ -70,6 +92,7 @@ def main(_): checkpoint_dir=FLAGS.checkpoint_dir, sample_dir=FLAGS.sample_dir, data_dir=FLAGS.data_dir, + out_dir=FLAGS.out_dir, max_to_keep=FLAGS.max_to_keep) else: dcgan = DCGAN( @@ -87,6 +110,7 @@ def main(_): checkpoint_dir=FLAGS.checkpoint_dir, sample_dir=FLAGS.sample_dir, data_dir=FLAGS.data_dir, + out_dir=FLAGS.out_dir, max_to_keep=FLAGS.max_to_keep) show_all_variables() diff --git a/model.py b/model.py index c47b146ae..0d6526b94 100644 --- a/model.py +++ b/model.py @@ -26,7 +26,7 @@ def __init__(self, sess, input_height=108, input_width=108, crop=True, y_dim=None, z_dim=100, gf_dim=64, df_dim=64, gfc_dim=1024, dfc_dim=1024, c_dim=3, dataset_name='default', max_to_keep=1, - input_fname_pattern='*.jpg', checkpoint_dir=None, sample_dir=None, data_dir='./data'): + input_fname_pattern='*.jpg', checkpoint_dir='ckpts', sample_dir='samples', out_dir='./out', data_dir='./data'): """ Args: @@ -78,6 +78,7 @@ def __init__(self, sess, input_height=108, input_width=108, crop=True, self.input_fname_pattern = input_fname_pattern self.checkpoint_dir = checkpoint_dir self.data_dir = data_dir + self.out_dir = out_dir self.max_to_keep = max_to_keep if self.dataset_name == 'mnist': @@ -173,7 +174,7 @@ def train(self, config): self.G_sum, self.d_loss_fake_sum, self.g_loss_sum]) self.d_sum = merge_summary( [self.z_sum, self.d_sum, self.d_loss_real_sum, self.d_loss_sum]) - self.writer = SummaryWriter("./logs", self.sess.graph) + self.writer = SummaryWriter(os.path.join(self.out_dir, "logs"), self.sess.graph) sample_z = gen_random(config.z_dist, size=(self.sample_num , self.z_dim)) diff --git a/utils.py b/utils.py index cd52cb5ec..879aa546d 100644 --- a/utils.py +++ b/utils.py @@ -8,6 +8,9 @@ import pprint import scipy.misc import numpy as np +import os +import time +import datetime from time import gmtime, strftime from six.moves import xrange @@ -18,6 +21,15 @@ get_stddev = lambda x, k_h, k_w: 1/math.sqrt(k_w*k_h*x.get_shape()[-1]) + +def expand_path(path): + return os.path.expanduser(os.path.expandvars(path)) + +def timestamp(s='%Y%m%d.%H%M%S', ts=None): + if not ts: ts = time.time() + st = datetime.datetime.fromtimestamp(ts).strftime(s) + return st + def show_all_variables(): model_vars = tf.trainable_variables() slim.model_analyzer.analyze_vars(model_vars, print_info=True) @@ -169,12 +181,12 @@ def make_frame(t): clip = mpy.VideoClip(make_frame, duration=duration) clip.write_gif(fname, fps = len(images) / duration) -def visualize(sess, dcgan, config, option): +def visualize(sess, dcgan, config, option, sample_dir='samples'): image_frame_dim = int(math.ceil(config.batch_size**.5)) if option == 0: z_sample = np.random.uniform(-0.5, 0.5, size=(config.batch_size, dcgan.z_dim)) samples = sess.run(dcgan.sampler, feed_dict={dcgan.z: z_sample}) - save_images(samples, [image_frame_dim, image_frame_dim], './samples/test_%s.png' % strftime("%Y-%m-%d-%H-%M-%S", gmtime())) + save_images(samples, [image_frame_dim, image_frame_dim], os.path.join(sample_dir, 'test_%s.png' % strftime("%Y%m%d%H%M%S", gmtime() ))) elif option == 1: values = np.arange(0, 1, 1./config.batch_size) for idx in xrange(dcgan.z_dim): @@ -192,7 +204,7 @@ def visualize(sess, dcgan, config, option): else: samples = sess.run(dcgan.sampler, feed_dict={dcgan.z: z_sample}) - save_images(samples, [image_frame_dim, image_frame_dim], './samples/test_arange_%s.png' % (idx)) + save_images(samples, [image_frame_dim, image_frame_dim], os.path.join(sample_dir, 'test_arange_%s.png' % (idx))) elif option == 2: values = np.arange(0, 1, 1./config.batch_size) for idx in [random.randint(0, dcgan.z_dim - 1) for _ in xrange(dcgan.z_dim)]: @@ -215,7 +227,7 @@ def visualize(sess, dcgan, config, option): try: make_gif(samples, './samples/test_gif_%s.gif' % (idx)) except: - save_images(samples, [image_frame_dim, image_frame_dim], './samples/test_%s.png' % strftime("%Y-%m-%d-%H-%M-%S", gmtime())) + save_images(samples, [image_frame_dim, image_frame_dim], os.path.join(sample_dir, 'test_%s.png' % strftime("%Y%m%d%H%M%S", gmtime() ))) elif option == 3: values = np.arange(0, 1, 1./config.batch_size) for idx in xrange(dcgan.z_dim): @@ -225,7 +237,7 @@ def visualize(sess, dcgan, config, option): z[idx] = values[kdx] samples = sess.run(dcgan.sampler, feed_dict={dcgan.z: z_sample}) - make_gif(samples, './samples/test_gif_%s.gif' % (idx)) + make_gif(samples, os.path.join(sample_dir, 'test_gif_%s.gif' % (idx))) elif option == 4: image_set = [] values = np.arange(0, 1, 1./config.batch_size) @@ -236,7 +248,7 @@ def visualize(sess, dcgan, config, option): for kdx, z in enumerate(z_sample): z[idx] = values[kdx] image_set.append(sess.run(dcgan.sampler, feed_dict={dcgan.z: z_sample})) - make_gif(image_set[-1], './samples/test_gif_%s.gif' % (idx)) + make_gif(image_set[-1], os.path.join(sample_dir, 'test_gif_%s.gif' % (idx))) new_image_set = [merge(np.array([images[idx] for images in image_set]), [10, 10]) \ for idx in range(64) + range(63, -1, -1)]