-
Notifications
You must be signed in to change notification settings - Fork 35
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
1 parent
186251c
commit 8bd9774
Showing
12 changed files
with
1,956 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,79 @@ | ||
# Learning Conditioned Graph Structures for Interpretable Visual Question Answering | ||
|
||
This code provides a pytorch implementation of our graph learning method for Visual Question Answering as described in [Learning Conditioned Graph Structures for Interpretable Visual Question Answering](https://arxiv.org/abs/1806.07243) | ||
|
||
### Model diagram | ||
![](./figures/model.png) | ||
### Examples of learned graph structures | ||
![](./figures/examples.png) | ||
|
||
## Getting Started | ||
|
||
### Reference | ||
|
||
If you use our code or any of the ideas from our paper please cite: | ||
``` | ||
@article{learningconditionedgraph, | ||
author = {Will Norcliffe-Brown and Efstathios Vafeias and Sarah Parisot}, | ||
title = {Learning Conditioned Graph Structures for Interpretable Visual Question Answering}, | ||
journal = {arXiv preprint arXiv:1806.07243}, | ||
year = {2018} | ||
} | ||
``` | ||
|
||
### Requirements | ||
|
||
- [pytorch (0.2.0) (with CUDA)](https://pytorch.org/) | ||
- [zarr (v2.2.0rc3)](https://github.com/zarr-developers/zarr) | ||
- [tdqm](https://github.com/tqdm/tqdm) | ||
- [spacy](https://spacy.io/usage/) | ||
|
||
### Data | ||
|
||
To download and unzip the required datasets, change to the data folder and run | ||
``` | ||
$ cd data; python download_data.py | ||
``` | ||
|
||
To preprocess the image data and text data the following commands can be executed respectively. Setting the data variable to trainval or test for preprocess_image.py and train, val or test for preprocess_text.py depending on which dataset you want to preprocess | ||
``` | ||
$ python preprocess_image.py --data trainval; python preprocess_text.py --data train | ||
``` | ||
|
||
### Training | ||
|
||
To train a model on the train set with our default parameters run | ||
``` | ||
$ python run.py --train | ||
``` | ||
and to train a model on the train and validation set for evaluation on the test set run | ||
``` | ||
$ python run.py --trainval | ||
``` | ||
Models can be validated via | ||
``` | ||
$ python run.py --eval --model_path path_to_your_model | ||
``` | ||
and a json of results from the test set can be produced with | ||
``` | ||
$ python run.py --test --model_path path_to_your_model | ||
``` | ||
To see a list and description of the model training parameters run | ||
``` | ||
$ python run.py --help | ||
``` | ||
|
||
## Authors | ||
|
||
* **Will Norcliffe-Brown** | ||
* **Sarah Parisot** | ||
* **Stathis Vafeias** | ||
|
||
|
||
## License | ||
|
||
This project is licensed under the Apache 2.0 license - see [Apache license](license.txt) | ||
|
||
## Acknowledgements | ||
|
||
Our code is based on this implementation of the 2017 VQA challenge winner [https://github.com/markdtw/vqa-winner-cvprw-2017](https://github.com/markdtw/vqa-winner-cvprw-2017) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,47 @@ | ||
# Copyright 2018 AimBrain Ltd. | ||
|
||
# Licensed under the Apache License, Version 2.0 (the "License"); | ||
# you may not use this file except in compliance with the License. | ||
# You may obtain a copy of the License at | ||
|
||
# http://www.apache.org/licenses/LICENSE-2.0 | ||
|
||
# Unless required by applicable law or agreed to in writing, software | ||
# distributed under the License is distributed on an "AS IS" BASIS, | ||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
# See the License for the specific language governing permissions and | ||
# limitations under the License. | ||
|
||
import os | ||
|
||
# download input questions (training, validation and test sets) | ||
os.system( | ||
'wget http://visualqa.org/data/mscoco/vqa/v2_Questions_Train_mscoco.zip -P zip/') | ||
os.system( | ||
'wget http://visualqa.org/data/mscoco/vqa/v2_Questions_Val_mscoco.zip -P zip/') | ||
os.system( | ||
'wget http://visualqa.org/data/mscoco/vqa/v2_Questions_Test_mscoco.zip -P zip/') | ||
|
||
# download annotations (training and validation sets) | ||
os.system( | ||
'wget http://visualqa.org/data/mscoco/vqa/v2_Annotations_Train_mscoco.zip -P zip/') | ||
os.system( | ||
'wget http://visualqa.org/data/mscoco/vqa/v2_Annotations_Val_mscoco.zip -P zip/') | ||
|
||
# download pre-trained glove embeddings | ||
os.system('wget http://nlp.stanford.edu/data/glove.6B.zip -P zip/') | ||
|
||
# download rcnn extracted features (may take a while, both very large files) | ||
os.system( | ||
'wget https://imagecaption.blob.core.windows.net/imagecaption/trainval_36.zip -P zip/') | ||
os.system( | ||
'wget https://imagecaption.blob.core.windows.net/imagecaption/test2015_36.zip -P zip/') | ||
|
||
# extract them | ||
os.system('unzip zip/v2_Questions_Train_mscoco.zip -d raw/') | ||
os.system('unzip zip/v2_Questions_Val_mscoco.zip -d raw/') | ||
os.system('unzip zip/v2_Questions_Test_mscoco.zip -d raw/') | ||
os.system('unzip zip/v2_Annotations_Train_mscoco.zip -d raw/') | ||
os.system('unzip zip/v2_Annotations_Val_mscoco.zip -d raw/') | ||
os.system('unzip zip/glove.6B.zip -d data/') | ||
os.system('unzip zip/trainval_36.zip -d raw/') |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,140 @@ | ||
# Copyright 2018 AimBrain Ltd. | ||
|
||
# Licensed under the Apache License, Version 2.0 (the "License"); | ||
# you may not use this file except in compliance with the License. | ||
# You may obtain a copy of the License at | ||
|
||
# http://www.apache.org/licenses/LICENSE-2.0 | ||
|
||
# Unless required by applicable law or agreed to in writing, software | ||
# distributed under the License is distributed on an "AS IS" BASIS, | ||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
# See the License for the specific language governing permissions and | ||
# limitations under the License. | ||
|
||
from __future__ import division | ||
from __future__ import print_function | ||
from __future__ import absolute_import | ||
|
||
import os | ||
import argparse | ||
import base64 | ||
import numpy as np | ||
import csv | ||
import sys | ||
import h5py | ||
import pandas as pd | ||
import zarr | ||
from tqdm import tqdm | ||
|
||
|
||
csv.field_size_limit(sys.maxsize) | ||
|
||
|
||
def features_to_zarr(phase): | ||
FIELDNAMES = ['image_id', 'image_w', 'image_h', | ||
'num_boxes', 'boxes', 'features'] | ||
|
||
if phase == 'trainval': | ||
infiles = [ | ||
'raw/trainval_36/trainval_resnet101_faster_rcnn_genome_36.tsv', | ||
] | ||
elif phase == 'test': | ||
infiles = [ | ||
'raw/test2015_36/test2015_resnet101_faster_rcnn_genome_36.tsv', | ||
] | ||
else: | ||
raise SystemExit('Unrecognised phase') | ||
|
||
# Read the tsv and load data in a dictionary | ||
in_data = {} | ||
for infile in infiles: | ||
print(infile) | ||
with open(infile, "r") as tsv_in_file: | ||
reader = csv.DictReader( | ||
tsv_in_file, delimiter='\t', fieldnames=FIELDNAMES) | ||
for item in reader: | ||
item['image_id'] = str(item['image_id']) | ||
item['image_h'] = int(item['image_h']) | ||
item['image_w'] = int(item['image_w']) | ||
item['num_boxes'] = int(item['num_boxes']) | ||
for field in ['boxes', 'features']: | ||
encoded_str = base64.decodestring( | ||
item[field].encode('utf-8')) | ||
item[field] = np.frombuffer(encoded_str, | ||
dtype=np.float32).reshape((item['num_boxes'], -1)) | ||
in_data[item['image_id']] = item | ||
|
||
# convert dict to pandas dataframe | ||
train = pd.DataFrame.from_dict(in_data) | ||
train = train.transpose() | ||
|
||
# create image sizes csv | ||
print('Writing image sizes csv...') | ||
d = train.to_dict() | ||
dw = d['image_w'] | ||
dh = d['image_h'] | ||
d = [dw, dh] | ||
dwh = {} | ||
for k in dw.keys(): | ||
dwh[k] = np.array([d0[k] for d0 in d]) | ||
image_sizes = pd.DataFrame(dwh) | ||
image_sizes.to_csv(phase + '_image_size.csv') | ||
|
||
# select bounding box coordinates and fill hdf5 | ||
h = h5py.File(phase + 'box.hdf5', mode='w') | ||
t = train['boxes'] | ||
d = t.to_dict() | ||
print('Creating bounding box file...') | ||
for k, v in tqdm(d.items()): | ||
h.create_dataset(str(k), data=v) | ||
if h: | ||
h.close() | ||
|
||
# convert to zarr | ||
print('Writing zarr file...') | ||
i_feat = h5py.File(phase + 'box.hdf5', 'r', libver='latest') | ||
dest = zarr.open_group(phase + '_boxes.zarr', mode='w') | ||
zarr.copy_all(i_feat, dest) | ||
i_feat.close() | ||
dest.close() | ||
os.remove(phase + 'box.hdf5') | ||
|
||
# select features and fill hdf5 | ||
h = h5py.File(phase + '.hdf5', mode='w') | ||
t = train['features'] | ||
d = t.to_dict() | ||
print('Creating image features file...') | ||
for k, v in tqdm(d.items()): | ||
h.create_dataset(str(k), data=v) | ||
if h: | ||
h.close() | ||
|
||
# convert to zarr | ||
print('Writing zarr file...') | ||
i_feat = h5py.File(phase + '.hdf5', 'r', libver='latest') | ||
dest = zarr.open_group(phase + '.zarr', mode='w') | ||
zarr.copy_all(i_feat, dest) | ||
i_feat.close() | ||
dest.close() | ||
os.remove(phase + '.hdf5') | ||
|
||
|
||
if __name__ == '__main__': | ||
parser = argparse.ArgumentParser( | ||
description='Preprocessing for VQA v2 image data') | ||
parser.add_argument('-d', '--data', nargs='+', help='trainval, and/or test, list of data phases to be processed', required=True) | ||
args, unparsed = parser.parse_known_args() | ||
if len(unparsed) != 0: | ||
raise SystemExit('Unknown argument: {}'.format(unparsed)) | ||
|
||
phase_list = args.data | ||
|
||
for phase in phase_list: | ||
# First download and extract | ||
|
||
if not os.path.exists(phase + '.zarr'): | ||
print('Converting features tsv to zarr file...') | ||
features_to_zarr(phase) | ||
|
||
print('Done') |
Oops, something went wrong.