Code for the paper
Visual Dialog
Abhishek Das, Satwik Kottur, Khushi Gupta, Avi Singh, Deshraj Yadav, José M. F. Moura, Devi Parikh, Dhruv Batra
arxiv.org/abs/1611.08669
CVPR 2017 (Spotlight)
Visual Dialog requires an AI agent to hold a meaningful dialog with humans in natural, conversational language about visual content. Given an image, dialog history, and a follow-up question about the image, the AI agent has to answer the question.
Demo: demo.visualdialog.org
This repository contains code for training, evaluating and visualizing results for all combinations of encoder-decoder architectures described in the paper. Specifically, we have 3 encoders: Late Fusion (LF), Hierarchical Recurrent Encoder (HRE), Memory Network (MN), and 2 kinds of decoding: Generative (G) and Discriminative (D).
If you find this code useful, consider citing our work:
@inproceedings{visdial,
title={{V}isual {D}ialog},
author={Abhishek Das and Satwik Kottur and Khushi Gupta and Avi Singh
and Deshraj Yadav and Jos\'e M.F. Moura and Devi Parikh and Dhruv Batra},
booktitle={Proceedings of the IEEE Conference on Computer Vision and Pattern Recognition},
year={2017}
}
All our code is implemented in Torch (Lua). Installation instructions are as follows:
git clone https://github.com/torch/distro.git ~/torch --recursive
cd ~/torch; bash install-deps;
TORCH_LUA_VERSION=LUA51 ./install.sh
Additionally, our code uses the following packages: torch/torch7, torch/nn, torch/nngraph, Element-Research/rnn, torch/image, lua-cjson, loadcaffe, torch-hdf5. After Torch is installed, these can be installed/updated using:
luarocks install torch
luarocks install nn
luarocks install nngraph
luarocks install image
luarocks install lua-cjson
luarocks install loadcaffe
luarocks install luabitop
luarocks install totem
NOTE: luarocks install rnn
defaults to torch/rnn, follow these steps to install Element-Research/rnn.
git clone https://github.com/Element-Research/rnn.git
cd rnn
luarocks make rocks/rnn-scm-1.rockspec
Installation instructions for torch-hdf5 are given here.
NOTE: torch-hdf5 does not work with few versions of gcc. It is recommended that you use gcc 4.8 / gcc 4.9 with Lua 5.1 for proper installation of torch-hdf5.
Although our code should work on CPUs, it is highly recommended to use GPU acceleration with CUDA. You'll also need torch/cutorch, torch/cudnn and torch/cunn.
luarocks install cutorch
luarocks install cunn
luarocks install cudnn
The preprocessing script is in Python, and you'll need to install NLTK.
pip install nltk
pip install numpy
pip install h5py
python -c "import nltk; nltk.download('all')"
VisDial v1.0 dataset can be downloaded and preprocessed as specified below. The path provided as -image_root
must have four subdirectories - train2014
and val2014
as per COCO dataset, VisualDialog_val2018
and VisualDialog_test2018
which can be downloaded from here.
cd data
python prepro.py -download -image_root /path/to/images
cd ..
To download and preprocess Visdial v0.9 dataset, provide an extra -version 0.9
argument while execution.
This script will generate the files data/visdial_data.h5
(contains tokenized captions, questions, answers, image indices) and data/visdial_params.json
(contains vocabulary mappings and COCO image ids).
Since we don't finetune the CNN, training is significantly faster if image features are pre-extracted. Currently this repository provides support for extraction from VGG-16 and ResNets. We use image features from VGG-16. The VGG-16 model can be downloaded and features extracted using:
sh scripts/download_model.sh vgg 16 # works for 19 as well
cd data
# For all models except mn-att-ques-im-hist
th prepro_img_vgg16.lua -imageRoot /path/to/images -gpuid 0
# For mn-att-ques-im-hist
th prepro_img_vgg16.lua -imageRoot /path/to/images -imgSize 448 -layerName pool5 -gpuid 0
Similarly, ResNet models released by Facebook can be used for feature extraction. Feature extraction can be carried out in a similar manner as VGG-16:
sh scripts/download_model.sh resnet 200 # works for 18, 34, 50, 101, 152 as well
cd data
th prepro_img_resnet.lua -imageRoot /path/to/images -cnnModel /path/to/t7/model -gpuid 0
Running either of these should generate data/data_img.h5
containing features for train
, val
and test
splits corresponding to VisDial v1.0.
Finally, we can get to training models! All supported encoders are in the encoders/
folder (lf-ques
, lf-ques-im
, lf-ques-hist
, lf-ques-im-hist
, hre-ques-hist
, hre-ques-im-hist
, hrea-ques-im-hist
, mn-ques-hist
, mn-ques-im-hist
, mn-att-ques-im-hist
), and decoders in the decoders/
folder (gen
and disc
).
Generative (gen
) decoding tries to maximize likelihood of ground-truth response and only has access to single input-output pairs of dialog, while discriminative (disc
) decoding makes use of 100 candidate option responses provided for every round of dialog, and maximizes likelihood of correct option.
Encoders and decoders can be arbitrarily plugged together. For example, to train an HRE model with question and history information only (no images), and generative decoding:
th train.lua -encoder hre-ques-hist -decoder gen -gpuid 0
Similarly, to train a Memory Network model with question, image and history information, and discriminative decoding:
th train.lua -encoder mn-ques-im-hist -decoder disc -gpuid 0
Note: For attention based encoders, set both imgSpatialSize
and imgFeatureSize
command line params, feature dimensions are interpreted as (batch X spatial X spatial X feature)
. For other encoders, imgSpatialSize
is redundant.
The training script saves model snapshots at regular intervals in the checkpoints/
folder.
It takes about 15-20 epochs to train models with generative decoding to convergence, and 4-8 epochs for discriminative decoding.
We evaluate model performance by where it ranks human response given 100 response options for every round of dialog, based on retrieval metrics — mean reciprocal rank, R@1, R@5, R@10, mean rank.
Model evaluation can be run using:
th evaluate.lua -loadPath checkpoints/model.t7 -gpuid 0
Note that evaluation requires image features data/data_img.h5
, tokenized dialogs data/visdial_data.h5
and vocabulary mappings data/visdial_params.json
.
We also include code for running beam search on your model snapshots. This gives significantly nicer results than argmax decoding, and can be run as follows:
th generate.lua -loadPath checkpoints/model.t7 -maxThreads 50
This would compute predictions for 50 threads from the val
split and save results in vis/results/results.json
.
cd vis
# python 3.6
python -m http.server
# python 2.7
# python -m SimpleHTTPServer
Now visit localhost:8000
in your browser to see generated results.
Sample results from HRE-QIH-G available here.
Extracted features for v0.9 train and val are available for download.
visdial_data.h5
: Tokenized captions, questions, answers, image indicesvisdial_params.json
: Vocabulary mappings and COCO image idsdata_img_vgg16_relu7.h5
: VGG16relu7
image featuresdata_img_vgg16_pool5.h5
: VGG16pool5
image features
Trained on v0.9 train
, results on v0.9 val
.
Encoder | Decoder | CNN | MRR | R@1 | R@5 | R@10 | MR | Download |
---|---|---|---|---|---|---|---|---|
lf-ques | gen | VGG-16 | 0.5048 | 0.3974 | 0.6067 | 0.6649 | 17.8003 | lf-ques-gen-vgg16-18 |
lf-ques-hist | gen | VGG-16 | 0.5099 | 0.4012 | 0.6155 | 0.6740 | 17.3974 | lf-ques-hist-gen-vgg16-18 |
lf-ques-im | gen | VGG-16 | 0.5206 | 0.4206 | 0.6165 | 0.6760 | 17.0578 | lf-ques-im-gen-vgg16-22 |
lf-ques-im-hist | gen | VGG-16 | 0.5146 | 0.4086 | 0.6205 | 0.6828 | 16.7553 | lf-ques-im-hist-gen-vgg16-26 |
lf-att-ques-im-hist | gen | VGG-16 | 0.5354 | 0.4354 | 0.6355 | 0.6941 | 16.7663 | lf-att-ques-im-hist-gen-vgg16-80 |
hre-ques-hist | gen | VGG-16 | 0.5089 | 0.4000 | 0.6154 | 0.6739 | 17.3618 | hre-ques-hist-gen-vgg16-18 |
hre-ques-im-hist | gen | VGG-16 | 0.5237 | 0.4223 | 0.6228 | 0.6811 | 16.9669 | hre-ques-im-hist-gen-vgg16-14 |
hrea-ques-im-hist | gen | VGG-16 | 0.5238 | 0.4213 | 0.6244 | 0.6842 | 16.6044 | hrea-ques-im-hist-gen-vgg16-24 |
mn-ques-hist | gen | VGG-16 | 0.5131 | 0.4057 | 0.6176 | 0.6770 | 17.6253 | mn-ques-hist-gen-vgg16-102 |
mn-ques-im-hist | gen | VGG-16 | 0.5258 | 0.4229 | 0.6274 | 0.6874 | 16.9871 | mn-ques-im-hist-gen-vgg16-78 |
mn-att-ques-im-hist | gen | VGG-16 | 0.5341 | 0.4354 | 0.6318 | 0.6903 | 17.0726 | mn-att-ques-im-hist-gen-vgg16-100 |
lf-ques | disc | VGG-16 | 0.5491 | 0.4113 | 0.7020 | 0.7964 | 7.1519 | lf-ques-disc-vgg16-10 |
lf-ques-hist | disc | VGG-16 | 0.5724 | 0.4319 | 0.7308 | 0.8251 | 6.2847 | lf-ques-hist-disc-vgg16-8 |
lf-ques-im | disc | VGG-16 | 0.5745 | 0.4331 | 0.7398 | 0.8340 | 5.9801 | lf-ques-im-disc-vgg16-12 |
lf-ques-im-hist | disc | VGG-16 | 0.5911 | 0.4490 | 0.7563 | 0.8493 | 5.5493 | lf-ques-im-hist-disc-vgg16-8 |
lf-att-ques-im-hist | disc | VGG-16 | 0.6079 | 0.4692 | 0.7731 | 0.8635 | 5.1965 | lf-att-ques-im-hist-disc-vgg16-20 |
hre-ques-hist | disc | VGG-16 | 0.5668 | 0.4265 | 0.7245 | 0.8207 | 6.3701 | hre-ques-hist-disc-vgg16-4 |
hre-ques-im-hist | disc | VGG-16 | 0.5818 | 0.4461 | 0.7373 | 0.8342 | 5.9647 | hre-ques-im-hist-disc-vgg16-4 |
hrea-ques-im-hist | disc | VGG-16 | 0.5821 | 0.4456 | 0.7378 | 0.8341 | 5.9646 | hrea-ques-im-hist-disc-vgg16-4 |
mn-ques-hist | disc | VGG-16 | 0.5831 | 0.4388 | 0.7507 | 0.8434 | 5.8090 | mn-ques-hist-disc-vgg16-20 |
mn-ques-im-hist | disc | VGG-16 | 0.5971 | 0.4562 | 0.7627 | 0.8539 | 5.4218 | mn-ques-im-hist-disc-vgg16-12 |
mn-att-ques-im-hist | disc | VGG-16 | 0.6082 | 0.4700 | 0.7724 | 0.8623 | 5.2930 | mn-att-ques-im-hist-disc-vgg16-28 |
Extracted features for v1.0 train, val and test are available for download.
visdial_data_train.h5
: Tokenized captions, questions, answers, image indices, for training ontrain
visdial_params_train.json
: Vocabulary mappings and COCO image ids for training ontrain
data_img_vgg16_relu7_train.h5
: VGG16relu7
image features for training ontrain
data_img_vgg16_pool5_train.h5
: VGG16pool5
image features for training ontrain
visdial_data_trainval.h5
: Tokenized captions, questions, answers, image indices, for training ontrain
+val
visdial_params_trainval.json
: Vocabulary mappings and COCO image ids for training ontrain
+val
data_img_vgg16_relu7_trainval.h5
: VGG16relu7
image features for training ontrain
+val
data_img_vgg16_pool5_trainval.h5
: VGG16pool5
image features for training ontrain
+val
Trained on v1.0 train
+ v1.0 val
, results on v1.0 test-std
. Leaderboard here.
Encoder | Decoder | CNN | NDCG | MRR | R@1 | R@5 | R@10 | MR | Download |
---|---|---|---|---|---|---|---|---|---|
lf-ques-im-hist | disc | VGG-16 | 0.4531 | 0.5542 | 40.95 | 72.45 | 82.83 | 5.9532 | lf-ques-im-hist-disc-vgg16-8 |
hre-ques-im-hist | disc | VGG-16 | 0.4546 | 0.5416 | 39.93 | 70.45 | 81.50 | 6.4082 | hre-ques-im-hist-disc-vgg16-4 |
mn-ques-im-hist | disc | VGG-16 | 0.4750 | 0.5549 | 40.98 | 72.30 | 83.30 | 5.9245 | mn-ques-im-hist-disc-vgg16-12 |
lf-att-ques-im-hist | disc | VGG-16 | 0.4976 | 0.5707 | 42.08 | 74.82 | 85.05 | 5.4092 | lf-att-ques-im-hist-disc-vgg16-24 |
mn-att-ques-im-hist | disc | VGG-16 | 0.4958 | 0.5690 | 42.42 | 74.00 | 84.35 | 5.5852 | mn-att-ques-im-hist-disc-vgg16-24 |
BSD