Skip to content

Commit

Permalink
Added MAD-GAN impl. + code-refactoring
Browse files Browse the repository at this point in the history
  • Loading branch information
sanghoon committed Jun 30, 2017
1 parent 87bc153 commit 4f4fd58
Show file tree
Hide file tree
Showing 13 changed files with 635 additions and 120 deletions.
Empty file added base_gan.py
Empty file.
8 changes: 4 additions & 4 deletions began_mnist.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,11 @@
#!/usr/bin/env python
import os
from tensorflow.examples.tutorials.mnist import input_data
from models import *
from utils import *
from common import *

from common import *
from models.models import *
from utils import *

# TODO: Refactoring
args = parse_args(models.keys())

print args
Expand Down
5 changes: 1 addition & 4 deletions classify_mnist.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,11 +3,8 @@
if 'DISPLAY' not in os.environ:
import matplotlib
matplotlib.use('Agg') # Use a different backend
import tensorflow as tf
from utils import *
import numpy as np
from tensorflow.examples.tutorials.mnist import input_data
from models import *
from models.models import *


# Load dataset
Expand Down
44 changes: 34 additions & 10 deletions common.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,11 +18,15 @@
N_ITERS = N_ITERS_PER_EPOCH * 200


DATASETS = ['mnist', 'celeba']


# Helper functions
def sample_z(m, n):
return np.random.uniform(-1., 1., size=[m, n])

def plot(samples, figId=None, retBytes=False):

def plot(samples, figId=None, retBytes=False, shape=None):
if figId is None:
fig = plt.figure(figsize=(4, 4))
else:
Expand All @@ -37,7 +41,12 @@ def plot(samples, figId=None, retBytes=False):
ax.set_xticklabels([])
ax.set_yticklabels([])
ax.set_aspect('equal')
plt.imshow(sample.reshape(28, 28), cmap='Greys_r')
if shape and shape[2] == 3:
# FIXME: Naive impl. of rescaling
rescaled = (sample + 1.0) / 2.0
plt.imshow(rescaled.reshape(*shape))
else:
plt.imshow(sample.reshape(28, 28), cmap='Greys_r')

if retBytes:
buf = io.BytesIO()
Expand All @@ -48,22 +57,37 @@ def plot(samples, figId=None, retBytes=False):
return fig


def parse_args(modelnames=[]):
def parse_args(modelnames=[], additional_args=[]):
parser = argparse.ArgumentParser()

if len(modelnames) > 0:
parser.add_argument('-net', choices=modelnames, default=None)
parser.add_argument('--net', choices=modelnames, default=None)

parser.add_argument('--gpu', type=int, default=0)
parser.add_argument('--batchsize', type=int, default=128)
parser.add_argument('--data', choices=DATASETS, default=DATASETS[0])
parser.add_argument('--lr', type=float, default=1e-5)

parser.add_argument('-w_clip', type=float, default=0.1)
parser.add_argument('-bn', action='store_true', default=False)
parser.add_argument('-nobn', action='store_true', default=False) # FIXME
parser.add_argument('-lr', type=float, default=1e-5)
parser.add_argument('-tag', type=str, default='')
parser.add_argument('-kernel', type=int, default=5) # only for ConvNets
# All of these arguments can be ignored
parser.add_argument('--w_clip', type=float, default=0.1)
parser.add_argument('--bn', action='store_true', default=False)
parser.add_argument('--nobn', action='store_true', default=False) # FIXME

parser.add_argument('--tag', type=str, default='')
parser.add_argument('--kernel', type=int, default=5) # only for ConvNets

for key, kwargs in additional_args:
parser.add_argument(key, **kwargs)

args = parser.parse_args()

if args.nobn == True:
assert(args.bn is not True)

return args


def set_gpu(gpu_id):
print "Override GPU setting: gpu={}".format(gpu_id)
import os
os.environ["CUDA_VISIBLE_DEVICES"] = "{}".format(gpu_id)
97 changes: 97 additions & 0 deletions data_celeba.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,97 @@
import tensorflow as tf
import os.path
import glob
import cv2
import random


class ImgDataset:
def __init__(self, dataDir, i_from=0, i_to=None, shuffle=False, crop=None, resize=None):
self.dataDir = dataDir
self.img_list = glob.glob(os.path.join(dataDir, "*.jpg"))
self.img_list = sorted(self.img_list)[i_from:i_to]

self.shuffle = shuffle
self._i = 0

self.resize = resize
self.crop = crop

if shuffle:
random.shuffle(self.img_list)

self.images = [self[0]] # Dummy image for size calculation in other codes

def crop_and_resize(self, im):
# Crop
if self.crop:
h, w, = im.shape[:2]
j = int(round((h - self.crop) / 2.))
i = int(round((w - self.crop) / 2.))

im = im[j:j+self.crop, i:i+self.crop, :]

if self.resize:
im = cv2.resize(im, (self.resize, self.resize))

# rescale (range: -1.0~1.0)
im = (im / 127.5 - 1.)

return im

def __getitem__(self, item):
if isinstance(item, tuple) or isinstance(item, slice):
im = map(cv2.imread, self.img_list[item])
im = map(self.crop_and_resize, im)
else:
# Read image
im = cv2.imread(self.img_list[item])
im = self.crop_and_resize(im)

return im

def __len__(self):
return len(self.img_list)

def next_batch(self, batch_size):
samples = self[self._i : self._i + batch_size]
self._i += batch_size

# If reached the end of the dataset
if self._i >= len(self):
# Re-initialize
self._i = 0
if self.shuffle:
random.shuffle(self.img_list)

n_more = batch_size - len(samples)
samples = samples + self.next_batch(n_more)[0]

return samples, None

class CelebA:
def __init__(self, dataDir):
self.train = ImgDataset(dataDir, i_from=0, i_to=150000, shuffle=True, crop=108, resize=64)
self.test = ImgDataset(dataDir, i_from=150000, i_to=None, crop=108, resize=64)

# TODO: Follow the original set's train/val/test pratition
# TODO: Provide label info.


if __name__ == '__main__':
import sys

dataDir = sys.argv[1]
data = CelebA(dataDir)

ims, _ = data.train.next_batch(16)

for i in range(16):
cv2.imshow('image', ims[i])
cv2.waitKey(0)

ims, _ = data.test.next_batch(16)

for i in range(16):
cv2.imshow('image', ims[i])
cv2.waitKey(0)
Loading

0 comments on commit 4f4fd58

Please sign in to comment.