Skip to content

Commit

Permalink
Modify prepro.py to enable preprocessing on v1.0 train.
Browse files Browse the repository at this point in the history
  • Loading branch information
Karan Desai committed Jun 9, 2018
1 parent 5c2c007 commit ecbbfc3
Show file tree
Hide file tree
Showing 2 changed files with 19 additions and 14 deletions.
2 changes: 1 addition & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -89,7 +89,7 @@ python -c "import nltk; nltk.download('all')"

```sh
cd data
python prepro.py -download 1
python prepro.py -download 1 -image_root /path/to/coco/images
cd ..
```

Expand Down
31 changes: 18 additions & 13 deletions data/prepro.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,8 @@
import os
import json
import h5py
import argparse
import glob
import h5py
import json
import os
import numpy as np
from nltk.tokenize import word_tokenize
from tqdm import tqdm
Expand All @@ -15,6 +16,7 @@
parser.add_argument('-input_json_train', default='visdial_0.9_train.json', help='Input `train` json file')
parser.add_argument('-input_json_val', default='visdial_0.9_val.json', help='Input `val` json file')
parser.add_argument('-input_json_test', default='visdial_0.9_test.json', help='Input `test` json file')
parser.add_argument('-image_root', default='/path/to/coco/images', help='Path to mscoco images.')

# Output files
parser.add_argument('-output_json', default='visdial_params.json', help='Output json file')
Expand Down Expand Up @@ -174,13 +176,16 @@ def create_data_mats(data, params, dtype):
return data_mats


def get_image_ids(data, dtype):
def get_image_ids(data, params, dtype):
image_ids = [dialog['image_id'] for dialog in data['data']['dialogs']]
for i, image_id in enumerate(image_ids):
path = '%s2014/COCO_%s2014_%012d.jpg'
print("[%s] Preparing image paths with image_ids..." % data['split'])
for i, image_id in enumerate(tqdm(image_ids)):
path = glob.glob(os.path.join(
params.image_root, '*', '*%s.jpg' % str(image_id)))[0]
path = '/'.join(path.split('/')[-2:])
if dtype == 'test':
path = '%s2017/VisualDialog_%s2017_%012d.jpg'
image_ids[i] = path % (dtype, dtype, image_id)
path = '%s2017/VisualDialog_%s2017_%012d.jpg' % (dtype, dtype, image_id)
image_ids[i] = path
return image_ids


Expand Down Expand Up @@ -265,10 +270,10 @@ def get_image_ids(data, dtype):
out['ind2word'] = ind2word
out['word2ind'] = word2ind
if args.train_split == 'train':
out['unique_img_train'] = get_image_ids(data_train, 'train')
out['unique_img_val'] = get_image_ids(data_val, 'val')
out['unique_img_train'] = get_image_ids(data_train, args, 'train')
out['unique_img_val'] = get_image_ids(data_val, args, 'val')
elif args.train_split == 'trainval':
out['unique_img_train'] = get_image_ids(data_train, 'train') + \
get_image_ids(data_val, 'val')
out['unique_img_test'] = get_image_ids(data_test, 'test')
out['unique_img_train'] = get_image_ids(data_train, args, 'train') + \
get_image_ids(data_val, args, 'val')
out['unique_img_test'] = get_image_ids(data_test, args, 'test')
json.dump(out, open(args.output_json, 'w'))

0 comments on commit ecbbfc3

Please sign in to comment.