diff --git a/.gitignore b/.gitignore index 7bbc71c..b040238 100644 --- a/.gitignore +++ b/.gitignore @@ -1,101 +1,3 @@ # Byte-compiled / optimized / DLL files __pycache__/ -*.py[cod] -*$py.class - -# C extensions -*.so - -# Distribution / packaging -.Python -env/ -build/ -develop-eggs/ -dist/ -downloads/ -eggs/ -.eggs/ -lib/ -lib64/ -parts/ -sdist/ -var/ -wheels/ -*.egg-info/ -.installed.cfg -*.egg - -# PyInstaller -# Usually these files are written by a python script from a template -# before PyInstaller builds the exe, so as to inject date/other infos into it. -*.manifest -*.spec - -# Installer logs -pip-log.txt -pip-delete-this-directory.txt - -# Unit test / coverage reports -htmlcov/ -.tox/ -.coverage -.coverage.* -.cache -nosetests.xml -coverage.xml -*.cover -.hypothesis/ - -# Translations -*.mo -*.pot - -# Django stuff: -*.log -local_settings.py - -# Flask stuff: -instance/ -.webassets-cache - -# Scrapy stuff: -.scrapy - -# Sphinx documentation -docs/_build/ - -# PyBuilder -target/ - -# Jupyter Notebook -.ipynb_checkpoints - -# pyenv -.python-version - -# celery beat schedule file -celerybeat-schedule - -# SageMath parsed files -*.sage.py - -# dotenv -.env - -# virtualenv -.venv -venv/ -ENV/ - -# Spyder project settings -.spyderproject -.spyproject - -# Rope project settings -.ropeproject - -# mkdocs documentation -/site - -# mypy -.mypy_cache/ +*.pyc diff --git a/data/bair.py b/data/bair.py new file mode 100644 index 0000000..f0020df --- /dev/null +++ b/data/bair.py @@ -0,0 +1,63 @@ +import os +import io +from scipy.misc import imresize +import numpy as np +from PIL import Image +from scipy.misc import imresize +from scipy.misc import imread + + +ROOT_DIR = '/misc/vlgscratch4/FergusGroup/denton/data/bair_robot_push/processed_data/' +class RobotPush(object): + + """Data Handler that loads robot pushing data.""" + + def __init__(self, train=True, seq_len=20, image_size=64): + self.root_dir = ROOT_DIR + if train: + self.data_dir = '%s/train' % self.root_dir + self.ordered = False + else: + self.data_dir = '%s/test' % self.root_dir + self.ordered = True + #self.data_dir = '/misc/vlgscratch4/FergusGroup/denton/data/push-dataset/processed_data/push_testseen' + self.dirs = [] + for d1 in os.listdir(self.data_dir): + for d2 in os.listdir('%s/%s' % (self.data_dir, d1)): + self.dirs.append('%s/%s/%s' % (self.data_dir, d1, d2)) + self.seq_len = seq_len + self.image_size = image_size + self.seed_is_set = False # multi threaded loading + self.d = 0 + + def set_seed(self, seed): + if not self.seed_is_set: + self.seed_is_set = True + np.random.seed(seed) + + def __len__(self): + return 10000 + + def get_seq(self): + if self.ordered: + d = self.dirs[self.d] + if self.d == len(self.dirs) - 1: + self.d = 0 + else: + self.d+=1 + else: + d = self.dirs[np.random.randint(len(self.dirs))] + image_seq = [] + for i in range(self.seq_len): + fname = '%s/%d.png' % (d, i) + im = imread(fname).reshape(1, 64, 64, 3) + image_seq.append(im/255.) + image_seq = np.concatenate(image_seq, axis=0) + return image_seq + + + def __getitem__(self, index): + self.set_seed(index) + return self.get_seq() + + diff --git a/data/moving_mnist.py b/data/moving_mnist.py new file mode 100644 index 0000000..9807190 --- /dev/null +++ b/data/moving_mnist.py @@ -0,0 +1,247 @@ +import socket +import numpy as np +from torchvision import datasets, transforms + +class MovingMNIST(object): + + """Data Handler that creates Bouncing MNIST dataset on the fly.""" + + def __init__(self, train, data_root, seq_len=20, num_digits=2, image_size=64, deterministic=True): + path = '/misc/vlgscratch4/FergusGroup/denton/data/mnist/' + self.seq_len = seq_len + self.num_digits = num_digits + self.image_size = image_size + self.step_length = 0.1 + self.digit_size = 32 + self.deterministic = deterministic + self.seed_is_set = False # multi threaded loading + self.channels = 1 + + self.data = datasets.MNIST( + path, + train=train, + download=True, + transform=transforms.Compose( + [transforms.Scale(self.digit_size), + transforms.ToTensor()])) + + self.N = len(self.data) + + def set_seed(self, seed): + if not self.seed_is_set: + self.seed_is_set = True + np.random.seed(seed) + + def __len__(self): + return self.N + + def __getitem__(self, index): + self.set_seed(index) + image_size = self.image_size + digit_size = self.digit_size + x = np.zeros((self.seq_len, + image_size, + image_size, + self.channels), + dtype=np.float32) + for n in range(self.num_digits): + idx = np.random.randint(self.N) + digit, _ = self.data[idx] + + sx = np.random.randint(image_size-digit_size) + sy = np.random.randint(image_size-digit_size) + dx = np.random.randint(-4, 5) + dy = np.random.randint(-4, 5) + for t in range(self.seq_len): + if sy < 0: + sy = 0 + if self.deterministic: + dy = -dy + else: + dy = np.random.randint(1, 5) + dx = np.random.randint(-4, 5) + elif sy >= image_size-32: + sy = image_size-32-1 + if self.deterministic: + dy = -dy + else: + dy = np.random.randint(-4, 0) + dx = np.random.randint(-4, 5) + + if sx < 0: + sx = 0 + if self.deterministic: + dx = -dx + else: + dx = np.random.randint(1, 5) + dy = np.random.randint(-4, 5) + elif sx >= image_size-32: + sx = image_size-32-1 + if self.deterministic: + dx = -dx + else: + dx = np.random.randint(-4, 0) + dy = np.random.randint(-4, 5) + + x[t, sy:sy+32, sx:sx+32, 0] += digit.numpy().squeeze() + sy += dy + sx += dx + + x[x>1] = 1. + return x + + +class MovingMNISTSynced(object): + + """Data Handler that creates Bouncing MNIST dataset on the fly.""" + + def __init__(self, train, data_root, seq_len=20, num_digits=2, image_size=64): + path = '/misc/vlgscratch4/FergusGroup/denton/data/mnist/' + self.seq_len = seq_len + self.num_digits = num_digits + self.image_size = image_size + self.step_length = 0.1 + self.digit_size = 32 + self.seed_is_set = False # multi threaded loading + self.channels = 1 + + self.data = datasets.MNIST( + path, + train=train, + download=True, + transform=transforms.Compose( + [transforms.Scale(self.digit_size), + transforms.ToTensor()])) + + self.N = len(self.data) + self.sx = np.random.randint(image_size-self.digit_size) + self.sy = np.random.randint(image_size-self.digit_size) + self.dx = [np.random.randint(-4, 5) for i in range(100)] + self.dy = [np.random.randint(-4, 5) for i in range(100)] + + self.pos_dx = [np.random.randint(1, 5) for i in range(100)] + self.pos_dy = [np.random.randint(1, 5) for i in range(100)] + + self.neg_dx = [np.random.randint(-4, 0) for i in range(100)] + self.neg_dy = [np.random.randint(-4, 0) for i in range(100)] + + def set_seed(self, seed): + if not self.seed_is_set: + self.seed_is_set = True + np.random.seed(seed) + + def __len__(self): + return self.N + + def get_sample(self, digit): + image_size = self.image_size + digit_size = self.digit_size + x = np.zeros((self.seq_len, + self.channels, + image_size, + image_size, ), + dtype=np.float32) + + for n in range(self.num_digits): + sx = self.sx + sy = self.sy + dx = self.dx[0] + dy = self.dy[0] + k = 1 + times = [] + for t in range(self.seq_len): + if sy < 0: + sy = 0 + dy = np.random.randint(1, 5) + dx = np.random.randint(-4, 5) + k+=1 + times.append(t) + elif sy >= image_size-32: + sy = image_size-32-1 + dy = np.random.randint(-4, 0) + dx = np.random.randint(-4, 0) + k+=1 + times.append(t) + + if sx < 0: + sx = 0 + dx = np.random.randint(1, 5) + dy = np.random.randint(-4, 5) + k+=1 + times.append(t) + elif sx >= image_size-32: + sx = image_size-32-1 + dx = np.random.randint(-4, 0) + dy = np.random.randint(-4, 5) + k+=1 + times.append(t) + + x[t,0, sy:sy+32, sx:sx+32] += digit.numpy().squeeze() + sy += dy + sx += dx + + x[x>1] = 1. + return x + + + def __getitem__(self, index): + self.set_seed(index) + image_size = self.image_size + digit_size = self.digit_size + x = np.zeros((self.seq_len, + image_size, + image_size, + self.channels), + dtype=np.float32) + + all_times = [] + for n in range(self.num_digits): + idx = np.random.randint(self.N) + digit, _ = self.data[idx] + + sx = self.sx + sy = self.sy + dx = self.dx[0] + dy = self.dy[0] + k = 1 + times = [] + for t in range(self.seq_len): + if sy < 0: + sy = 0 + dy = self.pos_dy[k] #np.random.randint(1, 4) + dx = self.dx[k] #np.random.randint(-4, 4) + k+=1 + times.append(t-1) + elif sy >= image_size-32: + sy = image_size-32-1 + dy = self.neg_dy[k+2*n] #np.random.randint(-4, -1) + dx = self.dx[k+2*n] #np.random.randint(-4, 4) + k+=1 + times.append(t-1) + + if sx < 0: + sx = 0 + dx = self.pos_dx[k+2*n] #np.random.randint(1, 4) + dy = self.dy[k+2*n] #np.random.randint(-4, 4) + k+=1 + times.append(t-1) + elif sx >= image_size-32: + sx = image_size-32-1 + dx = self.neg_dx[k+2*n] #np.random.randint(-4, -1) + dy = self.dy[k+2*n] #np.random.randint(-4, 4) + k+=1 + times.append(t-1) + + x[t, sy:sy+32, sx:sx+32, 0] += digit.numpy().squeeze() + sy += dy + sx += dx + all_times.append(times) + + x[x>1] = 1. + + x_sample = [] + for s in range(100): + #x_sample.append(0) + x_sample.append(self.get_sample(digit)) + + return x, digit, np.array(all_times[0]), np.array(all_times[0]), np.array(x_sample) diff --git a/models/dcgan_128.py b/models/dcgan_128.py new file mode 100644 index 0000000..fe3bdce --- /dev/null +++ b/models/dcgan_128.py @@ -0,0 +1,95 @@ +import torch +import torch.nn as nn + +class dcgan_conv(nn.Module): + def __init__(self, nin, nout): + super(dcgan_conv, self).__init__() + self.main = nn.Sequential( + nn.Conv2d(nin, nout, 4, 2, 1), + nn.BatchNorm2d(nout), + nn.LeakyReLU(0.2, inplace=True), + ) + + def forward(self, input): + return self.main(input) + +class dcgan_upconv(nn.Module): + def __init__(self, nin, nout): + super(dcgan_upconv, self).__init__() + self.main = nn.Sequential( + nn.ConvTranspose2d(nin, nout, 4, 2, 1), + nn.BatchNorm2d(nout), + nn.LeakyReLU(0.2, inplace=True), + ) + + def forward(self, input): + return self.main(input) + +class encoder(nn.Module): + def __init__(self, dim, nc=1): + super(encoder, self).__init__() + self.dim = dim + nf = 64 + # input is (nc) x 128 x 128 + self.c1 = dcgan_conv(nc, nf) + # state size. (nf) x 64 x 64 + self.c2 = dcgan_conv(nf, nf * 2) + # state size. (nf*2) x 32 x 32 + self.c3 = dcgan_conv(nf * 2, nf * 4) + # state size. (nf*4) x 16 x 16 + self.c4 = dcgan_conv(nf * 4, nf * 8) + # state size. (nf*8) x 8 x 8 + self.c5 = dcgan_conv(nf * 8, nf * 8) + # state size. (nf*8) x 4 x 4 + self.c6 = nn.Sequential( + nn.Conv2d(nf * 8, dim, 4, 1, 0), + nn.BatchNorm2d(dim), + nn.Tanh() + ) + + def forward(self, input): + h1 = self.c1(input) + h2 = self.c2(h1) + h3 = self.c3(h2) + h4 = self.c4(h3) + h5 = self.c5(h4) + h6 = self.c6(h5) + return h6.view(-1, self.dim), [h1, h2, h3, h4, h5] + + +class decoder(nn.Module): + def __init__(self, dim, nc=1): + super(decoder, self).__init__() + self.dim = dim + nf = 64 + self.upc1 = nn.Sequential( + # input is Z, going into a convolution + nn.ConvTranspose2d(dim, nf * 8, 4, 1, 0), + nn.BatchNorm2d(nf * 8), + nn.LeakyReLU(0.2, inplace=True) + ) + # state size. (nf*8) x 4 x 4 + self.upc2 = dcgan_upconv(nf * 8 * 2, nf * 8) + # state size. (nf*8) x 8 x 8 + self.upc3 = dcgan_upconv(nf * 8 * 2, nf * 4) + # state size. (nf*4) x 16 x 16 + self.upc4 = dcgan_upconv(nf * 4 * 2, nf * 2) + # state size. (nf*2) x 32 x 32 + self.upc5 = dcgan_upconv(nf * 2 * 2, nf) + # state size. (nf) x 64 x 64 + self.upc6 = nn.Sequential( + nn.ConvTranspose2d(nf * 2, nc, 4, 2, 1), + nn.Sigmoid() + # state size. (nc) x 128 x 128 + ) + + def forward(self, input): + vec, skip = input + d1 = self.upc1(vec.view(-1, self.dim, 1, 1)) + d2 = self.upc2(torch.cat([d1, skip[4]], 1)) + d3 = self.upc3(torch.cat([d2, skip[3]], 1)) + d4 = self.upc4(torch.cat([d3, skip[2]], 1)) + d5 = self.upc5(torch.cat([d4, skip[1]], 1)) + output = self.upc6(torch.cat([d5, skip[0]], 1)) + return output + diff --git a/models/dcgan_64.py b/models/dcgan_64.py new file mode 100644 index 0000000..c67dade --- /dev/null +++ b/models/dcgan_64.py @@ -0,0 +1,89 @@ +import torch +import torch.nn as nn + +class dcgan_conv(nn.Module): + def __init__(self, nin, nout): + super(dcgan_conv, self).__init__() + self.main = nn.Sequential( + nn.Conv2d(nin, nout, 4, 2, 1), + nn.BatchNorm2d(nout), + nn.LeakyReLU(0.2, inplace=True), + ) + + def forward(self, input): + return self.main(input) + +class dcgan_upconv(nn.Module): + def __init__(self, nin, nout): + super(dcgan_upconv, self).__init__() + self.main = nn.Sequential( + nn.ConvTranspose2d(nin, nout, 4, 2, 1), + nn.BatchNorm2d(nout), + nn.LeakyReLU(0.2, inplace=True), + ) + + def forward(self, input): + return self.main(input) + +class encoder(nn.Module): + def __init__(self, dim, nc=1): + super(encoder, self).__init__() + self.dim = dim + nf = 64 + # input is (nc) x 64 x 64 + self.c1 = dcgan_conv(nc, nf) + # state size. (nf) x 32 x 32 + self.c2 = dcgan_conv(nf, nf * 2) + # state size. (nf*2) x 16 x 16 + self.c3 = dcgan_conv(nf * 2, nf * 4) + # state size. (nf*4) x 8 x 8 + self.c4 = dcgan_conv(nf * 4, nf * 8) + # state size. (nf*8) x 4 x 4 + self.c5 = nn.Sequential( + nn.Conv2d(nf * 8, dim, 4, 1, 0), + nn.BatchNorm2d(dim), + nn.Tanh() + ) + + def forward(self, input): + h1 = self.c1(input) + h2 = self.c2(h1) + h3 = self.c3(h2) + h4 = self.c4(h3) + h5 = self.c5(h4) + return h5.view(-1, self.dim), [h1, h2, h3, h4] + + +class decoder(nn.Module): + def __init__(self, dim, nc=1): + super(decoder, self).__init__() + self.dim = dim + nf = 64 + self.upc1 = nn.Sequential( + # input is Z, going into a convolution + nn.ConvTranspose2d(dim, nf * 8, 4, 1, 0), + nn.BatchNorm2d(nf * 8), + nn.LeakyReLU(0.2, inplace=True) + ) + # state size. (nf*8) x 4 x 4 + self.upc2 = dcgan_upconv(nf * 8 * 2, nf * 4) + # state size. (nf*4) x 8 x 8 + self.upc3 = dcgan_upconv(nf * 4 * 2, nf * 2) + # state size. (nf*2) x 16 x 16 + self.upc4 = dcgan_upconv(nf * 2 * 2, nf) + # state size. (nf) x 32 x 32 + self.upc5 = nn.Sequential( + nn.ConvTranspose2d(nf * 2, nc, 4, 2, 1), + nn.Sigmoid() + # state size. (nc) x 64 x 64 + ) + + def forward(self, input): + vec, skip = input + d1 = self.upc1(vec.view(-1, self.dim, 1, 1)) + d2 = self.upc2(torch.cat([d1, skip[3]], 1)) + d3 = self.upc3(torch.cat([d2, skip[2]], 1)) + d4 = self.upc4(torch.cat([d3, skip[1]], 1)) + output = self.upc5(torch.cat([d4, skip[0]], 1)) + return output + diff --git a/models/lstm.py b/models/lstm.py new file mode 100644 index 0000000..9ea0d80 --- /dev/null +++ b/models/lstm.py @@ -0,0 +1,71 @@ +import torch +import torch.nn as nn +from torch.autograd import Variable + +class lstm(nn.Module): + def __init__(self, input_size, output_size, hidden_size, n_layers, batch_size): + super(lstm, self).__init__() + self.input_size = input_size + self.output_size = output_size + self.hidden_size = hidden_size + self.batch_size = batch_size + self.n_layers = n_layers + self.embed = nn.Linear(input_size, hidden_size) + self.lstm = nn.ModuleList([nn.LSTMCell(hidden_size, hidden_size) for i in range(self.n_layers)]) + self.output = nn.Sequential( + nn.Linear(hidden_size, output_size), + #nn.BatchNorm1d(output_size), + nn.Tanh()) + self.hidden = self.init_hidden() + + def init_hidden(self): + hidden = [] + for i in range(self.n_layers): + hidden.append((Variable(torch.zeros(self.batch_size, self.hidden_size).cuda()), + Variable(torch.zeros(self.batch_size, self.hidden_size).cuda()))) + return hidden + + def forward(self, input): + embedded = self.embed(input.view(-1, self.input_size)) + h_in = embedded + for i in range(self.n_layers): + self.hidden[i] = self.lstm[i](h_in, self.hidden[i]) + h_in = self.hidden[i][0] + + return self.output(h_in) + +class gaussian_lstm(nn.Module): + def __init__(self, input_size, output_size, hidden_size, batch_size): + super(gaussian_lstm, self).__init__() + self.input_size = input_size + self.output_size = output_size + self.hidden_size = hidden_size + self.batch_size = batch_size + self.embed = nn.Linear(input_size, hidden_size) + self.lstm = nn.LSTMCell(hidden_size, hidden_size) + self.mu_net = nn.Linear(hidden_size, output_size) + self.logvar_net = nn.Linear(hidden_size, output_size) + self.hidden = self.init_hidden() + + def init_hidden(self): + return (Variable(torch.zeros(self.batch_size, self.hidden_size).cuda()), + Variable(torch.zeros(self.batch_size, self.hidden_size).cuda())) + + + + def reparameterize(self, mu, logvar): + if self.training: + logvar = logvar.mul(0.5).exp_() + eps = Variable(logvar.data.new(logvar.size()).normal_()) + return eps.mul(logvar).add_(mu) + else: + return mu + + def forward(self, input): + embedded = self.embed(input.view(-1, self.input_size)) + self.hidden = self.lstm(embedded, self.hidden) + mu = self.mu_net(self.hidden[0]) + logvar = self.logvar_net(self.hidden[0]) + z = self.reparameterize(mu, logvar) + return z, mu, logvar + diff --git a/models/vgg_128.py b/models/vgg_128.py new file mode 100644 index 0000000..0c5b617 --- /dev/null +++ b/models/vgg_128.py @@ -0,0 +1,121 @@ +import torch +import torch.nn as nn + +class vgg_layer(nn.Module): + def __init__(self, nin, nout): + super(vgg_layer, self).__init__() + self.main = nn.Sequential( + nn.Conv2d(nin, nout, 3, 1, 1), + nn.BatchNorm2d(nout), + nn.LeakyReLU(0.2, inplace=True) + ) + + def forward(self, input): + return self.main(input) + +class encoder(nn.Module): + def __init__(self, dim, nc=1): + super(encoder, self).__init__() + self.dim = dim + # 128 x 128 + self.c1 = nn.Sequential( + vgg_layer(nc, 64), + vgg_layer(64, 64), + ) + # 64 x 64 + self.c2 = nn.Sequential( + vgg_layer(64, 128), + vgg_layer(128, 128), + ) + # 32 x 32 + self.c3 = nn.Sequential( + vgg_layer(128, 256), + vgg_layer(256, 256), + vgg_layer(256, 256), + ) + # 16 x 16 + self.c4 = nn.Sequential( + vgg_layer(256, 512), + vgg_layer(512, 512), + vgg_layer(512, 512), + ) + # 8 x 8 + self.c5 = nn.Sequential( + vgg_layer(512, 512), + vgg_layer(512, 512), + vgg_layer(512, 512), + ) + # 4 x 4 + self.c6 = nn.Sequential( + nn.Conv2d(512, dim, 4, 1, 0), + nn.BatchNorm2d(dim), + nn.Tanh() + ) + self.mp = nn.MaxPool2d(kernel_size=2, stride=2, padding=0) + + def forward(self, input): + h1 = self.c1(input) # 128 -> 64 + h2 = self.c2(self.mp(h1)) # 64 -> 32 + h3 = self.c3(self.mp(h2)) # 32 -> 16 + h4 = self.c4(self.mp(h3)) # 16 -> 8 + h5 = self.c5(self.mp(h4)) # 8 -> 4 + h6 = self.c6(self.mp(h5)) # 4 -> 1 + return h6.view(-1, self.dim), [h1, h2, h3, h4, h5] + + +class decoder(nn.Module): + def __init__(self, dim, nc=1): + super(decoder, self).__init__() + self.dim = dim + # 1 x 1 -> 4 x 4 + self.upc1 = nn.Sequential( + nn.ConvTranspose2d(dim, 512, 4, 1, 0), + nn.BatchNorm2d(512), + nn.LeakyReLU(0.2, inplace=True) + ) + # 8 x 8 + self.upc2 = nn.Sequential( + vgg_layer(512*2, 512), + vgg_layer(512, 512), + vgg_layer(512, 512) + ) + # 16 x 16 + self.upc3 = nn.Sequential( + vgg_layer(512*2, 512), + vgg_layer(512, 512), + vgg_layer(512, 256) + ) + # 32 x 32 + self.upc4 = nn.Sequential( + vgg_layer(256*2, 256), + vgg_layer(256, 256), + vgg_layer(256, 128) + ) + # 64 x 64 + self.upc5 = nn.Sequential( + vgg_layer(128*2, 128), + vgg_layer(128, 64) + ) + # 128 x 128 + self.upc6 = nn.Sequential( + vgg_layer(64*2, 64), + nn.ConvTranspose2d(64, nc, 3, 1, 1), + nn.Sigmoid() + ) + self.up = nn.UpsamplingNearest2d(scale_factor=2) + + def forward(self, input): + vec, skip = input + d1 = self.upc1(vec.view(-1, self.dim, 1, 1)) # 1 -> 4 + up1 = self.up(d1) # 4 -> 8 + d2 = self.upc2(torch.cat([up1, skip[4]], 1)) # 8 x 8 + up2 = self.up(d2) # 8 -> 16 + d3 = self.upc3(torch.cat([up2, skip[3]], 1)) # 16 x 16 + up3 = self.up(d3) # 16 -> 32 + d4 = self.upc4(torch.cat([up3, skip[2]], 1)) # 32 x 32 + up4 = self.up(d4) # 32 -> 64 + d5 = self.upc5(torch.cat([up4, skip[1]], 1)) # 64 x 64 + up5 = self.up(d5) # 64 -> 128 + output = self.upc6(torch.cat([up5, skip[0]], 1)) # 128 x 128 + return output + diff --git a/models/vgg_64.py b/models/vgg_64.py new file mode 100644 index 0000000..bd7d9f2 --- /dev/null +++ b/models/vgg_64.py @@ -0,0 +1,106 @@ +import torch +import torch.nn as nn + +class vgg_layer(nn.Module): + def __init__(self, nin, nout): + super(vgg_layer, self).__init__() + self.main = nn.Sequential( + nn.Conv2d(nin, nout, 3, 1, 1), + nn.BatchNorm2d(nout), + nn.LeakyReLU(0.2, inplace=True) + ) + + def forward(self, input): + return self.main(input) + +class encoder(nn.Module): + def __init__(self, dim, nc=1): + super(encoder, self).__init__() + self.dim = dim + # 64 x 64 + self.c1 = nn.Sequential( + vgg_layer(nc, 64), + vgg_layer(64, 64), + ) + # 32 x 32 + self.c2 = nn.Sequential( + vgg_layer(64, 128), + vgg_layer(128, 128), + ) + # 16 x 16 + self.c3 = nn.Sequential( + vgg_layer(128, 256), + vgg_layer(256, 256), + vgg_layer(256, 256), + ) + # 8 x 8 + self.c4 = nn.Sequential( + vgg_layer(256, 512), + vgg_layer(512, 512), + vgg_layer(512, 512), + ) + # 4 x 4 + self.c5 = nn.Sequential( + nn.Conv2d(512, dim, 4, 1, 0), + nn.BatchNorm2d(dim), + nn.Tanh() + ) + self.mp = nn.MaxPool2d(kernel_size=2, stride=2, padding=0) + + def forward(self, input): + h1 = self.c1(input) # 64 -> 32 + h2 = self.c2(self.mp(h1)) # 32 -> 16 + h3 = self.c3(self.mp(h2)) # 16 -> 8 + h4 = self.c4(self.mp(h3)) # 8 -> 4 + h5 = self.c5(self.mp(h4)) # 4 -> 1 + return h5.view(-1, self.dim), [h1, h2, h3, h4] + + +class decoder(nn.Module): + def __init__(self, dim, nc=1): + super(decoder, self).__init__() + self.dim = dim + # 1 x 1 -> 4 x 4 + self.upc1 = nn.Sequential( + nn.ConvTranspose2d(dim, 512, 4, 1, 0), + nn.BatchNorm2d(512), + nn.LeakyReLU(0.2, inplace=True) + ) + # 8 x 8 + self.upc2 = nn.Sequential( + vgg_layer(512*2, 512), + vgg_layer(512, 512), + vgg_layer(512, 256) + ) + # 16 x 16 + self.upc3 = nn.Sequential( + vgg_layer(256*2, 256), + vgg_layer(256, 256), + vgg_layer(256, 128) + ) + # 32 x 32 + self.upc4 = nn.Sequential( + vgg_layer(128*2, 128), + vgg_layer(128, 64) + ) + # 64 x 64 + self.upc5 = nn.Sequential( + vgg_layer(64*2, 64), + nn.ConvTranspose2d(64, nc, 3, 1, 1), + nn.Sigmoid() + ) + self.up = nn.UpsamplingNearest2d(scale_factor=2) + + def forward(self, input): + vec, skip = input + d1 = self.upc1(vec.view(-1, self.dim, 1, 1)) # 1 -> 4 + up1 = self.up(d1) # 4 -> 8 + d2 = self.upc2(torch.cat([up1, skip[3]], 1)) # 8 x 8 + up2 = self.up(d2) # 8 -> 16 + d3 = self.upc3(torch.cat([up2, skip[2]], 1)) # 16 x 16 + up3 = self.up(d3) # 8 -> 32 + d4 = self.upc4(torch.cat([up3, skip[1]], 1)) # 32 x 32 + up4 = self.up(d4) # 32 -> 64 + output = self.upc5(torch.cat([up4, skip[0]], 1)) # 64 x 64 + return output + diff --git a/train_svg_fp.py b/train_svg_fp.py new file mode 100644 index 0000000..b6c1106 --- /dev/null +++ b/train_svg_fp.py @@ -0,0 +1,356 @@ +import torch +import torch.optim as optim +import torch.nn as nn +import argparse +import os +import random +from torch.autograd import Variable +from torch.utils.data import DataLoader +import utils +import itertools +import progressbar +import numpy as np + +parser = argparse.ArgumentParser() +parser.add_argument('--lr', default=0.002, type=float, help='learning rate') +parser.add_argument('--beta1', default=0.9, type=float, help='momentum term for adam') +parser.add_argument('--batch_size', default=100, type=int, help='batch size') +parser.add_argument('--log_dir', default='/misc/vlgscratch4/FergusGroup/denton/svg_logs/svg_fp/', help='base directory to save logs') +parser.add_argument('--model_dir', default='', help='base directory to save logs') +parser.add_argument('--name', default='', help='identifier for directory') +parser.add_argument('--data_root', default='', help='root directory for data') +parser.add_argument('--optimizer', default='adam', help='optimizer to train with') +parser.add_argument('--niter', type=int, default=300, help='number of epochs to train for') +parser.add_argument('--seed', default=1, type=int, help='manual seed') +parser.add_argument('--epoch_size', type=int, default=600, help='epoch size') +parser.add_argument('--image_width', type=int, default=64, help='the height / width of the input image to network') +parser.add_argument('--channels', default=3, type=int) +parser.add_argument('--dataset', default='moving_dot', help='dataset to train with') +parser.add_argument('--n_past', type=int, default=5, help='number of frames to condition on') +parser.add_argument('--n_future', type=int, default=5, help='number of frames to predict') +parser.add_argument('--n_eval', type=int, default=30, help='number of frames to predict at eval time') +parser.add_argument('--rnn_size', type=int, default=256, help='dimensionality of hidden layer') +parser.add_argument('--rnn_layers', type=int, default=2, help='number of layers') +parser.add_argument('--z_dim', type=int, default=10, help='dimensionality of z_t') +parser.add_argument('--g_dim', type=int, default=64, help='dimensionality of encoder output vector and decoder input vector') +parser.add_argument('--beta', type=float, default=0.0001, help='weighting on KL to prior') +parser.add_argument('--model', default='dcgan | vgg', help='model type') +parser.add_argument('--data_threads', type=int, default=5, help='number of data loading threads') +parser.add_argument('--num_digits', type=int, default=2, help='number of digits for moving mnist') +parser.add_argument('--last_frame_skip', action='store_true', help='if true, skip connections go between frame t and frame t+t rather than last ground truth frame') + + +opt = parser.parse_args() +if opt.model_dir != '': + saved_model = torch.load('%s/model.pth' % opt.model_dir) + optimizer = opt.optimizer + model_dir = opt.model_dir + opt = saved_model['opt'] + opt.optimizer = optimizer + opt.model_dir = model_dir + opt.log_dir = '%s/continued' % opt.log_dir +else: + name = 'model=%s%dx%d-rnn_size=%d-rnn_layers=%d-n_past=%d-n_future=%d-lr=%.4f-g_dim=%d-z_dim=%d-last_frame_skip=%d-beta=%.7f%s' % (opt.model, opt.image_width, opt.image_width, opt.rnn_size, opt.rnn_layers, opt.n_past, opt.n_future, opt.lr, opt.g_dim, opt.z_dim, opt.last_frame_skip, opt.beta, opt.name) + if opt.dataset == 'smmnist': + opt.log_dir = '%s/%s-%d/%s' % (opt.log_dir, opt.dataset, opt.num_digits, name) + else: + opt.log_dir = '%s/%s/%s' % (opt.log_dir, opt.dataset, name) + +os.makedirs('%s/gen/' % opt.log_dir, exist_ok=True) +os.makedirs('%s/plots/' % opt.log_dir, exist_ok=True) +opt.max_step = opt.n_past+opt.n_future + +print("Random Seed: ", opt.seed) +random.seed(opt.seed) +torch.manual_seed(opt.seed) +torch.cuda.manual_seed_all(opt.seed) +dtype = torch.cuda.FloatTensor + + +# ---------------- load the models ---------------- + +print(opt) + +# ---------------- optimizers ---------------- +if opt.optimizer == 'adam': + opt.optimizer = optim.Adam +elif opt.optimizer == 'rmsprop': + opt.optimizer = optim.RMSprop +elif opt.optimizer == 'sgd': + opt.optimizer = optim.SGD +else: + raise ValueError('Unknown optimizer: %s' % opt.optimizer) + +import models.lstm as lstm_models +if opt.model_dir != '': + frame_predictor = saved_model['frame_predictor'] + posterior = saved_model['posterior'] +else: + frame_predictor = lstm_models.lstm(opt.g_dim+opt.z_dim, opt.g_dim, opt.rnn_size, opt.rnn_layers, opt.batch_size) + posterior = lstm_models.gaussian_lstm(opt.g_dim, opt.z_dim, opt.rnn_size, opt.batch_size) + frame_predictor.apply(utils.init_weights) + posterior.apply(utils.init_weights) + +if opt.model == 'dcgan': + if opt.image_width == 64: + import models.dcgan_64 as model + elif opt.image_width == 128: + import models.dcgan_128 as model +elif opt.model == 'vgg': + if opt.image_width == 64: + import models.vgg_64 as model + elif opt.image_width == 128: + import models.vgg_128 as model +else: + raise ValueError('Unknown model: %s' % opt.model) + +if opt.model_dir != '': + decoder = saved_model['decoder'] + encoder = saved_model['encoder'] +else: + encoder = model.encoder(opt.g_dim, opt.channels) + decoder = model.decoder(opt.g_dim, opt.channels) + encoder.apply(utils.init_weights) + decoder.apply(utils.init_weights) + +frame_predictor_optimizer = opt.optimizer(frame_predictor.parameters(), lr=opt.lr, betas=(opt.beta1, 0.999)) +posterior_optimizer = opt.optimizer(posterior.parameters(), lr=opt.lr, betas=(opt.beta1, 0.999)) +encoder_optimizer = opt.optimizer(encoder.parameters(), lr=opt.lr, betas=(opt.beta1, 0.999)) +decoder_optimizer = opt.optimizer(decoder.parameters(), lr=opt.lr, betas=(opt.beta1, 0.999)) + +# --------- loss functions ------------------------------------ +mse_criterion = nn.MSELoss() +def kl_criterion(mu, logvar): + # 0.5 * sum(1 + log(sigma^2) - mu^2 - sigma^2) + KLD = -0.5 * torch.sum(1 + logvar - mu.pow(2) - logvar.exp()) + KLD /= opt.batch_size + return KLD + + +# --------- transfer to gpu ------------------------------------ +frame_predictor.cuda() +posterior.cuda() +encoder.cuda() +decoder.cuda() +mse_criterion.cuda() + +# --------- load a dataset ------------------------------------ +train_data, test_data = utils.load_dataset(opt) + +train_loader = DataLoader(train_data, + num_workers=opt.data_threads, + batch_size=opt.batch_size, + shuffle=True, + drop_last=True, + pin_memory=True) +test_loader = DataLoader(test_data, + num_workers=opt.data_threads, + batch_size=opt.batch_size, + shuffle=True, + drop_last=True, + pin_memory=True) + +def get_training_batch(): + while True: + for sequence in train_loader: + batch = utils.normalize_data(opt, dtype, sequence) + yield batch +training_batch_generator = get_training_batch() + +def get_testing_batch(): + while True: + for sequence in test_loader: + batch = utils.normalize_data(opt, dtype, sequence) + yield batch +testing_batch_generator = get_testing_batch() + +# --------- plotting funtions ------------------------------------ +def plot(x, epoch): + nsample = 20 + gen_seq = [[] for i in range(nsample)] + gt_seq = [x[i] for i in range(len(x))] + + for s in range(nsample): + frame_predictor.hidden = frame_predictor.init_hidden() + gen_seq[s].append(x[0]) + x_in = x[0] + for i in range(1, opt.n_eval): + h = encoder(x_in) + if opt.last_frame_skip or i < opt.n_past: + h, skip = h + else: + h, _ = h + h = h.detach() + z_t = torch.cuda.FloatTensor(opt.batch_size, opt.z_dim).normal_() + if i < opt.n_past: + frame_predictor(torch.cat([h, z_t], 1)) + x_in = x[i] + gen_seq[s].append(x_in) + else: + h = frame_predictor(torch.cat([h, z_t], 1)).detach() + x_in = decoder([h, skip]).detach() + gen_seq[s].append(x_in) + + to_plot = [] + gifs = [ [] for t in range(opt.n_eval) ] + nrow = min(opt.batch_size, 10) + for i in range(nrow): + # ground truth sequence + row = [] + for t in range(opt.n_eval): + row.append(gt_seq[t][i]) + to_plot.append(row) + + # best sequence + min_mse = 1e7 + for s in range(nsample): + mse = 0 + for t in range(opt.n_eval): + mse += torch.sum( (gt_seq[t][i].data.cpu() - gen_seq[s][t][i].data.cpu())**2 ) + if mse < min_mse: + min_mse = mse + min_idx = s + + s_list = [min_idx, + np.random.randint(nsample), + np.random.randint(nsample), + np.random.randint(nsample), + np.random.randint(nsample)] + for ss in range(len(s_list)): + s = s_list[ss] + row = [] + for t in range(opt.n_eval): + row.append(gen_seq[s][t][i]) + to_plot.append(row) + for t in range(opt.n_eval): + row = [] + row.append(gt_seq[t][i]) + for ss in range(len(s_list)): + s = s_list[ss] + row.append(gen_seq[s][t][i]) + gifs[t].append(row) + + fname = '%s/gen/sample_%d.png' % (opt.log_dir, epoch) + utils.save_tensors_image(fname, to_plot) + + fname = '%s/gen/sample_%d.gif' % (opt.log_dir, epoch) + utils.save_gif(fname, gifs) + + +def plot_rec(x, epoch): + frame_predictor.hidden = frame_predictor.init_hidden() + posterior.hidden = posterior.init_hidden() + gen_seq = [] + gen_seq.append(x[0]) + x_in = x[0] + for i in range(1, opt.n_past+opt.n_future): + h = encoder(x[i-1]) + h_target = encoder(x[i]) + if opt.last_frame_skip or i < opt.n_past: + h, skip = h + else: + h, _ = h + h_target, _ = h_target + h = h.detach() + h_target = h_target.detach() + z_t, mu, logvar = posterior(h_target) + if i < opt.n_past: + frame_predictor(torch.cat([h, z_t], 1)) + gen_seq.append(x[i]) + else: + h_pred = frame_predictor(torch.cat([h, z_t], 1)) + x_pred = decoder([h_pred, skip]).detach() + gen_seq.append(x_pred) + + to_plot = [] + nrow = min(opt.batch_size, 10) + for i in range(nrow): + row = [] + for t in range(opt.n_past+opt.n_future): + row.append(gen_seq[t][i]) + to_plot.append(row) + fname = '%s/gen/rec_%d.png' % (opt.log_dir, epoch) + utils.save_tensors_image(fname, to_plot) + + +# --------- training funtions ------------------------------------ +def train(x): + frame_predictor.zero_grad() + posterior.zero_grad() + encoder.zero_grad() + decoder.zero_grad() + + # initialize the hidden state. + frame_predictor.hidden = frame_predictor.init_hidden() + posterior.hidden = posterior.init_hidden() + + mse = 0 + kld = 0 + for i in range(1, opt.n_past+opt.n_future): + h = encoder(x[i-1]) + h_target = encoder(x[i])[0] + if opt.last_frame_skip or i < opt.n_past: + h, skip = h + else: + h = h[0] + z_t, mu, logvar = posterior(h_target) + h_pred = frame_predictor(torch.cat([h, z_t], 1)) + x_pred = decoder([h_pred, skip]) + mse += mse_criterion(x_pred, x[i]) + kld += kl_criterion(mu, logvar) + + loss = mse + kld*opt.beta + loss.backward() + + frame_predictor_optimizer.step() + posterior_optimizer.step() + encoder_optimizer.step() + decoder_optimizer.step() + + return mse.data.cpu().numpy()/(opt.n_past+opt.n_future), kld.data.cpu().numpy()/(opt.n_future+opt.n_past) + +# --------- training loop ------------------------------------ +for epoch in range(opt.niter): + frame_predictor.train() + posterior.train() + encoder.train() + decoder.train() + epoch_mse = 0 + epoch_kld = 0 + progress = progressbar.ProgressBar(max_value=opt.epoch_size).start() + for i in range(opt.epoch_size): + progress.update(i+1) + x = next(training_batch_generator) + + # train frame_predictor + mse, kld = train(x) + epoch_mse += mse + epoch_kld += kld + + + progress.finish() + utils.clear_progressbar() + + print('[%02d] mse loss: %.5f | kld loss: %.5f (%d)' % (epoch, epoch_mse/opt.epoch_size, epoch_kld/opt.epoch_size, epoch*opt.epoch_size*opt.batch_size)) + + # plot some stuff + frame_predictor.eval() + encoder.eval() + decoder.eval() + posterior.eval() + x = next(testing_batch_generator) + plot(x, epoch) + plot_rec(x, epoch) + + # save the model + torch.save({ + 'encoder': encoder, + 'decoder': decoder, + 'frame_predictor': frame_predictor, + 'posterior': posterior, + 'opt': opt}, + '%s/model.pth' % opt.log_dir) + if epoch % 10 == 0: + print('log dir: %s' % opt.log_dir) + + diff --git a/train_svg_lp.py b/train_svg_lp.py new file mode 100644 index 0000000..9c71408 --- /dev/null +++ b/train_svg_lp.py @@ -0,0 +1,379 @@ +import torch +import torch.optim as optim +import torch.nn as nn +import argparse +import os +import random +from torch.autograd import Variable +from torch.utils.data import DataLoader +import utils +import itertools +import progressbar +import numpy as np + +parser = argparse.ArgumentParser() +parser.add_argument('--lr', default=0.002, type=float, help='learning rate') +parser.add_argument('--beta1', default=0.9, type=float, help='momentum term for adam') +parser.add_argument('--batch_size', default=100, type=int, help='batch size') +parser.add_argument('--log_dir', default='/misc/vlgscratch4/FergusGroup/denton/svg_logs/svg_lp/', help='base directory to save logs') +parser.add_argument('--model_dir', default='', help='base directory to save logs') +parser.add_argument('--name', default='', help='identifier for directory') +parser.add_argument('--data_root', default='', help='root directory for data') +parser.add_argument('--optimizer', default='adam', help='optimizer to train with') +parser.add_argument('--niter', type=int, default=300, help='number of epochs to train for') +parser.add_argument('--seed', default=1, type=int, help='manual seed') +parser.add_argument('--epoch_size', type=int, default=600, help='epoch size') +parser.add_argument('--image_width', type=int, default=64, help='the height / width of the input image to network') +parser.add_argument('--channels', default=3, type=int) +parser.add_argument('--dataset', default='moving_dot', help='dataset to train with') +parser.add_argument('--n_past', type=int, default=5, help='number of frames to condition on') +parser.add_argument('--n_future', type=int, default=5, help='number of frames to predict') +parser.add_argument('--n_eval', type=int, default=30, help='number of frames to predict') +parser.add_argument('--rnn_size', type=int, default=256, help='dimensionality of hidden layer') +parser.add_argument('--rnn_layers', type=int, default=2, help='number of layers') +parser.add_argument('--z_dim', type=int, default=10, help='dimensionality of z_t') +parser.add_argument('--g_dim', type=int, default=64, help='dimensionality of encoder output vector and decoder input vector') +parser.add_argument('--beta', type=float, default=0.0001, help='weighting on KL to prior') +parser.add_argument('--model', default='dcgan | vgg', help='model type') +parser.add_argument('--data_threads', type=int, default=5, help='number of data loading threads') +parser.add_argument('--num_digits', type=int, default=2, help='number of digits for moving mnist') +parser.add_argument('--last_frame_skip', action='store_true', help='if true, skip connections go between frame t and frame t+t rather than last ground truth frame') + + + +opt = parser.parse_args() +if opt.model_dir != '': + saved_model = torch.load('%s/model.pth' % opt.model_dir) + optimizer = opt.optimizer + model_dir = opt.model_dir + opt = saved_model['opt'] + opt.optimizer = optimizer + opt.model_dir = model_dir + opt.log_dir = '%s/continued' % opt.log_dir +else: + name = 'model=%s%dx%d-rnn_size=%d-rnn_layers=%d-n_past=%d-n_future=%d-lr=%.4f-g_dim=%d-z_dim=%d-last_frame_skip=%s-beta=%.7f%s' % (opt.model, opt.image_width, opt.image_width, opt.rnn_size, opt.rnn_layers, opt.n_past, opt.n_future, opt.lr, opt.g_dim, opt.z_dim, opt.last_frame_skip, opt.beta, opt.name) + if opt.dataset == 'smmnist': + opt.log_dir = '%s/%s-%d/%s' % (opt.log_dir, opt.dataset, opt.num_digits, name) + else: + opt.log_dir = '%s/%s/%s' % (opt.log_dir, opt.dataset, name) + +os.makedirs('%s/gen/' % opt.log_dir, exist_ok=True) +os.makedirs('%s/plots/' % opt.log_dir, exist_ok=True) +opt.max_step = opt.n_past+opt.n_future + +print("Random Seed: ", opt.seed) +random.seed(opt.seed) +torch.manual_seed(opt.seed) +torch.cuda.manual_seed_all(opt.seed) +dtype = torch.cuda.FloatTensor + + +# ---------------- load the models ---------------- + +print(opt) + +# ---------------- optimizers ---------------- +if opt.optimizer == 'adam': + opt.optimizer = optim.Adam +elif opt.optimizer == 'rmsprop': + opt.optimizer = optim.RMSprop +elif opt.optimizer == 'sgd': + opt.optimizer = optim.SGD +else: + raise ValueError('Unknown optimizer: %s' % opt.optimizer) + + +import models.lstm as lstm_models +if opt.model_dir != '': + frame_predictor = saved_model['frame_predictor'] + posterior = saved_model['posterior'] + prior = saved_model['prior'] +else: + frame_predictor = lstm_models.lstm(opt.g_dim+opt.z_dim, opt.g_dim, opt.rnn_size, opt.rnn_layers, opt.batch_size) + posterior = lstm_models.gaussian_lstm(opt.g_dim, opt.z_dim, opt.rnn_size, opt.batch_size) + prior = lstm_models.gaussian_lstm(opt.g_dim, opt.z_dim, opt.rnn_size, opt.batch_size) + frame_predictor.apply(utils.init_weights) + posterior.apply(utils.init_weights) + prior.apply(utils.init_weights) + +if opt.model == 'dcgan': + if opt.image_width == 64: + import models.dcgan_64 as model + elif opt.image_width == 128: + import models.dcgan_128 as model +elif opt.model == 'vgg': + if opt.image_width == 64: + import models.vgg_64 as model + elif opt.image_width == 128: + import models.vgg_128 as model +else: + raise ValueError('Unknown model: %s' % opt.model) + +if opt.model_dir != '': + decoder = saved_model['decoder'] + encoder = saved_model['encoder'] +else: + encoder = model.encoder(opt.g_dim, opt.channels) + decoder = model.decoder(opt.g_dim, opt.channels) + encoder.apply(utils.init_weights) + decoder.apply(utils.init_weights) + +frame_predictor_optimizer = opt.optimizer(frame_predictor.parameters(), lr=opt.lr, betas=(opt.beta1, 0.999)) +posterior_optimizer = opt.optimizer(posterior.parameters(), lr=opt.lr, betas=(opt.beta1, 0.999)) +prior_optimizer = opt.optimizer(prior.parameters(), lr=opt.lr, betas=(opt.beta1, 0.999)) +encoder_optimizer = opt.optimizer(encoder.parameters(), lr=opt.lr, betas=(opt.beta1, 0.999)) +decoder_optimizer = opt.optimizer(decoder.parameters(), lr=opt.lr, betas=(opt.beta1, 0.999)) + +# --------- loss functions ------------------------------------ +mse_criterion = nn.MSELoss() +def kl_criterion(mu1, logvar1, mu2, logvar2): + # KL( N(mu_1, sigma2_1) || N(mu_2, sigma2_2)) = + # log( sqrt( + # + sigma1 = logvar1.mul(0.5).exp() + sigma2 = logvar2.mul(0.5).exp() + kld = torch.log(sigma2/sigma1) + (torch.exp(logvar1) + (mu1 - mu2)**2)/(2*torch.exp(logvar2)) - 1/2 + return kld.sum() / opt.batch_size + +# --------- transfer to gpu ------------------------------------ +frame_predictor.cuda() +posterior.cuda() +prior.cuda() +encoder.cuda() +decoder.cuda() +mse_criterion.cuda() + +# --------- load a dataset ------------------------------------ +train_data, test_data = utils.load_dataset(opt) + +train_loader = DataLoader(train_data, + num_workers=opt.data_threads, + batch_size=opt.batch_size, + shuffle=True, + drop_last=True, + pin_memory=True) +test_loader = DataLoader(test_data, + num_workers=opt.data_threads, + batch_size=opt.batch_size, + shuffle=True, + drop_last=True, + pin_memory=True) + +def get_training_batch(): + while True: + for sequence in train_loader: + batch = utils.normalize_data(opt, dtype, sequence) + yield batch +training_batch_generator = get_training_batch() + +def get_testing_batch(): + while True: + for sequence in test_loader: + batch = utils.normalize_data(opt, dtype, sequence) + yield batch +testing_batch_generator = get_testing_batch() + +# --------- plotting funtions ------------------------------------ +def plot(x, epoch): + nsample = 20 + gen_seq = [[] for i in range(nsample)] + gt_seq = [x[i] for i in range(len(x))] + + for s in range(nsample): + frame_predictor.hidden = frame_predictor.init_hidden() + posterior.hidden = posterior.init_hidden() + prior.hidden = prior.init_hidden() + gen_seq[s].append(x[0]) + x_in = x[0] + for i in range(1, opt.n_eval): + h = encoder(x_in) + if opt.last_frame_skip or i < opt.n_past: + h, skip = h + else: + h, _ = h + h = h.detach() + if i < opt.n_past: + h_target = encoder(x[i]) + h_target = h_target[0].detach() + z_t, _, _ = posterior(h_target) + frame_predictor(torch.cat([h, z_t], 1)) + x_in = x[i] + gen_seq[s].append(x_in) + else: + z_t, _, _ = prior(h) + h = frame_predictor(torch.cat([h, z_t], 1)).detach() + x_in = decoder([h, skip]).detach() + gen_seq[s].append(x_in) + + to_plot = [] + gifs = [ [] for t in range(opt.n_eval) ] + nrow = min(opt.batch_size, 10) + for i in range(nrow): + # ground truth sequence + row = [] + for t in range(opt.n_eval): + row.append(gt_seq[t][i]) + to_plot.append(row) + + # best sequence + min_mse = 1e7 + for s in range(nsample): + mse = 0 + for t in range(opt.n_eval): + mse += torch.sum( (gt_seq[t][i].data.cpu() - gen_seq[s][t][i].data.cpu())**2 ) + if mse < min_mse: + min_mse = mse + min_idx = s + + s_list = [min_idx, + np.random.randint(nsample), + np.random.randint(nsample), + np.random.randint(nsample), + np.random.randint(nsample)] + for ss in range(len(s_list)): + s = s_list[ss] + row = [] + for t in range(opt.n_eval): + row.append(gen_seq[s][t][i]) + to_plot.append(row) + for t in range(opt.n_eval): + row = [] + row.append(gt_seq[t][i]) + for ss in range(len(s_list)): + s = s_list[ss] + row.append(gen_seq[s][t][i]) + gifs[t].append(row) + + fname = '%s/gen/sample_%d.png' % (opt.log_dir, epoch) + utils.save_tensors_image(fname, to_plot) + + fname = '%s/gen/sample_%d.gif' % (opt.log_dir, epoch) + utils.save_gif(fname, gifs) + + +def plot_rec(x, epoch): + frame_predictor.hidden = frame_predictor.init_hidden() + posterior.hidden = posterior.init_hidden() + gen_seq = [] + gen_seq.append(x[0]) + x_in = x[0] + for i in range(1, opt.n_past+opt.n_future): + h = encoder(x[i-1]) + h_target = encoder(x[i]) + if opt.last_frame_skip or i < opt.n_past: + h, skip = h + else: + h, _ = h + h_target, _ = h_target + h = h.detach() + h_target = h_target.detach() + z_t, _, _= posterior(h_target) + if i < opt.n_past: + frame_predictor(torch.cat([h, z_t], 1)) + gen_seq.append(x[i]) + else: + h_pred = frame_predictor(torch.cat([h, z_t], 1)) + x_pred = decoder([h_pred, skip]).detach() + gen_seq.append(x_pred) + + to_plot = [] + nrow = min(opt.batch_size, 10) + for i in range(nrow): + row = [] + for t in range(opt.n_past+opt.n_future): + row.append(gen_seq[t][i]) + to_plot.append(row) + fname = '%s/gen/rec_%d.png' % (opt.log_dir, epoch) + utils.save_tensors_image(fname, to_plot) + + +# --------- training funtions ------------------------------------ +def train(x): + frame_predictor.zero_grad() + posterior.zero_grad() + prior.zero_grad() + encoder.zero_grad() + decoder.zero_grad() + + # initialize the hidden state. + frame_predictor.hidden = frame_predictor.init_hidden() + posterior.hidden = posterior.init_hidden() + prior.hidden = prior.init_hidden() + + mse = 0 + kld = 0 + for i in range(1, opt.n_past+opt.n_future): + h = encoder(x[i-1]) + h_target = encoder(x[i])[0] + if opt.last_frame_skip or i < opt.n_past: + h, skip = h + else: + h = h[0] + z_t, mu, logvar = posterior(h_target) + _, mu_p, logvar_p = prior(h) + h_pred = frame_predictor(torch.cat([h, z_t], 1)) + x_pred = decoder([h_pred, skip]) + mse += mse_criterion(x_pred, x[i]) + kld += kl_criterion(mu, logvar, mu_p, logvar_p) + + loss = mse + kld*opt.beta + loss.backward() + + frame_predictor_optimizer.step() + posterior_optimizer.step() + prior_optimizer.step() + encoder_optimizer.step() + decoder_optimizer.step() + + + return mse.data.cpu().numpy()/(opt.n_past+opt.n_future), kld.data.cpu().numpy()/(opt.n_future+opt.n_past) + +# --------- training loop ------------------------------------ +for epoch in range(opt.niter): + frame_predictor.train() + posterior.train() + prior.train() + encoder.train() + decoder.train() + epoch_mse = 0 + epoch_kld = 0 + progress = progressbar.ProgressBar(max_value=opt.epoch_size).start() + for i in range(opt.epoch_size): + progress.update(i+1) + x = next(training_batch_generator) + + # train frame_predictor + mse, kld = train(x) + epoch_mse += mse + epoch_kld += kld + + + progress.finish() + utils.clear_progressbar() + + print('[%02d] mse loss: %.5f | kld loss: %.5f (%d)' % (epoch, epoch_mse/opt.epoch_size, epoch_kld/opt.epoch_size, epoch*opt.epoch_size*opt.batch_size)) + + # plot some stuff + frame_predictor.eval() + encoder.eval() + decoder.eval() + posterior.eval() + prior.eval() + + x = next(testing_batch_generator) + plot(x, epoch) + plot_rec(x, epoch) + + # save the model + torch.save({ + 'encoder': encoder, + 'decoder': decoder, + 'frame_predictor': frame_predictor, + 'posterior': posterior, + 'prior': prior, + 'opt': opt}, + '%s/model.pth' % opt.log_dir) + if epoch % 10 == 0: + print('log dir: %s' % opt.log_dir) + + diff --git a/utils.py b/utils.py new file mode 100755 index 0000000..830e234 --- /dev/null +++ b/utils.py @@ -0,0 +1,286 @@ +import math +import torch +import socket +import argparse +import os +import numpy as np +from sklearn.manifold import TSNE +import scipy.misc +import matplotlib +matplotlib.use('agg') +import matplotlib.pyplot as plt +import functools +from skimage.measure import compare_psnr as psnr_metric +from skimage.measure import compare_ssim as ssim_metric +from scipy import signal +from scipy import ndimage +from PIL import Image, ImageDraw + + +from torchvision import datasets, transforms +from torch.autograd import Variable +import imageio + + +hostname = socket.gethostname() + +def load_dataset(opt): + if opt.dataset == 'smmnist': + from data.moving_mnist import MovingMNIST + train_data = MovingMNIST( + train=True, + data_root='', + seq_len=opt.max_step, + image_size=opt.image_width, + deterministic=False, + num_digits=opt.num_digits) + test_data = MovingMNIST( + train=False, + data_root='', + seq_len=opt.n_eval, + image_size=opt.image_width, + deterministic=False, + num_digits=opt.num_digits) + elif opt.dataset == 'bair': + from data.bair import RobotPush + train_data = RobotPush( + train=True, + seq_len=opt.max_step, + image_size=opt.image_width) + test_data = RobotPush( + train=False, + seq_len=opt.n_eval, + image_size=opt.image_width) + + return train_data, test_data + +def sequence_input(seq, dtype): + return [Variable(x.type(dtype)) for x in seq] + +def normalize_data(opt, dtype, sequence): + if opt.dataset == 'smmnist' or opt.dataset == 'kth' or opt.dataset == 'bair' : + sequence.transpose_(0, 1) + sequence.transpose_(3, 4).transpose_(2, 3) + else: + sequence.transpose_(0, 1) + + return sequence_input(sequence, dtype) + +def is_sequence(arg): + return (not hasattr(arg, "strip") and + not type(arg) is np.ndarray and + not hasattr(arg, "dot") and + (hasattr(arg, "__getitem__") or + hasattr(arg, "__iter__"))) + +def image_tensor(inputs, padding=1): + # assert is_sequence(inputs) + assert len(inputs) > 0 + # print(inputs) + + # if this is a list of lists, unpack them all and grid them up + if is_sequence(inputs[0]) or (hasattr(inputs, "dim") and inputs.dim() > 4): + images = [image_tensor(x) for x in inputs] + if images[0].dim() == 3: + c_dim = images[0].size(0) + x_dim = images[0].size(1) + y_dim = images[0].size(2) + else: + c_dim = 1 + x_dim = images[0].size(0) + y_dim = images[0].size(1) + + result = torch.ones(c_dim, + x_dim * len(images) + padding * (len(images)-1), + y_dim) + for i, image in enumerate(images): + result[:, i * x_dim + i * padding : + (i+1) * x_dim + i * padding, :].copy_(image) + + return result + + # if this is just a list, make a stacked image + else: + images = [x.data if isinstance(x, torch.autograd.Variable) else x + for x in inputs] + # print(images) + if images[0].dim() == 3: + c_dim = images[0].size(0) + x_dim = images[0].size(1) + y_dim = images[0].size(2) + else: + c_dim = 1 + x_dim = images[0].size(0) + y_dim = images[0].size(1) + + result = torch.ones(c_dim, + x_dim, + y_dim * len(images) + padding * (len(images)-1)) + for i, image in enumerate(images): + result[:, :, i * y_dim + i * padding : + (i+1) * y_dim + i * padding].copy_(image) + return result + +def save_np_img(fname, x): + if x.shape[0] == 1: + x = np.tile(x, (3, 1, 1)) + img = scipy.misc.toimage(x, + high=255*x.max(), + channel_axis=0) + img.save(fname) + +def make_image(tensor): + tensor = tensor.cpu().clamp(0, 1) + if tensor.size(0) == 1: + tensor = tensor.expand(3, tensor.size(1), tensor.size(2)) + # pdb.set_trace() + return scipy.misc.toimage(tensor.numpy(), + high=255*tensor.max(), + channel_axis=0) + +def draw_text_tensor(tensor, text): + np_x = tensor.transpose(0, 1).transpose(1, 2).data.cpu().numpy() + pil = Image.fromarray(np.uint8(np_x*255)) + draw = ImageDraw.Draw(pil) + draw.text((4, 64), text, (0,0,0)) + img = np.asarray(pil) + return Variable(torch.Tensor(img / 255.)).transpose(1, 2).transpose(0, 1) + +def save_gif(filename, inputs, duration=0.25): + images = [] + for tensor in inputs: + img = image_tensor(tensor, padding=0) + img = img.cpu() + img = img.transpose(0,1).transpose(1,2).clamp(0,1) + images.append(img.numpy()) + imageio.mimsave(filename, images, duration=duration) + +def save_gif_with_text(filename, inputs, text, duration=0.25): + images = [] + for tensor, text in zip(inputs, text): + img = image_tensor([draw_text_tensor(ti, texti) for ti, texti in zip(tensor, text)], padding=0) + img = img.cpu() + img = img.transpose(0,1).transpose(1,2).clamp(0,1).numpy() + images.append(img) + imageio.mimsave(filename, images, duration=duration) + +def save_image(filename, tensor): + img = make_image(tensor) + img.save(filename) + +def save_tensors_image(filename, inputs, padding=1): + images = image_tensor(inputs, padding) + return save_image(filename, images) + +def prod(l): + return functools.reduce(lambda x, y: x * y, l) + +def batch_flatten(x): + return x.resize(x.size(0), prod(x.size()[1:])) + +def clear_progressbar(): + # moves up 3 lines + print("\033[2A") + # deletes the whole line, regardless of character position + print("\033[2K") + # moves up two lines again + print("\033[2A") + +def mse_metric(x1, x2): + err = np.sum((x1 - x2) ** 2) + err /= float(x1.shape[0] * x1.shape[1] * x1.shape[2]) + return err + +def eval_seq(gt, pred): + T = len(gt) + bs = gt[0].shape[0] + ssim = np.zeros((bs, T)) + psnr = np.zeros((bs, T)) + mse = np.zeros((bs, T)) + for i in range(bs): + for t in range(T): + for c in range(gt[t][i].shape[0]): + ssim[i, t] += ssim_metric(gt[t][i][c], pred[t][i][c]) + psnr[i, t] += psnr_metric(gt[t][i][c], pred[t][i][c]) + ssim[i, t] /= gt[t][i].shape[0] + psnr[i, t] /= gt[t][i].shape[0] + mse[i, t] = mse_metric(gt[t][i], pred[t][i]) + + return mse, ssim, psnr + +# ssim function used in Babaeizadeh et al. (2017), Fin et al. (2016), etc. +def finn_eval_seq(gt, pred): + T = len(gt) + bs = gt[0].shape[0] + ssim = np.zeros((bs, T)) + psnr = np.zeros((bs, T)) + mse = np.zeros((bs, T)) + for i in range(bs): + for t in range(T): + for c in range(gt[t][i].shape[0]): + res = finn_ssim(gt[t][i][c], pred[t][i][c]).mean() + if math.isnan(res): + ssim[i, t] += -1 + else: + ssim[i, t] += res + psnr[i, t] += finn_psnr(gt[t][i][c], pred[t][i][c]) + ssim[i, t] /= gt[t][i].shape[0] + psnr[i, t] /= gt[t][i].shape[0] + mse[i, t] = mse_metric(gt[t][i], pred[t][i]) + + return mse, ssim, psnr + + +def finn_psnr(x, y): + mse = ((x - y)**2).mean() + return 10*np.log(1/mse)/np.log(10) + + +def gaussian2(size, sigma): + A = 1/(2.0*np.pi*sigma**2) + x, y = np.mgrid[-size//2 + 1:size//2 + 1, -size//2 + 1:size//2 + 1] + g = A*np.exp(-((x**2/(2.0*sigma**2))+(y**2/(2.0*sigma**2)))) + return g + +def fspecial_gauss(size, sigma): + x, y = np.mgrid[-size//2 + 1:size//2 + 1, -size//2 + 1:size//2 + 1] + g = np.exp(-((x**2 + y**2)/(2.0*sigma**2))) + return g/g.sum() + +def finn_ssim(img1, img2, cs_map=False): + img1 = img1.astype(np.float64) + img2 = img2.astype(np.float64) + size = 11 + sigma = 1.5 + window = fspecial_gauss(size, sigma) + K1 = 0.01 + K2 = 0.03 + L = 1 #bitdepth of image + C1 = (K1*L)**2 + C2 = (K2*L)**2 + mu1 = signal.fftconvolve(img1, window, mode='valid') + mu2 = signal.fftconvolve(img2, window, mode='valid') + mu1_sq = mu1*mu1 + mu2_sq = mu2*mu2 + mu1_mu2 = mu1*mu2 + sigma1_sq = signal.fftconvolve(img1*img1, window, mode='valid') - mu1_sq + sigma2_sq = signal.fftconvolve(img2*img2, window, mode='valid') - mu2_sq + sigma12 = signal.fftconvolve(img1*img2, window, mode='valid') - mu1_mu2 + if cs_map: + return (((2*mu1_mu2 + C1)*(2*sigma12 + C2))/((mu1_sq + mu2_sq + C1)* + (sigma1_sq + sigma2_sq + C2)), + (2.0*sigma12 + C2)/(sigma1_sq + sigma2_sq + C2)) + else: + return ((2*mu1_mu2 + C1)*(2*sigma12 + C2))/((mu1_sq + mu2_sq + C1)* + (sigma1_sq + sigma2_sq + C2)) + + +def init_weights(m): + classname = m.__class__.__name__ + if classname.find('Conv') != -1 or classname.find('Linear') != -1: + m.weight.data.normal_(0.0, 0.02) + m.bias.data.fill_(0) + elif classname.find('BatchNorm') != -1: + m.weight.data.normal_(1.0, 0.02) + m.bias.data.fill_(0) +