Skip to content

Commit

Permalink
output management:
Browse files Browse the repository at this point in the history
- set root output directory with --out_dir.
- set output name with --out_name (or leave blank for it to be generated automatically from hyperparams)
- checkpoints, samples and log saved under out_dir/out_name
- environment variables expanded
- save json for FLAGS for reference
- auto set output size to input size if None
  • Loading branch information
memo committed Jul 3, 2018
1 parent 3dd932f commit c3b734b
Show file tree
Hide file tree
Showing 3 changed files with 59 additions and 22 deletions.
52 changes: 38 additions & 14 deletions main.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,10 @@
import os
import scipy.misc
import numpy as np
import json

from model import DCGAN
from utils import pp, visualize, to_json, show_all_variables
from utils import pp, visualize, to_json, show_all_variables, expand_path, timestamp

import tensorflow as tf

Expand All @@ -19,9 +20,11 @@
flags.DEFINE_integer("output_width", None, "The size of the output images to produce. If None, same value as output_height [None]")
flags.DEFINE_string("dataset", "celebA", "The name of dataset [celebA, mnist, lsun]")
flags.DEFINE_string("input_fname_pattern", "*.jpg", "Glob pattern of filename of input images [*]")
flags.DEFINE_string("checkpoint_dir", "checkpoint", "Directory name to save the checkpoints [checkpoint]")
flags.DEFINE_string("data_dir", "./data", "Root directory of dataset [data]")
flags.DEFINE_string("sample_dir", "samples", "Directory name to save the image samples [samples]")
flags.DEFINE_string("data_dir", "$HOME/data", "path to datasets [$HOME/data]")
flags.DEFINE_string("out_dir", "$HOME/out", "Root directory for outputs [$HOME/out]")
flags.DEFINE_string("out_name", "", "Folder (under out_root_dir) for all outputs. Generated automatically if left blank []")
flags.DEFINE_string("checkpoint_dir", "checkpoint", "Folder (under out_root_dir/out_name) to save checkpoints [checkpoint]")
flags.DEFINE_string("sample_dir", "samples", "Folder (under out_root_dir/out_name) to save samples [samples]")
flags.DEFINE_boolean("train", False, "True for training, False for testing [False]")
flags.DEFINE_boolean("crop", False, "True for training, False for testing [False]")
flags.DEFINE_boolean("visualize", False, "True for visualizing, False for nothing [False]")
Expand All @@ -37,16 +40,35 @@

def main(_):
pp.pprint(flags.FLAGS.__flags)

if FLAGS.input_width is None:
FLAGS.input_width = FLAGS.input_height
if FLAGS.output_width is None:
FLAGS.output_width = FLAGS.output_height

if not os.path.exists(FLAGS.checkpoint_dir):
os.makedirs(FLAGS.checkpoint_dir)
if not os.path.exists(FLAGS.sample_dir):
os.makedirs(FLAGS.sample_dir)

# expand user name and environment variables
FLAGS.data_dir = expand_path(FLAGS.data_dir)
FLAGS.out_dir = expand_path(FLAGS.out_dir)
FLAGS.out_name = expand_path(FLAGS.out_name)
FLAGS.checkpoint_dir = expand_path(FLAGS.checkpoint_dir)
FLAGS.sample_dir = expand_path(FLAGS.sample_dir)

if FLAGS.output_height is None: FLAGS.output_height = FLAGS.input_height
if FLAGS.input_width is None: FLAGS.input_width = FLAGS.input_height
if FLAGS.output_width is None: FLAGS.output_width = FLAGS.output_height

# output folders
if FLAGS.out_name == "":
FLAGS.out_name = '{} - {} - {}'.format(timestamp(), FLAGS.data_dir.split('/')[-1], FLAGS.dataset) # penultimate folder of path
if FLAGS.train:
FLAGS.out_name += ' - x{}.z{}.{}.y{}.b{}'.format(FLAGS.input_width, FLAGS.z_dim, FLAGS.z_dist, FLAGS.output_width, FLAGS.batch_size)

FLAGS.out_dir = os.path.join(FLAGS.out_dir, FLAGS.out_name)
FLAGS.checkpoint_dir = os.path.join(FLAGS.out_dir, FLAGS.checkpoint_dir)
FLAGS.sample_dir = os.path.join(FLAGS.out_dir, FLAGS.sample_dir)

if not os.path.exists(FLAGS.checkpoint_dir): os.makedirs(FLAGS.checkpoint_dir)
if not os.path.exists(FLAGS.sample_dir): os.makedirs(FLAGS.sample_dir)

with open(os.path.join(FLAGS.out_dir, 'FLAGS.json'), 'w') as f:
flags_dict = {k:FLAGS[k].value for k in FLAGS}
json.dump(flags_dict, f, indent=4, sort_keys=True, ensure_ascii=False)


#gpu_options = tf.GPUOptions(per_process_gpu_memory_fraction=0.333)
run_config = tf.ConfigProto()
Expand All @@ -70,6 +92,7 @@ def main(_):
checkpoint_dir=FLAGS.checkpoint_dir,
sample_dir=FLAGS.sample_dir,
data_dir=FLAGS.data_dir,
out_dir=FLAGS.out_dir,
max_to_keep=FLAGS.max_to_keep)
else:
dcgan = DCGAN(
Expand All @@ -87,6 +110,7 @@ def main(_):
checkpoint_dir=FLAGS.checkpoint_dir,
sample_dir=FLAGS.sample_dir,
data_dir=FLAGS.data_dir,
out_dir=FLAGS.out_dir,
max_to_keep=FLAGS.max_to_keep)

show_all_variables()
Expand Down
5 changes: 3 additions & 2 deletions model.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@ def __init__(self, sess, input_height=108, input_width=108, crop=True,
y_dim=None, z_dim=100, gf_dim=64, df_dim=64,
gfc_dim=1024, dfc_dim=1024, c_dim=3, dataset_name='default',
max_to_keep=1,
input_fname_pattern='*.jpg', checkpoint_dir=None, sample_dir=None, data_dir='./data'):
input_fname_pattern='*.jpg', checkpoint_dir='ckpts', sample_dir='samples', out_dir='./out', data_dir='./data'):
"""
Args:
Expand Down Expand Up @@ -78,6 +78,7 @@ def __init__(self, sess, input_height=108, input_width=108, crop=True,
self.input_fname_pattern = input_fname_pattern
self.checkpoint_dir = checkpoint_dir
self.data_dir = data_dir
self.out_dir = out_dir
self.max_to_keep = max_to_keep

if self.dataset_name == 'mnist':
Expand Down Expand Up @@ -173,7 +174,7 @@ def train(self, config):
self.G_sum, self.d_loss_fake_sum, self.g_loss_sum])
self.d_sum = merge_summary(
[self.z_sum, self.d_sum, self.d_loss_real_sum, self.d_loss_sum])
self.writer = SummaryWriter("./logs", self.sess.graph)
self.writer = SummaryWriter(os.path.join(self.out_dir, "logs"), self.sess.graph)

sample_z = gen_random(config.z_dist, size=(self.sample_num , self.z_dim))

Expand Down
24 changes: 18 additions & 6 deletions utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,9 @@
import pprint
import scipy.misc
import numpy as np
import os
import time
import datetime
from time import gmtime, strftime
from six.moves import xrange

Expand All @@ -18,6 +21,15 @@

get_stddev = lambda x, k_h, k_w: 1/math.sqrt(k_w*k_h*x.get_shape()[-1])


def expand_path(path):
return os.path.expanduser(os.path.expandvars(path))

def timestamp(s='%Y%m%d.%H%M%S', ts=None):
if not ts: ts = time.time()
st = datetime.datetime.fromtimestamp(ts).strftime(s)
return st

def show_all_variables():
model_vars = tf.trainable_variables()
slim.model_analyzer.analyze_vars(model_vars, print_info=True)
Expand Down Expand Up @@ -169,12 +181,12 @@ def make_frame(t):
clip = mpy.VideoClip(make_frame, duration=duration)
clip.write_gif(fname, fps = len(images) / duration)

def visualize(sess, dcgan, config, option):
def visualize(sess, dcgan, config, option, sample_dir='samples'):
image_frame_dim = int(math.ceil(config.batch_size**.5))
if option == 0:
z_sample = np.random.uniform(-0.5, 0.5, size=(config.batch_size, dcgan.z_dim))
samples = sess.run(dcgan.sampler, feed_dict={dcgan.z: z_sample})
save_images(samples, [image_frame_dim, image_frame_dim], './samples/test_%s.png' % strftime("%Y-%m-%d-%H-%M-%S", gmtime()))
save_images(samples, [image_frame_dim, image_frame_dim], os.path.join(sample_dir, 'test_%s.png' % strftime("%Y%m%d%H%M%S", gmtime() )))
elif option == 1:
values = np.arange(0, 1, 1./config.batch_size)
for idx in xrange(dcgan.z_dim):
Expand All @@ -192,7 +204,7 @@ def visualize(sess, dcgan, config, option):
else:
samples = sess.run(dcgan.sampler, feed_dict={dcgan.z: z_sample})

save_images(samples, [image_frame_dim, image_frame_dim], './samples/test_arange_%s.png' % (idx))
save_images(samples, [image_frame_dim, image_frame_dim], os.path.join(sample_dir, 'test_arange_%s.png' % (idx)))
elif option == 2:
values = np.arange(0, 1, 1./config.batch_size)
for idx in [random.randint(0, dcgan.z_dim - 1) for _ in xrange(dcgan.z_dim)]:
Expand All @@ -215,7 +227,7 @@ def visualize(sess, dcgan, config, option):
try:
make_gif(samples, './samples/test_gif_%s.gif' % (idx))
except:
save_images(samples, [image_frame_dim, image_frame_dim], './samples/test_%s.png' % strftime("%Y-%m-%d-%H-%M-%S", gmtime()))
save_images(samples, [image_frame_dim, image_frame_dim], os.path.join(sample_dir, 'test_%s.png' % strftime("%Y%m%d%H%M%S", gmtime() )))
elif option == 3:
values = np.arange(0, 1, 1./config.batch_size)
for idx in xrange(dcgan.z_dim):
Expand All @@ -225,7 +237,7 @@ def visualize(sess, dcgan, config, option):
z[idx] = values[kdx]

samples = sess.run(dcgan.sampler, feed_dict={dcgan.z: z_sample})
make_gif(samples, './samples/test_gif_%s.gif' % (idx))
make_gif(samples, os.path.join(sample_dir, 'test_gif_%s.gif' % (idx)))
elif option == 4:
image_set = []
values = np.arange(0, 1, 1./config.batch_size)
Expand All @@ -236,7 +248,7 @@ def visualize(sess, dcgan, config, option):
for kdx, z in enumerate(z_sample): z[idx] = values[kdx]

image_set.append(sess.run(dcgan.sampler, feed_dict={dcgan.z: z_sample}))
make_gif(image_set[-1], './samples/test_gif_%s.gif' % (idx))
make_gif(image_set[-1], os.path.join(sample_dir, 'test_gif_%s.gif' % (idx)))

new_image_set = [merge(np.array([images[idx] for images in image_set]), [10, 10]) \
for idx in range(64) + range(63, -1, -1)]
Expand Down

0 comments on commit c3b734b

Please sign in to comment.