forked from paarthneekhara/text-to-image
-
Notifications
You must be signed in to change notification settings - Fork 0
/
download_datasets.py
116 lines (94 loc) · 4.2 KB
/
download_datasets.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
# downloads/extracts datasets described in the README.md
import os
import sys
import errno
import tarfile
if sys.version_info >= (3,):
from urllib.request import urlretrieve
else:
from urllib import urlretrieve
DATA_DIR = 'Data'
# http://stackoverflow.com/questions/273192/how-to-check-if-a-directory-exists-and-create-it-if-necessary
def make_sure_path_exists(path):
try:
os.makedirs(path)
except OSError as exception:
if exception.errno != errno.EEXIST:
raise
def create_data_paths():
if not os.path.isdir(DATA_DIR):
raise EnvironmentError('Needs to be run from project directory containing ' + DATA_DIR)
needed_paths = [
os.path.join(DATA_DIR, 'samples'),
os.path.join(DATA_DIR, 'val_samples'),
os.path.join(DATA_DIR, 'Models'),
]
for p in needed_paths:
make_sure_path_exists(p)
# adapted from http://stackoverflow.com/questions/51212/how-to-write-a-download-progress-indicator-in-python
def dl_progress_hook(count, blockSize, totalSize):
percent = int(count * blockSize * 100 / totalSize)
sys.stdout.write("\r" + "...%d%%" % percent)
sys.stdout.flush()
def download_dataset(data_name):
if data_name == 'flowers':
print('== Flowers dataset ==')
flowers_dir = os.path.join(DATA_DIR, 'flowers')
flowers_jpg_tgz = os.path.join(flowers_dir, '102flowers.tgz')
make_sure_path_exists(flowers_dir)
# the original google drive link at https://drive.google.com/file/d/0B0ywwgffWnLLcms2WWJQRFNSWXM/view
# from https://github.com/reedscot/icml2016 is problematic to download automatically, so included
# the text_c10 directory from that archive as a bzipped file in the repo
captions_tbz = os.path.join(DATA_DIR, 'flowers_text_c10.tar.bz2')
print('Extracting ' + captions_tbz)
captions_tar = tarfile.open(captions_tbz, 'r:bz2')
captions_tar.extractall(flowers_dir)
flowers_url = 'http://www.robots.ox.ac.uk/~vgg/data/flowers/102/102flowers.tgz'
print('Downloading ' + flowers_jpg_tgz + ' from ' + flowers_url)
urlretrieve(flowers_url, flowers_jpg_tgz,
reporthook=dl_progress_hook)
print('Extracting ' + flowers_jpg_tgz)
flowers_jpg_tar = tarfile.open(flowers_jpg_tgz, 'r:gz')
flowers_jpg_tar.extractall(flowers_dir) # archive contains jpg/ folder
elif data_name == 'skipthoughts':
print('== Skipthoughts models ==')
SKIPTHOUGHTS_DIR = os.path.join(DATA_DIR, 'skipthoughts')
SKIPTHOUGHTS_BASE_URL = 'http://www.cs.toronto.edu/~rkiros/models/'
make_sure_path_exists(SKIPTHOUGHTS_DIR)
# following https://github.com/ryankiros/skip-thoughts#getting-started
skipthoughts_files = [
'dictionary.txt', 'utable.npy', 'btable.npy', 'uni_skip.npz', 'uni_skip.npz.pkl', 'bi_skip.npz',
'bi_skip.npz.pkl',
]
for filename in skipthoughts_files:
src_url = SKIPTHOUGHTS_BASE_URL + filename
print('Downloading ' + src_url)
urlretrieve(src_url, os.path.join(SKIPTHOUGHTS_DIR, filename),
reporthook=dl_progress_hook)
elif data_name == 'nltk_punkt':
import nltk
print('== NLTK pre-trained Punkt tokenizer for English ==')
nltk.download('punkt')
elif data_name == 'pretrained_model':
print('== Pretrained model ==')
MODEL_DIR = os.path.join(DATA_DIR, 'Models')
pretrained_model_filename = 'latest_model_flowers_temp.ckpt'
src_url = 'https://bitbucket.org/paarth_neekhara/texttomimagemodel/raw/74a4bbaeee26fe31e148a54c4f495694680e2c31/' + pretrained_model_filename
print('Downloading ' + src_url)
urlretrieve(
src_url,
os.path.join(MODEL_DIR, pretrained_model_filename),
reporthook=dl_progress_hook,
)
else:
raise ValueError('Unknown dataset name: ' + data_name)
def main():
create_data_paths()
# TODO: make configurable via command-line
download_dataset('flowers')
download_dataset('skipthoughts')
download_dataset('nltk_punkt')
download_dataset('pretrained_model')
print('Done')
if __name__ == '__main__':
main()