Skip to content

Commit

Permalink
Add files via upload
Browse files Browse the repository at this point in the history
  • Loading branch information
wnorcbrown authored Jun 20, 2018
1 parent 186251c commit 8bd9774
Show file tree
Hide file tree
Showing 12 changed files with 1,956 additions and 0 deletions.
79 changes: 79 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,79 @@
# Learning Conditioned Graph Structures for Interpretable Visual Question Answering

This code provides a pytorch implementation of our graph learning method for Visual Question Answering as described in [Learning Conditioned Graph Structures for Interpretable Visual Question Answering](https://arxiv.org/abs/1806.07243)

### Model diagram
![](./figures/model.png)
### Examples of learned graph structures
![](./figures/examples.png)

## Getting Started

### Reference

If you use our code or any of the ideas from our paper please cite:
```
@article{learningconditionedgraph,
author = {Will Norcliffe-Brown and Efstathios Vafeias and Sarah Parisot},
title = {Learning Conditioned Graph Structures for Interpretable Visual Question Answering},
journal = {arXiv preprint arXiv:1806.07243},
year = {2018}
}
```

### Requirements

- [pytorch (0.2.0) (with CUDA)](https://pytorch.org/)
- [zarr (v2.2.0rc3)](https://github.com/zarr-developers/zarr)
- [tdqm](https://github.com/tqdm/tqdm)
- [spacy](https://spacy.io/usage/)

### Data

To download and unzip the required datasets, change to the data folder and run
```
$ cd data; python download_data.py
```

To preprocess the image data and text data the following commands can be executed respectively. Setting the data variable to trainval or test for preprocess_image.py and train, val or test for preprocess_text.py depending on which dataset you want to preprocess
```
$ python preprocess_image.py --data trainval; python preprocess_text.py --data train
```

### Training

To train a model on the train set with our default parameters run
```
$ python run.py --train
```
and to train a model on the train and validation set for evaluation on the test set run
```
$ python run.py --trainval
```
Models can be validated via
```
$ python run.py --eval --model_path path_to_your_model
```
and a json of results from the test set can be produced with
```
$ python run.py --test --model_path path_to_your_model
```
To see a list and description of the model training parameters run
```
$ python run.py --help
```

## Authors

* **Will Norcliffe-Brown**
* **Sarah Parisot**
* **Stathis Vafeias**


## License

This project is licensed under the Apache 2.0 license - see [Apache license](license.txt)

## Acknowledgements

Our code is based on this implementation of the 2017 VQA challenge winner [https://github.com/markdtw/vqa-winner-cvprw-2017](https://github.com/markdtw/vqa-winner-cvprw-2017)
47 changes: 47 additions & 0 deletions data/download_data.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,47 @@
# Copyright 2018 AimBrain Ltd.

# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at

# http://www.apache.org/licenses/LICENSE-2.0

# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import os

# download input questions (training, validation and test sets)
os.system(
'wget http://visualqa.org/data/mscoco/vqa/v2_Questions_Train_mscoco.zip -P zip/')
os.system(
'wget http://visualqa.org/data/mscoco/vqa/v2_Questions_Val_mscoco.zip -P zip/')
os.system(
'wget http://visualqa.org/data/mscoco/vqa/v2_Questions_Test_mscoco.zip -P zip/')

# download annotations (training and validation sets)
os.system(
'wget http://visualqa.org/data/mscoco/vqa/v2_Annotations_Train_mscoco.zip -P zip/')
os.system(
'wget http://visualqa.org/data/mscoco/vqa/v2_Annotations_Val_mscoco.zip -P zip/')

# download pre-trained glove embeddings
os.system('wget http://nlp.stanford.edu/data/glove.6B.zip -P zip/')

# download rcnn extracted features (may take a while, both very large files)
os.system(
'wget https://imagecaption.blob.core.windows.net/imagecaption/trainval_36.zip -P zip/')
os.system(
'wget https://imagecaption.blob.core.windows.net/imagecaption/test2015_36.zip -P zip/')

# extract them
os.system('unzip zip/v2_Questions_Train_mscoco.zip -d raw/')
os.system('unzip zip/v2_Questions_Val_mscoco.zip -d raw/')
os.system('unzip zip/v2_Questions_Test_mscoco.zip -d raw/')
os.system('unzip zip/v2_Annotations_Train_mscoco.zip -d raw/')
os.system('unzip zip/v2_Annotations_Val_mscoco.zip -d raw/')
os.system('unzip zip/glove.6B.zip -d data/')
os.system('unzip zip/trainval_36.zip -d raw/')
140 changes: 140 additions & 0 deletions data/preprocess_image.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,140 @@
# Copyright 2018 AimBrain Ltd.

# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at

# http://www.apache.org/licenses/LICENSE-2.0

# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from __future__ import division
from __future__ import print_function
from __future__ import absolute_import

import os
import argparse
import base64
import numpy as np
import csv
import sys
import h5py
import pandas as pd
import zarr
from tqdm import tqdm


csv.field_size_limit(sys.maxsize)


def features_to_zarr(phase):
FIELDNAMES = ['image_id', 'image_w', 'image_h',
'num_boxes', 'boxes', 'features']

if phase == 'trainval':
infiles = [
'raw/trainval_36/trainval_resnet101_faster_rcnn_genome_36.tsv',
]
elif phase == 'test':
infiles = [
'raw/test2015_36/test2015_resnet101_faster_rcnn_genome_36.tsv',
]
else:
raise SystemExit('Unrecognised phase')

# Read the tsv and load data in a dictionary
in_data = {}
for infile in infiles:
print(infile)
with open(infile, "r") as tsv_in_file:
reader = csv.DictReader(
tsv_in_file, delimiter='\t', fieldnames=FIELDNAMES)
for item in reader:
item['image_id'] = str(item['image_id'])
item['image_h'] = int(item['image_h'])
item['image_w'] = int(item['image_w'])
item['num_boxes'] = int(item['num_boxes'])
for field in ['boxes', 'features']:
encoded_str = base64.decodestring(
item[field].encode('utf-8'))
item[field] = np.frombuffer(encoded_str,
dtype=np.float32).reshape((item['num_boxes'], -1))
in_data[item['image_id']] = item

# convert dict to pandas dataframe
train = pd.DataFrame.from_dict(in_data)
train = train.transpose()

# create image sizes csv
print('Writing image sizes csv...')
d = train.to_dict()
dw = d['image_w']
dh = d['image_h']
d = [dw, dh]
dwh = {}
for k in dw.keys():
dwh[k] = np.array([d0[k] for d0 in d])
image_sizes = pd.DataFrame(dwh)
image_sizes.to_csv(phase + '_image_size.csv')

# select bounding box coordinates and fill hdf5
h = h5py.File(phase + 'box.hdf5', mode='w')
t = train['boxes']
d = t.to_dict()
print('Creating bounding box file...')
for k, v in tqdm(d.items()):
h.create_dataset(str(k), data=v)
if h:
h.close()

# convert to zarr
print('Writing zarr file...')
i_feat = h5py.File(phase + 'box.hdf5', 'r', libver='latest')
dest = zarr.open_group(phase + '_boxes.zarr', mode='w')
zarr.copy_all(i_feat, dest)
i_feat.close()
dest.close()
os.remove(phase + 'box.hdf5')

# select features and fill hdf5
h = h5py.File(phase + '.hdf5', mode='w')
t = train['features']
d = t.to_dict()
print('Creating image features file...')
for k, v in tqdm(d.items()):
h.create_dataset(str(k), data=v)
if h:
h.close()

# convert to zarr
print('Writing zarr file...')
i_feat = h5py.File(phase + '.hdf5', 'r', libver='latest')
dest = zarr.open_group(phase + '.zarr', mode='w')
zarr.copy_all(i_feat, dest)
i_feat.close()
dest.close()
os.remove(phase + '.hdf5')


if __name__ == '__main__':
parser = argparse.ArgumentParser(
description='Preprocessing for VQA v2 image data')
parser.add_argument('-d', '--data', nargs='+', help='trainval, and/or test, list of data phases to be processed', required=True)
args, unparsed = parser.parse_known_args()
if len(unparsed) != 0:
raise SystemExit('Unknown argument: {}'.format(unparsed))

phase_list = args.data

for phase in phase_list:
# First download and extract

if not os.path.exists(phase + '.zarr'):
print('Converting features tsv to zarr file...')
features_to_zarr(phase)

print('Done')
Loading

0 comments on commit 8bd9774

Please sign in to comment.