Skip to content

Commit

Permalink
Add evaluate
Browse files Browse the repository at this point in the history
Add disc mse error
  • Loading branch information
ix64 committed Nov 29, 2018
1 parent 573ad41 commit c750d5e
Show file tree
Hide file tree
Showing 7 changed files with 103 additions and 53 deletions.
2 changes: 1 addition & 1 deletion config.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@ class Arg:
def __init__(self):
parser = ArgumentParser(prog="LittleGAN", description="The code for paper: LittleGAN")

parser.add_argument("mode", type=str, help="run mode", default="train", choices=["train", "plot", "visual", "random-sample"])
parser.add_argument("mode", type=str, help="run mode", default="train", choices=["train", "plot", "visual", "random-sample", "evaluate"])
parser.add_argument("exp_name", type=str, help="experience name")
parser.add_argument("-e", "--env", type=str, help="config environment", default="default")
parser.add_argument("-g", "--gpu", type=str, required=False, help="gpu ids, eg: 0,1,2,3", default="-1")
Expand Down
3 changes: 2 additions & 1 deletion dataset.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,13 @@
import tensorflow as tf
from glob import glob
from utils import soft, data_rescale
from os import path


class CelebA:
def __init__(self, args):
self.args = args
self._image_list = glob(args.image_path + "/*." + args.image_ext)
self._image_list = glob(path.join(args.image_path, "*." + args.image_ext))
self._attributes_list = self._get_attr_list(args.attr_path, args.attr)
self.batches = len(self._image_list) // args.batch_size
self.all_label = ["有短髭", "柳叶眉", "有魅力", "有眼袋", "秃头", "有刘海", "大嘴唇", "大鼻子", "黑发", "金发", "睡眼惺松", "棕发", "浓眉",
Expand Down
58 changes: 32 additions & 26 deletions evaluate.py
Original file line number Diff line number Diff line change
@@ -1,50 +1,56 @@
#!/usr/bin/env python3


from argparse import ArgumentParser
import os

os.environ['TF_CPP_MIN_LOG_LEVEL'] = '3'
parser = ArgumentParser()
parser.add_argument("mode")
parser.add_argument("image_path")
parser.add_argument("stats_path")
parser.add_argument("model_path")
parser.add_argument("--gpu", default="-1")
args = parser.parse_args()

os.environ['CUDA_VISIBLE_DEVICES'] = args.gpu
import glob
# os.environ['CUDA_VISIBLE_DEVICES'] = '2'
import numpy as np
import fid
from scipy.misc import imread
import tensorflow as tf

data_path = 'E:/ixarea/dataset/128'
stat_path = '../LittleGAN-test/fid_stats_celeba_128_all.npz'
inception_path = '../LittleGAN-test/'
image_path = '../LittleGAN-result'

print("check for inception model..", end=" ", flush=True)
inception_path = fid.check_or_download_inception(inception_path)
print("ok")
print("check for inception model..")
inception_path = fid.check_or_download_inception(args.model_path)

if not os.path.isfile(stat_path):
print("load images..", end=" ", flush=True)
image_list = glob.glob(os.path.join(data_path, '*.jpg'))
if args.mode == "pre-calculate":
print("load images..")
image_list = glob.glob(os.path.join(args.image_path, '*.jpg'))
images = np.array([imread(image).astype(np.float32) for image in image_list])
print("%d images found and loaded" % len(images))

print("create inception graph..", end=" ", flush=True)
fid.create_inception_graph(inception_path)
print("ok")

print("calculte FID stats..", end=" ", flush=True)
print("calculate FID stats..", end=" ", flush=True)
with tf.Session() as sess:
sess.run(tf.global_variables_initializer())
mu, sigma = fid.calculate_activation_statistics(images, sess, batch_size=100)
np.savez_compressed(stat_path, mu=mu, sigma=sigma)
np.savez_compressed(args.stats_path, mu=mu, sigma=sigma)
print("finished")
else:
image_list = glob.glob(os.path.join(args.image_path, '*.jpg'))
images = np.array([imread(str(fn)).astype(np.float32) for fn in image_list])

image_list = glob.glob(os.path.join(image_path, '*.jpg'))
images = np.array([imread(str(fn)).astype(np.float32) for fn in image_list])
f = np.load(args.stats_path)
mu_real, sigma_real = f['mu'][:], f['sigma'][:]
f.close()

f = np.load(stat_path)
mu_real, sigma_real = f['mu'][:], f['sigma'][:]
f.close()

fid.create_inception_graph(inception_path)
with tf.Session() as sess:
sess.run(tf.global_variables_initializer())
mu_gen, sigma_gen = fid.calculate_activation_statistics(images, sess, batch_size=100)
fid.create_inception_graph(inception_path)
with tf.Session() as sess:
sess.run(tf.global_variables_initializer())
mu_gen, sigma_gen = fid.calculate_activation_statistics(images, sess, batch_size=100)

fid_value = fid.calculate_frechet_distance(mu_gen, sigma_gen, mu_real, sigma_real)
print("FID: %s" % fid_value)
fid_value = fid.calculate_frechet_distance(mu_gen, sigma_gen, mu_real, sigma_real)
print("FID: %s" % fid_value)
2 changes: 1 addition & 1 deletion fid.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@ class InvalidFIDException(Exception):
def create_inception_graph(pth):
"""Creates a graph from saved GraphDef file."""
# Creates graph from saved graph_def.pb.
with tf.gfile.FastGFile(pth, 'rb') as f:
with tf.gfile.GFile(pth, 'rb') as f:
graph_def = tf.GraphDef()
graph_def.ParseFromString(f.read())
_ = tf.import_graph_def(graph_def, name='FID_Inception_Net')
Expand Down
10 changes: 7 additions & 3 deletions littlegan.py
Original file line number Diff line number Diff line change
Expand Up @@ -339,7 +339,7 @@ def train(self):
self.checkpoint.save(path.join(self.args.result_dir, "checkpoint", str(e)))

def init_result_dir(self):
dirs = [".", "train/gen", "train/adj", "test/adj", "test/gen", "test/disc", "checkpoint", "log", "sample", "evaluate/sample"]
dirs = [".", "train/gen", "train/adj", "test/adj", "test/gen", "test/disc", "checkpoint", "log", "sample", "evaluate/gen", "evaluate/adj"]
for item in dirs:
if not path.exists(path.join(self.args.result_dir, item)):
makedirs(path.join(self.args.result_dir, item))
Expand Down Expand Up @@ -375,8 +375,12 @@ def predict(self, noise, cond, image, gen_image_save_path=None, json_save_path=N
save["real_cond"] = cond
save["real_pr"], save["real_c"] = self.discriminator(image)
save["fake_pr"], save["fake_c"] = self.discriminator(gen_image)
for x in save:
save[x] = (tf.round(save[x] * 10)).numpy().astype(int).tolist()
save["real_pr_mse"] = tf.reduce_mean(tf.keras.metrics.mean_squared_error(soft(1), save["real_pr"]), axis=0).numpy().astype(float)
save["real_c_mse"] = tf.reduce_mean(tf.keras.metrics.mean_squared_error(cond, save["real_c"]), axis=0).numpy().astype(float)
save["fake_pr_mse"] = tf.reduce_mean(tf.keras.metrics.mean_squared_error(soft(0), save["fake_pr"]), axis=0).numpy().astype(float)
save["fake_c_mse"] = tf.reduce_mean(tf.keras.metrics.mean_squared_error(cond, save["fake_c"]), axis=0).numpy().astype(float)
for x in ["real_cond", "real_pr", "real_c", "fake_c", "fake_pr"]:
save[x] = (tf.round(save[x] * 100)).numpy().astype(int).tolist()
if None is not json_save_path:
with open(json_save_path, "w") as f:
json.dump(save, f)
Expand Down
76 changes: 57 additions & 19 deletions main.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,12 +6,12 @@
tf.enable_eager_execution()

from os import path, system

from dataset import CelebA
from utils import save_image
from littlegan import Trainer, Adjuster, Discriminator, Decoder, Encoder, Generator

from git import Repo
import time
import numpy as np

decoder = Decoder(args)
encoder = Encoder(args)
Expand All @@ -26,35 +26,73 @@

print("Using GPUs: ", args.gpu)

data = CelebA(args)
print("\r\nImage Flows From: ", args.image_path, " Image Count: ", args.batch_size * data.batches)
print("\r\nUsing Attribute: ", data.label)

model = Trainer(args, generator, discriminator, adjuster, data)

if args.mode == "train":
repo = Repo(".")
if repo.is_dirty() and not args.debug: # 程序被修改且不是测试模式
raise EnvironmentError("Git repo is Dirty! Please train after committed.")
data = CelebA(args)
print("\r\nImage Flows From: ", args.image_path, " Image Count: ", args.batch_size * data.batches)
print("\r\nUsing Attribute: ", data.label)
model = Trainer(args, generator, discriminator, adjuster, data)
model.train()
elif args.mode == "visual": # loss etc的可视化
print("The result path is ", path.join(args.result_dir, "log"))
system("tensorboard --host 0.0.0.0 --logdir " + path.join(args.result_dir, "log"))
elif args.mode == "plot":
model = Trainer(args, generator, discriminator, adjuster, None)
model.plot() # 输出模型结构图
elif args.mode == "random-sample":
args.batch_size = args.random_sample_size
args.prefetch = args.random_sample_size
data = CelebA(args)
model = Trainer(args, generator, discriminator, adjuster, data)
iterator = data.get_new_iterator()
image, cond = iterator.get_next()
noise = tf.random_uniform([cond.shape[0], args.noise_dim])
time = int(time.time())
model.predict(noise, cond, image,
path.join(args.result_dir, "sample", "generator-%s.jpg" % time),
path.join(args.result_dir, "sample", "discriminator%s.json" % time),
path.join(args.result_dir, "sample", "adjuster%s.jpg" % time)
)
now_time = int(time.time())
for b in range(args.random_sample_batch):
image, cond = iterator.get_next()
noise = tf.random_uniform([cond.shape[0], args.noise_dim])

model.predict(noise, cond, image,
path.join(args.result_dir, "sample", "generator-%s-%d.jpg" % (now_time, b)),
path.join(args.result_dir, "sample", "discriminator-%s-%d.json" % (now_time, b)),
path.join(args.result_dir, "sample", "adjuster-%s-%d.jpg" % (now_time, b))
)
np.savez_compressed(path.join(args.result_dir, "sample", "input_data-%s-%d.npz" % (now_time, b)), n=noise, c=cond, i=image)
elif args.mode == "evaluate":
iterator = data.get_new_iterator()
progress = tf.keras.utils.Progbar(args.evaluate_sample_batch * args.batch_size)
for b in range(args.evaluate_sample_batch):
base_index = b * args.batch_size + 1
image, cond = iterator.get_next()
noise = tf.random_uniform([cond.shape[0], args.noise_dim])
gen_image, save, adj_real_image, adj_fake_image = model.predict(noise, cond, image,
None, path.join(args.result_dir, "evaluate", "discriminator.json"), None)
for i in range(args.batch_size):
save_image(gen_image[i], path.join(args.result_dir, "evaluate", "gen", str(base_index + i) + ".jpg"))
if adj_real_image is not None and adj_fake_image is not None:
save_image(adj_real_image[i], path.join(args.result_dir, "evaluate", "adj", "real_" + str(base_index + i) + ".jpg"))
save_image(adj_fake_image[i], path.join(args.result_dir, "evaluate", "adj", "fake_" + str(base_index + i) + ".jpg"))
progress.add(args.batch_size)

if not args.gpu:
args.gpu = [-1]

gen_cmd = "python evaluate.py calc %s %s %s --gpu %s" % (
path.join(args.result_dir, "evaluate", "gen"),
path.join(args.test_data_dir, args.evaluate_pre_calculated),
args.test_data_dir,
",".join(map(str, args.gpu))
)

print("Running: \"", gen_cmd, "\"")
system(gen_cmd)
if args.train_adj:
adj_cmd = "python evaluate.py calc %s %s %s --gpu %s" % (
path.join(args.result_dir, "evaluate", "adj"),
path.join(args.test_data_dir, args.evaluate_pre_calculated),
args.test_data_dir,
",".join(map(str, args.gpu))
)
print("Running: \"", adj_cmd, "\"")
system(adj_cmd)


else:
print("没有此模式:", args.mode)
Expand Down
5 changes: 3 additions & 2 deletions sample.config.json
Original file line number Diff line number Diff line change
Expand Up @@ -34,8 +34,9 @@
"freq_test": 2000,
"all_result_dir": "/path/to/LittleGAN-result",
"test_data_dir": "/path/to/LittleGAN-test",
"random_sample_size": 100,
"evaluate_sample_size": 1000,
"evaluate_pre_calculated": "fid_stats_celeba_128_all.npz",
"random_sample_batch": 4,
"evaluate_sample_batch": 50,
"restore": true,
"train_adj": true,
"prefetch": 320
Expand Down

0 comments on commit c750d5e

Please sign in to comment.