Skip to content

Commit

Permalink
preprocessing refactored and integrated into main code
Browse files Browse the repository at this point in the history
  • Loading branch information
Shivanshu-Gupta committed Nov 20, 2017
1 parent 9b723d1 commit 50126e4
Show file tree
Hide file tree
Showing 7 changed files with 210 additions and 154 deletions.
36 changes: 22 additions & 14 deletions config/config_vqa_sgd.yml
Original file line number Diff line number Diff line change
Expand Up @@ -6,19 +6,27 @@ use_gpu: True
model_class: vqa
debug: False
seed: 123213
data:
path: '/home/cse/phd/csz178058/scratch/vqadata/'
random_seed: 1
#shuffle: False
questions_path: 'v2_OpenEnded_mscoco_val2014_questions.json'
annotation_path: 'v2_mscoco_val2014_annotations.json'
custom_batch_size: 128
features_dir: 'train2014_vqa_i_1024'
scale_params: [256,256]
crop_params: 228
loader_params:
batch_size: 128
num_workers: 4
data: # Shivanshu: Driver code expects the following config['data']
preprocess: False
dir: '/home/cse/phd/csz178058/scratch/vqadata/'
train:
ques: 'v2_OpenEnded_mscoco_train2014_questions.json'
ans: 'v2_mscoco_train2014_annotations.json'
img_dir: 'train2014'
emb_dir: 'train2014_vqa_i_1024_vgg'
batch_size: 32
val:
ques: 'v2_OpenEnded_mscoco_val2014_questions.json'
ans: 'v2_mscoco_val2014_annotations.json'
img_dir: 'val2014'
emb_dir: 'train2014_vqa_i_1024_vgg'
batch_size: 32
images:
preprocess: False
scale: [256,256]
crop: 224
loader:
workers: 4
pin_memory: True

checkpoints:
Expand Down Expand Up @@ -49,7 +57,7 @@ optim:

params:
momentum: 0.9
lr: 0.1 # learning rate
lr: 0.01 # learning rate
# alpha: 0.99 # alpha for adagrad/rmsprop/momentum/adam
# beta: 0.995 # beta used for adam
# eps: 0.00001 # epsilon that goes into denominator in rmsprop
Expand Down
43 changes: 33 additions & 10 deletions dataset.py
Original file line number Diff line number Diff line change
@@ -1,33 +1,56 @@
import os
import pickle
import torch
import torchvision.transforms as transforms
from PIL import Image
from IPython.core.debugger import Pdb


class VQADataset(torch.utils.data.Dataset):
def __init__(self, qafile, img_dir, phase):
self.examples = pickle.load(open(qafile, 'rb'))
ques_vocab = {}
ans_vocab = {}

def __init__(self, data_dir, qafile, img_dir, phase, raw_images=False):
self.data_dir = data_dir
self.examples = pickle.load(open(os.path.join(data_dir, qafile), 'rb'))
if phase == 'train':
self.load_vocab(data_dir)
self.transforms = transforms.Compose([
transforms.Scale((256, 256)),
transforms.CenterCrop(224),
transforms.ToTensor(),
transforms.Normalize(
mean=[0.485, 0.456, 0.406],
std=[0.229, 0.224, 0.225])])
self.image_dir = img_dir
self.img_dir = img_dir
self.phase = phase
self.raw_images = raw_images # if true, images and load images, not embeddings

def load_vocab(self, data_dir):
ques_vocab_file = os.path.join(data_dir, 'ques_stoi.tsv')
for line in open(ques_vocab_file):
parts = line.split('\t')
tok, idx = parts[0], int(parts[1].strip())
VQADataset.ques_vocab[idx] = tok
# NOTE: in version 0.1.1 of torchtext, index 0 is assigned to '<unk>' the first time a unknown token is encountered.
VQADataset.ques_vocab[0] = '<unk>'
ans_vocab_file = os.path.join(data_dir, 'ans_itos.tsv')
for line in open(ans_vocab_file):
parts = line.split('\t')
VQADataset.ans_vocab[parts[0]] = parts[1]

def __len__(self):
return len(self.examples)

def __getitem__(self, idx):
_, ques, _, image_id, ans = self.examples[idx]
# img = Image.open('{0}/{1}2014/COCO_{1}2014_{2:012d}.jpg'.format(self.image_dir, self.phase, image_id))
# img = img.convert('RGB')
# img = self.transforms(img)
emb = torch.load('/home/cse/phd/csz178058/scratch/vqadata/train2014_vqa_i_1024_vgg/{}'.format(image_id))
img = emb
return torch.from_numpy(ques).squeeze(), img, ans
_, ques, _, imgid, ans = self.examples[idx]
if self.raw_images:
img = Image.open('{0}/{1}/{2}2014/COCO_{2}2014_{3:012d}.jpg'.format(self.data_dir, self.img_dir, self.phase, imgid))
img = img.convert('RGB')
img = self.transforms(img)
else:
img = torch.load('{}/{}/{}'.format(self.data_dir, self.img_dir, imgid))
return torch.from_numpy(ques), img, imgid, ans


class VQABatchSampler:
Expand Down
75 changes: 48 additions & 27 deletions main.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,49 +6,57 @@
import torch.optim as optim
from torch.optim import lr_scheduler
from torch.utils.data import DataLoader
from dataset import VQADataset, VQABatchSampler
from train import train_model
import vqa
import san
from IPython.core.debugger import Pdb
# These will usually be more like 32 or 64 dimensional.
# We will keep them small, so we can see how the weights change as we train.
EMBEDDING_DIM = 300
HIDDEN_DIM = 200

from preprocess import preprocess
from dataset import VQADataset, VQABatchSampler
from train import train_model
from vqa import VQAModel
from san import SANModel

parser = argparse.ArgumentParser()
parser.add_argument('-c', '--config', type=str, default='config.yml')


def load_datasets(data_dir, phases, img_emb_dir):
# Pdb().set_trace()
datasets = {x: VQADataset('{}/{}_data.pkl'.format(data_dir, x), img_emb_dir, x) for x in phases}
def load_datasets(config, phases):
config = config['data']
if 'preprocess' in config and config['preprocess']:
print('Preprocessing datasets')
preprocess(
data_dir=config['dir'],
train_ques_file=config['train']['ques'],
train_ans_file=config['train']['ans'],
val_ques_file=config['val']['ques'],
val_ans_file=config['val']['ans'])

print('Loading preprocessed datasets')
datafiles = {x: '{}.pkl'.format(x) for x in phases}
raw_images = 'preprocess' in config['images'] and config['images']['preprocess']
if raw_images:
img_dir = {x: config[x]['img_dir'] for x in phases}
else:
img_dir = {x: config[x]['emb_dir'] for x in phases}
datasets = {x: VQADataset(data_dir=config['dir'], qafile=datafiles[x], img_dir=img_dir[x], phase=x, raw_images=raw_images) for x in phases}
batch_samplers = {x: VQABatchSampler(datasets[x], 32) for x in phases}
dataloaders = {x: DataLoader(datasets[x], batch_sampler=batch_samplers[x], num_workers=4) for x in phases}

dataloaders = {x: DataLoader(datasets[x], batch_sampler=batch_samplers[x], num_workers=config['loader']['workers']) for x in phases}
dataset_sizes = {x: len(datasets[x]) for x in phases}
print(dataset_sizes)
return dataloaders
print("ques vocab size: {}".format(len(VQADataset.ques_vocab)))
print("ans vocab size: {}".format(len(VQADataset.ans_vocab)))
return dataloaders, VQADataset.ques_vocab, VQADataset.ans_vocab


if __name__ == '__main__':
global args
args = parser.parse_args()
args.config = os.path.join(os.getcwd(), args.config)
config = yaml.load(open(args.config))
config['use_gpu'] = config['use_gpu'] and torch.cuda.is_available()
torch.manual_seed(config['seed'])
torch.cuda.manual_seed(config['seed'])
def main(config):
phases = ['train', 'val']
dataloaders = load_datasets('datasets', phases, img_emb_dir='/scratch/cse/phd/csz178058/vqadata/')

config['model']['params']['vocab_size'] = 22226 + 1 # +1 to include '<unk>'
config['model']['params']['output_size'] = 1001
dataloaders, ques_vocab, ans_vocab = load_datasets(config, phases)
config['model']['params']['vocab_size'] = len(ques_vocab)
config['model']['params']['output_size'] = len(ans_vocab) # don't want model to predict '<unk>'

if config['model_class'] == 'vqa':
model = vqa.VQAModel(**config['model']['params'])
model = VQAModel(**config['model']['params'])
elif config['model_class'] == 'san':
model = san.SANModel(**config['model']['params'])
model = SANModel(**config['model']['params'])
print(model)
criterion = nn.CrossEntropyLoss()

Expand All @@ -70,3 +78,16 @@ def load_datasets(data_dir, phases, img_emb_dir):
print("begin training")
model = train_model(model, dataloaders, criterion, optimizer, exp_lr_scheduler, '/scratch/cse/dual/cs5130298/vqa',
num_epochs=25, use_gpu=config['use_gpu'])


if __name__ == '__main__':
global args
args = parser.parse_args()
args.config = os.path.join(os.getcwd(), args.config)
config = yaml.load(open(args.config))
config['use_gpu'] = config['use_gpu'] and torch.cuda.is_available()

# TODO: seeding still not perfect
torch.manual_seed(config['seed'])
torch.cuda.manual_seed(config['seed'])
main(config)
87 changes: 0 additions & 87 deletions preproc.py

This file was deleted.

Loading

0 comments on commit 50126e4

Please sign in to comment.