Skip to content

Commit

Permalink
fixed prior and learning prior code for bair and moving mnist.
Browse files Browse the repository at this point in the history
  • Loading branch information
edenton committed Feb 20, 2018
1 parent a5dc5ba commit fad4feb
Show file tree
Hide file tree
Showing 11 changed files with 1,814 additions and 99 deletions.
100 changes: 1 addition & 99 deletions .gitignore
Original file line number Diff line number Diff line change
@@ -1,101 +1,3 @@
# Byte-compiled / optimized / DLL files
__pycache__/
*.py[cod]
*$py.class

# C extensions
*.so

# Distribution / packaging
.Python
env/
build/
develop-eggs/
dist/
downloads/
eggs/
.eggs/
lib/
lib64/
parts/
sdist/
var/
wheels/
*.egg-info/
.installed.cfg
*.egg

# PyInstaller
# Usually these files are written by a python script from a template
# before PyInstaller builds the exe, so as to inject date/other infos into it.
*.manifest
*.spec

# Installer logs
pip-log.txt
pip-delete-this-directory.txt

# Unit test / coverage reports
htmlcov/
.tox/
.coverage
.coverage.*
.cache
nosetests.xml
coverage.xml
*.cover
.hypothesis/

# Translations
*.mo
*.pot

# Django stuff:
*.log
local_settings.py

# Flask stuff:
instance/
.webassets-cache

# Scrapy stuff:
.scrapy

# Sphinx documentation
docs/_build/

# PyBuilder
target/

# Jupyter Notebook
.ipynb_checkpoints

# pyenv
.python-version

# celery beat schedule file
celerybeat-schedule

# SageMath parsed files
*.sage.py

# dotenv
.env

# virtualenv
.venv
venv/
ENV/

# Spyder project settings
.spyderproject
.spyproject

# Rope project settings
.ropeproject

# mkdocs documentation
/site

# mypy
.mypy_cache/
*.pyc
63 changes: 63 additions & 0 deletions data/bair.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,63 @@
import os
import io
from scipy.misc import imresize
import numpy as np
from PIL import Image
from scipy.misc import imresize
from scipy.misc import imread


ROOT_DIR = '/misc/vlgscratch4/FergusGroup/denton/data/bair_robot_push/processed_data/'
class RobotPush(object):

"""Data Handler that loads robot pushing data."""

def __init__(self, train=True, seq_len=20, image_size=64):
self.root_dir = ROOT_DIR
if train:
self.data_dir = '%s/train' % self.root_dir
self.ordered = False
else:
self.data_dir = '%s/test' % self.root_dir
self.ordered = True
#self.data_dir = '/misc/vlgscratch4/FergusGroup/denton/data/push-dataset/processed_data/push_testseen'
self.dirs = []
for d1 in os.listdir(self.data_dir):
for d2 in os.listdir('%s/%s' % (self.data_dir, d1)):
self.dirs.append('%s/%s/%s' % (self.data_dir, d1, d2))
self.seq_len = seq_len
self.image_size = image_size
self.seed_is_set = False # multi threaded loading
self.d = 0

def set_seed(self, seed):
if not self.seed_is_set:
self.seed_is_set = True
np.random.seed(seed)

def __len__(self):
return 10000

def get_seq(self):
if self.ordered:
d = self.dirs[self.d]
if self.d == len(self.dirs) - 1:
self.d = 0
else:
self.d+=1
else:
d = self.dirs[np.random.randint(len(self.dirs))]
image_seq = []
for i in range(self.seq_len):
fname = '%s/%d.png' % (d, i)
im = imread(fname).reshape(1, 64, 64, 3)
image_seq.append(im/255.)
image_seq = np.concatenate(image_seq, axis=0)
return image_seq


def __getitem__(self, index):
self.set_seed(index)
return self.get_seq()


Loading

0 comments on commit fad4feb

Please sign in to comment.