Skip to content

Commit

Permalink
Add rebar model.
Browse files Browse the repository at this point in the history
  • Loading branch information
gjtucker committed Jul 24, 2017
1 parent e88d0cf commit 2012e8d
Show file tree
Hide file tree
Showing 9 changed files with 1,798 additions and 0 deletions.
1 change: 1 addition & 0 deletions CODEOWNERS
Validating CODEOWNERS rules …
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ object_detection/* @jch1 @tombstone @derekjchow @jesu9 @dreamdragon
pcl_rl/* @ofirnachum
ptn/* @xcyan @arkanath @hellojas @honglaklee
real_nvp/* @laurent-dinh
rebar/* @gjtucker
resnet/* @panyx0718
skip_thoughts/* @cshallue
slim/* @sguada @nathansilberman
Expand Down
122 changes: 122 additions & 0 deletions rebar/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,122 @@
# REINFORCing Concrete with REBAR
*Implemention of REBAR (and other closely related methods) as described
in "REBAR: Low-variance, unbiased gradient estimates for discrete latent variable models" by
George Tucker, Andriy Mnih, Chris J. Maddison, Dieterich Lawson, Jascha Sohl-Dickstein [https://arxiv.org/abs/1703.07370](https://arxiv.org/abs/1703.07370).*

Learning in models with discrete latent variables is challenging due to high variance gradient estimators. Generally, approaches have relied on control variates to reduce the variance of the REINFORCE estimator. Recent work (Jang et al. 2016; Maddison et al. 2016) has taken a different approach, introducing a continuous relaxation of discrete variables to produce low-variance, but biased, gradient estimates. In this work, we combine the two approaches through a novel control variate that produces low-variance, unbiased gradient estimates. Then, we introduce a novel continuous relaxation and show that the tightness of the relaxation can be adapted online, removing it as a hyperparameter. We show state-of-the-art variance reduction on several benchmark generative modeling tasks, generally leading to faster convergence to a better final log likelihood.

REBAR applied to multilayer sigmoid belief networks is implemented in rebar.py and rebar_train.py provides a training/evaluation setup. As a comparison, we also implemented the following methods:
* [NVIL](https://arxiv.org/abs/1402.0030)
* [MuProp](https://arxiv.org/abs/1511.05176)
* [Gumbel-Softmax](https://arxiv.org/abs/1611.01144)

The code is not optimized and some computation is repeated for ease of
implementation. We hope that this code will be a useful starting point for future research in this area.

## Quick Start:

Requirements:
* TensorFlow (see tensorflow.org for how to install)
* MNIST dataset
* Omniglot dataset

First download datasets, by selecting URLs to download the data from. Then
fill in the download_data.py script like so:

```
MNIST_URL = 'http://yann.lecun.com/exdb/mnist'
MNIST_BINARIZED_URL = 'http://www.cs.toronto.edu/~larocheh/public/datasets/binarized_mnist'
OMNIGLOT_URL = 'https://github.com/yburda/iwae/raw/master/datasets/OMNIGLOT/chardata.mat'
```

Then run the script to download the data:

```
python download_data.py
```

Then run the training script:

```
python rebar_train.py --hparams="model=SBNDynamicRebar,learning_rate=0.0003,n_layer=2,task=sbn"
```

and you should see something like:

```
Step 2084: [-231.026474 0.3711713 1. 1.06934261 1.07023323
1.02173257 1.02171052 1. 1. 1. 1. ]
-3.6465678215
Step 4168: [-156.86795044 0.3097114 1. 1.03964758 1.03936625
1.02627242 1.02629256 1. 1. 1. 1. ]
-4.42727231979
Step 6252: [-143.4650116 0.26153237 1. 1.03633797 1.03600132
1.02639604 1.02639794 1. 1. 1. 1. ]
-4.85577583313
Step 8336: [-137.65275574 0.22313026 1. 1.03467286 1.03428006
1.02336085 1.02335203 0.99999988 1. 0.99999988
1. ]
-4.95563364029
```

The first number in the list is the log likelihood lower bound and the number
after the list is the log of the variance of the gradient estimator. The rest of
the numbers are for debugging.

We can also compare the variance between methods:

```
python rebar_train.py \
--hparams="model=SBNTrackGradVariances,learning_rate=0.0003,n_layer=2,task=omni"
```

and you should see something like:

```
Step 959: [ -2.60478699e+02 3.84281784e-01 6.31126612e-02 3.27319391e-02
6.13379292e-03 1.98278503e-04 1.96425783e-04 8.83973844e-04
8.70995224e-04 -inf]
('DynamicREBAR', -3.725339889526367)
('MuProp', -0.033569782972335815)
('NVIL', 2.7640280723571777)
('REBAR', -3.539274215698242)
('SimpleMuProp', -0.040744658559560776)
Step 1918: [ -2.06948471e+02 3.35904926e-01 5.20901568e-03 7.81541676e-05
2.06885766e-03 1.08521657e-04 1.07351625e-04 2.30646547e-04
2.26554010e-04 -8.22885323e+00]
('DynamicREBAR', -3.864381790161133)
('MuProp', -0.7183765172958374)
('NVIL', 2.266523599624634)
('REBAR', -3.662022113800049)
('SimpleMuProp', -0.7071359157562256)
```
where the tuples show the log of the variance of the gradient estimators.

The training script has a number of hyperparameter configuration flags:
* task (sbn): one of {sbn, sp, omni} which correspond to MNIST generative
modeling, structured prediction on MNIST, and Omniglot generative modeling,
respectively
* model (SBNGumbel) : one of {SBN, SBNNVIL, SBNMuProp, SBNSimpleMuProp,
SBNRebar, SBNDynamicRebar, SBNGumbel SBNTrackGradVariances}. DynamicRebar automatically
adjusts the temperature, whereas Rebar and Gumbel-Softmax require tuning the
temperature. The ones named after
methods uses that method to estimate the gradients (SBN refers to
REINFORCE). SBNTrackGradVariances runs multiple methods and follows a single
optimization trajectory.
* n_hidden (200): number of hidden nodes per layer
* n_layer (1): number of layers in the model
* nonlinear (false): if true use 2 x tanh layers between each stochastic layer,
otherwise use a linear layer
* learning_rate (0.001): learning rate
* temperature (0.5): temperature hyperparameter (for DynamicRebar, this is the initial
value of the temperature)
* n_samples (1): number of samples used to compute the gradient estimator (for the
experiments in the paper, set to 1)
* batch_size (24): batch size
* muprop_relaxation (true): if true use the new relaxation described in the paper,
otherwise use the Concrete/Gumbel softmax relaxation
* dynamic_b (false): if true dynamically binarize the training set. This
increases the effective training dataset size and reduces overfitting, though
it is not a standard dataset

Maintained by George Tucker ([email protected], github user: gjtucker).
25 changes: 25 additions & 0 deletions rebar/config.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,25 @@
# Copyright 2017 Google Inc. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================

"""Configuration variables."""

from __future__ import absolute_import
from __future__ import division
from __future__ import print_function

DATA_DIR = 'data'
MNIST_BINARIZED = 'mnist_salakhutdinov_07-19-2017.pkl'
MNIST_FLOAT = 'mnist_train_xs_07-19-2017.npy'
OMNIGLOT = 'omniglot_07-19-2017.mat'
100 changes: 100 additions & 0 deletions rebar/datasets.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,100 @@
# Copyright 2017 Google Inc. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================

"""Library of datasets for REBAR."""

from __future__ import absolute_import
from __future__ import division
from __future__ import print_function

import random
import os
import scipy.io
import numpy as np
import cPickle as pickle
import tensorflow as tf
import config
gfile = tf.gfile


def load_data(hparams):
# Load data
if hparams.task in ['sbn', 'sp']:
reader = read_MNIST
elif hparams.task == 'omni':
reader = read_omniglot
x_train, x_valid, x_test = reader(binarize=not hparams.dynamic_b)

return x_train, x_valid, x_test

def read_MNIST(binarize=False):
"""Reads in MNIST images.
Args:
binarize: whether to use the fixed binarization
Returns:
x_train: 50k training images
x_valid: 10k validation images
x_test: 10k test images
"""
with gfile.FastGFile(os.path.join(config.DATA_DIR, config.MNIST_BINARIZED), 'r') as f:
(x_train, _), (x_valid, _), (x_test, _) = pickle.load(f)

if not binarize:
with gfile.FastGFile(os.path.join(config.DATA_DIR, config.MNIST_FLOAT), 'r') as f:
x_train = np.load(f).reshape(-1, 784)

return x_train, x_valid, x_test

def read_omniglot(binarize=False):
"""Reads in Omniglot images.
Args:
binarize: whether to use the fixed binarization
Returns:
x_train: training images
x_valid: validation images
x_test: test images
"""
n_validation=1345

def reshape_data(data):
return data.reshape((-1, 28, 28)).reshape((-1, 28*28), order='fortran')

omni_raw = scipy.io.loadmat(os.path.join(config.DATA_DIR, config.OMNIGLOT))

train_data = reshape_data(omni_raw['data'].T.astype('float32'))
test_data = reshape_data(omni_raw['testdata'].T.astype('float32'))

# Binarize the data with a fixed seed
if binarize:
np.random.seed(5)
train_data = (np.random.rand(*train_data.shape) < train_data).astype(float)
test_data = (np.random.rand(*test_data.shape) < test_data).astype(float)

shuffle_seed = 123
permutation = np.random.RandomState(seed=shuffle_seed).permutation(train_data.shape[0])
train_data = train_data[permutation]

x_train = train_data[:-n_validation]
x_valid = train_data[-n_validation:]
x_test = test_data

return x_train, x_valid, x_test

89 changes: 89 additions & 0 deletions rebar/download_data.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,89 @@
# Copyright 2017 Google Inc. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================

"""Download MNIST, Omniglot datasets for Rebar."""

from __future__ import absolute_import
from __future__ import division
from __future__ import print_function

import urllib
import gzip
import os
import config
import struct
import numpy as np
import cPickle as pickle
import datasets

MNIST_URL = 'see README'
MNIST_BINARIZED_URL = 'see README'
OMNIGLOT_URL = 'see README'

MNIST_FLOAT_TRAIN = 'train-images-idx3-ubyte'


def load_mnist_float(local_filename):
with open(local_filename, 'rb') as f:
f.seek(4)
nimages, rows, cols = struct.unpack('>iii', f.read(12))
dim = rows*cols

images = np.fromfile(f, dtype=np.dtype(np.ubyte))
images = (images/255.0).astype('float32').reshape((nimages, dim))

return images

if __name__ == '__main__':
if not os.path.exists(config.DATA_DIR):
os.makedirs(config.DATA_DIR)

# Get MNIST and convert to npy file
local_filename = os.path.join(config.DATA_DIR, MNIST_FLOAT_TRAIN)
if not os.path.exists(local_filename):
urllib.urlretrieve("%s/%s.gz" % (MNIST_URL, MNIST_FLOAT_TRAIN), local_filename+'.gz')
with gzip.open(local_filename+'.gz', 'rb') as f:
file_content = f.read()
with open(local_filename, 'wb') as f:
f.write(file_content)
os.remove(local_filename+'.gz')

mnist_float_train = load_mnist_float(local_filename)[:-10000]
# save in a nice format
np.save(os.path.join(config.DATA_DIR, config.MNIST_FLOAT), mnist_float_train)

# Get binarized MNIST
splits = ['train', 'valid', 'test']
mnist_binarized = []
for split in splits:
filename = 'binarized_mnist_%s.amat' % split
url = '%s/binarized_mnist_%s.amat' % (MNIST_BINARIZED_URL, split)
local_filename = os.path.join(config.DATA_DIR, filename)
if not os.path.exists(local_filename):
urllib.urlretrieve(url, local_filename)

with open(local_filename, 'rb') as f:
mnist_binarized.append((np.array([map(int, line.split()) for line in f.readlines()]).astype('float32'), None))

# save in a nice format
with open(os.path.join(config.DATA_DIR, config.MNIST_BINARIZED), 'w') as out:
pickle.dump(mnist_binarized, out)

# Get Omniglot
local_filename = os.path.join(config.DATA_DIR, config.OMNIGLOT)
if not os.path.exists(local_filename):
urllib.urlretrieve(OMNIGLOT_URL,
local_filename)

30 changes: 30 additions & 0 deletions rebar/logger.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,30 @@
# Copyright 2017 Google Inc. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================

"""Logger for REBAR"""

from __future__ import absolute_import
from __future__ import division
from __future__ import print_function

class Logger:
def __init__(self):
pass

def log(self, key, value):
pass

def flush(self):
pass
Loading

0 comments on commit 2012e8d

Please sign in to comment.